Compare commits
38 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 5e80827369 | |||
| 6894056e46 | |||
| c4402daff1 | |||
| a2fa787121 | |||
| ad7befeb2f | |||
| 1f392ad45b | |||
| fe5034c540 | |||
| 0f5e6454dc | |||
| 638d2bbcd9 | |||
| 4dfcfaa137 | |||
| 1c0f6bb2cf | |||
| 78922ed7c7 | |||
| 6fde5a6dd6 | |||
| d1d0b8afce | |||
| 04ddad484e | |||
| 03d829d59e | |||
| 8d8b4311b9 | |||
| 1fbcc78d6e | |||
| 51593da25a | |||
| 38e563d0c7 | |||
| b8f089c5a3 | |||
| 187ea539ae | |||
| 8bf80fc8d8 | |||
| 45f6d52b10 | |||
| 746215670a | |||
| bc9a8cef6f | |||
| b62d9a1fdc | |||
| 46af98267d | |||
| de1426119d | |||
| 41ea88f38c | |||
| aed7499a8d | |||
| 07c9a08e67 | |||
| 2837d49079 | |||
| 1997614aa9 | |||
| 4e898560ce | |||
| 332d2bbea3 | |||
| b8a5dda56e | |||
| 572d8e2002 |
@@ -9,13 +9,15 @@ on:
|
||||
- v*-patch
|
||||
|
||||
jobs:
|
||||
build:
|
||||
build:
|
||||
uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@main
|
||||
with:
|
||||
commit_sha: ${{ github.sha }}
|
||||
install_libgl1: true
|
||||
package: diffusers
|
||||
notebook_folder: diffusers_doc
|
||||
languages: en ko zh
|
||||
|
||||
secrets:
|
||||
token: ${{ secrets.HUGGINGFACE_PUSH }}
|
||||
hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }}
|
||||
|
||||
@@ -13,5 +13,6 @@ jobs:
|
||||
with:
|
||||
commit_sha: ${{ github.event.pull_request.head.sha }}
|
||||
pr_number: ${{ github.event.number }}
|
||||
install_libgl1: true
|
||||
package: diffusers
|
||||
languages: en ko
|
||||
languages: en ko zh
|
||||
|
||||
@@ -62,7 +62,7 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
apt-get update && apt-get install libsndfile1-dev -y
|
||||
apt-get update && apt-get install libsndfile1-dev libgl1 -y
|
||||
python -m pip install -e .[quality,test]
|
||||
|
||||
- name: Environment
|
||||
|
||||
@@ -61,6 +61,7 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
apt-get update && apt-get install libsndfile1-dev libgl1 -y
|
||||
python -m pip install -e .[quality,test]
|
||||
|
||||
- name: Environment
|
||||
|
||||
@@ -60,7 +60,7 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
apt-get update && apt-get install libsndfile1-dev -y
|
||||
apt-get update && apt-get install libsndfile1-dev libgl1 -y
|
||||
python -m pip install -e .[quality,test]
|
||||
|
||||
- name: Environment
|
||||
|
||||
@@ -14,6 +14,7 @@ RUN apt update && \
|
||||
libsndfile1-dev \
|
||||
python3.8 \
|
||||
python3-pip \
|
||||
libgl1 \
|
||||
python3.8-venv && \
|
||||
rm -rf /var/lib/apt/lists
|
||||
|
||||
@@ -27,6 +28,7 @@ RUN python3 -m pip install --no-cache-dir --upgrade pip && \
|
||||
torch \
|
||||
torchvision \
|
||||
torchaudio \
|
||||
invisible_watermark \
|
||||
--extra-index-url https://download.pytorch.org/whl/cpu && \
|
||||
python3 -m pip install --no-cache-dir \
|
||||
accelerate \
|
||||
@@ -40,4 +42,4 @@ RUN python3 -m pip install --no-cache-dir --upgrade pip && \
|
||||
tensorboard \
|
||||
transformers
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
CMD ["/bin/bash"]
|
||||
|
||||
@@ -12,6 +12,7 @@ RUN apt update && \
|
||||
curl \
|
||||
ca-certificates \
|
||||
libsndfile1-dev \
|
||||
libgl1 \
|
||||
python3.8 \
|
||||
python3-pip \
|
||||
python3.8-venv && \
|
||||
@@ -26,7 +27,8 @@ RUN python3 -m pip install --no-cache-dir --upgrade pip && \
|
||||
python3 -m pip install --no-cache-dir \
|
||||
torch \
|
||||
torchvision \
|
||||
torchaudio && \
|
||||
torchaudio \
|
||||
invisible_watermark && \
|
||||
python3 -m pip install --no-cache-dir \
|
||||
accelerate \
|
||||
datasets \
|
||||
|
||||
@@ -184,6 +184,8 @@
|
||||
title: Audio Diffusion
|
||||
- local: api/pipelines/audioldm
|
||||
title: AudioLDM
|
||||
- local: api/pipelines/consistency_models
|
||||
title: Consistency Models
|
||||
- local: api/pipelines/controlnet
|
||||
title: ControlNet
|
||||
- local: api/pipelines/cycle_diffusion
|
||||
@@ -224,6 +226,8 @@
|
||||
title: Self-Attention Guidance
|
||||
- local: api/pipelines/semantic_stable_diffusion
|
||||
title: Semantic Guidance
|
||||
- local: api/pipelines/shap_e
|
||||
title: Shap-E
|
||||
- local: api/pipelines/spectrogram_diffusion
|
||||
title: Spectrogram Diffusion
|
||||
- sections:
|
||||
@@ -243,6 +247,8 @@
|
||||
title: Safe Stable Diffusion
|
||||
- local: api/pipelines/stable_diffusion/stable_diffusion_2
|
||||
title: Stable Diffusion 2
|
||||
- local: api/pipelines/stable_diffusion/stable_diffusion_xl
|
||||
title: Stable Diffusion XL
|
||||
- local: api/pipelines/stable_diffusion/latent_upscale
|
||||
title: Stable-Diffusion-Latent-Upscaler
|
||||
- local: api/pipelines/stable_diffusion/upscale
|
||||
@@ -274,6 +280,8 @@
|
||||
- sections:
|
||||
- local: api/schedulers/overview
|
||||
title: Overview
|
||||
- local: api/schedulers/cm_stochastic_iterative
|
||||
title: Consistency Model Multistep Scheduler
|
||||
- local: api/schedulers/ddim
|
||||
title: DDIM
|
||||
- local: api/schedulers/ddim_inverse
|
||||
|
||||
@@ -32,6 +32,6 @@ Adapters (textual inversion, LoRA, hypernetworks) allow you to modify a diffusio
|
||||
|
||||
[[autodoc]] loaders.LoraLoaderMixin
|
||||
|
||||
## FromCkptMixin
|
||||
## FromSingleFileMixin
|
||||
|
||||
[[autodoc]] loaders.FromCkptMixin
|
||||
[[autodoc]] loaders.FromSingleFileMixin
|
||||
|
||||
@@ -43,7 +43,7 @@ pipe = DiffusionPipeline.from_pretrained("teticio/audio-diffusion-256").to(devic
|
||||
|
||||
output = pipe()
|
||||
display(output.images[0])
|
||||
display(Audio(output.audios[0], rate=mel.get_sample_rate()))
|
||||
display(Audio(output.audios[0], rate=pipe.mel.get_sample_rate()))
|
||||
```
|
||||
|
||||
### Latent Audio Diffusion
|
||||
|
||||
@@ -0,0 +1,87 @@
|
||||
# Consistency Models
|
||||
|
||||
Consistency Models were proposed in [Consistency Models](https://arxiv.org/abs/2303.01469) by Yang Song, Prafulla Dhariwal, Mark Chen, and Ilya Sutskever.
|
||||
|
||||
The abstract of the [paper](https://arxiv.org/pdf/2303.01469.pdf) is as follows:
|
||||
|
||||
*Diffusion models have significantly advanced the fields of image, audio, and video generation, but they depend on an iterative sampling process that causes slow generation. To overcome this limitation, we propose consistency models, a new family of models that generate high quality samples by directly mapping noise to data. They support fast one-step generation by design, while still allowing multistep sampling to trade compute for sample quality. They also support zero-shot data editing, such as image inpainting, colorization, and super-resolution, without requiring explicit training on these tasks. Consistency models can be trained either by distilling pre-trained diffusion models, or as standalone generative models altogether. Through extensive experiments, we demonstrate that they outperform existing distillation techniques for diffusion models in one- and few-step sampling, achieving the new state-of-the-art FID of 3.55 on CIFAR-10 and 6.20 on ImageNet 64x64 for one-step generation. When trained in isolation, consistency models become a new family of generative models that can outperform existing one-step, non-adversarial generative models on standard benchmarks such as CIFAR-10, ImageNet 64x64 and LSUN 256x256. *
|
||||
|
||||
Resources:
|
||||
|
||||
* [Paper](https://arxiv.org/abs/2303.01469)
|
||||
* [Original Code](https://github.com/openai/consistency_models)
|
||||
|
||||
Available Checkpoints are:
|
||||
- *cd_imagenet64_l2 (64x64 resolution)* [openai/consistency-model-pipelines](https://huggingface.co/openai/diffusers-cd_imagenet64_l2)
|
||||
- *cd_imagenet64_lpips (64x64 resolution)* [openai/diffusers-cd_imagenet64_lpips](https://huggingface.co/openai/diffusers-cd_imagenet64_lpips)
|
||||
- *ct_imagenet64 (64x64 resolution)* [openai/diffusers-ct_imagenet64](https://huggingface.co/openai/diffusers-ct_imagenet64)
|
||||
- *cd_bedroom256_l2 (256x256 resolution)* [openai/diffusers-cd_bedroom256_l2](https://huggingface.co/openai/diffusers-cd_bedroom256_l2)
|
||||
- *cd_bedroom256_lpips (256x256 resolution)* [openai/diffusers-cd_bedroom256_lpips](https://huggingface.co/openai/diffusers-cd_bedroom256_lpips)
|
||||
- *ct_bedroom256 (256x256 resolution)* [openai/diffusers-ct_bedroom256](https://huggingface.co/openai/diffusers-ct_bedroom256)
|
||||
- *cd_cat256_l2 (256x256 resolution)* [openai/diffusers-cd_cat256_l2](https://huggingface.co/openai/diffusers-cd_cat256_l2)
|
||||
- *cd_cat256_lpips (256x256 resolution)* [openai/diffusers-cd_cat256_lpips](https://huggingface.co/openai/diffusers-cd_cat256_lpips)
|
||||
- *ct_cat256 (256x256 resolution)* [openai/diffusers-ct_cat256](https://huggingface.co/openai/diffusers-ct_cat256)
|
||||
|
||||
## Available Pipelines
|
||||
|
||||
| Pipeline | Tasks | Demo | Colab |
|
||||
|:---:|:---:|:---:|:---:|
|
||||
| [ConsistencyModelPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_consistency_models.py) | *Unconditional Image Generation* | | |
|
||||
|
||||
This pipeline was contributed by our community members [dg845](https://github.com/dg845) and [ayushtues](https://huggingface.co/ayushtues) ❤️
|
||||
|
||||
## Usage Example
|
||||
|
||||
```python
|
||||
import torch
|
||||
|
||||
from diffusers import ConsistencyModelPipeline
|
||||
|
||||
device = "cuda"
|
||||
# Load the cd_imagenet64_l2 checkpoint.
|
||||
model_id_or_path = "openai/diffusers-cd_imagenet64_l2"
|
||||
pipe = ConsistencyModelPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16)
|
||||
pipe.to(device)
|
||||
|
||||
# Onestep Sampling
|
||||
image = pipe(num_inference_steps=1).images[0]
|
||||
image.save("consistency_model_onestep_sample.png")
|
||||
|
||||
# Onestep sampling, class-conditional image generation
|
||||
# ImageNet-64 class label 145 corresponds to king penguins
|
||||
image = pipe(num_inference_steps=1, class_labels=145).images[0]
|
||||
image.save("consistency_model_onestep_sample_penguin.png")
|
||||
|
||||
# Multistep sampling, class-conditional image generation
|
||||
# Timesteps can be explicitly specified; the particular timesteps below are from the original Github repo.
|
||||
# https://github.com/openai/consistency_models/blob/main/scripts/launch.sh#L77
|
||||
image = pipe(timesteps=[22, 0], class_labels=145).images[0]
|
||||
image.save("consistency_model_multistep_sample_penguin.png")
|
||||
```
|
||||
|
||||
For an additional speed-up, one can also make use of `torch.compile`. Multiple images can be generated in <1 second as follows:
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import ConsistencyModelPipeline
|
||||
|
||||
device = "cuda"
|
||||
# Load the cd_bedroom256_lpips checkpoint.
|
||||
model_id_or_path = "openai/diffusers-cd_bedroom256_lpips"
|
||||
pipe = ConsistencyModelPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16)
|
||||
pipe.to(device)
|
||||
|
||||
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
|
||||
|
||||
# Multistep sampling
|
||||
# Timesteps can be explicitly specified; the particular timesteps below are from the original Github repo:
|
||||
# https://github.com/openai/consistency_models/blob/main/scripts/launch.sh#L83
|
||||
for _ in range(10):
|
||||
image = pipe(timesteps=[17, 0]).images[0]
|
||||
image.show()
|
||||
```
|
||||
|
||||
## ConsistencyModelPipeline
|
||||
[[autodoc]] ConsistencyModelPipeline
|
||||
- all
|
||||
- __call__
|
||||
@@ -11,19 +11,12 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
## Overview
|
||||
|
||||
Kandinsky 2.1 inherits best practices from [DALL-E 2](https://arxiv.org/abs/2204.06125) and [Latent Diffusion](https://huggingface.co/docs/diffusers/api/pipelines/latent_diffusion), while introducing some new ideas.
|
||||
Kandinsky inherits best practices from [DALL-E 2](https://huggingface.co/papers/2204.06125) and [Latent Diffusion](https://huggingface.co/docs/diffusers/api/pipelines/latent_diffusion), while introducing some new ideas.
|
||||
|
||||
It uses [CLIP](https://huggingface.co/docs/transformers/model_doc/clip) for encoding images and text, and a diffusion image prior (mapping) between latent spaces of CLIP modalities. This approach enhances the visual performance of the model and unveils new horizons in blending images and text-guided image manipulation.
|
||||
|
||||
The Kandinsky model is created by [Arseniy Shakhmatov](https://github.com/cene555), [Anton Razzhigaev](https://github.com/razzant), [Aleksandr Nikolich](https://github.com/AlexWortega), [Igor Pavlov](https://github.com/boomb0om), [Andrey Kuznetsov](https://github.com/kuznetsoffandrey) and [Denis Dimitrov](https://github.com/denndimitrov) and the original codebase can be found [here](https://github.com/ai-forever/Kandinsky-2)
|
||||
The Kandinsky model is created by [Arseniy Shakhmatov](https://github.com/cene555), [Anton Razzhigaev](https://github.com/razzant), [Aleksandr Nikolich](https://github.com/AlexWortega), [Igor Pavlov](https://github.com/boomb0om), [Andrey Kuznetsov](https://github.com/kuznetsoffandrey) and [Denis Dimitrov](https://github.com/denndimitrov). The original codebase can be found [here](https://github.com/ai-forever/Kandinsky-2)
|
||||
|
||||
## Available Pipelines:
|
||||
|
||||
| Pipeline | Tasks |
|
||||
|---|---|
|
||||
| [pipeline_kandinsky.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py) | *Text-to-Image Generation* |
|
||||
| [pipeline_kandinsky_inpaint.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py) | *Image-Guided Image Generation* |
|
||||
| [pipeline_kandinsky_img2img.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py) | *Image-Guided Image Generation* |
|
||||
|
||||
## Usage example
|
||||
|
||||
@@ -135,6 +128,7 @@ prompt = "birds eye view of a quilted paper style alien planet landscape, vibran
|
||||

|
||||
|
||||
|
||||
|
||||
### Text Guided Image-to-Image Generation
|
||||
|
||||
The same Kandinsky model weights can be used for text-guided image-to-image translation. In this case, just make sure to load the weights using the [`KandinskyImg2ImgPipeline`] pipeline.
|
||||
@@ -283,6 +277,207 @@ image.save("starry_cat.png")
|
||||

|
||||
|
||||
|
||||
### Text-to-Image Generation with ControlNet Conditioning
|
||||
|
||||
In the following, we give a simple example of how to use [`KandinskyV22ControlnetPipeline`] to add control to the text-to-image generation with a depth image.
|
||||
|
||||
First, let's take an image and extract its depth map.
|
||||
|
||||
```python
|
||||
from diffusers.utils import load_image
|
||||
|
||||
img = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinskyv22/cat.png"
|
||||
).resize((768, 768))
|
||||
```
|
||||

|
||||
|
||||
We can use the `depth-estimation` pipeline from transformers to process the image and retrieve its depth map.
|
||||
|
||||
```python
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from transformers import pipeline
|
||||
from diffusers.utils import load_image
|
||||
|
||||
|
||||
def make_hint(image, depth_estimator):
|
||||
image = depth_estimator(image)["depth"]
|
||||
image = np.array(image)
|
||||
image = image[:, :, None]
|
||||
image = np.concatenate([image, image, image], axis=2)
|
||||
detected_map = torch.from_numpy(image).float() / 255.0
|
||||
hint = detected_map.permute(2, 0, 1)
|
||||
return hint
|
||||
|
||||
|
||||
depth_estimator = pipeline("depth-estimation")
|
||||
hint = make_hint(img, depth_estimator).unsqueeze(0).half().to("cuda")
|
||||
```
|
||||
Now, we load the prior pipeline and the text-to-image controlnet pipeline
|
||||
|
||||
```python
|
||||
from diffusers import KandinskyV22PriorPipeline, KandinskyV22ControlnetPipeline
|
||||
|
||||
pipe_prior = KandinskyV22PriorPipeline.from_pretrained(
|
||||
"kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float16
|
||||
)
|
||||
pipe_prior = pipe_prior.to("cuda")
|
||||
|
||||
pipe = KandinskyV22ControlnetPipeline.from_pretrained(
|
||||
"kandinsky-community/kandinsky-2-2-controlnet-depth", torch_dtype=torch.float16
|
||||
)
|
||||
pipe = pipe.to("cuda")
|
||||
```
|
||||
|
||||
We pass the prompt and negative prompt through the prior to generate image embeddings
|
||||
|
||||
```python
|
||||
prompt = "A robot, 4k photo"
|
||||
|
||||
negative_prior_prompt = "lowres, text, error, cropped, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, out of frame, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck, username, watermark, signature"
|
||||
|
||||
generator = torch.Generator(device="cuda").manual_seed(43)
|
||||
image_emb, zero_image_emb = pipe_prior(
|
||||
prompt=prompt, negative_prompt=negative_prior_prompt, generator=generator
|
||||
).to_tuple()
|
||||
```
|
||||
|
||||
Now we can pass the image embeddings and the depth image we extracted to the controlnet pipeline. With Kandinsky 2.2, only prior pipelines accept `prompt` input. You do not need to pass the prompt to the controlnet pipeline.
|
||||
|
||||
```python
|
||||
images = pipe(
|
||||
image_embeds=image_emb,
|
||||
negative_image_embeds=zero_image_emb,
|
||||
hint=hint,
|
||||
num_inference_steps=50,
|
||||
generator=generator,
|
||||
height=768,
|
||||
width=768,
|
||||
).images
|
||||
|
||||
images[0].save("robot_cat.png")
|
||||
```
|
||||
|
||||
The output image looks as follow:
|
||||

|
||||
|
||||
### Image-to-Image Generation with ControlNet Conditioning
|
||||
|
||||
Kandinsky 2.2 also includes a [`KandinskyV22ControlnetImg2ImgPipeline`] that will allow you to add control to the image generation process with both the image and its depth map. This pipeline works really well with [`KandinskyV22PriorEmb2EmbPipeline`], which generates image embeddings based on both a text prompt and an image.
|
||||
|
||||
For our robot cat example, we will pass the prompt and cat image together to the prior pipeline to generate an image embedding. We will then use that image embedding and the depth map of the cat to further control the image generation process.
|
||||
|
||||
We can use the same cat image and its depth map from the last example.
|
||||
|
||||
```python
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from diffusers import KandinskyV22PriorEmb2EmbPipeline, KandinskyV22ControlnetImg2ImgPipeline
|
||||
from diffusers.utils import load_image
|
||||
from transformers import pipeline
|
||||
|
||||
img = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" "/kandinskyv22/cat.png"
|
||||
).resize((768, 768))
|
||||
|
||||
|
||||
def make_hint(image, depth_estimator):
|
||||
image = depth_estimator(image)["depth"]
|
||||
image = np.array(image)
|
||||
image = image[:, :, None]
|
||||
image = np.concatenate([image, image, image], axis=2)
|
||||
detected_map = torch.from_numpy(image).float() / 255.0
|
||||
hint = detected_map.permute(2, 0, 1)
|
||||
return hint
|
||||
|
||||
|
||||
depth_estimator = pipeline("depth-estimation")
|
||||
hint = make_hint(img, depth_estimator).unsqueeze(0).half().to("cuda")
|
||||
|
||||
pipe_prior = KandinskyV22PriorEmb2EmbPipeline.from_pretrained(
|
||||
"kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float16
|
||||
)
|
||||
pipe_prior = pipe_prior.to("cuda")
|
||||
|
||||
pipe = KandinskyV22ControlnetImg2ImgPipeline.from_pretrained(
|
||||
"kandinsky-community/kandinsky-2-2-controlnet-depth", torch_dtype=torch.float16
|
||||
)
|
||||
pipe = pipe.to("cuda")
|
||||
|
||||
prompt = "A robot, 4k photo"
|
||||
negative_prior_prompt = "lowres, text, error, cropped, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, out of frame, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck, username, watermark, signature"
|
||||
|
||||
generator = torch.Generator(device="cuda").manual_seed(43)
|
||||
|
||||
# run prior pipeline
|
||||
|
||||
img_emb = pipe_prior(prompt=prompt, image=img, strength=0.85, generator=generator)
|
||||
negative_emb = pipe_prior(prompt=negative_prior_prompt, image=img, strength=1, generator=generator)
|
||||
|
||||
# run controlnet img2img pipeline
|
||||
images = pipe(
|
||||
image=img,
|
||||
strength=0.5,
|
||||
image_embeds=img_emb.image_embeds,
|
||||
negative_image_embeds=negative_emb.image_embeds,
|
||||
hint=hint,
|
||||
num_inference_steps=50,
|
||||
generator=generator,
|
||||
height=768,
|
||||
width=768,
|
||||
).images
|
||||
|
||||
images[0].save("robot_cat.png")
|
||||
```
|
||||
|
||||
Here is the output. Compared with the output from our text-to-image controlnet example, it kept a lot more cat facial details from the original image and worked into the robot style we asked for.
|
||||
|
||||

|
||||
|
||||
## Kandinsky 2.2
|
||||
|
||||
The Kandinsky 2.2 release includes robust new text-to-image models that support text-to-image generation, image-to-image generation, image interpolation, and text-guided image inpainting. The general workflow to perform these tasks using Kandinsky 2.2 is the same as in Kandinsky 2.1. First, you will need to use a prior pipeline to generate image embeddings based on your text prompt, and then use one of the image decoding pipelines to generate the output image. The only difference is that in Kandinsky 2.2, all of the decoding pipelines no longer accept the `prompt` input, and the image generation process is conditioned with only `image_embeds` and `negative_image_embeds`.
|
||||
|
||||
Let's look at an example of how to perform text-to-image generation using Kandinsky 2.2.
|
||||
|
||||
First, let's create the prior pipeline and text-to-image pipeline with Kandinsky 2.2 checkpoints.
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
|
||||
pipe_prior = DiffusionPipeline.from_pretrained("kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float16)
|
||||
pipe_prior.to("cuda")
|
||||
|
||||
t2i_pipe = DiffusionPipeline.from_pretrained("kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16)
|
||||
t2i_pipe.to("cuda")
|
||||
```
|
||||
|
||||
You can then use `pipe_prior` to generate image embeddings.
|
||||
|
||||
```python
|
||||
prompt = "portrait of a women, blue eyes, cinematic"
|
||||
negative_prompt = "low quality, bad quality"
|
||||
|
||||
image_embeds, negative_image_embeds = pipe_prior(prompt, guidance_scale=1.0).to_tuple()
|
||||
```
|
||||
|
||||
Now you can pass these embeddings to the text-to-image pipeline. When using Kandinsky 2.2 you don't need to pass the `prompt` (but you do with the previous version, Kandinsky 2.1).
|
||||
|
||||
```
|
||||
image = t2i_pipe(image_embeds=image_embeds, negative_image_embeds=negative_image_embeds, height=768, width=768).images[
|
||||
0
|
||||
]
|
||||
image.save("portrait.png")
|
||||
```
|
||||

|
||||
|
||||
We used the text-to-image pipeline as an example, but the same process applies to all decoding pipelines in Kandinsky 2.2. For more information, please refer to our API section for each pipeline.
|
||||
|
||||
|
||||
## Optimization
|
||||
|
||||
Running Kandinsky in inference requires running both a first prior pipeline: [`KandinskyPriorPipeline`]
|
||||
@@ -335,30 +530,84 @@ t2i_pipe.unet = torch.compile(t2i_pipe.unet, mode="reduce-overhead", fullgraph=T
|
||||
After compilation you should see a very fast inference time. For more information,
|
||||
feel free to have a look at [Our PyTorch 2.0 benchmark](https://huggingface.co/docs/diffusers/main/en/optimization/torch2.0).
|
||||
|
||||
## Available Pipelines:
|
||||
|
||||
| Pipeline | Tasks |
|
||||
|---|---|
|
||||
| [pipeline_kandinsky2_2.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py) | *Text-to-Image Generation* |
|
||||
| [pipeline_kandinsky.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py) | *Text-to-Image Generation* |
|
||||
| [pipeline_kandinsky2_2_inpaint.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py) | *Image-Guided Image Generation* |
|
||||
| [pipeline_kandinsky_inpaint.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py) | *Image-Guided Image Generation* |
|
||||
| [pipeline_kandinsky2_2_img2img.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py) | *Image-Guided Image Generation* |
|
||||
| [pipeline_kandinsky_img2img.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py) | *Image-Guided Image Generation* |
|
||||
| [pipeline_kandinsky2_2_controlnet.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py) | *Image-Guided Image Generation* |
|
||||
| [pipeline_kandinsky2_2_controlnet_img2img.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py) | *Image-Guided Image Generation* |
|
||||
|
||||
|
||||
### KandinskyV22Pipeline
|
||||
|
||||
[[autodoc]] KandinskyV22Pipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## KandinskyPriorPipeline
|
||||
### KandinskyV22ControlnetPipeline
|
||||
|
||||
[[autodoc]] KandinskyV22ControlnetPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
### KandinskyV22ControlnetImg2ImgPipeline
|
||||
|
||||
[[autodoc]] KandinskyV22ControlnetImg2ImgPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
### KandinskyV22Img2ImgPipeline
|
||||
|
||||
[[autodoc]] KandinskyV22Img2ImgPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
### KandinskyV22InpaintPipeline
|
||||
|
||||
[[autodoc]] KandinskyV22InpaintPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
### KandinskyV22PriorPipeline
|
||||
|
||||
[[autodoc]] ## KandinskyV22PriorPipeline
|
||||
- all
|
||||
- __call__
|
||||
- interpolate
|
||||
|
||||
### KandinskyV22PriorEmb2EmbPipeline
|
||||
|
||||
[[autodoc]] KandinskyV22PriorEmb2EmbPipeline
|
||||
- all
|
||||
- __call__
|
||||
- interpolate
|
||||
|
||||
### KandinskyPriorPipeline
|
||||
|
||||
[[autodoc]] KandinskyPriorPipeline
|
||||
- all
|
||||
- __call__
|
||||
- interpolate
|
||||
|
||||
## KandinskyPipeline
|
||||
### KandinskyPipeline
|
||||
|
||||
[[autodoc]] KandinskyPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## KandinskyImg2ImgPipeline
|
||||
### KandinskyImg2ImgPipeline
|
||||
|
||||
[[autodoc]] KandinskyImg2ImgPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## KandinskyInpaintPipeline
|
||||
### KandinskyInpaintPipeline
|
||||
|
||||
[[autodoc]] KandinskyInpaintPipeline
|
||||
- all
|
||||
|
||||
@@ -0,0 +1,139 @@
|
||||
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Shap-E
|
||||
|
||||
## Overview
|
||||
|
||||
|
||||
The Shap-E model was proposed in [Shap-E: Generating Conditional 3D Implicit Functions](https://arxiv.org/abs/2305.02463) by Alex Nichol and Heewon Jun from [OpenAI](https://github.com/openai).
|
||||
|
||||
The abstract of the paper is the following:
|
||||
|
||||
*We present Shap-E, a conditional generative model for 3D assets. Unlike recent work on 3D generative models which produce a single output representation, Shap-E directly generates the parameters of implicit functions that can be rendered as both textured meshes and neural radiance fields. We train Shap-E in two stages: first, we train an encoder that deterministically maps 3D assets into the parameters of an implicit function; second, we train a conditional diffusion model on outputs of the encoder. When trained on a large dataset of paired 3D and text data, our resulting models are capable of generating complex and diverse 3D assets in a matter of seconds. When compared to Point-E, an explicit generative model over point clouds, Shap-E converges faster and reaches comparable or better sample quality despite modeling a higher-dimensional, multi-representation output space.*
|
||||
|
||||
The original codebase can be found [here](https://github.com/openai/shap-e).
|
||||
|
||||
## Available Pipelines:
|
||||
|
||||
| Pipeline | Tasks |
|
||||
|---|---|
|
||||
| [pipeline_shap_e.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/shap_e/pipeline_shap_e.py) | *Text-to-Image Generation* |
|
||||
| [pipeline_shap_e_img2img.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py) | *Image-to-Image Generation* |
|
||||
|
||||
## Available checkpoints
|
||||
|
||||
* [`openai/shap-e`](https://huggingface.co/openai/shap-e)
|
||||
* [`openai/shap-e-img2img`](https://huggingface.co/openai/shap-e-img2img)
|
||||
|
||||
## Usage Examples
|
||||
|
||||
In the following, we will walk you through some examples of how to use Shap-E pipelines to create 3D objects in gif format.
|
||||
|
||||
### Text-to-3D image generation
|
||||
|
||||
We can use [`ShapEPipeline`] to create 3D object based on a text prompt. In this example, we will make a birthday cupcake for :firecracker: diffusers library's 1 year birthday. The workflow to use the Shap-E text-to-image pipeline is same as how you would use other text-to-image pipelines in diffusers.
|
||||
|
||||
```python
|
||||
import torch
|
||||
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
repo = "openai/shap-e"
|
||||
pipe = DiffusionPipeline.from_pretrained(repo, torch_dtype=torch.float16)
|
||||
pipe = pipe.to(device)
|
||||
|
||||
guidance_scale = 15.0
|
||||
prompt = ["A firecracker", "A birthday cupcake"]
|
||||
|
||||
images = pipe(
|
||||
prompt,
|
||||
guidance_scale=guidance_scale,
|
||||
num_inference_steps=64,
|
||||
frame_size=256,
|
||||
).images
|
||||
```
|
||||
|
||||
The output of [`ShapEPipeline`] is a list of lists of images frames. Each list of frames can be used to create a 3D object. Let's use the `export_to_gif` utility function in diffusers to make a 3D cupcake!
|
||||
|
||||
```python
|
||||
from diffusers.utils import export_to_gif
|
||||
|
||||
export_to_gif(images[0], "firecracker_3d.gif")
|
||||
export_to_gif(images[1], "cake_3d.gif")
|
||||
```
|
||||

|
||||

|
||||
|
||||
|
||||
### Image-to-Image generation
|
||||
|
||||
You can use [`ShapEImg2ImgPipeline`] along with other text-to-image pipelines in diffusers and turn your 2D generation into 3D.
|
||||
|
||||
In this example, We will first genrate a cheeseburger with a simple prompt "A cheeseburger, white background"
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
|
||||
pipe_prior = DiffusionPipeline.from_pretrained("kandinsky-community/kandinsky-2-1-prior", torch_dtype=torch.float16)
|
||||
pipe_prior.to("cuda")
|
||||
|
||||
t2i_pipe = DiffusionPipeline.from_pretrained("kandinsky-community/kandinsky-2-1", torch_dtype=torch.float16)
|
||||
t2i_pipe.to("cuda")
|
||||
|
||||
prompt = "A cheeseburger, white background"
|
||||
|
||||
image_embeds, negative_image_embeds = pipe_prior(prompt, guidance_scale=1.0).to_tuple()
|
||||
image = t2i_pipe(
|
||||
prompt,
|
||||
image_embeds=image_embeds,
|
||||
negative_image_embeds=negative_image_embeds,
|
||||
).images[0]
|
||||
|
||||
image.save("burger.png")
|
||||
```
|
||||
|
||||

|
||||
|
||||
we will then use the Shap-E image-to-image pipeline to turn it into a 3D cheeseburger :)
|
||||
|
||||
```python
|
||||
from PIL import Image
|
||||
from diffusers.utils import export_to_gif
|
||||
|
||||
repo = "openai/shap-e-img2img"
|
||||
pipe = DiffusionPipeline.from_pretrained(repo, torch_dtype=torch.float16)
|
||||
pipe = pipe.to("cuda")
|
||||
|
||||
guidance_scale = 3.0
|
||||
image = Image.open("burger.png").resize((256, 256))
|
||||
|
||||
images = pipe(
|
||||
image,
|
||||
guidance_scale=guidance_scale,
|
||||
num_inference_steps=64,
|
||||
frame_size=256,
|
||||
).images
|
||||
|
||||
gif_path = export_to_gif(images[0], "burger_3d.gif")
|
||||
```
|
||||

|
||||
|
||||
## ShapEPipeline
|
||||
[[autodoc]] ShapEPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## ShapEImg2ImgPipeline
|
||||
[[autodoc]] ShapEImg2ImgPipeline
|
||||
- all
|
||||
- __call__
|
||||
@@ -31,7 +31,7 @@ proposed by Chenlin Meng, Yutong He, Yang Song, Jiaming Song, Jiajun Wu, Jun-Yan
|
||||
- enable_xformers_memory_efficient_attention
|
||||
- disable_xformers_memory_efficient_attention
|
||||
- load_textual_inversion
|
||||
- from_ckpt
|
||||
- from_single_file
|
||||
- load_lora_weights
|
||||
- save_lora_weights
|
||||
|
||||
|
||||
@@ -0,0 +1,177 @@
|
||||
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Stable diffusion XL
|
||||
|
||||
Stable Diffusion XL was proposed in [SDXL: Improving Latent Diffusion Models for High-Resolution Image Synthesis](https://arxiv.org/abs/2307.01952) by Dustin Podell, Zion English, Kyle Lacey, Andreas Blattmann, Tim Dockhorn, Jonas Müller, Joe Penna, Robin Rombach
|
||||
|
||||
The abstract of the paper is the following:
|
||||
|
||||
*We present SDXL, a latent diffusion model for text-to-image synthesis. Compared to previous versions of Stable Diffusion, SDXL leverages a three times larger UNet backbone: The increase of model parameters is mainly due to more attention blocks and a larger cross-attention context as SDXL uses a second text encoder. We design multiple novel conditioning schemes and train SDXL on multiple aspect ratios. We also introduce a refinement model which is used to improve the visual fidelity of samples generated by SDXL using a post-hoc image-to-image technique. We demonstrate that SDXL shows drastically improved performance compared the previous versions of Stable Diffusion and achieves results competitive with those of black-box state-of-the-art image generators.*
|
||||
|
||||
## Tips
|
||||
|
||||
- Stable Diffusion XL works especially well with images between 768 and 1024.
|
||||
- Stable Diffusion XL output image can be improved by making use of a refiner as shown below.
|
||||
|
||||
### Available checkpoints:
|
||||
|
||||
- *Text-to-Image (1024x1024 resolution)*: [stabilityai/stable-diffusion-xl-base-0.9](https://huggingface.co/stabilityai/stable-diffusion-xl-base-0.9) with [`StableDiffusionXLPipeline`]
|
||||
- *Image-to-Image / Refiner (1024x1024 resolution)*: [stabilityai/stable-diffusion-xl-refiner-0.9](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-0.9) with [`StableDiffusionXLImg2ImgPipeline`]
|
||||
|
||||
## Usage Example
|
||||
|
||||
Before using SDXL make sure to have `transformers`, `accelerate`, `safetensors` and `invisible_watermark` installed.
|
||||
You can install the libraries as follows:
|
||||
|
||||
```
|
||||
pip install transformers
|
||||
pip install accelerate
|
||||
pip install safetensors
|
||||
pip install invisible-watermark>=2.0
|
||||
```
|
||||
|
||||
### Text-to-Image
|
||||
|
||||
You can use SDXL as follows for *text-to-image*:
|
||||
|
||||
```py
|
||||
from diffusers import StableDiffusionXLPipeline
|
||||
import torch
|
||||
|
||||
pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-0.9", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
|
||||
)
|
||||
pipe.to("cuda")
|
||||
|
||||
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
|
||||
image = pipe(prompt=prompt).images[0]
|
||||
```
|
||||
|
||||
### Refining the image output
|
||||
|
||||
The image can be refined by making use of [stabilityai/stable-diffusion-xl-refiner-0.9](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-0.9).
|
||||
In this case, you only have to output the `latents` from the base model.
|
||||
|
||||
```py
|
||||
from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline
|
||||
import torch
|
||||
|
||||
pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-0.9", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
|
||||
)
|
||||
pipe.to("cuda")
|
||||
|
||||
use_refiner = True
|
||||
refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-refiner-0.9", torch_dtype=torch.float16, use_safetensors=True, variant="fp16"
|
||||
)
|
||||
refiner.to("cuda")
|
||||
|
||||
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
|
||||
|
||||
image = pipe(prompt=prompt, output_type="latent" if use_refiner else "pil").images[0]
|
||||
image = refiner(prompt=prompt, image=image[None, :]).images[0]
|
||||
```
|
||||
|
||||
### Image-to-image
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import StableDiffusionXLImg2ImgPipeline
|
||||
from diffusers.utils import load_image
|
||||
|
||||
pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-refiner-0.9", torch_dtype=torch.float16
|
||||
)
|
||||
pipe = pipe.to("cuda")
|
||||
url = "https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/aa_xl/000000009.png"
|
||||
|
||||
init_image = load_image(url).convert("RGB")
|
||||
prompt = "a photo of an astronaut riding a horse on mars"
|
||||
image = pipe(prompt, image=init_image).images[0]
|
||||
```
|
||||
|
||||
| Original Image | Refined Image |
|
||||
|---|---|
|
||||
|  |  |
|
||||
|
||||
### Loading single file checkpoints / original file format
|
||||
|
||||
By making use of [`~diffusers.loaders.FromSingleFileMixin.from_single_file`] you can also load the
|
||||
original file format into `diffusers`:
|
||||
|
||||
```py
|
||||
from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline
|
||||
import torch
|
||||
|
||||
pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-0.9", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
|
||||
)
|
||||
pipe.to("cuda")
|
||||
|
||||
refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-refiner-0.9", torch_dtype=torch.float16, use_safetensors=True, variant="fp16"
|
||||
)
|
||||
refiner.to("cuda")
|
||||
```
|
||||
|
||||
### Memory optimization via model offloading
|
||||
|
||||
If you are seeing out-of-memory errors, we recommend making use of [`StableDiffusionXLPipeline.enable_model_cpu_offload`].
|
||||
|
||||
```diff
|
||||
- pipe.to("cuda")
|
||||
+ pipe.enable_model_cpu_offload()
|
||||
```
|
||||
|
||||
and
|
||||
|
||||
```diff
|
||||
- refiner.to("cuda")
|
||||
+ refiner.enable_model_cpu_offload()
|
||||
```
|
||||
|
||||
### Speed-up inference with `torch.compile`
|
||||
|
||||
You can speed up inference by making use of `torch.compile`. This should give you **ca.** 20% speed-up.
|
||||
|
||||
```diff
|
||||
+ pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
|
||||
+ refiner.unet = torch.compile(refiner.unet, mode="reduce-overhead", fullgraph=True)
|
||||
```
|
||||
|
||||
### Running with `torch` \< 2.0
|
||||
|
||||
**Note** that if you want to run Stable Diffusion XL with `torch` < 2.0, please make sure to enable xformers
|
||||
attention:
|
||||
|
||||
```
|
||||
pip install xformers
|
||||
```
|
||||
|
||||
```diff
|
||||
+pipe.enable_xformers_memory_efficient_attention()
|
||||
+refiner.enable_xformers_memory_efficient_attention()
|
||||
```
|
||||
|
||||
## StableDiffusionXLPipeline
|
||||
|
||||
[[autodoc]] StableDiffusionXLPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## StableDiffusionXLImg2ImgPipeline
|
||||
|
||||
[[autodoc]] StableDiffusionXLImg2ImgPipeline
|
||||
- all
|
||||
- __call__
|
||||
@@ -40,7 +40,7 @@ Available Checkpoints are:
|
||||
- enable_vae_tiling
|
||||
- disable_vae_tiling
|
||||
- load_textual_inversion
|
||||
- from_ckpt
|
||||
- from_single_file
|
||||
- load_lora_weights
|
||||
- save_lora_weights
|
||||
|
||||
|
||||
@@ -138,6 +138,7 @@ pipe = DiffusionPipeline.from_pretrained("cerspense/zeroscope_v2_576w", torch_dt
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
# memory optimization
|
||||
pipe.unet.enable_forward_chunking(chunk_size=1, dim=1)
|
||||
pipe.enable_vae_slicing()
|
||||
|
||||
prompt = "Darth Vader surfing a wave"
|
||||
@@ -150,10 +151,13 @@ Now the video can be upscaled:
|
||||
|
||||
```py
|
||||
pipe = DiffusionPipeline.from_pretrained("cerspense/zeroscope_v2_XL", torch_dtype=torch.float16)
|
||||
pipe.vae.enable_slicing()
|
||||
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
# memory optimization
|
||||
pipe.unet.enable_forward_chunking(chunk_size=1, dim=1)
|
||||
pipe.enable_vae_slicing()
|
||||
|
||||
video = [Image.fromarray(frame).resize((1024, 576)) for frame in video_frames]
|
||||
|
||||
video_frames = pipe(prompt, video=video, strength=0.6).frames
|
||||
@@ -175,6 +179,28 @@ Here are some sample outputs:
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
### Memory optimizations
|
||||
|
||||
Text-guided video generation with [`~TextToVideoSDPipeline`] and [`~VideoToVideoSDPipeline`] is very memory intensive both
|
||||
when denoising with [`~UNet3DConditionModel`] and when decoding with [`~AutoencoderKL`]. It is possible though to reduce
|
||||
memory usage at the cost of increased runtime to achieve the exact same result. To do so, it is recommended to enable
|
||||
**forward chunking** and **vae slicing**:
|
||||
|
||||
Forward chunking via [`~UNet3DConditionModel.enable_forward_chunking`]is explained in [this blog post](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers) and
|
||||
allows to significantly reduce the required memory for the unet. You can chunk the feed forward layer over the `num_frames`
|
||||
dimension by doing:
|
||||
|
||||
```py
|
||||
pipe.unet.enable_forward_chunking(chunk_size=1, dim=1)
|
||||
```
|
||||
|
||||
Vae slicing via [`~TextToVideoSDPipeline.enable_vae_slicing`] and [`~VideoToVideoSDPipeline.enable_vae_slicing`] also
|
||||
gives significant memory savings since the two pipelines decode all image frames at once.
|
||||
|
||||
```py
|
||||
pipe.enable_vae_slicing()
|
||||
```
|
||||
|
||||
## Available checkpoints
|
||||
|
||||
* [damo-vilab/text-to-video-ms-1.7b](https://huggingface.co/damo-vilab/text-to-video-ms-1.7b/)
|
||||
|
||||
@@ -0,0 +1,11 @@
|
||||
# Consistency Model Multistep Scheduler
|
||||
|
||||
## Overview
|
||||
|
||||
Multistep and onestep scheduler (Algorithm 1) introduced alongside consistency models in the paper [Consistency Models](https://arxiv.org/abs/2303.01469) by Yang Song, Prafulla Dhariwal, Mark Chen, and Ilya Sutskever.
|
||||
Based on the [original consistency models implementation](https://github.com/openai/consistency_models).
|
||||
Should generate good samples from [`ConsistencyModelPipeline`] in one or a small number of steps.
|
||||
|
||||
## CMStochasticIterativeScheduler
|
||||
[[autodoc]] CMStochasticIterativeScheduler
|
||||
|
||||
@@ -174,7 +174,7 @@ A checkpoint variant is usually a checkpoint where it's weights are:
|
||||
|
||||
</Tip>
|
||||
|
||||
Otherwise, a variant is **identical** to the original checkpoint. They have exactly the same serialization format (like [Safetensors](./using-diffusers/using_safetensors)), model structure, and weights have identical tensor shapes.
|
||||
Otherwise, a variant is **identical** to the original checkpoint. They have exactly the same serialization format (like [Safetensors](./using_safetensors)), model structure, and weights have identical tensor shapes.
|
||||
|
||||
| **checkpoint type** | **weight name** | **argument for loading weights** |
|
||||
|---------------------|-------------------------------------|----------------------------------|
|
||||
@@ -190,6 +190,7 @@ There are two important arguments to know for loading variants:
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
|
||||
# load fp16 variant
|
||||
stable_diffusion = DiffusionPipeline.from_pretrained(
|
||||
|
||||
@@ -26,7 +26,7 @@ This guide will show you how to convert other Stable Diffusion formats to be com
|
||||
|
||||
## PyTorch .ckpt
|
||||
|
||||
The checkpoint - or `.ckpt` - format is commonly used to store and save models. The `.ckpt` file contains the entire model and is typically several GBs in size. While you can load and use a `.ckpt` file directly with the [`~StableDiffusionPipeline.from_ckpt`] method, it is generally better to convert the `.ckpt` file to 🤗 Diffusers so both formats are available.
|
||||
The checkpoint - or `.ckpt` - format is commonly used to store and save models. The `.ckpt` file contains the entire model and is typically several GBs in size. While you can load and use a `.ckpt` file directly with the [`~StableDiffusionPipeline.from_single_file`] method, it is generally better to convert the `.ckpt` file to 🤗 Diffusers so both formats are available.
|
||||
|
||||
There are two options for converting a `.ckpt` file; use a Space to convert the checkpoint or convert the `.ckpt` file with a script.
|
||||
|
||||
|
||||
@@ -21,12 +21,12 @@ from diffusers import DiffusionPipeline
|
||||
pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", use_safetensors=True)
|
||||
```
|
||||
|
||||
However, model weights are not necessarily stored in separate subfolders like in the example above. Sometimes, all the weights are stored in a single `.safetensors` file. In this case, if the weights are Stable Diffusion weights, you can load the file directly with the [`~diffusers.loaders.FromCkptMixin.from_ckpt`] method:
|
||||
However, model weights are not necessarily stored in separate subfolders like in the example above. Sometimes, all the weights are stored in a single `.safetensors` file. In this case, if the weights are Stable Diffusion weights, you can load the file directly with the [`~diffusers.loaders.FromSingleFileMixin.from_single_file`] method:
|
||||
|
||||
```py
|
||||
from diffusers import StableDiffusionPipeline
|
||||
|
||||
pipeline = StableDiffusionPipeline.from_ckpt(
|
||||
pipeline = StableDiffusionPipeline.from_single_file(
|
||||
"https://huggingface.co/WarriorMama777/OrangeMixs/blob/main/Models/AbyssOrangeMix/AbyssOrangeMix.safetensors"
|
||||
)
|
||||
```
|
||||
|
||||
@@ -11,7 +11,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers.configuration_utils import FrozenDict
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.loaders import FromCkptMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
|
||||
from diffusers.schedulers import KarrasDiffusionSchedulers
|
||||
@@ -410,7 +410,7 @@ def preprocess_mask(mask, batch_size, scale_factor=8):
|
||||
|
||||
|
||||
class StableDiffusionLongPromptWeightingPipeline(
|
||||
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromCkptMixin
|
||||
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
|
||||
):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing
|
||||
|
||||
@@ -56,7 +56,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.18.0.dev0")
|
||||
check_min_version("0.18.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -59,7 +59,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.18.0.dev0")
|
||||
check_min_version("0.18.0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -57,7 +57,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.18.0.dev0")
|
||||
check_min_version("0.18.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -59,7 +59,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.18.0.dev0")
|
||||
check_min_version("0.18.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -36,7 +36,7 @@ from diffusers.utils import check_min_version
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.18.0.dev0")
|
||||
check_min_version("0.18.0")
|
||||
|
||||
# Cache compiled models across invocations of this script.
|
||||
cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache"))
|
||||
|
||||
@@ -65,7 +65,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.18.0.dev0")
|
||||
check_min_version("0.18.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -436,6 +436,12 @@ def parse_args(input_args=None):
|
||||
default=None,
|
||||
help="The optional `class_label` conditioning to pass to the unet, available values are `timesteps`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rank",
|
||||
type=int,
|
||||
default=4,
|
||||
help=("The dimension of the LoRA update matrices."),
|
||||
)
|
||||
|
||||
if input_args is not None:
|
||||
args = parser.parse_args(input_args)
|
||||
@@ -845,7 +851,9 @@ def main(args):
|
||||
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
|
||||
)
|
||||
unet_lora_attn_procs[name] = lora_attn_processor_class(
|
||||
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
|
||||
hidden_size=hidden_size,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
rank=args.rank,
|
||||
)
|
||||
|
||||
unet.set_attn_processor(unet_lora_attn_procs)
|
||||
@@ -860,7 +868,9 @@ def main(args):
|
||||
for name, module in text_encoder.named_modules():
|
||||
if name.endswith(TEXT_ENCODER_ATTN_MODULE):
|
||||
text_lora_attn_procs[name] = LoRAAttnProcessor(
|
||||
hidden_size=module.out_proj.out_features, cross_attention_dim=None
|
||||
hidden_size=module.out_proj.out_features,
|
||||
cross_attention_dim=None,
|
||||
rank=args.rank,
|
||||
)
|
||||
text_encoder_lora_layers = AttnProcsLayers(text_lora_attn_procs)
|
||||
temp_pipeline = DiffusionPipeline.from_pretrained(
|
||||
|
||||
@@ -52,7 +52,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.18.0.dev0")
|
||||
check_min_version("0.18.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -86,6 +86,53 @@ This example shows training for 2 subjects, but please note that the model can b
|
||||
|
||||
Note also that in this script, `sks` and `t@y` were used as tokens to learn the new subjects ([this thread](https://github.com/XavierXiao/Dreambooth-Stable-Diffusion/issues/71) inspired the use of `t@y` as our second identifier). However, there may be better rare tokens to experiment with, and results also seemed to be good when more intuitive words are used.
|
||||
|
||||
**Important**: New parameters are added to the script, making possible to validate the progress of the training by
|
||||
generating images at specified steps. Taking also into account that a comma separated list in a text field for a prompt
|
||||
it's never a good idea (simply because it is very common in prompts to have them as part of a regular text) we
|
||||
introduce the `concept_list` parameter: allowing to specify a json-like file where you can define the different
|
||||
configuration for each subject that you want to train.
|
||||
|
||||
An example of how to generate the file:
|
||||
```python
|
||||
import json
|
||||
|
||||
# here we are using parameters for prior-preservation and validation as well.
|
||||
concepts_list = [
|
||||
{
|
||||
"instance_prompt": "drawing of a t@y meme",
|
||||
"class_prompt": "drawing of a meme",
|
||||
"instance_data_dir": "/some_folder/meme_toy",
|
||||
"class_data_dir": "/data/meme",
|
||||
"validation_prompt": "drawing of a t@y meme about football in Uruguay",
|
||||
"validation_negative_prompt": "black and white"
|
||||
},
|
||||
{
|
||||
"instance_prompt": "drawing of a sks sir",
|
||||
"class_prompt": "drawing of a sir",
|
||||
"instance_data_dir": "/some_other_folder/sir_sks",
|
||||
"class_data_dir": "/data/sir",
|
||||
"validation_prompt": "drawing of a sks sir with the Uruguayan sun in his chest",
|
||||
"validation_negative_prompt": "an old man",
|
||||
"validation_guidance_scale": 20,
|
||||
"validation_number_images": 3,
|
||||
"validation_inference_steps": 10
|
||||
}
|
||||
]
|
||||
|
||||
with open("concepts_list.json", "w") as f:
|
||||
json.dump(concepts_list, f, indent=4)
|
||||
```
|
||||
And then just point to the file when executing the script:
|
||||
|
||||
```bash
|
||||
# exports...
|
||||
accelerate launch train_multi_subject_dreambooth.py \
|
||||
# more parameters...
|
||||
--concepts_list="concepts_list.json"
|
||||
```
|
||||
|
||||
You can use the helper from the script to get a better sense of each parameter.
|
||||
|
||||
### Inference
|
||||
|
||||
Once you have trained a model using above command, the inference can be done simply using the `StableDiffusionPipeline`. Make sure to include the `identifier`(e.g. sks in above example) in your prompt.
|
||||
|
||||
+357
-53
@@ -1,13 +1,18 @@
|
||||
import argparse
|
||||
import hashlib
|
||||
import itertools
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import uuid
|
||||
import warnings
|
||||
from os import environ, listdir, makedirs
|
||||
from os.path import basename, join
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
@@ -17,24 +22,140 @@ from accelerate.logging import get_logger
|
||||
from accelerate.utils import ProjectConfiguration, set_seed
|
||||
from huggingface_hub import create_repo, upload_folder
|
||||
from PIL import Image
|
||||
from torch import dtype
|
||||
from torch.nn import Module
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision import transforms
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import AutoTokenizer, PretrainedConfig
|
||||
|
||||
import diffusers
|
||||
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDPMScheduler,
|
||||
DiffusionPipeline,
|
||||
DPMSolverMultistepScheduler,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.utils import check_min_version
|
||||
from diffusers.utils import check_min_version, is_wandb_available
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.13.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def log_validation_images_to_tracker(
|
||||
images: List[np.array], label: str, validation_prompt: str, accelerator: Accelerator, epoch: int
|
||||
):
|
||||
logger.info(f"Logging images to tracker for validation prompt: {validation_prompt}.")
|
||||
|
||||
for tracker in accelerator.trackers:
|
||||
if tracker.name == "tensorboard":
|
||||
np_images = np.stack([np.asarray(img) for img in images])
|
||||
tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
|
||||
if tracker.name == "wandb":
|
||||
tracker.log(
|
||||
{
|
||||
"validation": [
|
||||
wandb.Image(image, caption=f"{label}_{epoch}_{i}: {validation_prompt}")
|
||||
for i, image in enumerate(images)
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# TODO: Add `prompt_embeds` and `negative_prompt_embeds` parameters to the function when `pre_compute_text_embeddings`
|
||||
# argument is implemented.
|
||||
def generate_validation_images(
|
||||
text_encoder: Module,
|
||||
tokenizer: Module,
|
||||
unet: Module,
|
||||
vae: Module,
|
||||
arguments: argparse.Namespace,
|
||||
accelerator: Accelerator,
|
||||
weight_dtype: dtype,
|
||||
):
|
||||
logger.info("Running validation images.")
|
||||
|
||||
pipeline_args = {}
|
||||
|
||||
if text_encoder is not None:
|
||||
pipeline_args["text_encoder"] = accelerator.unwrap_model(text_encoder)
|
||||
|
||||
if vae is not None:
|
||||
pipeline_args["vae"] = vae
|
||||
|
||||
# create pipeline (note: unet and vae are loaded again in float32)
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
arguments.pretrained_model_name_or_path,
|
||||
tokenizer=tokenizer,
|
||||
unet=accelerator.unwrap_model(unet),
|
||||
revision=arguments.revision,
|
||||
torch_dtype=weight_dtype,
|
||||
**pipeline_args,
|
||||
)
|
||||
|
||||
# We train on the simplified learning objective. If we were previously predicting a variance, we need the
|
||||
# scheduler to ignore it
|
||||
scheduler_args = {}
|
||||
|
||||
if "variance_type" in pipeline.scheduler.config:
|
||||
variance_type = pipeline.scheduler.config.variance_type
|
||||
|
||||
if variance_type in ["learned", "learned_range"]:
|
||||
variance_type = "fixed_small"
|
||||
|
||||
scheduler_args["variance_type"] = variance_type
|
||||
|
||||
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)
|
||||
pipeline = pipeline.to(accelerator.device)
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
generator = (
|
||||
None if arguments.seed is None else torch.Generator(device=accelerator.device).manual_seed(arguments.seed)
|
||||
)
|
||||
|
||||
images_sets = []
|
||||
for vp, nvi, vnp, vis, vgs in zip(
|
||||
arguments.validation_prompt,
|
||||
arguments.validation_number_images,
|
||||
arguments.validation_negative_prompt,
|
||||
arguments.validation_inference_steps,
|
||||
arguments.validation_guidance_scale,
|
||||
):
|
||||
images = []
|
||||
if vp is not None:
|
||||
logger.info(
|
||||
f"Generating {nvi} images with prompt: '{vp}', negative prompt: '{vnp}', inference steps: {vis}, "
|
||||
f"guidance scale: {vgs}."
|
||||
)
|
||||
|
||||
pipeline_args = {"prompt": vp, "negative_prompt": vnp, "num_inference_steps": vis, "guidance_scale": vgs}
|
||||
|
||||
# run inference
|
||||
# TODO: it would be good to measure whether it's faster to run inference on all images at once, one at a
|
||||
# time or in small batches
|
||||
for _ in range(nvi):
|
||||
with torch.autocast("cuda"):
|
||||
image = pipeline(**pipeline_args, num_images_per_prompt=1, generator=generator).images[0]
|
||||
images.append(image)
|
||||
|
||||
images_sets.append(images)
|
||||
|
||||
del pipeline
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return images_sets
|
||||
|
||||
|
||||
def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
|
||||
text_encoder_config = PretrainedConfig.from_pretrained(
|
||||
pretrained_model_name_or_path,
|
||||
@@ -81,7 +202,7 @@ def parse_args(input_args=None):
|
||||
"--instance_data_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
required=False,
|
||||
help="A folder containing the training data of instance images.",
|
||||
)
|
||||
parser.add_argument(
|
||||
@@ -95,7 +216,7 @@ def parse_args(input_args=None):
|
||||
"--instance_prompt",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
required=False,
|
||||
help="The prompt with identifier specifying the instance",
|
||||
)
|
||||
parser.add_argument(
|
||||
@@ -272,6 +393,52 @@ def parse_args(input_args=None):
|
||||
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validation_steps",
|
||||
type=int,
|
||||
default=None,
|
||||
help=(
|
||||
"Run validation every X steps. Validation consists of running the prompt(s) `validation_prompt` "
|
||||
"multiple times (`validation_number_images`) and logging the images."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validation_prompt",
|
||||
type=str,
|
||||
default=None,
|
||||
help="A prompt that is used during validation to verify that the model is learning. You can use commas to "
|
||||
"define multiple negative prompts. This parameter can be defined also within the file given by "
|
||||
"`concepts_list` parameter in the respective subject.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validation_number_images",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Number of images that should be generated during validation with the validation parameters given. This "
|
||||
"can be defined within the file given by `concepts_list` parameter in the respective subject.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validation_negative_prompt",
|
||||
type=str,
|
||||
default=None,
|
||||
help="A negative prompt that is used during validation to verify that the model is learning. You can use commas"
|
||||
" to define multiple negative prompts, each one corresponding to a validation prompt. This parameter can "
|
||||
"be defined also within the file given by `concepts_list` parameter in the respective subject.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validation_inference_steps",
|
||||
type=int,
|
||||
default=25,
|
||||
help="Number of inference steps (denoising steps) to run during validation. This can be defined within the "
|
||||
"file given by `concepts_list` parameter in the respective subject.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validation_guidance_scale",
|
||||
type=float,
|
||||
default=7.5,
|
||||
help="To control how much the image generation process follows the text prompt. This can be defined within the "
|
||||
"file given by `concepts_list` parameter in the respective subject.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mixed_precision",
|
||||
type=str,
|
||||
@@ -297,27 +464,80 @@ def parse_args(input_args=None):
|
||||
parser.add_argument(
|
||||
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--set_grads_to_none",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain"
|
||||
" behaviors, so disable this argument if it causes any problems. More info:"
|
||||
" https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--concepts_list",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to json file containing a list of multiple concepts, will overwrite parameters like instance_prompt,"
|
||||
" class_prompt, etc.",
|
||||
)
|
||||
|
||||
if input_args is not None:
|
||||
if input_args:
|
||||
args = parser.parse_args(input_args)
|
||||
else:
|
||||
args = parser.parse_args()
|
||||
|
||||
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
||||
if not args.concepts_list and (not args.instance_data_dir or not args.instance_prompt):
|
||||
raise ValueError(
|
||||
"You must specify either instance parameters (data directory, prompt, etc.) or use "
|
||||
"the `concept_list` parameter and specify them within the file."
|
||||
)
|
||||
|
||||
if args.concepts_list:
|
||||
if args.instance_prompt:
|
||||
raise ValueError("If you are using `concepts_list` parameter, define the instance prompt within the file.")
|
||||
if args.instance_data_dir:
|
||||
raise ValueError(
|
||||
"If you are using `concepts_list` parameter, define the instance data directory within the file."
|
||||
)
|
||||
if args.validation_steps and (args.validation_prompt or args.validation_negative_prompt):
|
||||
raise ValueError(
|
||||
"If you are using `concepts_list` parameter, define validation parameters for "
|
||||
"each subject within the file:\n - `validation_prompt`."
|
||||
"\n - `validation_negative_prompt`.\n - `validation_guidance_scale`."
|
||||
"\n - `validation_number_images`.\n - `validation_prompt`."
|
||||
"\n - `validation_inference_steps`.\nThe `validation_steps` parameter is the only one "
|
||||
"that needs to be defined outside the file."
|
||||
)
|
||||
|
||||
env_local_rank = int(environ.get("LOCAL_RANK", -1))
|
||||
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
||||
args.local_rank = env_local_rank
|
||||
|
||||
if args.with_prior_preservation:
|
||||
if args.class_data_dir is None:
|
||||
raise ValueError("You must specify a data directory for class images.")
|
||||
if args.class_prompt is None:
|
||||
raise ValueError("You must specify prompt for class images.")
|
||||
if not args.concepts_list:
|
||||
if not args.class_data_dir:
|
||||
raise ValueError("You must specify a data directory for class images.")
|
||||
if not args.class_prompt:
|
||||
raise ValueError("You must specify prompt for class images.")
|
||||
else:
|
||||
if args.class_data_dir:
|
||||
raise ValueError(
|
||||
"If you are using `concepts_list` parameter, define the class data directory within the file."
|
||||
)
|
||||
if args.class_prompt:
|
||||
raise ValueError(
|
||||
"If you are using `concepts_list` parameter, define the class prompt within the file."
|
||||
)
|
||||
else:
|
||||
# logger is not available yet
|
||||
if args.class_data_dir is not None:
|
||||
warnings.warn("You need not use --class_data_dir without --with_prior_preservation.")
|
||||
if args.class_prompt is not None:
|
||||
warnings.warn("You need not use --class_prompt without --with_prior_preservation.")
|
||||
if not args.class_data_dir:
|
||||
warnings.warn(
|
||||
"Ignoring `class_data_dir` parameter, you need to use it together with `with_prior_preservation`."
|
||||
)
|
||||
if not args.class_prompt:
|
||||
warnings.warn(
|
||||
"Ignoring `class_prompt` parameter, you need to use it together with `with_prior_preservation`."
|
||||
)
|
||||
|
||||
return args
|
||||
|
||||
@@ -325,7 +545,7 @@ def parse_args(input_args=None):
|
||||
class DreamBoothDataset(Dataset):
|
||||
"""
|
||||
A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
|
||||
It pre-processes the images and the tokenizes prompts.
|
||||
It pre-processes the images and then tokenizes prompts.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -346,7 +566,7 @@ class DreamBoothDataset(Dataset):
|
||||
self.instance_images_path = []
|
||||
self.num_instance_images = []
|
||||
self.instance_prompt = []
|
||||
self.class_data_root = []
|
||||
self.class_data_root = [] if class_data_root is not None else None
|
||||
self.class_images_path = []
|
||||
self.num_class_images = []
|
||||
self.class_prompt = []
|
||||
@@ -371,8 +591,6 @@ class DreamBoothDataset(Dataset):
|
||||
self._length -= self.num_instance_images[i]
|
||||
self._length += self.num_class_images[i]
|
||||
self.class_prompt.append(class_prompt[i])
|
||||
else:
|
||||
self.class_data_root = None
|
||||
|
||||
self.image_transforms = transforms.Compose(
|
||||
[
|
||||
@@ -446,7 +664,7 @@ def collate_fn(num_instances, examples, with_prior_preservation=False):
|
||||
|
||||
|
||||
class PromptDataset(Dataset):
|
||||
"A simple dataset to prepare the prompts to generate class images on multiple GPUs."
|
||||
"""A simple dataset to prepare the prompts to generate class images on multiple GPUs."""
|
||||
|
||||
def __init__(self, prompt, num_samples):
|
||||
self.prompt = prompt
|
||||
@@ -474,6 +692,10 @@ def main(args):
|
||||
project_config=accelerator_project_config,
|
||||
)
|
||||
|
||||
if args.report_to == "wandb":
|
||||
if not is_wandb_available():
|
||||
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
|
||||
|
||||
# Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
|
||||
# This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
|
||||
# TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
|
||||
@@ -483,23 +705,84 @@ def main(args):
|
||||
"Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
|
||||
)
|
||||
|
||||
# Parse instance and class inputs, and double check that lengths match
|
||||
instance_data_dir = args.instance_data_dir.split(",")
|
||||
instance_prompt = args.instance_prompt.split(",")
|
||||
assert all(
|
||||
x == len(instance_data_dir) for x in [len(instance_data_dir), len(instance_prompt)]
|
||||
), "Instance data dir and prompt inputs are not of the same length."
|
||||
instance_data_dir = []
|
||||
instance_prompt = []
|
||||
class_data_dir = [] if args.with_prior_preservation else None
|
||||
class_prompt = [] if args.with_prior_preservation else None
|
||||
if args.concepts_list:
|
||||
with open(args.concepts_list, "r") as f:
|
||||
concepts_list = json.load(f)
|
||||
|
||||
if args.with_prior_preservation:
|
||||
class_data_dir = args.class_data_dir.split(",")
|
||||
class_prompt = args.class_prompt.split(",")
|
||||
assert all(
|
||||
x == len(instance_data_dir)
|
||||
for x in [len(instance_data_dir), len(instance_prompt), len(class_data_dir), len(class_prompt)]
|
||||
), "Instance & class data dir or prompt inputs are not of the same length."
|
||||
if args.validation_steps:
|
||||
args.validation_prompt = []
|
||||
args.validation_number_images = []
|
||||
args.validation_negative_prompt = []
|
||||
args.validation_inference_steps = []
|
||||
args.validation_guidance_scale = []
|
||||
|
||||
for concept in concepts_list:
|
||||
instance_data_dir.append(concept["instance_data_dir"])
|
||||
instance_prompt.append(concept["instance_prompt"])
|
||||
|
||||
if args.with_prior_preservation:
|
||||
try:
|
||||
class_data_dir.append(concept["class_data_dir"])
|
||||
class_prompt.append(concept["class_prompt"])
|
||||
except KeyError:
|
||||
raise KeyError(
|
||||
"`class_data_dir` or `class_prompt` not found in concepts_list while using "
|
||||
"`with_prior_preservation`."
|
||||
)
|
||||
else:
|
||||
if "class_data_dir" in concept:
|
||||
warnings.warn(
|
||||
"Ignoring `class_data_dir` key, to use it you need to enable `with_prior_preservation`."
|
||||
)
|
||||
if "class_prompt" in concept:
|
||||
warnings.warn(
|
||||
"Ignoring `class_prompt` key, to use it you need to enable `with_prior_preservation`."
|
||||
)
|
||||
|
||||
if args.validation_steps:
|
||||
args.validation_prompt.append(concept.get("validation_prompt", None))
|
||||
args.validation_number_images.append(concept.get("validation_number_images", 4))
|
||||
args.validation_negative_prompt.append(concept.get("validation_negative_prompt", None))
|
||||
args.validation_inference_steps.append(concept.get("validation_inference_steps", 25))
|
||||
args.validation_guidance_scale.append(concept.get("validation_guidance_scale", 7.5))
|
||||
else:
|
||||
class_data_dir = args.class_data_dir
|
||||
class_prompt = args.class_prompt
|
||||
# Parse instance and class inputs, and double check that lengths match
|
||||
instance_data_dir = args.instance_data_dir.split(",")
|
||||
instance_prompt = args.instance_prompt.split(",")
|
||||
assert all(
|
||||
x == len(instance_data_dir) for x in [len(instance_data_dir), len(instance_prompt)]
|
||||
), "Instance data dir and prompt inputs are not of the same length."
|
||||
|
||||
if args.with_prior_preservation:
|
||||
class_data_dir = args.class_data_dir.split(",")
|
||||
class_prompt = args.class_prompt.split(",")
|
||||
assert all(
|
||||
x == len(instance_data_dir)
|
||||
for x in [len(instance_data_dir), len(instance_prompt), len(class_data_dir), len(class_prompt)]
|
||||
), "Instance & class data dir or prompt inputs are not of the same length."
|
||||
|
||||
if args.validation_steps:
|
||||
validation_prompts = args.validation_prompt.split(",")
|
||||
num_of_validation_prompts = len(validation_prompts)
|
||||
args.validation_prompt = validation_prompts
|
||||
args.validation_number_images = [args.validation_number_images] * num_of_validation_prompts
|
||||
|
||||
negative_validation_prompts = [None] * num_of_validation_prompts
|
||||
if args.validation_negative_prompt:
|
||||
negative_validation_prompts = args.validation_negative_prompt.split(",")
|
||||
while len(negative_validation_prompts) < num_of_validation_prompts:
|
||||
negative_validation_prompts.append(None)
|
||||
args.validation_negative_prompt = negative_validation_prompts
|
||||
|
||||
assert num_of_validation_prompts == len(
|
||||
negative_validation_prompts
|
||||
), "The length of negative prompts for validation is greater than the number of validation prompts."
|
||||
args.validation_inference_steps = [args.validation_inference_steps] * num_of_validation_prompts
|
||||
args.validation_guidance_scale = [args.validation_guidance_scale] * num_of_validation_prompts
|
||||
|
||||
# Make one log on every process with the configuration for debugging.
|
||||
logging.basicConfig(
|
||||
@@ -559,21 +842,24 @@ def main(args):
|
||||
):
|
||||
images = pipeline(example["prompt"]).images
|
||||
|
||||
for i, image in enumerate(images):
|
||||
for ii, image in enumerate(images):
|
||||
hash_image = hashlib.sha1(image.tobytes()).hexdigest()
|
||||
image_filename = (
|
||||
class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
|
||||
class_images_dir / f"{example['index'][ii] + cur_class_images}-{hash_image}.jpg"
|
||||
)
|
||||
image.save(image_filename)
|
||||
|
||||
# Clean up the memory deleting one-time-use variables.
|
||||
del pipeline
|
||||
del sample_dataloader
|
||||
del sample_dataset
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Handle the repository creation
|
||||
if accelerator.is_main_process:
|
||||
if args.output_dir is not None:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
if args.push_to_hub:
|
||||
repo_id = create_repo(
|
||||
@@ -581,6 +867,7 @@ def main(args):
|
||||
).repo_id
|
||||
|
||||
# Load the tokenizer
|
||||
tokenizer = None
|
||||
if args.tokenizer_name:
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)
|
||||
elif args.pretrained_model_name_or_path:
|
||||
@@ -658,7 +945,7 @@ def main(args):
|
||||
train_dataset = DreamBoothDataset(
|
||||
instance_data_root=instance_data_dir,
|
||||
instance_prompt=instance_prompt,
|
||||
class_data_root=class_data_dir if args.with_prior_preservation else None,
|
||||
class_data_root=class_data_dir,
|
||||
class_prompt=class_prompt,
|
||||
tokenizer=tokenizer,
|
||||
size=args.resolution,
|
||||
@@ -720,7 +1007,7 @@ def main(args):
|
||||
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
|
||||
# We need to initialize the trackers we use, and also store our configuration.
|
||||
# The trackers initializes automatically on the main process.
|
||||
# The trackers initialize automatically on the main process.
|
||||
if accelerator.is_main_process:
|
||||
accelerator.init_trackers("dreambooth", config=vars(args))
|
||||
|
||||
@@ -741,10 +1028,10 @@ def main(args):
|
||||
# Potentially load in the weights and states from a previous save
|
||||
if args.resume_from_checkpoint:
|
||||
if args.resume_from_checkpoint != "latest":
|
||||
path = os.path.basename(args.resume_from_checkpoint)
|
||||
path = basename(args.resume_from_checkpoint)
|
||||
else:
|
||||
# Get the mos recent checkpoint
|
||||
dirs = os.listdir(args.output_dir)
|
||||
dirs = listdir(args.output_dir)
|
||||
dirs = [d for d in dirs if d.startswith("checkpoint")]
|
||||
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
||||
path = dirs[-1] if len(dirs) > 0 else None
|
||||
@@ -756,7 +1043,7 @@ def main(args):
|
||||
args.resume_from_checkpoint = None
|
||||
else:
|
||||
accelerator.print(f"Resuming from checkpoint {path}")
|
||||
accelerator.load_state(os.path.join(args.output_dir, path))
|
||||
accelerator.load_state(join(args.output_dir, path))
|
||||
global_step = int(path.split("-")[1])
|
||||
|
||||
resume_global_step = global_step * args.gradient_accumulation_steps
|
||||
@@ -787,24 +1074,26 @@ def main(args):
|
||||
noise = torch.randn_like(latents)
|
||||
bsz = latents.shape[0]
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
|
||||
timesteps = timesteps.long()
|
||||
time_steps = torch.randint(
|
||||
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device
|
||||
)
|
||||
time_steps = time_steps.long()
|
||||
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise, time_steps)
|
||||
|
||||
# Get the text embedding for conditioning
|
||||
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
|
||||
|
||||
# Predict the noise residual
|
||||
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||
model_pred = unet(noisy_latents, time_steps, encoder_hidden_states).sample
|
||||
|
||||
# Get the target for loss depending on the prediction type
|
||||
if noise_scheduler.config.prediction_type == "epsilon":
|
||||
target = noise
|
||||
elif noise_scheduler.config.prediction_type == "v_prediction":
|
||||
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
||||
target = noise_scheduler.get_velocity(latents, noise, time_steps)
|
||||
else:
|
||||
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
||||
|
||||
@@ -834,19 +1123,34 @@ def main(args):
|
||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
optimizer.zero_grad(set_to_none=args.set_grads_to_none)
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
if global_step % args.checkpointing_steps == 0:
|
||||
if accelerator.is_main_process:
|
||||
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
||||
if accelerator.is_main_process:
|
||||
if global_step % args.checkpointing_steps == 0:
|
||||
save_path = join(args.output_dir, f"checkpoint-{global_step}")
|
||||
accelerator.save_state(save_path)
|
||||
logger.info(f"Saved state to {save_path}")
|
||||
|
||||
if (
|
||||
args.validation_steps
|
||||
and any(args.validation_prompt)
|
||||
and global_step % args.validation_steps == 0
|
||||
):
|
||||
images_set = generate_validation_images(
|
||||
text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype
|
||||
)
|
||||
for images, validation_prompt in zip(images_set, args.validation_prompt):
|
||||
if len(images) > 0:
|
||||
label = str(uuid.uuid1())[:8] # generate an id for different set of images
|
||||
log_validation_images_to_tracker(
|
||||
images, label, validation_prompt, accelerator, global_step
|
||||
)
|
||||
|
||||
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
||||
progress_bar.set_postfix(**logs)
|
||||
accelerator.log(logs, step=global_step)
|
||||
@@ -854,7 +1158,7 @@ def main(args):
|
||||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
|
||||
# Create the pipeline using using the trained modules and save it.
|
||||
# Create the pipeline using the trained modules and save it.
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
|
||||
+3
-1
@@ -568,7 +568,9 @@ def main(args):
|
||||
|
||||
clean_images = batch["input"]
|
||||
# Sample noise that we'll add to the images
|
||||
noise = torch.randn(clean_images.shape).to(clean_images.device)
|
||||
noise = torch.randn(
|
||||
clean_images.shape, dtype=(torch.float32 if args.mixed_precision == "no" else torch.float16)
|
||||
).to(clean_images.device)
|
||||
bsz = clean_images.shape[0]
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.randint(
|
||||
|
||||
@@ -54,7 +54,7 @@ if is_wandb_available():
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.18.0.dev0")
|
||||
check_min_version("0.18.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -33,7 +33,7 @@ from diffusers.utils import check_min_version
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.18.0.dev0")
|
||||
check_min_version("0.18.0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -48,7 +48,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.18.0.dev0")
|
||||
check_min_version("0.18.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -78,7 +78,7 @@ else:
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.18.0.dev0")
|
||||
check_min_version("0.18.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -56,7 +56,7 @@ else:
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.18.0.dev0")
|
||||
check_min_version("0.18.0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.18.0.dev0")
|
||||
check_min_version("0.18.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
@@ -557,7 +557,9 @@ def main(args):
|
||||
|
||||
clean_images = batch["input"]
|
||||
# Sample noise that we'll add to the images
|
||||
noise = torch.randn(clean_images.shape).to(clean_images.device)
|
||||
noise = torch.randn(
|
||||
clean_images.shape, dtype=(torch.float32 if args.mixed_precision == "no" else torch.float16)
|
||||
).to(clean_images.device)
|
||||
bsz = clean_images.shape[0]
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.randint(
|
||||
|
||||
@@ -0,0 +1,313 @@
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import (
|
||||
CMStochasticIterativeScheduler,
|
||||
ConsistencyModelPipeline,
|
||||
UNet2DModel,
|
||||
)
|
||||
|
||||
|
||||
TEST_UNET_CONFIG = {
|
||||
"sample_size": 32,
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"layers_per_block": 2,
|
||||
"num_class_embeds": 1000,
|
||||
"block_out_channels": [32, 64],
|
||||
"attention_head_dim": 8,
|
||||
"down_block_types": [
|
||||
"ResnetDownsampleBlock2D",
|
||||
"AttnDownBlock2D",
|
||||
],
|
||||
"up_block_types": [
|
||||
"AttnUpBlock2D",
|
||||
"ResnetUpsampleBlock2D",
|
||||
],
|
||||
"resnet_time_scale_shift": "scale_shift",
|
||||
"upsample_type": "resnet",
|
||||
"downsample_type": "resnet",
|
||||
}
|
||||
|
||||
IMAGENET_64_UNET_CONFIG = {
|
||||
"sample_size": 64,
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"layers_per_block": 3,
|
||||
"num_class_embeds": 1000,
|
||||
"block_out_channels": [192, 192 * 2, 192 * 3, 192 * 4],
|
||||
"attention_head_dim": 64,
|
||||
"down_block_types": [
|
||||
"ResnetDownsampleBlock2D",
|
||||
"AttnDownBlock2D",
|
||||
"AttnDownBlock2D",
|
||||
"AttnDownBlock2D",
|
||||
],
|
||||
"up_block_types": [
|
||||
"AttnUpBlock2D",
|
||||
"AttnUpBlock2D",
|
||||
"AttnUpBlock2D",
|
||||
"ResnetUpsampleBlock2D",
|
||||
],
|
||||
"resnet_time_scale_shift": "scale_shift",
|
||||
"upsample_type": "resnet",
|
||||
"downsample_type": "resnet",
|
||||
}
|
||||
|
||||
LSUN_256_UNET_CONFIG = {
|
||||
"sample_size": 256,
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"layers_per_block": 2,
|
||||
"num_class_embeds": None,
|
||||
"block_out_channels": [256, 256, 256 * 2, 256 * 2, 256 * 4, 256 * 4],
|
||||
"attention_head_dim": 64,
|
||||
"down_block_types": [
|
||||
"ResnetDownsampleBlock2D",
|
||||
"ResnetDownsampleBlock2D",
|
||||
"ResnetDownsampleBlock2D",
|
||||
"AttnDownBlock2D",
|
||||
"AttnDownBlock2D",
|
||||
"AttnDownBlock2D",
|
||||
],
|
||||
"up_block_types": [
|
||||
"AttnUpBlock2D",
|
||||
"AttnUpBlock2D",
|
||||
"AttnUpBlock2D",
|
||||
"ResnetUpsampleBlock2D",
|
||||
"ResnetUpsampleBlock2D",
|
||||
"ResnetUpsampleBlock2D",
|
||||
],
|
||||
"resnet_time_scale_shift": "default",
|
||||
"upsample_type": "resnet",
|
||||
"downsample_type": "resnet",
|
||||
}
|
||||
|
||||
CD_SCHEDULER_CONFIG = {
|
||||
"num_train_timesteps": 40,
|
||||
"sigma_min": 0.002,
|
||||
"sigma_max": 80.0,
|
||||
}
|
||||
|
||||
CT_IMAGENET_64_SCHEDULER_CONFIG = {
|
||||
"num_train_timesteps": 201,
|
||||
"sigma_min": 0.002,
|
||||
"sigma_max": 80.0,
|
||||
}
|
||||
|
||||
CT_LSUN_256_SCHEDULER_CONFIG = {
|
||||
"num_train_timesteps": 151,
|
||||
"sigma_min": 0.002,
|
||||
"sigma_max": 80.0,
|
||||
}
|
||||
|
||||
|
||||
def str2bool(v):
|
||||
"""
|
||||
https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
|
||||
"""
|
||||
if isinstance(v, bool):
|
||||
return v
|
||||
if v.lower() in ("yes", "true", "t", "y", "1"):
|
||||
return True
|
||||
elif v.lower() in ("no", "false", "f", "n", "0"):
|
||||
return False
|
||||
else:
|
||||
raise argparse.ArgumentTypeError("boolean value expected")
|
||||
|
||||
|
||||
def convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip=False):
|
||||
new_checkpoint[f"{new_prefix}.norm1.weight"] = checkpoint[f"{old_prefix}.in_layers.0.weight"]
|
||||
new_checkpoint[f"{new_prefix}.norm1.bias"] = checkpoint[f"{old_prefix}.in_layers.0.bias"]
|
||||
new_checkpoint[f"{new_prefix}.conv1.weight"] = checkpoint[f"{old_prefix}.in_layers.2.weight"]
|
||||
new_checkpoint[f"{new_prefix}.conv1.bias"] = checkpoint[f"{old_prefix}.in_layers.2.bias"]
|
||||
new_checkpoint[f"{new_prefix}.time_emb_proj.weight"] = checkpoint[f"{old_prefix}.emb_layers.1.weight"]
|
||||
new_checkpoint[f"{new_prefix}.time_emb_proj.bias"] = checkpoint[f"{old_prefix}.emb_layers.1.bias"]
|
||||
new_checkpoint[f"{new_prefix}.norm2.weight"] = checkpoint[f"{old_prefix}.out_layers.0.weight"]
|
||||
new_checkpoint[f"{new_prefix}.norm2.bias"] = checkpoint[f"{old_prefix}.out_layers.0.bias"]
|
||||
new_checkpoint[f"{new_prefix}.conv2.weight"] = checkpoint[f"{old_prefix}.out_layers.3.weight"]
|
||||
new_checkpoint[f"{new_prefix}.conv2.bias"] = checkpoint[f"{old_prefix}.out_layers.3.bias"]
|
||||
|
||||
if has_skip:
|
||||
new_checkpoint[f"{new_prefix}.conv_shortcut.weight"] = checkpoint[f"{old_prefix}.skip_connection.weight"]
|
||||
new_checkpoint[f"{new_prefix}.conv_shortcut.bias"] = checkpoint[f"{old_prefix}.skip_connection.bias"]
|
||||
|
||||
return new_checkpoint
|
||||
|
||||
|
||||
def convert_attention(checkpoint, new_checkpoint, old_prefix, new_prefix, attention_dim=None):
|
||||
weight_q, weight_k, weight_v = checkpoint[f"{old_prefix}.qkv.weight"].chunk(3, dim=0)
|
||||
bias_q, bias_k, bias_v = checkpoint[f"{old_prefix}.qkv.bias"].chunk(3, dim=0)
|
||||
|
||||
new_checkpoint[f"{new_prefix}.group_norm.weight"] = checkpoint[f"{old_prefix}.norm.weight"]
|
||||
new_checkpoint[f"{new_prefix}.group_norm.bias"] = checkpoint[f"{old_prefix}.norm.bias"]
|
||||
|
||||
new_checkpoint[f"{new_prefix}.to_q.weight"] = weight_q.squeeze(-1).squeeze(-1)
|
||||
new_checkpoint[f"{new_prefix}.to_q.bias"] = bias_q.squeeze(-1).squeeze(-1)
|
||||
new_checkpoint[f"{new_prefix}.to_k.weight"] = weight_k.squeeze(-1).squeeze(-1)
|
||||
new_checkpoint[f"{new_prefix}.to_k.bias"] = bias_k.squeeze(-1).squeeze(-1)
|
||||
new_checkpoint[f"{new_prefix}.to_v.weight"] = weight_v.squeeze(-1).squeeze(-1)
|
||||
new_checkpoint[f"{new_prefix}.to_v.bias"] = bias_v.squeeze(-1).squeeze(-1)
|
||||
|
||||
new_checkpoint[f"{new_prefix}.to_out.0.weight"] = (
|
||||
checkpoint[f"{old_prefix}.proj_out.weight"].squeeze(-1).squeeze(-1)
|
||||
)
|
||||
new_checkpoint[f"{new_prefix}.to_out.0.bias"] = checkpoint[f"{old_prefix}.proj_out.bias"].squeeze(-1).squeeze(-1)
|
||||
|
||||
return new_checkpoint
|
||||
|
||||
|
||||
def con_pt_to_diffuser(checkpoint_path: str, unet_config):
|
||||
checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
||||
new_checkpoint = {}
|
||||
|
||||
new_checkpoint["time_embedding.linear_1.weight"] = checkpoint["time_embed.0.weight"]
|
||||
new_checkpoint["time_embedding.linear_1.bias"] = checkpoint["time_embed.0.bias"]
|
||||
new_checkpoint["time_embedding.linear_2.weight"] = checkpoint["time_embed.2.weight"]
|
||||
new_checkpoint["time_embedding.linear_2.bias"] = checkpoint["time_embed.2.bias"]
|
||||
|
||||
if unet_config["num_class_embeds"] is not None:
|
||||
new_checkpoint["class_embedding.weight"] = checkpoint["label_emb.weight"]
|
||||
|
||||
new_checkpoint["conv_in.weight"] = checkpoint["input_blocks.0.0.weight"]
|
||||
new_checkpoint["conv_in.bias"] = checkpoint["input_blocks.0.0.bias"]
|
||||
|
||||
down_block_types = unet_config["down_block_types"]
|
||||
layers_per_block = unet_config["layers_per_block"]
|
||||
attention_head_dim = unet_config["attention_head_dim"]
|
||||
channels_list = unet_config["block_out_channels"]
|
||||
current_layer = 1
|
||||
prev_channels = channels_list[0]
|
||||
|
||||
for i, layer_type in enumerate(down_block_types):
|
||||
current_channels = channels_list[i]
|
||||
downsample_block_has_skip = current_channels != prev_channels
|
||||
if layer_type == "ResnetDownsampleBlock2D":
|
||||
for j in range(layers_per_block):
|
||||
new_prefix = f"down_blocks.{i}.resnets.{j}"
|
||||
old_prefix = f"input_blocks.{current_layer}.0"
|
||||
has_skip = True if j == 0 and downsample_block_has_skip else False
|
||||
new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip=has_skip)
|
||||
current_layer += 1
|
||||
|
||||
elif layer_type == "AttnDownBlock2D":
|
||||
for j in range(layers_per_block):
|
||||
new_prefix = f"down_blocks.{i}.resnets.{j}"
|
||||
old_prefix = f"input_blocks.{current_layer}.0"
|
||||
has_skip = True if j == 0 and downsample_block_has_skip else False
|
||||
new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip=has_skip)
|
||||
new_prefix = f"down_blocks.{i}.attentions.{j}"
|
||||
old_prefix = f"input_blocks.{current_layer}.1"
|
||||
new_checkpoint = convert_attention(
|
||||
checkpoint, new_checkpoint, old_prefix, new_prefix, attention_head_dim
|
||||
)
|
||||
current_layer += 1
|
||||
|
||||
if i != len(down_block_types) - 1:
|
||||
new_prefix = f"down_blocks.{i}.downsamplers.0"
|
||||
old_prefix = f"input_blocks.{current_layer}.0"
|
||||
new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix)
|
||||
current_layer += 1
|
||||
|
||||
prev_channels = current_channels
|
||||
|
||||
# hardcoded the mid-block for now
|
||||
new_prefix = "mid_block.resnets.0"
|
||||
old_prefix = "middle_block.0"
|
||||
new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix)
|
||||
new_prefix = "mid_block.attentions.0"
|
||||
old_prefix = "middle_block.1"
|
||||
new_checkpoint = convert_attention(checkpoint, new_checkpoint, old_prefix, new_prefix, attention_head_dim)
|
||||
new_prefix = "mid_block.resnets.1"
|
||||
old_prefix = "middle_block.2"
|
||||
new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix)
|
||||
|
||||
current_layer = 0
|
||||
up_block_types = unet_config["up_block_types"]
|
||||
|
||||
for i, layer_type in enumerate(up_block_types):
|
||||
if layer_type == "ResnetUpsampleBlock2D":
|
||||
for j in range(layers_per_block + 1):
|
||||
new_prefix = f"up_blocks.{i}.resnets.{j}"
|
||||
old_prefix = f"output_blocks.{current_layer}.0"
|
||||
new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip=True)
|
||||
current_layer += 1
|
||||
|
||||
if i != len(up_block_types) - 1:
|
||||
new_prefix = f"up_blocks.{i}.upsamplers.0"
|
||||
old_prefix = f"output_blocks.{current_layer-1}.1"
|
||||
new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix)
|
||||
elif layer_type == "AttnUpBlock2D":
|
||||
for j in range(layers_per_block + 1):
|
||||
new_prefix = f"up_blocks.{i}.resnets.{j}"
|
||||
old_prefix = f"output_blocks.{current_layer}.0"
|
||||
new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip=True)
|
||||
new_prefix = f"up_blocks.{i}.attentions.{j}"
|
||||
old_prefix = f"output_blocks.{current_layer}.1"
|
||||
new_checkpoint = convert_attention(
|
||||
checkpoint, new_checkpoint, old_prefix, new_prefix, attention_head_dim
|
||||
)
|
||||
current_layer += 1
|
||||
|
||||
if i != len(up_block_types) - 1:
|
||||
new_prefix = f"up_blocks.{i}.upsamplers.0"
|
||||
old_prefix = f"output_blocks.{current_layer-1}.2"
|
||||
new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix)
|
||||
|
||||
new_checkpoint["conv_norm_out.weight"] = checkpoint["out.0.weight"]
|
||||
new_checkpoint["conv_norm_out.bias"] = checkpoint["out.0.bias"]
|
||||
new_checkpoint["conv_out.weight"] = checkpoint["out.2.weight"]
|
||||
new_checkpoint["conv_out.bias"] = checkpoint["out.2.bias"]
|
||||
|
||||
return new_checkpoint
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("--unet_path", default=None, type=str, required=True, help="Path to the unet.pt to convert.")
|
||||
parser.add_argument(
|
||||
"--dump_path", default=None, type=str, required=True, help="Path to output the converted UNet model."
|
||||
)
|
||||
parser.add_argument("--class_cond", default=True, type=str, help="Whether the model is class-conditional.")
|
||||
|
||||
args = parser.parse_args()
|
||||
args.class_cond = str2bool(args.class_cond)
|
||||
|
||||
ckpt_name = os.path.basename(args.unet_path)
|
||||
print(f"Checkpoint: {ckpt_name}")
|
||||
|
||||
# Get U-Net config
|
||||
if "imagenet64" in ckpt_name:
|
||||
unet_config = IMAGENET_64_UNET_CONFIG
|
||||
elif "256" in ckpt_name and (("bedroom" in ckpt_name) or ("cat" in ckpt_name)):
|
||||
unet_config = LSUN_256_UNET_CONFIG
|
||||
elif "test" in ckpt_name:
|
||||
unet_config = TEST_UNET_CONFIG
|
||||
else:
|
||||
raise ValueError(f"Checkpoint type {ckpt_name} is not currently supported.")
|
||||
|
||||
if not args.class_cond:
|
||||
unet_config["num_class_embeds"] = None
|
||||
|
||||
converted_unet_ckpt = con_pt_to_diffuser(args.unet_path, unet_config)
|
||||
|
||||
image_unet = UNet2DModel(**unet_config)
|
||||
image_unet.load_state_dict(converted_unet_ckpt)
|
||||
|
||||
# Get scheduler config
|
||||
if "cd" in ckpt_name or "test" in ckpt_name:
|
||||
scheduler_config = CD_SCHEDULER_CONFIG
|
||||
elif "ct" in ckpt_name and "imagenet64" in ckpt_name:
|
||||
scheduler_config = CT_IMAGENET_64_SCHEDULER_CONFIG
|
||||
elif "ct" in ckpt_name and "256" in ckpt_name and (("bedroom" in ckpt_name) or ("cat" in ckpt_name)):
|
||||
scheduler_config = CT_LSUN_256_SCHEDULER_CONFIG
|
||||
else:
|
||||
raise ValueError(f"Checkpoint type {ckpt_name} is not currently supported.")
|
||||
|
||||
cm_scheduler = CMStochasticIterativeScheduler(**scheduler_config)
|
||||
|
||||
consistency_model = ConsistencyModelPipeline(unet=image_unet, scheduler=cm_scheduler)
|
||||
consistency_model.save_pretrained(args.dump_path)
|
||||
@@ -126,6 +126,13 @@ if __name__ == "__main__":
|
||||
"--controlnet", action="store_true", default=None, help="Set flag if this is a controlnet checkpoint."
|
||||
)
|
||||
parser.add_argument("--half", action="store_true", help="Save weights in half precision.")
|
||||
parser.add_argument(
|
||||
"--vae_path",
|
||||
type=str,
|
||||
default=None,
|
||||
required=False,
|
||||
help="Set to a path, hub id to an already converted vae to not convert it again.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
pipe = download_from_original_stable_diffusion_ckpt(
|
||||
@@ -144,6 +151,7 @@ if __name__ == "__main__":
|
||||
stable_unclip_prior=args.stable_unclip_prior,
|
||||
clip_stats_path=args.clip_stats_path,
|
||||
controlnet=args.controlnet,
|
||||
vae_path=args.vae_path,
|
||||
)
|
||||
|
||||
if args.half:
|
||||
|
||||
@@ -0,0 +1,594 @@
|
||||
import argparse
|
||||
import tempfile
|
||||
|
||||
import torch
|
||||
from accelerate import load_checkpoint_and_dispatch
|
||||
|
||||
from diffusers.models.prior_transformer import PriorTransformer
|
||||
from diffusers.pipelines.shap_e import ShapERenderer
|
||||
|
||||
|
||||
"""
|
||||
Example - From the diffusers root directory:
|
||||
|
||||
Download weights:
|
||||
```sh
|
||||
$ wget "https://openaipublic.azureedge.net/main/shap-e/text_cond.pt"
|
||||
```
|
||||
|
||||
Convert the model:
|
||||
```sh
|
||||
$ python scripts/convert_shap_e_to_diffusers.py \
|
||||
--prior_checkpoint_path /home/yiyi_huggingface_co/shap-e/shap_e_model_cache/text_cond.pt \
|
||||
--prior_image_checkpoint_path /home/yiyi_huggingface_co/shap-e/shap_e_model_cache/image_cond.pt \
|
||||
--transmitter_checkpoint_path /home/yiyi_huggingface_co/shap-e/shap_e_model_cache/transmitter.pt\
|
||||
--dump_path /home/yiyi_huggingface_co/model_repo/shap-e/renderer\
|
||||
--debug renderer
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
# prior
|
||||
|
||||
PRIOR_ORIGINAL_PREFIX = "wrapped"
|
||||
|
||||
PRIOR_CONFIG = {
|
||||
"num_attention_heads": 16,
|
||||
"attention_head_dim": 1024 // 16,
|
||||
"num_layers": 24,
|
||||
"embedding_dim": 1024,
|
||||
"num_embeddings": 1024,
|
||||
"additional_embeddings": 0,
|
||||
"time_embed_act_fn": "gelu",
|
||||
"norm_in_type": "layer",
|
||||
"encoder_hid_proj_type": None,
|
||||
"added_emb_type": None,
|
||||
"time_embed_dim": 1024 * 4,
|
||||
"embedding_proj_dim": 768,
|
||||
"clip_embed_dim": 1024 * 2,
|
||||
}
|
||||
|
||||
|
||||
def prior_model_from_original_config():
|
||||
model = PriorTransformer(**PRIOR_CONFIG)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def prior_original_checkpoint_to_diffusers_checkpoint(model, checkpoint):
|
||||
diffusers_checkpoint = {}
|
||||
|
||||
# <original>.time_embed.c_fc -> <diffusers>.time_embedding.linear_1
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
"time_embedding.linear_1.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.time_embed.c_fc.weight"],
|
||||
"time_embedding.linear_1.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.time_embed.c_fc.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
# <original>.time_embed.c_proj -> <diffusers>.time_embedding.linear_2
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
"time_embedding.linear_2.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.time_embed.c_proj.weight"],
|
||||
"time_embedding.linear_2.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.time_embed.c_proj.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
# <original>.input_proj -> <diffusers>.proj_in
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
"proj_in.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.input_proj.weight"],
|
||||
"proj_in.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.input_proj.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
# <original>.clip_emb -> <diffusers>.embedding_proj
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
"embedding_proj.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.clip_embed.weight"],
|
||||
"embedding_proj.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.clip_embed.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
# <original>.pos_emb -> <diffusers>.positional_embedding
|
||||
diffusers_checkpoint.update({"positional_embedding": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.pos_emb"][None, :]})
|
||||
|
||||
# <original>.ln_pre -> <diffusers>.norm_in
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
"norm_in.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.ln_pre.weight"],
|
||||
"norm_in.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.ln_pre.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
# <original>.backbone.resblocks.<x> -> <diffusers>.transformer_blocks.<x>
|
||||
for idx in range(len(model.transformer_blocks)):
|
||||
diffusers_transformer_prefix = f"transformer_blocks.{idx}"
|
||||
original_transformer_prefix = f"{PRIOR_ORIGINAL_PREFIX}.backbone.resblocks.{idx}"
|
||||
|
||||
# <original>.attn -> <diffusers>.attn1
|
||||
diffusers_attention_prefix = f"{diffusers_transformer_prefix}.attn1"
|
||||
original_attention_prefix = f"{original_transformer_prefix}.attn"
|
||||
diffusers_checkpoint.update(
|
||||
prior_attention_to_diffusers(
|
||||
checkpoint,
|
||||
diffusers_attention_prefix=diffusers_attention_prefix,
|
||||
original_attention_prefix=original_attention_prefix,
|
||||
attention_head_dim=model.attention_head_dim,
|
||||
)
|
||||
)
|
||||
|
||||
# <original>.mlp -> <diffusers>.ff
|
||||
diffusers_ff_prefix = f"{diffusers_transformer_prefix}.ff"
|
||||
original_ff_prefix = f"{original_transformer_prefix}.mlp"
|
||||
diffusers_checkpoint.update(
|
||||
prior_ff_to_diffusers(
|
||||
checkpoint, diffusers_ff_prefix=diffusers_ff_prefix, original_ff_prefix=original_ff_prefix
|
||||
)
|
||||
)
|
||||
|
||||
# <original>.ln_1 -> <diffusers>.norm1
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
f"{diffusers_transformer_prefix}.norm1.weight": checkpoint[
|
||||
f"{original_transformer_prefix}.ln_1.weight"
|
||||
],
|
||||
f"{diffusers_transformer_prefix}.norm1.bias": checkpoint[f"{original_transformer_prefix}.ln_1.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
# <original>.ln_2 -> <diffusers>.norm3
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
f"{diffusers_transformer_prefix}.norm3.weight": checkpoint[
|
||||
f"{original_transformer_prefix}.ln_2.weight"
|
||||
],
|
||||
f"{diffusers_transformer_prefix}.norm3.bias": checkpoint[f"{original_transformer_prefix}.ln_2.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
# <original>.ln_post -> <diffusers>.norm_out
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
"norm_out.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.ln_post.weight"],
|
||||
"norm_out.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.ln_post.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
# <original>.output_proj -> <diffusers>.proj_to_clip_embeddings
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
"proj_to_clip_embeddings.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.output_proj.weight"],
|
||||
"proj_to_clip_embeddings.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.output_proj.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
return diffusers_checkpoint
|
||||
|
||||
|
||||
def prior_attention_to_diffusers(
|
||||
checkpoint, *, diffusers_attention_prefix, original_attention_prefix, attention_head_dim
|
||||
):
|
||||
diffusers_checkpoint = {}
|
||||
|
||||
# <original>.c_qkv -> <diffusers>.{to_q, to_k, to_v}
|
||||
[q_weight, k_weight, v_weight], [q_bias, k_bias, v_bias] = split_attentions(
|
||||
weight=checkpoint[f"{original_attention_prefix}.c_qkv.weight"],
|
||||
bias=checkpoint[f"{original_attention_prefix}.c_qkv.bias"],
|
||||
split=3,
|
||||
chunk_size=attention_head_dim,
|
||||
)
|
||||
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
f"{diffusers_attention_prefix}.to_q.weight": q_weight,
|
||||
f"{diffusers_attention_prefix}.to_q.bias": q_bias,
|
||||
f"{diffusers_attention_prefix}.to_k.weight": k_weight,
|
||||
f"{diffusers_attention_prefix}.to_k.bias": k_bias,
|
||||
f"{diffusers_attention_prefix}.to_v.weight": v_weight,
|
||||
f"{diffusers_attention_prefix}.to_v.bias": v_bias,
|
||||
}
|
||||
)
|
||||
|
||||
# <original>.c_proj -> <diffusers>.to_out.0
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
f"{diffusers_attention_prefix}.to_out.0.weight": checkpoint[f"{original_attention_prefix}.c_proj.weight"],
|
||||
f"{diffusers_attention_prefix}.to_out.0.bias": checkpoint[f"{original_attention_prefix}.c_proj.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
return diffusers_checkpoint
|
||||
|
||||
|
||||
def prior_ff_to_diffusers(checkpoint, *, diffusers_ff_prefix, original_ff_prefix):
|
||||
diffusers_checkpoint = {
|
||||
# <original>.c_fc -> <diffusers>.net.0.proj
|
||||
f"{diffusers_ff_prefix}.net.{0}.proj.weight": checkpoint[f"{original_ff_prefix}.c_fc.weight"],
|
||||
f"{diffusers_ff_prefix}.net.{0}.proj.bias": checkpoint[f"{original_ff_prefix}.c_fc.bias"],
|
||||
# <original>.c_proj -> <diffusers>.net.2
|
||||
f"{diffusers_ff_prefix}.net.{2}.weight": checkpoint[f"{original_ff_prefix}.c_proj.weight"],
|
||||
f"{diffusers_ff_prefix}.net.{2}.bias": checkpoint[f"{original_ff_prefix}.c_proj.bias"],
|
||||
}
|
||||
|
||||
return diffusers_checkpoint
|
||||
|
||||
|
||||
# done prior
|
||||
|
||||
|
||||
# prior_image (only slightly different from prior)
|
||||
|
||||
|
||||
PRIOR_IMAGE_ORIGINAL_PREFIX = "wrapped"
|
||||
|
||||
# Uses default arguments
|
||||
PRIOR_IMAGE_CONFIG = {
|
||||
"num_attention_heads": 8,
|
||||
"attention_head_dim": 1024 // 8,
|
||||
"num_layers": 24,
|
||||
"embedding_dim": 1024,
|
||||
"num_embeddings": 1024,
|
||||
"additional_embeddings": 0,
|
||||
"time_embed_act_fn": "gelu",
|
||||
"norm_in_type": "layer",
|
||||
"embedding_proj_norm_type": "layer",
|
||||
"encoder_hid_proj_type": None,
|
||||
"added_emb_type": None,
|
||||
"time_embed_dim": 1024 * 4,
|
||||
"embedding_proj_dim": 1024,
|
||||
"clip_embed_dim": 1024 * 2,
|
||||
}
|
||||
|
||||
|
||||
def prior_image_model_from_original_config():
|
||||
model = PriorTransformer(**PRIOR_IMAGE_CONFIG)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def prior_image_original_checkpoint_to_diffusers_checkpoint(model, checkpoint):
|
||||
diffusers_checkpoint = {}
|
||||
|
||||
# <original>.time_embed.c_fc -> <diffusers>.time_embedding.linear_1
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
"time_embedding.linear_1.weight": checkpoint[f"{PRIOR_IMAGE_ORIGINAL_PREFIX}.time_embed.c_fc.weight"],
|
||||
"time_embedding.linear_1.bias": checkpoint[f"{PRIOR_IMAGE_ORIGINAL_PREFIX}.time_embed.c_fc.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
# <original>.time_embed.c_proj -> <diffusers>.time_embedding.linear_2
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
"time_embedding.linear_2.weight": checkpoint[f"{PRIOR_IMAGE_ORIGINAL_PREFIX}.time_embed.c_proj.weight"],
|
||||
"time_embedding.linear_2.bias": checkpoint[f"{PRIOR_IMAGE_ORIGINAL_PREFIX}.time_embed.c_proj.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
# <original>.input_proj -> <diffusers>.proj_in
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
"proj_in.weight": checkpoint[f"{PRIOR_IMAGE_ORIGINAL_PREFIX}.input_proj.weight"],
|
||||
"proj_in.bias": checkpoint[f"{PRIOR_IMAGE_ORIGINAL_PREFIX}.input_proj.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
# <original>.clip_embed.0 -> <diffusers>.embedding_proj_norm
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
"embedding_proj_norm.weight": checkpoint[f"{PRIOR_IMAGE_ORIGINAL_PREFIX}.clip_embed.0.weight"],
|
||||
"embedding_proj_norm.bias": checkpoint[f"{PRIOR_IMAGE_ORIGINAL_PREFIX}.clip_embed.0.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
# <original>..clip_embed.1 -> <diffusers>.embedding_proj
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
"embedding_proj.weight": checkpoint[f"{PRIOR_IMAGE_ORIGINAL_PREFIX}.clip_embed.1.weight"],
|
||||
"embedding_proj.bias": checkpoint[f"{PRIOR_IMAGE_ORIGINAL_PREFIX}.clip_embed.1.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
# <original>.pos_emb -> <diffusers>.positional_embedding
|
||||
diffusers_checkpoint.update(
|
||||
{"positional_embedding": checkpoint[f"{PRIOR_IMAGE_ORIGINAL_PREFIX}.pos_emb"][None, :]}
|
||||
)
|
||||
|
||||
# <original>.ln_pre -> <diffusers>.norm_in
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
"norm_in.weight": checkpoint[f"{PRIOR_IMAGE_ORIGINAL_PREFIX}.ln_pre.weight"],
|
||||
"norm_in.bias": checkpoint[f"{PRIOR_IMAGE_ORIGINAL_PREFIX}.ln_pre.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
# <original>.backbone.resblocks.<x> -> <diffusers>.transformer_blocks.<x>
|
||||
for idx in range(len(model.transformer_blocks)):
|
||||
diffusers_transformer_prefix = f"transformer_blocks.{idx}"
|
||||
original_transformer_prefix = f"{PRIOR_IMAGE_ORIGINAL_PREFIX}.backbone.resblocks.{idx}"
|
||||
|
||||
# <original>.attn -> <diffusers>.attn1
|
||||
diffusers_attention_prefix = f"{diffusers_transformer_prefix}.attn1"
|
||||
original_attention_prefix = f"{original_transformer_prefix}.attn"
|
||||
diffusers_checkpoint.update(
|
||||
prior_attention_to_diffusers(
|
||||
checkpoint,
|
||||
diffusers_attention_prefix=diffusers_attention_prefix,
|
||||
original_attention_prefix=original_attention_prefix,
|
||||
attention_head_dim=model.attention_head_dim,
|
||||
)
|
||||
)
|
||||
|
||||
# <original>.mlp -> <diffusers>.ff
|
||||
diffusers_ff_prefix = f"{diffusers_transformer_prefix}.ff"
|
||||
original_ff_prefix = f"{original_transformer_prefix}.mlp"
|
||||
diffusers_checkpoint.update(
|
||||
prior_ff_to_diffusers(
|
||||
checkpoint, diffusers_ff_prefix=diffusers_ff_prefix, original_ff_prefix=original_ff_prefix
|
||||
)
|
||||
)
|
||||
|
||||
# <original>.ln_1 -> <diffusers>.norm1
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
f"{diffusers_transformer_prefix}.norm1.weight": checkpoint[
|
||||
f"{original_transformer_prefix}.ln_1.weight"
|
||||
],
|
||||
f"{diffusers_transformer_prefix}.norm1.bias": checkpoint[f"{original_transformer_prefix}.ln_1.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
# <original>.ln_2 -> <diffusers>.norm3
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
f"{diffusers_transformer_prefix}.norm3.weight": checkpoint[
|
||||
f"{original_transformer_prefix}.ln_2.weight"
|
||||
],
|
||||
f"{diffusers_transformer_prefix}.norm3.bias": checkpoint[f"{original_transformer_prefix}.ln_2.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
# <original>.ln_post -> <diffusers>.norm_out
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
"norm_out.weight": checkpoint[f"{PRIOR_IMAGE_ORIGINAL_PREFIX}.ln_post.weight"],
|
||||
"norm_out.bias": checkpoint[f"{PRIOR_IMAGE_ORIGINAL_PREFIX}.ln_post.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
# <original>.output_proj -> <diffusers>.proj_to_clip_embeddings
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
"proj_to_clip_embeddings.weight": checkpoint[f"{PRIOR_IMAGE_ORIGINAL_PREFIX}.output_proj.weight"],
|
||||
"proj_to_clip_embeddings.bias": checkpoint[f"{PRIOR_IMAGE_ORIGINAL_PREFIX}.output_proj.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
return diffusers_checkpoint
|
||||
|
||||
|
||||
# done prior_image
|
||||
|
||||
|
||||
# renderer
|
||||
|
||||
RENDERER_CONFIG = {}
|
||||
|
||||
|
||||
def renderer_model_from_original_config():
|
||||
model = ShapERenderer(**RENDERER_CONFIG)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
RENDERER_MLP_ORIGINAL_PREFIX = "renderer.nerstf"
|
||||
|
||||
RENDERER_PARAMS_PROJ_ORIGINAL_PREFIX = "encoder.params_proj"
|
||||
|
||||
|
||||
def renderer_model_original_checkpoint_to_diffusers_checkpoint(model, checkpoint):
|
||||
diffusers_checkpoint = {}
|
||||
diffusers_checkpoint.update(
|
||||
{f"mlp.{k}": checkpoint[f"{RENDERER_MLP_ORIGINAL_PREFIX}.{k}"] for k in model.mlp.state_dict().keys()}
|
||||
)
|
||||
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
f"params_proj.{k}": checkpoint[f"{RENDERER_PARAMS_PROJ_ORIGINAL_PREFIX}.{k}"]
|
||||
for k in model.params_proj.state_dict().keys()
|
||||
}
|
||||
)
|
||||
|
||||
diffusers_checkpoint.update({"void.background": torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32)})
|
||||
|
||||
return diffusers_checkpoint
|
||||
|
||||
|
||||
# done renderer
|
||||
|
||||
|
||||
# TODO maybe document and/or can do more efficiently (build indices in for loop and extract once for each split?)
|
||||
def split_attentions(*, weight, bias, split, chunk_size):
|
||||
weights = [None] * split
|
||||
biases = [None] * split
|
||||
|
||||
weights_biases_idx = 0
|
||||
|
||||
for starting_row_index in range(0, weight.shape[0], chunk_size):
|
||||
row_indices = torch.arange(starting_row_index, starting_row_index + chunk_size)
|
||||
|
||||
weight_rows = weight[row_indices, :]
|
||||
bias_rows = bias[row_indices]
|
||||
|
||||
if weights[weights_biases_idx] is None:
|
||||
assert weights[weights_biases_idx] is None
|
||||
weights[weights_biases_idx] = weight_rows
|
||||
biases[weights_biases_idx] = bias_rows
|
||||
else:
|
||||
assert weights[weights_biases_idx] is not None
|
||||
weights[weights_biases_idx] = torch.concat([weights[weights_biases_idx], weight_rows])
|
||||
biases[weights_biases_idx] = torch.concat([biases[weights_biases_idx], bias_rows])
|
||||
|
||||
weights_biases_idx = (weights_biases_idx + 1) % split
|
||||
|
||||
return weights, biases
|
||||
|
||||
|
||||
# done unet utils
|
||||
|
||||
|
||||
# Driver functions
|
||||
|
||||
|
||||
def prior(*, args, checkpoint_map_location):
|
||||
print("loading prior")
|
||||
|
||||
prior_checkpoint = torch.load(args.prior_checkpoint_path, map_location=checkpoint_map_location)
|
||||
|
||||
prior_model = prior_model_from_original_config()
|
||||
|
||||
prior_diffusers_checkpoint = prior_original_checkpoint_to_diffusers_checkpoint(prior_model, prior_checkpoint)
|
||||
|
||||
del prior_checkpoint
|
||||
|
||||
load_prior_checkpoint_to_model(prior_diffusers_checkpoint, prior_model)
|
||||
|
||||
print("done loading prior")
|
||||
|
||||
return prior_model
|
||||
|
||||
|
||||
def prior_image(*, args, checkpoint_map_location):
|
||||
print("loading prior_image")
|
||||
|
||||
print(f"load checkpoint from {args.prior_image_checkpoint_path}")
|
||||
prior_checkpoint = torch.load(args.prior_image_checkpoint_path, map_location=checkpoint_map_location)
|
||||
|
||||
prior_model = prior_image_model_from_original_config()
|
||||
|
||||
prior_diffusers_checkpoint = prior_image_original_checkpoint_to_diffusers_checkpoint(prior_model, prior_checkpoint)
|
||||
|
||||
del prior_checkpoint
|
||||
|
||||
load_prior_checkpoint_to_model(prior_diffusers_checkpoint, prior_model)
|
||||
|
||||
print("done loading prior_image")
|
||||
|
||||
return prior_model
|
||||
|
||||
|
||||
def renderer(*, args, checkpoint_map_location):
|
||||
print(" loading renderer")
|
||||
|
||||
renderer_checkpoint = torch.load(args.transmitter_checkpoint_path, map_location=checkpoint_map_location)
|
||||
|
||||
renderer_model = renderer_model_from_original_config()
|
||||
|
||||
renderer_diffusers_checkpoint = renderer_model_original_checkpoint_to_diffusers_checkpoint(
|
||||
renderer_model, renderer_checkpoint
|
||||
)
|
||||
|
||||
del renderer_checkpoint
|
||||
|
||||
load_checkpoint_to_model(renderer_diffusers_checkpoint, renderer_model, strict=True)
|
||||
|
||||
print("done loading renderer")
|
||||
|
||||
return renderer_model
|
||||
|
||||
|
||||
# prior model will expect clip_mean and clip_std, whic are missing from the state_dict
|
||||
PRIOR_EXPECTED_MISSING_KEYS = ["clip_mean", "clip_std"]
|
||||
|
||||
|
||||
def load_prior_checkpoint_to_model(checkpoint, model):
|
||||
with tempfile.NamedTemporaryFile() as file:
|
||||
torch.save(checkpoint, file.name)
|
||||
del checkpoint
|
||||
missing_keys, unexpected_keys = model.load_state_dict(torch.load(file.name), strict=False)
|
||||
missing_keys = list(set(missing_keys) - set(PRIOR_EXPECTED_MISSING_KEYS))
|
||||
|
||||
if len(unexpected_keys) > 0:
|
||||
raise ValueError(f"Unexpected keys when loading prior model: {unexpected_keys}")
|
||||
if len(missing_keys) > 0:
|
||||
raise ValueError(f"Missing keys when loading prior model: {missing_keys}")
|
||||
|
||||
|
||||
def load_checkpoint_to_model(checkpoint, model, strict=False):
|
||||
with tempfile.NamedTemporaryFile() as file:
|
||||
torch.save(checkpoint, file.name)
|
||||
del checkpoint
|
||||
if strict:
|
||||
model.load_state_dict(torch.load(file.name), strict=True)
|
||||
else:
|
||||
load_checkpoint_and_dispatch(model, file.name, device_map="auto")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
|
||||
|
||||
parser.add_argument(
|
||||
"--prior_checkpoint_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=False,
|
||||
help="Path to the prior checkpoint to convert.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--prior_image_checkpoint_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=False,
|
||||
help="Path to the prior_image checkpoint to convert.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--transmitter_checkpoint_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=False,
|
||||
help="Path to the transmitter checkpoint to convert.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint_load_device",
|
||||
default="cpu",
|
||||
type=str,
|
||||
required=False,
|
||||
help="The device passed to `map_location` when loading checkpoints.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--debug",
|
||||
default=None,
|
||||
type=str,
|
||||
required=False,
|
||||
help="Only run a specific stage of the convert script. Used for debugging",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
print(f"loading checkpoints to {args.checkpoint_load_device}")
|
||||
|
||||
checkpoint_map_location = torch.device(args.checkpoint_load_device)
|
||||
|
||||
if args.debug is not None:
|
||||
print(f"debug: only executing {args.debug}")
|
||||
|
||||
if args.debug is None:
|
||||
print("YiYi TO-DO")
|
||||
elif args.debug == "prior":
|
||||
prior_model = prior(args=args, checkpoint_map_location=checkpoint_map_location)
|
||||
prior_model.save_pretrained(args.dump_path)
|
||||
elif args.debug == "prior_image":
|
||||
prior_model = prior_image(args=args, checkpoint_map_location=checkpoint_map_location)
|
||||
prior_model.save_pretrained(args.dump_path)
|
||||
elif args.debug == "renderer":
|
||||
renderer_model = renderer(args=args, checkpoint_map_location=checkpoint_map_location)
|
||||
renderer_model.save_pretrained(args.dump_path)
|
||||
else:
|
||||
raise ValueError(f"unknown debug value : {args.debug}")
|
||||
@@ -89,6 +89,7 @@ _deps = [
|
||||
"huggingface-hub>=0.13.2",
|
||||
"requests-mock==1.10.0",
|
||||
"importlib_metadata",
|
||||
"invisible-watermark",
|
||||
"isort>=5.5.4",
|
||||
"jax>=0.2.8,!=0.3.2",
|
||||
"jaxlib>=0.1.65",
|
||||
@@ -193,6 +194,7 @@ extras["test"] = deps_list(
|
||||
"compel",
|
||||
"datasets",
|
||||
"Jinja2",
|
||||
"invisible-watermark",
|
||||
"k-diffusion",
|
||||
"librosa",
|
||||
"omegaconf",
|
||||
@@ -230,7 +232,7 @@ install_requires = [
|
||||
|
||||
setup(
|
||||
name="diffusers",
|
||||
version="0.18.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
version="0.18.2", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
description="Diffusers",
|
||||
long_description=open("README.md", "r", encoding="utf-8").read(),
|
||||
long_description_content_type="text/markdown",
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
__version__ = "0.18.0.dev0"
|
||||
__version__ = "0.18.2"
|
||||
|
||||
from .configuration_utils import ConfigMixin
|
||||
from .utils import (
|
||||
OptionalDependencyNotAvailable,
|
||||
is_flax_available,
|
||||
is_inflect_available,
|
||||
is_invisible_watermark_available,
|
||||
is_k_diffusion_available,
|
||||
is_k_diffusion_version,
|
||||
is_librosa_available,
|
||||
@@ -58,6 +59,7 @@ else:
|
||||
)
|
||||
from .pipelines import (
|
||||
AudioPipelineOutput,
|
||||
ConsistencyModelPipeline,
|
||||
DanceDiffusionPipeline,
|
||||
DDIMPipeline,
|
||||
DDPMPipeline,
|
||||
@@ -72,6 +74,7 @@ else:
|
||||
ScoreSdeVePipeline,
|
||||
)
|
||||
from .schedulers import (
|
||||
CMStochasticIterativeScheduler,
|
||||
DDIMInverseScheduler,
|
||||
DDIMParallelScheduler,
|
||||
DDIMScheduler,
|
||||
@@ -136,9 +139,18 @@ else:
|
||||
KandinskyInpaintPipeline,
|
||||
KandinskyPipeline,
|
||||
KandinskyPriorPipeline,
|
||||
KandinskyV22ControlnetImg2ImgPipeline,
|
||||
KandinskyV22ControlnetPipeline,
|
||||
KandinskyV22Img2ImgPipeline,
|
||||
KandinskyV22InpaintPipeline,
|
||||
KandinskyV22Pipeline,
|
||||
KandinskyV22PriorEmb2EmbPipeline,
|
||||
KandinskyV22PriorPipeline,
|
||||
LDMTextToImagePipeline,
|
||||
PaintByExamplePipeline,
|
||||
SemanticStableDiffusionPipeline,
|
||||
ShapEImg2ImgPipeline,
|
||||
ShapEPipeline,
|
||||
StableDiffusionAttendAndExcitePipeline,
|
||||
StableDiffusionControlNetImg2ImgPipeline,
|
||||
StableDiffusionControlNetInpaintPipeline,
|
||||
@@ -177,6 +189,14 @@ else:
|
||||
VQDiffusionPipeline,
|
||||
)
|
||||
|
||||
try:
|
||||
if not (is_torch_available() and is_transformers_available() and is_invisible_watermark_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils.dummy_torch_and_transformers_and_invisible_watermark_objects import * # noqa F403
|
||||
else:
|
||||
from .pipelines import StableDiffusionXLImg2ImgPipeline, StableDiffusionXLPipeline
|
||||
|
||||
try:
|
||||
if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
|
||||
@@ -423,6 +423,10 @@ class ConfigMixin:
|
||||
|
||||
@classmethod
|
||||
def extract_init_dict(cls, config_dict, **kwargs):
|
||||
# Skip keys that were not present in the original config, so default __init__ values were used
|
||||
used_defaults = config_dict.get("_use_default_values", [])
|
||||
config_dict = {k: v for k, v in config_dict.items() if k not in used_defaults and k != "_use_default_values"}
|
||||
|
||||
# 0. Copy origin config dict
|
||||
original_dict = dict(config_dict.items())
|
||||
|
||||
@@ -544,8 +548,9 @@ class ConfigMixin:
|
||||
return value
|
||||
|
||||
config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()}
|
||||
# Don't save "_ignore_files"
|
||||
# Don't save "_ignore_files" or "_use_default_values"
|
||||
config_dict.pop("_ignore_files", None)
|
||||
config_dict.pop("_use_default_values", None)
|
||||
|
||||
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
|
||||
|
||||
@@ -599,6 +604,11 @@ def register_to_config(init):
|
||||
if k not in ignore and k not in new_kwargs
|
||||
}
|
||||
)
|
||||
|
||||
# Take note of the parameters that were not present in the loaded config
|
||||
if len(set(new_kwargs.keys()) - set(init_kwargs)) > 0:
|
||||
new_kwargs["_use_default_values"] = list(set(new_kwargs.keys()) - set(init_kwargs))
|
||||
|
||||
new_kwargs = {**config_init_kwargs, **new_kwargs}
|
||||
getattr(self, "register_to_config")(**new_kwargs)
|
||||
init(self, *args, **init_kwargs)
|
||||
@@ -643,6 +653,10 @@ def flax_register_to_config(cls):
|
||||
name = fields[i].name
|
||||
new_kwargs[name] = arg
|
||||
|
||||
# Take note of the parameters that were not present in the loaded config
|
||||
if len(set(new_kwargs.keys()) - set(init_kwargs)) > 0:
|
||||
new_kwargs["_use_default_values"] = list(set(new_kwargs.keys()) - set(init_kwargs))
|
||||
|
||||
getattr(self, "register_to_config")(**new_kwargs)
|
||||
original_init(self, *args, **kwargs)
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ deps = {
|
||||
"huggingface-hub": "huggingface-hub>=0.13.2",
|
||||
"requests-mock": "requests-mock==1.10.0",
|
||||
"importlib_metadata": "importlib_metadata",
|
||||
"invisible-watermark": "invisible-watermark",
|
||||
"isort": "isort>=5.5.4",
|
||||
"jax": "jax>=0.2.8,!=0.3.2",
|
||||
"jaxlib": "jaxlib>=0.1.65",
|
||||
|
||||
@@ -312,12 +312,17 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
|
||||
"""
|
||||
if images.ndim == 3:
|
||||
images = images[None, ...]
|
||||
images = (images * 255).round().astype("uint8")
|
||||
if images.shape[-1] == 1:
|
||||
# special case for grayscale (single channel) images
|
||||
raise Exception("Not supported")
|
||||
images_depth = images[:, :, :, 3:]
|
||||
if images.shape[-1] == 6:
|
||||
images_depth = (images_depth * 255).round().astype("uint8")
|
||||
pil_images = [
|
||||
Image.fromarray(self.rgblike_to_depthmap(image_depth), mode="I;16") for image_depth in images_depth
|
||||
]
|
||||
elif images.shape[-1] == 4:
|
||||
images_depth = (images_depth * 65535.0).astype(np.uint16)
|
||||
pil_images = [Image.fromarray(image_depth, mode="I;16") for image_depth in images_depth]
|
||||
else:
|
||||
pil_images = [Image.fromarray(self.rgblike_to_depthmap(image[:, :, 3:]), mode="I;16") for image in images]
|
||||
raise Exception("Not supported")
|
||||
|
||||
return pil_images
|
||||
|
||||
@@ -349,7 +354,11 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
|
||||
image = self.pt_to_numpy(image)
|
||||
|
||||
if output_type == "np":
|
||||
return image[:, :, :, :3], np.stack([self.rgblike_to_depthmap(im[:, :, 3:]) for im in image], axis=0)
|
||||
if image.shape[-1] == 6:
|
||||
image_depth = np.stack([self.rgblike_to_depthmap(im[:, :, 3:]) for im in image], axis=0)
|
||||
else:
|
||||
image_depth = image[:, :, :, 3:]
|
||||
return image[:, :, :, :3], image_depth
|
||||
|
||||
if output_type == "pil":
|
||||
return self.numpy_to_pil(image), self.numpy_to_depth(image)
|
||||
|
||||
+42
-62
@@ -12,8 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import io
|
||||
import requests
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
@@ -179,7 +177,7 @@ class UNet2DConditionLoadersMixin:
|
||||
|
||||
if use_safetensors and not is_safetensors_available():
|
||||
raise ValueError(
|
||||
"`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetenstors"
|
||||
"`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors"
|
||||
)
|
||||
|
||||
allow_pickle = False
|
||||
@@ -591,7 +589,7 @@ class TextualInversionLoaderMixin:
|
||||
|
||||
if use_safetensors and not is_safetensors_available():
|
||||
raise ValueError(
|
||||
"`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetenstors"
|
||||
"`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors"
|
||||
)
|
||||
|
||||
allow_pickle = False
|
||||
@@ -808,7 +806,7 @@ class LoraLoaderMixin:
|
||||
|
||||
if use_safetensors and not is_safetensors_available():
|
||||
raise ValueError(
|
||||
"`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetenstors"
|
||||
"`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors"
|
||||
)
|
||||
|
||||
allow_pickle = False
|
||||
@@ -1056,7 +1054,7 @@ class LoraLoaderMixin:
|
||||
|
||||
if use_safetensors and not is_safetensors_available():
|
||||
raise ValueError(
|
||||
"`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetenstors"
|
||||
"`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors"
|
||||
)
|
||||
|
||||
allow_pickle = False
|
||||
@@ -1278,13 +1276,19 @@ class LoraLoaderMixin:
|
||||
return new_state_dict, network_alpha
|
||||
|
||||
|
||||
class FromCkptMixin:
|
||||
class FromSingleFileMixin:
|
||||
"""
|
||||
Load model weights saved in the `.ckpt` format into a [`DiffusionPipeline`].
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def from_ckpt(cls, pretrained_model_link_or_path, **kwargs):
|
||||
def from_ckpt(cls, *args, **kwargs):
|
||||
deprecation_message = "The function `from_ckpt` is deprecated in favor of `from_single_file` and will be removed in diffusers v.0.21. Please make sure to use `StableDiffusionPipeline.from_single_file(...)` instead."
|
||||
deprecate("from_ckpt", "0.21.0", deprecation_message, standard_warn=False)
|
||||
return cls.from_single_file(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
|
||||
r"""
|
||||
Instantiate a [`DiffusionPipeline`] from pretrained pipeline weights saved in the `.ckpt` format. The pipeline
|
||||
is set in evaluation mode (`model.eval()`) by default.
|
||||
@@ -1363,16 +1367,16 @@ class FromCkptMixin:
|
||||
>>> from diffusers import StableDiffusionPipeline
|
||||
|
||||
>>> # Download pipeline from huggingface.co and cache.
|
||||
>>> pipeline = StableDiffusionPipeline.from_ckpt(
|
||||
>>> pipeline = StableDiffusionPipeline.from_single_file(
|
||||
... "https://huggingface.co/WarriorMama777/OrangeMixs/blob/main/Models/AbyssOrangeMix/AbyssOrangeMix.safetensors"
|
||||
... )
|
||||
|
||||
>>> # Download pipeline from local file
|
||||
>>> # file is downloaded under ./v1-5-pruned-emaonly.ckpt
|
||||
>>> pipeline = StableDiffusionPipeline.from_ckpt("./v1-5-pruned-emaonly")
|
||||
>>> pipeline = StableDiffusionPipeline.from_single_file("./v1-5-pruned-emaonly")
|
||||
|
||||
>>> # Enable float16 and move to GPU
|
||||
>>> pipeline = StableDiffusionPipeline.from_ckpt(
|
||||
>>> pipeline = StableDiffusionPipeline.from_single_file(
|
||||
... "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt",
|
||||
... torch_dtype=torch.float16,
|
||||
... )
|
||||
@@ -1390,7 +1394,7 @@ class FromCkptMixin:
|
||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
extract_ema = kwargs.pop("extract_ema", False)
|
||||
image_size = kwargs.pop("image_size", 512)
|
||||
image_size = kwargs.pop("image_size", None)
|
||||
scheduler_type = kwargs.pop("scheduler_type", "pndm")
|
||||
num_in_channels = kwargs.pop("num_in_channels", None)
|
||||
upcast_attention = kwargs.pop("upcast_attention", None)
|
||||
@@ -1434,62 +1438,38 @@ class FromCkptMixin:
|
||||
else:
|
||||
raise ValueError(f"Unhandled pipeline class: {pipeline_name}")
|
||||
|
||||
# remove huggingface url
|
||||
for prefix in ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]:
|
||||
if pretrained_model_link_or_path.startswith(prefix):
|
||||
pretrained_model_link_or_path = pretrained_model_link_or_path[len(prefix) :]
|
||||
|
||||
# Code based on diffusers.pipelines.pipeline_utils.DiffusionPipeline.from_pretrained
|
||||
if Path(pretrained_model_link_or_path).is_file():
|
||||
pretrained_model_path_or_dict = pretrained_model_link_or_path
|
||||
elif not Path(pretrained_model_link_or_path).is_file():
|
||||
is_hf = False
|
||||
is_civit_ai = False
|
||||
for prefix in ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]:
|
||||
if pretrained_model_link_or_path.startswith(prefix):
|
||||
pretrained_model_link_or_path = pretrained_model_link_or_path[len(prefix) :]
|
||||
is_hf = True
|
||||
ckpt_path = Path(pretrained_model_link_or_path)
|
||||
if not ckpt_path.is_file():
|
||||
# get repo_id and (potentially nested) file path of ckpt in repo
|
||||
repo_id = "/".join(ckpt_path.parts[:2])
|
||||
file_path = "/".join(ckpt_path.parts[2:])
|
||||
|
||||
for prefix in ["https://civitai.com/", "civitai.com"]:
|
||||
if pretrained_model_link_or_path.startswith(prefix):
|
||||
if "api" not in pretrained_model_link_or_path:
|
||||
raise ValueError(f"{pretrained_model_link_or_path} is not a valid Civitai link. Make sure to provide a link in the form: https://civitai.com/api/models/<num>")
|
||||
is_civit_ai = True
|
||||
if file_path.startswith("blob/"):
|
||||
file_path = file_path[len("blob/") :]
|
||||
|
||||
if is_hf:
|
||||
# get repo_id and (potentially nested) file path of ckpt in repo
|
||||
repo_id = "/".join(pretrained_model_link_or_path.parts[:2])
|
||||
file_path = "/".join(pretrained_model_link_or_path.parts[2:])
|
||||
if file_path.startswith("main/"):
|
||||
file_path = file_path[len("main/") :]
|
||||
|
||||
if file_path.startswith("blob/"):
|
||||
file_path = file_path[len("blob/") :]
|
||||
|
||||
if file_path.startswith("main/"):
|
||||
file_path = file_path[len("main/") :]
|
||||
|
||||
pretrained_model_path_or_dict = hf_hub_download(
|
||||
repo_id,
|
||||
filename=file_path,
|
||||
cache_dir=cache_dir,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
force_download=force_download,
|
||||
)
|
||||
pretrained_model_path_or_dict = pretrained_model_link_or_path
|
||||
elif is_civit_ai:
|
||||
response = requests.get(pretrained_model_link_or_path)
|
||||
checkpoint_bytes = response.content
|
||||
|
||||
# Create an in-memory byte stream using io.BytesIO()
|
||||
buffer = io.BytesIO(checkpoint_bytes)
|
||||
|
||||
try:
|
||||
pretrained_model_path_or_dict = safetensors.torch.load(buffer)
|
||||
except IOError as e:
|
||||
pass
|
||||
|
||||
pretrained_model_path_or_dict = torch.load(buffer)
|
||||
pretrained_model_link_or_path = hf_hub_download(
|
||||
repo_id,
|
||||
filename=file_path,
|
||||
cache_dir=cache_dir,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
force_download=force_download,
|
||||
)
|
||||
|
||||
pipe = download_from_original_stable_diffusion_ckpt(
|
||||
pretrained_model_path_or_dict,
|
||||
pretrained_model_link_or_path,
|
||||
pipeline_class=cls,
|
||||
model_type=model_type,
|
||||
stable_unclip=stable_unclip,
|
||||
|
||||
@@ -119,6 +119,15 @@ class BasicTransformerBlock(nn.Module):
|
||||
self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
||||
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
|
||||
|
||||
# let chunk size default to None
|
||||
self._chunk_size = None
|
||||
self._chunk_dim = 0
|
||||
|
||||
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
|
||||
# Sets chunk feed-forward
|
||||
self._chunk_size = chunk_size
|
||||
self._chunk_dim = dim
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
@@ -141,6 +150,7 @@ class BasicTransformerBlock(nn.Module):
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
|
||||
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
|
||||
|
||||
attn_output = self.attn1(
|
||||
norm_hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
|
||||
@@ -171,7 +181,20 @@ class BasicTransformerBlock(nn.Module):
|
||||
if self.use_ada_layer_norm_zero:
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
||||
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
if self._chunk_size is not None:
|
||||
# "feed_forward_chunk_size" can be used to save memory
|
||||
if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
|
||||
raise ValueError(
|
||||
f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
|
||||
)
|
||||
|
||||
num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
|
||||
ff_output = torch.cat(
|
||||
[self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
|
||||
dim=self._chunk_dim,
|
||||
)
|
||||
else:
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
|
||||
if self.use_ada_layer_norm_zero:
|
||||
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
||||
|
||||
@@ -152,6 +152,7 @@ class FlaxAttention(nn.Module):
|
||||
self.value = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_v")
|
||||
|
||||
self.proj_attn = nn.Dense(self.query_dim, dtype=self.dtype, name="to_out_0")
|
||||
self.dropout_layer = nn.Dropout(rate=self.dropout)
|
||||
|
||||
def reshape_heads_to_batch_dim(self, tensor):
|
||||
batch_size, seq_len, dim = tensor.shape
|
||||
@@ -214,7 +215,7 @@ class FlaxAttention(nn.Module):
|
||||
|
||||
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
||||
hidden_states = self.proj_attn(hidden_states)
|
||||
return hidden_states
|
||||
return self.dropout_layer(hidden_states, deterministic=deterministic)
|
||||
|
||||
|
||||
class FlaxBasicTransformerBlock(nn.Module):
|
||||
@@ -260,6 +261,7 @@ class FlaxBasicTransformerBlock(nn.Module):
|
||||
self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
|
||||
self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
|
||||
self.norm3 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
|
||||
self.dropout_layer = nn.Dropout(rate=self.dropout)
|
||||
|
||||
def __call__(self, hidden_states, context, deterministic=True):
|
||||
# self attention
|
||||
@@ -280,7 +282,7 @@ class FlaxBasicTransformerBlock(nn.Module):
|
||||
hidden_states = self.ff(self.norm3(hidden_states), deterministic=deterministic)
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
return hidden_states
|
||||
return self.dropout_layer(hidden_states, deterministic=deterministic)
|
||||
|
||||
|
||||
class FlaxTransformer2DModel(nn.Module):
|
||||
@@ -356,6 +358,8 @@ class FlaxTransformer2DModel(nn.Module):
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
self.dropout_layer = nn.Dropout(rate=self.dropout)
|
||||
|
||||
def __call__(self, hidden_states, context, deterministic=True):
|
||||
batch, height, width, channels = hidden_states.shape
|
||||
residual = hidden_states
|
||||
@@ -378,7 +382,7 @@ class FlaxTransformer2DModel(nn.Module):
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
|
||||
hidden_states = hidden_states + residual
|
||||
return hidden_states
|
||||
return self.dropout_layer(hidden_states, deterministic=deterministic)
|
||||
|
||||
|
||||
class FlaxFeedForward(nn.Module):
|
||||
@@ -409,7 +413,7 @@ class FlaxFeedForward(nn.Module):
|
||||
self.net_2 = nn.Dense(self.dim, dtype=self.dtype)
|
||||
|
||||
def __call__(self, hidden_states, deterministic=True):
|
||||
hidden_states = self.net_0(hidden_states)
|
||||
hidden_states = self.net_0(hidden_states, deterministic=deterministic)
|
||||
hidden_states = self.net_2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
@@ -434,8 +438,9 @@ class FlaxGEGLU(nn.Module):
|
||||
def setup(self):
|
||||
inner_dim = self.dim * 4
|
||||
self.proj = nn.Dense(inner_dim * 2, dtype=self.dtype)
|
||||
self.dropout_layer = nn.Dropout(rate=self.dropout)
|
||||
|
||||
def __call__(self, hidden_states, deterministic=True):
|
||||
hidden_states = self.proj(hidden_states)
|
||||
hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=2)
|
||||
return hidden_linear * nn.gelu(hidden_gelu)
|
||||
return self.dropout_layer(hidden_linear * nn.gelu(hidden_gelu), deterministic=deterministic)
|
||||
|
||||
@@ -1118,7 +1118,9 @@ class AttnProcessor2_0:
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
|
||||
@@ -376,6 +376,29 @@ class TextImageProjection(nn.Module):
|
||||
return torch.cat([image_text_embeds, text_embeds], dim=1)
|
||||
|
||||
|
||||
class ImageProjection(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
image_embed_dim: int = 768,
|
||||
cross_attention_dim: int = 768,
|
||||
num_image_text_embeds: int = 32,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.num_image_text_embeds = num_image_text_embeds
|
||||
self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim)
|
||||
self.norm = nn.LayerNorm(cross_attention_dim)
|
||||
|
||||
def forward(self, image_embeds: torch.FloatTensor):
|
||||
batch_size = image_embeds.shape[0]
|
||||
|
||||
# image
|
||||
image_embeds = self.image_embeds(image_embeds)
|
||||
image_embeds = image_embeds.reshape(batch_size, self.num_image_text_embeds, -1)
|
||||
image_embeds = self.norm(image_embeds)
|
||||
return image_embeds
|
||||
|
||||
|
||||
class CombinedTimestepLabelEmbeddings(nn.Module):
|
||||
def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1):
|
||||
super().__init__()
|
||||
@@ -429,6 +452,50 @@ class TextImageTimeEmbedding(nn.Module):
|
||||
return time_image_embeds + time_text_embeds
|
||||
|
||||
|
||||
class ImageTimeEmbedding(nn.Module):
|
||||
def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536):
|
||||
super().__init__()
|
||||
self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
|
||||
self.image_norm = nn.LayerNorm(time_embed_dim)
|
||||
|
||||
def forward(self, image_embeds: torch.FloatTensor):
|
||||
# image
|
||||
time_image_embeds = self.image_proj(image_embeds)
|
||||
time_image_embeds = self.image_norm(time_image_embeds)
|
||||
return time_image_embeds
|
||||
|
||||
|
||||
class ImageHintTimeEmbedding(nn.Module):
|
||||
def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536):
|
||||
super().__init__()
|
||||
self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
|
||||
self.image_norm = nn.LayerNorm(time_embed_dim)
|
||||
self.input_hint_block = nn.Sequential(
|
||||
nn.Conv2d(3, 16, 3, padding=1),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(16, 16, 3, padding=1),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(16, 32, 3, padding=1, stride=2),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(32, 32, 3, padding=1),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(32, 96, 3, padding=1, stride=2),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(96, 96, 3, padding=1),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(96, 256, 3, padding=1, stride=2),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(256, 4, 3, padding=1),
|
||||
)
|
||||
|
||||
def forward(self, image_embeds: torch.FloatTensor, hint: torch.FloatTensor):
|
||||
# image
|
||||
time_image_embeds = self.image_proj(image_embeds)
|
||||
time_image_embeds = self.image_norm(time_image_embeds)
|
||||
hint = self.input_hint_block(hint)
|
||||
return time_image_embeds, hint
|
||||
|
||||
|
||||
class AttentionPooling(nn.Module):
|
||||
# Copied from https://github.com/deep-floyd/IF/blob/2f91391f27dd3c468bf174be5805b4cc92980c0b/deepfloyd_if/model/nn.py#L54
|
||||
|
||||
|
||||
@@ -456,7 +456,7 @@ class ModelMixin(torch.nn.Module):
|
||||
|
||||
if use_safetensors and not is_safetensors_available():
|
||||
raise ValueError(
|
||||
"`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetenstors"
|
||||
"`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors"
|
||||
)
|
||||
|
||||
allow_pickle = False
|
||||
|
||||
@@ -34,14 +34,33 @@ class PriorTransformer(ModelMixin, ConfigMixin):
|
||||
num_attention_heads (`int`, *optional*, defaults to 32): The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
|
||||
num_layers (`int`, *optional*, defaults to 20): The number of layers of Transformer blocks to use.
|
||||
embedding_dim (`int`, *optional*, defaults to 768):
|
||||
The dimension of the CLIP embeddings. Image embeddings and text embeddings are both the same dimension.
|
||||
num_embeddings (`int`, *optional*, defaults to 77): The max number of CLIP embeddings allowed (the
|
||||
length of the prompt after it has been tokenized).
|
||||
embedding_dim (`int`, *optional*, defaults to 768): The dimension of the model input `hidden_states`
|
||||
num_embeddings (`int`, *optional*, defaults to 77):
|
||||
The number of embeddings of the model input `hidden_states`
|
||||
additional_embeddings (`int`, *optional*, defaults to 4): The number of additional tokens appended to the
|
||||
projected `hidden_states`. The actual length of the used `hidden_states` is `num_embeddings +
|
||||
additional_embeddings`.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
time_embed_act_fn (`str`, *optional*, defaults to 'silu'):
|
||||
The activation function to use to create timestep embeddings.
|
||||
norm_in_type (`str`, *optional*, defaults to None): The normalization layer to apply on hidden states before
|
||||
passing to Transformer blocks. Set it to `None` if normalization is not needed.
|
||||
embedding_proj_norm_type (`str`, *optional*, defaults to None):
|
||||
The normalization layer to apply on the input `proj_embedding`. Set it to `None` if normalization is not
|
||||
needed.
|
||||
encoder_hid_proj_type (`str`, *optional*, defaults to `linear`):
|
||||
The projection layer to apply on the input `encoder_hidden_states`. Set it to `None` if
|
||||
`encoder_hidden_states` is `None`.
|
||||
added_emb_type (`str`, *optional*, defaults to `prd`): Additional embeddings to condition the model.
|
||||
Choose from `prd` or `None`. if choose `prd`, it will prepend a token indicating the (quantized) dot
|
||||
product between the text embedding and image embedding as proposed in the unclip paper
|
||||
https://arxiv.org/abs/2204.06125 If it is `None`, no additional embeddings will be prepended.
|
||||
time_embed_dim (`int, *optional*, defaults to None): The dimension of timestep embeddings.
|
||||
If None, will be set to `num_attention_heads * attention_head_dim`
|
||||
embedding_proj_dim (`int`, *optional*, default to None):
|
||||
The dimension of `proj_embedding`. If None, will be set to `embedding_dim`.
|
||||
clip_embed_dim (`int`, *optional*, default to None):
|
||||
The dimension of the output. If None, will be set to `embedding_dim`.
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
@@ -54,6 +73,14 @@ class PriorTransformer(ModelMixin, ConfigMixin):
|
||||
num_embeddings=77,
|
||||
additional_embeddings=4,
|
||||
dropout: float = 0.0,
|
||||
time_embed_act_fn: str = "silu",
|
||||
norm_in_type: Optional[str] = None, # layer
|
||||
embedding_proj_norm_type: Optional[str] = None, # layer
|
||||
encoder_hid_proj_type: Optional[str] = "linear", # linear
|
||||
added_emb_type: Optional[str] = "prd", # prd
|
||||
time_embed_dim: Optional[int] = None,
|
||||
embedding_proj_dim: Optional[int] = None,
|
||||
clip_embed_dim: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_attention_heads = num_attention_heads
|
||||
@@ -61,17 +88,41 @@ class PriorTransformer(ModelMixin, ConfigMixin):
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
self.additional_embeddings = additional_embeddings
|
||||
|
||||
time_embed_dim = time_embed_dim or inner_dim
|
||||
embedding_proj_dim = embedding_proj_dim or embedding_dim
|
||||
clip_embed_dim = clip_embed_dim or embedding_dim
|
||||
|
||||
self.time_proj = Timesteps(inner_dim, True, 0)
|
||||
self.time_embedding = TimestepEmbedding(inner_dim, inner_dim)
|
||||
self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, out_dim=inner_dim, act_fn=time_embed_act_fn)
|
||||
|
||||
self.proj_in = nn.Linear(embedding_dim, inner_dim)
|
||||
|
||||
self.embedding_proj = nn.Linear(embedding_dim, inner_dim)
|
||||
self.encoder_hidden_states_proj = nn.Linear(embedding_dim, inner_dim)
|
||||
if embedding_proj_norm_type is None:
|
||||
self.embedding_proj_norm = None
|
||||
elif embedding_proj_norm_type == "layer":
|
||||
self.embedding_proj_norm = nn.LayerNorm(embedding_proj_dim)
|
||||
else:
|
||||
raise ValueError(f"unsupported embedding_proj_norm_type: {embedding_proj_norm_type}")
|
||||
|
||||
self.embedding_proj = nn.Linear(embedding_proj_dim, inner_dim)
|
||||
|
||||
if encoder_hid_proj_type is None:
|
||||
self.encoder_hidden_states_proj = None
|
||||
elif encoder_hid_proj_type == "linear":
|
||||
self.encoder_hidden_states_proj = nn.Linear(embedding_dim, inner_dim)
|
||||
else:
|
||||
raise ValueError(f"unsupported encoder_hid_proj_type: {encoder_hid_proj_type}")
|
||||
|
||||
self.positional_embedding = nn.Parameter(torch.zeros(1, num_embeddings + additional_embeddings, inner_dim))
|
||||
|
||||
self.prd_embedding = nn.Parameter(torch.zeros(1, 1, inner_dim))
|
||||
if added_emb_type == "prd":
|
||||
self.prd_embedding = nn.Parameter(torch.zeros(1, 1, inner_dim))
|
||||
elif added_emb_type is None:
|
||||
self.prd_embedding = None
|
||||
else:
|
||||
raise ValueError(
|
||||
f"`added_emb_type`: {added_emb_type} is not supported. Make sure to choose one of `'prd'` or `None`."
|
||||
)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
@@ -87,8 +138,16 @@ class PriorTransformer(ModelMixin, ConfigMixin):
|
||||
]
|
||||
)
|
||||
|
||||
if norm_in_type == "layer":
|
||||
self.norm_in = nn.LayerNorm(inner_dim)
|
||||
elif norm_in_type is None:
|
||||
self.norm_in = None
|
||||
else:
|
||||
raise ValueError(f"Unsupported norm_in_type: {norm_in_type}.")
|
||||
|
||||
self.norm_out = nn.LayerNorm(inner_dim)
|
||||
self.proj_to_clip_embeddings = nn.Linear(inner_dim, embedding_dim)
|
||||
|
||||
self.proj_to_clip_embeddings = nn.Linear(inner_dim, clip_embed_dim)
|
||||
|
||||
causal_attention_mask = torch.full(
|
||||
[num_embeddings + additional_embeddings, num_embeddings + additional_embeddings], -10000.0
|
||||
@@ -97,8 +156,8 @@ class PriorTransformer(ModelMixin, ConfigMixin):
|
||||
causal_attention_mask = causal_attention_mask[None, ...]
|
||||
self.register_buffer("causal_attention_mask", causal_attention_mask, persistent=False)
|
||||
|
||||
self.clip_mean = nn.Parameter(torch.zeros(1, embedding_dim))
|
||||
self.clip_std = nn.Parameter(torch.zeros(1, embedding_dim))
|
||||
self.clip_mean = nn.Parameter(torch.zeros(1, clip_embed_dim))
|
||||
self.clip_std = nn.Parameter(torch.zeros(1, clip_embed_dim))
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
@@ -172,7 +231,7 @@ class PriorTransformer(ModelMixin, ConfigMixin):
|
||||
hidden_states,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
proj_embedding: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.BoolTensor] = None,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
@@ -217,23 +276,61 @@ class PriorTransformer(ModelMixin, ConfigMixin):
|
||||
timesteps_projected = timesteps_projected.to(dtype=self.dtype)
|
||||
time_embeddings = self.time_embedding(timesteps_projected)
|
||||
|
||||
if self.embedding_proj_norm is not None:
|
||||
proj_embedding = self.embedding_proj_norm(proj_embedding)
|
||||
|
||||
proj_embeddings = self.embedding_proj(proj_embedding)
|
||||
encoder_hidden_states = self.encoder_hidden_states_proj(encoder_hidden_states)
|
||||
if self.encoder_hidden_states_proj is not None and encoder_hidden_states is not None:
|
||||
encoder_hidden_states = self.encoder_hidden_states_proj(encoder_hidden_states)
|
||||
elif self.encoder_hidden_states_proj is not None and encoder_hidden_states is None:
|
||||
raise ValueError("`encoder_hidden_states_proj` requires `encoder_hidden_states` to be set")
|
||||
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
prd_embedding = self.prd_embedding.to(hidden_states.dtype).expand(batch_size, -1, -1)
|
||||
|
||||
positional_embeddings = self.positional_embedding.to(hidden_states.dtype)
|
||||
|
||||
additional_embeds = []
|
||||
additional_embeddings_len = 0
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
additional_embeds.append(encoder_hidden_states)
|
||||
additional_embeddings_len += encoder_hidden_states.shape[1]
|
||||
|
||||
if len(proj_embeddings.shape) == 2:
|
||||
proj_embeddings = proj_embeddings[:, None, :]
|
||||
|
||||
if len(hidden_states.shape) == 2:
|
||||
hidden_states = hidden_states[:, None, :]
|
||||
|
||||
additional_embeds = additional_embeds + [
|
||||
proj_embeddings,
|
||||
time_embeddings[:, None, :],
|
||||
hidden_states,
|
||||
]
|
||||
|
||||
if self.prd_embedding is not None:
|
||||
prd_embedding = self.prd_embedding.to(hidden_states.dtype).expand(batch_size, -1, -1)
|
||||
additional_embeds.append(prd_embedding)
|
||||
|
||||
hidden_states = torch.cat(
|
||||
[
|
||||
encoder_hidden_states,
|
||||
proj_embeddings[:, None, :],
|
||||
time_embeddings[:, None, :],
|
||||
hidden_states[:, None, :],
|
||||
prd_embedding,
|
||||
],
|
||||
additional_embeds,
|
||||
dim=1,
|
||||
)
|
||||
|
||||
# Allow positional_embedding to not include the `addtional_embeddings` and instead pad it with zeros for these additional tokens
|
||||
additional_embeddings_len = additional_embeddings_len + proj_embeddings.shape[1] + 1
|
||||
if positional_embeddings.shape[1] < hidden_states.shape[1]:
|
||||
positional_embeddings = F.pad(
|
||||
positional_embeddings,
|
||||
(
|
||||
0,
|
||||
0,
|
||||
additional_embeddings_len,
|
||||
self.prd_embedding.shape[1] if self.prd_embedding is not None else 0,
|
||||
),
|
||||
value=0.0,
|
||||
)
|
||||
|
||||
hidden_states = hidden_states + positional_embeddings
|
||||
|
||||
if attention_mask is not None:
|
||||
@@ -242,11 +339,19 @@ class PriorTransformer(ModelMixin, ConfigMixin):
|
||||
attention_mask = (attention_mask[:, None, :] + self.causal_attention_mask).to(hidden_states.dtype)
|
||||
attention_mask = attention_mask.repeat_interleave(self.config.num_attention_heads, dim=0)
|
||||
|
||||
if self.norm_in is not None:
|
||||
hidden_states = self.norm_in(hidden_states)
|
||||
|
||||
for block in self.transformer_blocks:
|
||||
hidden_states = block(hidden_states, attention_mask=attention_mask)
|
||||
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
hidden_states = hidden_states[:, -1]
|
||||
|
||||
if self.prd_embedding is not None:
|
||||
hidden_states = hidden_states[:, -1]
|
||||
else:
|
||||
hidden_states = hidden_states[:, additional_embeddings_len:]
|
||||
|
||||
predicted_image_embedding = self.proj_to_clip_embeddings(hidden_states)
|
||||
|
||||
if not return_dict:
|
||||
|
||||
@@ -66,6 +66,10 @@ class UNet2DModel(ModelMixin, ConfigMixin):
|
||||
layers_per_block (`int`, *optional*, defaults to `2`): The number of layers per block.
|
||||
mid_block_scale_factor (`float`, *optional*, defaults to `1`): The scale factor for the mid block.
|
||||
downsample_padding (`int`, *optional*, defaults to `1`): The padding for the downsample convolution.
|
||||
downsample_type (`str`, *optional*, defaults to `conv`):
|
||||
The downsample type for downsampling layers. Choose between "conv" and "resnet"
|
||||
upsample_type (`str`, *optional*, defaults to `conv`):
|
||||
The upsample type for upsampling layers. Choose between "conv" and "resnet"
|
||||
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
||||
attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension.
|
||||
norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for normalization.
|
||||
@@ -96,6 +100,8 @@ class UNet2DModel(ModelMixin, ConfigMixin):
|
||||
layers_per_block: int = 2,
|
||||
mid_block_scale_factor: float = 1,
|
||||
downsample_padding: int = 1,
|
||||
downsample_type: str = "conv",
|
||||
upsample_type: str = "conv",
|
||||
act_fn: str = "silu",
|
||||
attention_head_dim: Optional[int] = 8,
|
||||
norm_num_groups: int = 32,
|
||||
@@ -168,6 +174,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
|
||||
attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel,
|
||||
downsample_padding=downsample_padding,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
downsample_type=downsample_type,
|
||||
)
|
||||
self.down_blocks.append(down_block)
|
||||
|
||||
@@ -207,6 +214,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
|
||||
resnet_groups=norm_num_groups,
|
||||
attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
upsample_type=upsample_type,
|
||||
)
|
||||
self.up_blocks.append(up_block)
|
||||
prev_output_channel = output_channel
|
||||
|
||||
@@ -38,6 +38,7 @@ def get_down_block(
|
||||
add_downsample,
|
||||
resnet_eps,
|
||||
resnet_act_fn,
|
||||
transformer_layers_per_block=1,
|
||||
num_attention_heads=None,
|
||||
resnet_groups=None,
|
||||
cross_attention_dim=None,
|
||||
@@ -51,6 +52,7 @@ def get_down_block(
|
||||
resnet_out_scale_factor=1.0,
|
||||
cross_attention_norm=None,
|
||||
attention_head_dim=None,
|
||||
downsample_type=None,
|
||||
):
|
||||
# If attn head dim is not defined, we default it to the number of heads
|
||||
if attention_head_dim is None:
|
||||
@@ -88,24 +90,29 @@ def get_down_block(
|
||||
output_scale_factor=resnet_out_scale_factor,
|
||||
)
|
||||
elif down_block_type == "AttnDownBlock2D":
|
||||
if add_downsample is False:
|
||||
downsample_type = None
|
||||
else:
|
||||
downsample_type = downsample_type or "conv" # default to 'conv'
|
||||
return AttnDownBlock2D(
|
||||
num_layers=num_layers,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
add_downsample=add_downsample,
|
||||
resnet_eps=resnet_eps,
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
resnet_groups=resnet_groups,
|
||||
downsample_padding=downsample_padding,
|
||||
attention_head_dim=attention_head_dim,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
downsample_type=downsample_type,
|
||||
)
|
||||
elif down_block_type == "CrossAttnDownBlock2D":
|
||||
if cross_attention_dim is None:
|
||||
raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D")
|
||||
return CrossAttnDownBlock2D(
|
||||
num_layers=num_layers,
|
||||
transformer_layers_per_block=transformer_layers_per_block,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
@@ -227,6 +234,7 @@ def get_up_block(
|
||||
add_upsample,
|
||||
resnet_eps,
|
||||
resnet_act_fn,
|
||||
transformer_layers_per_block=1,
|
||||
num_attention_heads=None,
|
||||
resnet_groups=None,
|
||||
cross_attention_dim=None,
|
||||
@@ -239,6 +247,7 @@ def get_up_block(
|
||||
resnet_out_scale_factor=1.0,
|
||||
cross_attention_norm=None,
|
||||
attention_head_dim=None,
|
||||
upsample_type=None,
|
||||
):
|
||||
# If attn head dim is not defined, we default it to the number of heads
|
||||
if attention_head_dim is None:
|
||||
@@ -281,6 +290,7 @@ def get_up_block(
|
||||
raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D")
|
||||
return CrossAttnUpBlock2D(
|
||||
num_layers=num_layers,
|
||||
transformer_layers_per_block=transformer_layers_per_block,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
prev_output_channel=prev_output_channel,
|
||||
@@ -319,18 +329,23 @@ def get_up_block(
|
||||
cross_attention_norm=cross_attention_norm,
|
||||
)
|
||||
elif up_block_type == "AttnUpBlock2D":
|
||||
if add_upsample is False:
|
||||
upsample_type = None
|
||||
else:
|
||||
upsample_type = upsample_type or "conv" # default to 'conv'
|
||||
|
||||
return AttnUpBlock2D(
|
||||
num_layers=num_layers,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
prev_output_channel=prev_output_channel,
|
||||
temb_channels=temb_channels,
|
||||
add_upsample=add_upsample,
|
||||
resnet_eps=resnet_eps,
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
resnet_groups=resnet_groups,
|
||||
attention_head_dim=attention_head_dim,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
upsample_type=upsample_type,
|
||||
)
|
||||
elif up_block_type == "SkipUpBlock2D":
|
||||
return SkipUpBlock2D(
|
||||
@@ -506,6 +521,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
||||
temb_channels: int,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 1,
|
||||
transformer_layers_per_block: int = 1,
|
||||
resnet_eps: float = 1e-6,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
resnet_act_fn: str = "swish",
|
||||
@@ -548,7 +564,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
||||
num_attention_heads,
|
||||
in_channels // num_attention_heads,
|
||||
in_channels=in_channels,
|
||||
num_layers=1,
|
||||
num_layers=transformer_layers_per_block,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
norm_num_groups=resnet_groups,
|
||||
use_linear_projection=use_linear_projection,
|
||||
@@ -747,11 +763,12 @@ class AttnDownBlock2D(nn.Module):
|
||||
attention_head_dim=1,
|
||||
output_scale_factor=1.0,
|
||||
downsample_padding=1,
|
||||
add_downsample=True,
|
||||
downsample_type="conv",
|
||||
):
|
||||
super().__init__()
|
||||
resnets = []
|
||||
attentions = []
|
||||
self.downsample_type = downsample_type
|
||||
|
||||
if attention_head_dim is None:
|
||||
logger.warn(
|
||||
@@ -793,7 +810,7 @@ class AttnDownBlock2D(nn.Module):
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
if add_downsample:
|
||||
if downsample_type == "conv":
|
||||
self.downsamplers = nn.ModuleList(
|
||||
[
|
||||
Downsample2D(
|
||||
@@ -801,6 +818,24 @@ class AttnDownBlock2D(nn.Module):
|
||||
)
|
||||
]
|
||||
)
|
||||
elif downsample_type == "resnet":
|
||||
self.downsamplers = nn.ModuleList(
|
||||
[
|
||||
ResnetBlock2D(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=output_scale_factor,
|
||||
pre_norm=resnet_pre_norm,
|
||||
down=True,
|
||||
)
|
||||
]
|
||||
)
|
||||
else:
|
||||
self.downsamplers = None
|
||||
|
||||
@@ -810,11 +845,14 @@ class AttnDownBlock2D(nn.Module):
|
||||
for resnet, attn in zip(self.resnets, self.attentions):
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = attn(hidden_states)
|
||||
output_states += (hidden_states,)
|
||||
output_states = output_states + (hidden_states,)
|
||||
|
||||
if self.downsamplers is not None:
|
||||
for downsampler in self.downsamplers:
|
||||
hidden_states = downsampler(hidden_states)
|
||||
if self.downsample_type == "resnet":
|
||||
hidden_states = downsampler(hidden_states, temb=temb)
|
||||
else:
|
||||
hidden_states = downsampler(hidden_states)
|
||||
|
||||
output_states += (hidden_states,)
|
||||
|
||||
@@ -829,6 +867,7 @@ class CrossAttnDownBlock2D(nn.Module):
|
||||
temb_channels: int,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 1,
|
||||
transformer_layers_per_block: int = 1,
|
||||
resnet_eps: float = 1e-6,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
resnet_act_fn: str = "swish",
|
||||
@@ -873,7 +912,7 @@ class CrossAttnDownBlock2D(nn.Module):
|
||||
num_attention_heads,
|
||||
out_channels // num_attention_heads,
|
||||
in_channels=out_channels,
|
||||
num_layers=1,
|
||||
num_layers=transformer_layers_per_block,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
norm_num_groups=resnet_groups,
|
||||
use_linear_projection=use_linear_projection,
|
||||
@@ -1860,12 +1899,14 @@ class AttnUpBlock2D(nn.Module):
|
||||
resnet_pre_norm: bool = True,
|
||||
attention_head_dim=1,
|
||||
output_scale_factor=1.0,
|
||||
add_upsample=True,
|
||||
upsample_type="conv",
|
||||
):
|
||||
super().__init__()
|
||||
resnets = []
|
||||
attentions = []
|
||||
|
||||
self.upsample_type = upsample_type
|
||||
|
||||
if attention_head_dim is None:
|
||||
logger.warn(
|
||||
f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}."
|
||||
@@ -1908,8 +1949,26 @@ class AttnUpBlock2D(nn.Module):
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
if add_upsample:
|
||||
if upsample_type == "conv":
|
||||
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
|
||||
elif upsample_type == "resnet":
|
||||
self.upsamplers = nn.ModuleList(
|
||||
[
|
||||
ResnetBlock2D(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=output_scale_factor,
|
||||
pre_norm=resnet_pre_norm,
|
||||
up=True,
|
||||
)
|
||||
]
|
||||
)
|
||||
else:
|
||||
self.upsamplers = None
|
||||
|
||||
@@ -1925,7 +1984,10 @@ class AttnUpBlock2D(nn.Module):
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
hidden_states = upsampler(hidden_states)
|
||||
if self.upsample_type == "resnet":
|
||||
hidden_states = upsampler(hidden_states, temb=temb)
|
||||
else:
|
||||
hidden_states = upsampler(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@@ -1939,6 +2001,7 @@ class CrossAttnUpBlock2D(nn.Module):
|
||||
temb_channels: int,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 1,
|
||||
transformer_layers_per_block: int = 1,
|
||||
resnet_eps: float = 1e-6,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
resnet_act_fn: str = "swish",
|
||||
@@ -1984,7 +2047,7 @@ class CrossAttnUpBlock2D(nn.Module):
|
||||
num_attention_heads,
|
||||
out_channels // num_attention_heads,
|
||||
in_channels=out_channels,
|
||||
num_layers=1,
|
||||
num_layers=transformer_layers_per_block,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
norm_num_groups=resnet_groups,
|
||||
use_linear_projection=use_linear_projection,
|
||||
|
||||
@@ -25,6 +25,9 @@ from .activations import get_activation
|
||||
from .attention_processor import AttentionProcessor, AttnProcessor
|
||||
from .embeddings import (
|
||||
GaussianFourierProjection,
|
||||
ImageHintTimeEmbedding,
|
||||
ImageProjection,
|
||||
ImageTimeEmbedding,
|
||||
TextImageProjection,
|
||||
TextImageTimeEmbedding,
|
||||
TextTimeEmbedding,
|
||||
@@ -57,7 +60,7 @@ class UNet2DConditionOutput(BaseOutput):
|
||||
The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
|
||||
"""
|
||||
|
||||
sample: torch.FloatTensor
|
||||
sample: torch.FloatTensor = None
|
||||
|
||||
|
||||
class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
||||
@@ -98,7 +101,11 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
|
||||
cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
|
||||
The dimension of the cross attention features.
|
||||
encoder_hid_dim (`int`, *optional*, defaults to `None`):
|
||||
transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
|
||||
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
|
||||
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
|
||||
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
|
||||
encoder_hid_dim (`int`, *optional*, defaults to None):
|
||||
If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
|
||||
dimension to `cross_attention_dim`.
|
||||
encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
|
||||
@@ -115,6 +122,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
addition_embed_type (`str`, *optional*, defaults to `None`):
|
||||
Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
|
||||
"text". "text" will use the `TextTimeEmbedding` layer.
|
||||
addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
|
||||
Dimension for the timestep embeddings.
|
||||
num_class_embeds (`int`, *optional*, defaults to `None`):
|
||||
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
|
||||
class conditioning with `class_embed_type` equal to `None`.
|
||||
@@ -170,6 +179,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
norm_num_groups: Optional[int] = 32,
|
||||
norm_eps: float = 1e-5,
|
||||
cross_attention_dim: Union[int, Tuple[int]] = 1280,
|
||||
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
|
||||
encoder_hid_dim: Optional[int] = None,
|
||||
encoder_hid_dim_type: Optional[str] = None,
|
||||
attention_head_dim: Union[int, Tuple[int]] = 8,
|
||||
@@ -178,6 +188,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
use_linear_projection: bool = False,
|
||||
class_embed_type: Optional[str] = None,
|
||||
addition_embed_type: Optional[str] = None,
|
||||
addition_time_embed_dim: Optional[int] = None,
|
||||
num_class_embeds: Optional[int] = None,
|
||||
upcast_attention: bool = False,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
@@ -200,6 +211,11 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
|
||||
self.sample_size = sample_size
|
||||
|
||||
if num_attention_heads is not None:
|
||||
raise ValueError(
|
||||
"At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
|
||||
)
|
||||
|
||||
# If `num_attention_heads` is not defined (which is the case for most models)
|
||||
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
|
||||
# The reason for this behavior is to correct for incorrectly named variables that were introduced
|
||||
@@ -298,7 +314,12 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
image_embed_dim=cross_attention_dim,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
)
|
||||
|
||||
elif encoder_hid_dim_type == "image_proj":
|
||||
# Kandinsky 2.2
|
||||
self.encoder_hid_proj = ImageProjection(
|
||||
image_embed_dim=encoder_hid_dim,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
)
|
||||
elif encoder_hid_dim_type is not None:
|
||||
raise ValueError(
|
||||
f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
|
||||
@@ -351,6 +372,15 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
self.add_embedding = TextImageTimeEmbedding(
|
||||
text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
|
||||
)
|
||||
elif addition_embed_type == "text_time":
|
||||
self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
|
||||
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
||||
elif addition_embed_type == "image":
|
||||
# Kandinsky 2.2
|
||||
self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
|
||||
elif addition_embed_type == "image_hint":
|
||||
# Kandinsky 2.2 ControlNet
|
||||
self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
|
||||
elif addition_embed_type is not None:
|
||||
raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
|
||||
|
||||
@@ -383,6 +413,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
if isinstance(layers_per_block, int):
|
||||
layers_per_block = [layers_per_block] * len(down_block_types)
|
||||
|
||||
if isinstance(transformer_layers_per_block, int):
|
||||
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
|
||||
|
||||
if class_embeddings_concat:
|
||||
# The time embeddings are concatenated with the class embeddings. The dimension of the
|
||||
# time embeddings passed to the down, middle, and up blocks is twice the dimension of the
|
||||
@@ -401,6 +434,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
down_block = get_down_block(
|
||||
down_block_type,
|
||||
num_layers=layers_per_block[i],
|
||||
transformer_layers_per_block=transformer_layers_per_block[i],
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
temb_channels=blocks_time_embed_dim,
|
||||
@@ -426,6 +460,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
# mid
|
||||
if mid_block_type == "UNetMidBlock2DCrossAttn":
|
||||
self.mid_block = UNetMidBlock2DCrossAttn(
|
||||
transformer_layers_per_block=transformer_layers_per_block[-1],
|
||||
in_channels=block_out_channels[-1],
|
||||
temb_channels=blocks_time_embed_dim,
|
||||
resnet_eps=norm_eps,
|
||||
@@ -467,6 +502,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
reversed_num_attention_heads = list(reversed(num_attention_heads))
|
||||
reversed_layers_per_block = list(reversed(layers_per_block))
|
||||
reversed_cross_attention_dim = list(reversed(cross_attention_dim))
|
||||
reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
|
||||
only_cross_attention = list(reversed(only_cross_attention))
|
||||
|
||||
output_channel = reversed_block_out_channels[0]
|
||||
@@ -487,6 +523,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
up_block = get_up_block(
|
||||
up_block_type,
|
||||
num_layers=reversed_layers_per_block[i] + 1,
|
||||
transformer_layers_per_block=reversed_transformer_layers_per_block[i],
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
prev_output_channel=prev_output_channel,
|
||||
@@ -693,6 +730,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
tuple.
|
||||
cross_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
|
||||
added_cond_kwargs: (`dict`, *optional*):
|
||||
A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
|
||||
are passed along to the UNet blocks.
|
||||
|
||||
Returns:
|
||||
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
|
||||
@@ -763,6 +803,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
t_emb = t_emb.to(dtype=sample.dtype)
|
||||
|
||||
emb = self.time_embedding(t_emb, timestep_cond)
|
||||
aug_emb = None
|
||||
|
||||
if self.class_embedding is not None:
|
||||
if class_labels is None:
|
||||
@@ -784,9 +825,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
|
||||
if self.config.addition_embed_type == "text":
|
||||
aug_emb = self.add_embedding(encoder_hidden_states)
|
||||
emb = emb + aug_emb
|
||||
elif self.config.addition_embed_type == "text_image":
|
||||
# Kadinsky 2.1 - style
|
||||
# Kandinsky 2.1 - style
|
||||
if "image_embeds" not in added_cond_kwargs:
|
||||
raise ValueError(
|
||||
f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
|
||||
@@ -794,9 +834,44 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
|
||||
image_embs = added_cond_kwargs.get("image_embeds")
|
||||
text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
|
||||
|
||||
aug_emb = self.add_embedding(text_embs, image_embs)
|
||||
emb = emb + aug_emb
|
||||
elif self.config.addition_embed_type == "text_time":
|
||||
if "text_embeds" not in added_cond_kwargs:
|
||||
raise ValueError(
|
||||
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
|
||||
)
|
||||
text_embeds = added_cond_kwargs.get("text_embeds")
|
||||
if "time_ids" not in added_cond_kwargs:
|
||||
raise ValueError(
|
||||
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
|
||||
)
|
||||
time_ids = added_cond_kwargs.get("time_ids")
|
||||
time_embeds = self.add_time_proj(time_ids.flatten())
|
||||
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
|
||||
|
||||
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
|
||||
add_embeds = add_embeds.to(emb.dtype)
|
||||
aug_emb = self.add_embedding(add_embeds)
|
||||
elif self.config.addition_embed_type == "image":
|
||||
# Kandinsky 2.2 - style
|
||||
if "image_embeds" not in added_cond_kwargs:
|
||||
raise ValueError(
|
||||
f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
|
||||
)
|
||||
image_embs = added_cond_kwargs.get("image_embeds")
|
||||
aug_emb = self.add_embedding(image_embs)
|
||||
elif self.config.addition_embed_type == "image_hint":
|
||||
# Kandinsky 2.2 - style
|
||||
if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
|
||||
raise ValueError(
|
||||
f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
|
||||
)
|
||||
image_embs = added_cond_kwargs.get("image_embeds")
|
||||
hint = added_cond_kwargs.get("hint")
|
||||
aug_emb, hint = self.add_embedding(image_embs, hint)
|
||||
sample = torch.cat([sample, hint], dim=1)
|
||||
|
||||
emb = emb + aug_emb if aug_emb is not None else emb
|
||||
|
||||
if self.time_embed_act is not None:
|
||||
emb = self.time_embed_act(emb)
|
||||
@@ -812,7 +887,14 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
|
||||
image_embeds = added_cond_kwargs.get("image_embeds")
|
||||
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
|
||||
|
||||
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
|
||||
# Kandinsky 2.2 - style
|
||||
if "image_embeds" not in added_cond_kwargs:
|
||||
raise ValueError(
|
||||
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
|
||||
)
|
||||
image_embeds = added_cond_kwargs.get("image_embeds")
|
||||
encoder_hidden_states = self.encoder_hid_proj(image_embeds)
|
||||
# 2. pre-process
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
|
||||
@@ -133,6 +133,11 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
block_out_channels = self.block_out_channels
|
||||
time_embed_dim = block_out_channels[0] * 4
|
||||
|
||||
if self.num_attention_heads is not None:
|
||||
raise ValueError(
|
||||
"At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
|
||||
)
|
||||
|
||||
# If `num_attention_heads` is not defined (which is the case for most models)
|
||||
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
|
||||
# The reason for this behavior is to correct for incorrectly named variables that were introduced
|
||||
|
||||
@@ -250,10 +250,11 @@ class UNetMidBlock3DCrossAttn(nn.Module):
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
).sample
|
||||
return_dict=False,
|
||||
)[0]
|
||||
hidden_states = temp_attn(
|
||||
hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs
|
||||
).sample
|
||||
hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, return_dict=False
|
||||
)[0]
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = temp_conv(hidden_states, num_frames=num_frames)
|
||||
|
||||
@@ -377,10 +378,11 @@ class CrossAttnDownBlock3D(nn.Module):
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
).sample
|
||||
return_dict=False,
|
||||
)[0]
|
||||
hidden_states = temp_attn(
|
||||
hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs
|
||||
).sample
|
||||
hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, return_dict=False
|
||||
)[0]
|
||||
|
||||
output_states += (hidden_states,)
|
||||
|
||||
@@ -590,10 +592,11 @@ class CrossAttnUpBlock3D(nn.Module):
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
).sample
|
||||
return_dict=False,
|
||||
)[0]
|
||||
hidden_states = temp_attn(
|
||||
hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs
|
||||
).sample
|
||||
hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, return_dict=False
|
||||
)[0]
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
|
||||
@@ -114,6 +114,11 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
|
||||
self.sample_size = sample_size
|
||||
|
||||
if num_attention_heads is not None:
|
||||
raise NotImplementedError(
|
||||
"At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
|
||||
)
|
||||
|
||||
# If `num_attention_heads` is not defined (which is the case for most models)
|
||||
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
|
||||
# The reason for this behavior is to correct for incorrectly named variables that were introduced
|
||||
@@ -389,6 +394,46 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
def enable_forward_chunking(self, chunk_size=None, dim=0):
|
||||
"""
|
||||
Sets the attention processor to use [feed forward
|
||||
chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
|
||||
|
||||
Parameters:
|
||||
chunk_size (`int`, *optional*):
|
||||
The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
|
||||
over each tensor of dim=`dim`.
|
||||
dim (`int`, *optional*, defaults to `0`):
|
||||
The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
|
||||
or dim=1 (sequence length).
|
||||
"""
|
||||
if dim not in [0, 1]:
|
||||
raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
|
||||
|
||||
# By default chunk size is 1
|
||||
chunk_size = chunk_size or 1
|
||||
|
||||
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
|
||||
if hasattr(module, "set_chunk_feed_forward"):
|
||||
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
|
||||
|
||||
for child in module.children():
|
||||
fn_recursive_feed_forward(child, chunk_size, dim)
|
||||
|
||||
for module in self.children():
|
||||
fn_recursive_feed_forward(module, chunk_size, dim)
|
||||
|
||||
def disable_forward_chunking(self):
|
||||
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
|
||||
if hasattr(module, "set_chunk_feed_forward"):
|
||||
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
|
||||
|
||||
for child in module.children():
|
||||
fn_recursive_feed_forward(child, chunk_size, dim)
|
||||
|
||||
for module in self.children():
|
||||
fn_recursive_feed_forward(module, None, 0)
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
||||
def set_default_attn_processor(self):
|
||||
"""
|
||||
@@ -486,8 +531,11 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
sample = self.transformer_in(
|
||||
sample, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs
|
||||
).sample
|
||||
sample,
|
||||
num_frames=num_frames,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# 3. down
|
||||
down_block_res_samples = (sample,)
|
||||
|
||||
@@ -18,7 +18,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import BaseOutput
|
||||
from ..utils import BaseOutput, apply_forward_hook
|
||||
from .modeling_utils import ModelMixin
|
||||
from .vae import Decoder, DecoderOutput, Encoder, VectorQuantizer
|
||||
|
||||
@@ -116,6 +116,7 @@ class VQModel(ModelMixin, ConfigMixin):
|
||||
norm_type=norm_type,
|
||||
)
|
||||
|
||||
@apply_forward_hook
|
||||
def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> VQEncoderOutput:
|
||||
h = self.encoder(x)
|
||||
h = self.quant_conv(h)
|
||||
@@ -125,6 +126,7 @@ class VQModel(ModelMixin, ConfigMixin):
|
||||
|
||||
return VQEncoderOutput(latents=h)
|
||||
|
||||
@apply_forward_hook
|
||||
def decode(
|
||||
self, h: torch.FloatTensor, force_not_quantize: bool = False, return_dict: bool = True
|
||||
) -> Union[DecoderOutput, torch.FloatTensor]:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from ..utils import (
|
||||
OptionalDependencyNotAvailable,
|
||||
is_flax_available,
|
||||
is_invisible_watermark_available,
|
||||
is_k_diffusion_available,
|
||||
is_librosa_available,
|
||||
is_note_seq_available,
|
||||
@@ -16,6 +17,7 @@ try:
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils.dummy_pt_objects import * # noqa F403
|
||||
else:
|
||||
from .consistency_models import ConsistencyModelPipeline
|
||||
from .dance_diffusion import DanceDiffusionPipeline
|
||||
from .ddim import DDIMPipeline
|
||||
from .ddpm import DDPMPipeline
|
||||
@@ -63,9 +65,19 @@ else:
|
||||
KandinskyPipeline,
|
||||
KandinskyPriorPipeline,
|
||||
)
|
||||
from .kandinsky2_2 import (
|
||||
KandinskyV22ControlnetImg2ImgPipeline,
|
||||
KandinskyV22ControlnetPipeline,
|
||||
KandinskyV22Img2ImgPipeline,
|
||||
KandinskyV22InpaintPipeline,
|
||||
KandinskyV22Pipeline,
|
||||
KandinskyV22PriorEmb2EmbPipeline,
|
||||
KandinskyV22PriorPipeline,
|
||||
)
|
||||
from .latent_diffusion import LDMTextToImagePipeline
|
||||
from .paint_by_example import PaintByExamplePipeline
|
||||
from .semantic_stable_diffusion import SemanticStableDiffusionPipeline
|
||||
from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline
|
||||
from .stable_diffusion import (
|
||||
CycleDiffusionPipeline,
|
||||
StableDiffusionAttendAndExcitePipeline,
|
||||
@@ -100,6 +112,15 @@ else:
|
||||
)
|
||||
from .vq_diffusion import VQDiffusionPipeline
|
||||
|
||||
|
||||
try:
|
||||
if not (is_torch_available() and is_transformers_available() and is_invisible_watermark_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils.dummy_torch_and_transformers_and_invisible_watermark_objects import * # noqa F403
|
||||
else:
|
||||
from .stable_diffusion_xl import StableDiffusionXLImg2ImgPipeline, StableDiffusionXLPipeline
|
||||
|
||||
try:
|
||||
if not is_onnx_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
|
||||
@@ -77,7 +77,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
|
||||
In addition the pipeline inherits the following loading methods:
|
||||
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
|
||||
- *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
|
||||
- *Ckpt*: [`loaders.FromCkptMixin.from_ckpt`]
|
||||
- *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
|
||||
|
||||
as well as the following saving methods:
|
||||
- *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
|
||||
|
||||
@@ -26,7 +26,7 @@ from diffusers.utils import is_accelerate_available, is_accelerate_version
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import PIL_INTERPOLATION, deprecate, logging, randn_tensor, replace_example_docstring
|
||||
@@ -95,7 +95,9 @@ def preprocess(image):
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline with Stable->Alt, CLIPTextModel->RobertaSeriesModelWithTransformation, CLIPTokenizer->XLMRobertaTokenizer, AltDiffusionSafetyChecker->StableDiffusionSafetyChecker
|
||||
class AltDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
|
||||
class AltDiffusionImg2ImgPipeline(
|
||||
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
|
||||
):
|
||||
r"""
|
||||
Pipeline for text-guided image to image generation using Alt Diffusion.
|
||||
|
||||
@@ -105,7 +107,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin
|
||||
In addition the pipeline inherits the following loading methods:
|
||||
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
|
||||
- *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
|
||||
- *Ckpt*: [`loaders.FromCkptMixin.from_ckpt`]
|
||||
- *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
|
||||
|
||||
as well as the following saving methods:
|
||||
- *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
from .pipeline_consistency_models import ConsistencyModelPipeline
|
||||
@@ -0,0 +1,337 @@
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ...models import UNet2DModel
|
||||
from ...schedulers import CMStochasticIterativeScheduler
|
||||
from ...utils import (
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
logging,
|
||||
randn_tensor,
|
||||
replace_example_docstring,
|
||||
)
|
||||
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
|
||||
>>> from diffusers import ConsistencyModelPipeline
|
||||
|
||||
>>> device = "cuda"
|
||||
>>> # Load the cd_imagenet64_l2 checkpoint.
|
||||
>>> model_id_or_path = "openai/diffusers-cd_imagenet64_l2"
|
||||
>>> pipe = ConsistencyModelPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16)
|
||||
>>> pipe.to(device)
|
||||
|
||||
>>> # Onestep Sampling
|
||||
>>> image = pipe(num_inference_steps=1).images[0]
|
||||
>>> image.save("cd_imagenet64_l2_onestep_sample.png")
|
||||
|
||||
>>> # Onestep sampling, class-conditional image generation
|
||||
>>> # ImageNet-64 class label 145 corresponds to king penguins
|
||||
>>> image = pipe(num_inference_steps=1, class_labels=145).images[0]
|
||||
>>> image.save("cd_imagenet64_l2_onestep_sample_penguin.png")
|
||||
|
||||
>>> # Multistep sampling, class-conditional image generation
|
||||
>>> # Timesteps can be explicitly specified; the particular timesteps below are from the original Github repo:
|
||||
>>> # https://github.com/openai/consistency_models/blob/main/scripts/launch.sh#L77
|
||||
>>> image = pipe(num_inference_steps=None, timesteps=[22, 0], class_labels=145).images[0]
|
||||
>>> image.save("cd_imagenet64_l2_multistep_sample_penguin.png")
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
class ConsistencyModelPipeline(DiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for consistency models for unconditional or class-conditional image generation, as introduced in [1].
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
[1] Song, Yang and Dhariwal, Prafulla and Chen, Mark and Sutskever, Ilya. "Consistency Models"
|
||||
https://arxiv.org/pdf/2303.01469
|
||||
|
||||
Args:
|
||||
unet ([`UNet2DModel`]):
|
||||
Unconditional or class-conditional U-Net architecture to denoise image latents.
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `unet` to denoise the image latents. Currently only compatible
|
||||
with [`CMStochasticIterativeScheduler`].
|
||||
"""
|
||||
|
||||
def __init__(self, unet: UNet2DModel, scheduler: CMStochasticIterativeScheduler) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
self.safety_checker = None
|
||||
|
||||
def enable_sequential_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
|
||||
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
|
||||
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
|
||||
Note that offloading happens on a submodule basis. Memory savings are higher than with
|
||||
`enable_model_cpu_offload`, but performance is lower.
|
||||
"""
|
||||
if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"):
|
||||
from accelerate import cpu_offload
|
||||
else:
|
||||
raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher")
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
if self.device.type != "cpu":
|
||||
self.to("cpu", silence_dtype_warnings=True)
|
||||
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
||||
|
||||
for cpu_offloaded_model in [self.unet]:
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
|
||||
if self.safety_checker is not None:
|
||||
cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
|
||||
|
||||
def enable_model_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
|
||||
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
|
||||
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
|
||||
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
|
||||
"""
|
||||
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
|
||||
from accelerate import cpu_offload_with_hook
|
||||
else:
|
||||
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
if self.device.type != "cpu":
|
||||
self.to("cpu", silence_dtype_warnings=True)
|
||||
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
||||
|
||||
hook = None
|
||||
for cpu_offloaded_model in [self.unet]:
|
||||
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
|
||||
|
||||
if self.safety_checker is not None:
|
||||
_, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
|
||||
|
||||
# We'll offload the last model manually.
|
||||
self.final_offload_hook = hook
|
||||
|
||||
@property
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
|
||||
def _execution_device(self):
|
||||
r"""
|
||||
Returns the device on which the pipeline's models will be executed. After calling
|
||||
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
|
||||
hooks.
|
||||
"""
|
||||
if not hasattr(self.unet, "_hf_hook"):
|
||||
return self.device
|
||||
for module in self.unet.modules():
|
||||
if (
|
||||
hasattr(module, "_hf_hook")
|
||||
and hasattr(module._hf_hook, "execution_device")
|
||||
and module._hf_hook.execution_device is not None
|
||||
):
|
||||
return torch.device(module._hf_hook.execution_device)
|
||||
return self.device
|
||||
|
||||
def prepare_latents(self, batch_size, num_channels, height, width, dtype, device, generator, latents=None):
|
||||
shape = (batch_size, num_channels, height, width)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
latents = latents.to(device=device, dtype=dtype)
|
||||
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
# Follows diffusers.VaeImageProcessor.postprocess
|
||||
def postprocess_image(self, sample: torch.FloatTensor, output_type: str = "pil"):
|
||||
if output_type not in ["pt", "np", "pil"]:
|
||||
raise ValueError(
|
||||
f"output_type={output_type} is not supported. Make sure to choose one of ['pt', 'np', or 'pil']"
|
||||
)
|
||||
|
||||
# Equivalent to diffusers.VaeImageProcessor.denormalize
|
||||
sample = (sample / 2 + 0.5).clamp(0, 1)
|
||||
if output_type == "pt":
|
||||
return sample
|
||||
|
||||
# Equivalent to diffusers.VaeImageProcessor.pt_to_numpy
|
||||
sample = sample.cpu().permute(0, 2, 3, 1).numpy()
|
||||
if output_type == "np":
|
||||
return sample
|
||||
|
||||
# Output_type must be 'pil'
|
||||
sample = self.numpy_to_pil(sample)
|
||||
return sample
|
||||
|
||||
def prepare_class_labels(self, batch_size, device, class_labels=None):
|
||||
if self.unet.config.num_class_embeds is not None:
|
||||
if isinstance(class_labels, list):
|
||||
class_labels = torch.tensor(class_labels, dtype=torch.int)
|
||||
elif isinstance(class_labels, int):
|
||||
assert batch_size == 1, "Batch size must be 1 if classes is an int"
|
||||
class_labels = torch.tensor([class_labels], dtype=torch.int)
|
||||
elif class_labels is None:
|
||||
# Randomly generate batch_size class labels
|
||||
# TODO: should use generator here? int analogue of randn_tensor is not exposed in ...utils
|
||||
class_labels = torch.randint(0, self.unet.config.num_class_embeds, size=(batch_size,))
|
||||
class_labels = class_labels.to(device)
|
||||
else:
|
||||
class_labels = None
|
||||
return class_labels
|
||||
|
||||
def check_inputs(self, num_inference_steps, timesteps, latents, batch_size, img_size, callback_steps):
|
||||
if num_inference_steps is None and timesteps is None:
|
||||
raise ValueError("Exactly one of `num_inference_steps` or `timesteps` must be supplied.")
|
||||
|
||||
if num_inference_steps is not None and timesteps is not None:
|
||||
logger.warning(
|
||||
f"Both `num_inference_steps`: {num_inference_steps} and `timesteps`: {timesteps} are supplied;"
|
||||
" `timesteps` will be used over `num_inference_steps`."
|
||||
)
|
||||
|
||||
if latents is not None:
|
||||
expected_shape = (batch_size, 3, img_size, img_size)
|
||||
if latents.shape != expected_shape:
|
||||
raise ValueError(f"The shape of latents is {latents.shape} but is expected to be {expected_shape}.")
|
||||
|
||||
if (callback_steps is None) or (
|
||||
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
||||
f" {type(callback_steps)}."
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
batch_size: int = 1,
|
||||
class_labels: Optional[Union[torch.Tensor, List[int], int]] = None,
|
||||
num_inference_steps: int = 1,
|
||||
timesteps: List[int] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: int = 1,
|
||||
):
|
||||
r"""
|
||||
Args:
|
||||
batch_size (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate.
|
||||
class_labels (`torch.Tensor` or `List[int]` or `int`, *optional*):
|
||||
Optional class labels for conditioning class-conditional consistency models. Will not be used if the
|
||||
model is not class-conditional.
|
||||
num_inference_steps (`int`, *optional*, defaults to 1):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
|
||||
timesteps are used. Must be in descending order.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
|
||||
callback (`Callable`, *optional*):
|
||||
A function that will be called every `callback_steps` steps during inference. The function will be
|
||||
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
||||
callback_steps (`int`, *optional*, defaults to 1):
|
||||
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
||||
called at every step.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is
|
||||
True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images.
|
||||
"""
|
||||
# 0. Prepare call parameters
|
||||
img_size = self.unet.config.sample_size
|
||||
device = self._execution_device
|
||||
|
||||
# 1. Check inputs
|
||||
self.check_inputs(num_inference_steps, timesteps, latents, batch_size, img_size, callback_steps)
|
||||
|
||||
# 2. Prepare image latents
|
||||
# Sample image latents x_0 ~ N(0, sigma_0^2 * I)
|
||||
sample = self.prepare_latents(
|
||||
batch_size=batch_size,
|
||||
num_channels=self.unet.config.in_channels,
|
||||
height=img_size,
|
||||
width=img_size,
|
||||
dtype=self.unet.dtype,
|
||||
device=device,
|
||||
generator=generator,
|
||||
latents=latents,
|
||||
)
|
||||
|
||||
# 3. Handle class_labels for class-conditional models
|
||||
class_labels = self.prepare_class_labels(batch_size, device, class_labels=class_labels)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
if timesteps is not None:
|
||||
self.scheduler.set_timesteps(timesteps=timesteps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 5. Denoising loop
|
||||
# Multistep sampling: implements Algorithm 1 in the paper
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
scaled_sample = self.scheduler.scale_model_input(sample, t)
|
||||
model_output = self.unet(scaled_sample, t, class_labels=class_labels, return_dict=False)[0]
|
||||
|
||||
sample = self.scheduler.step(model_output, t, sample, generator=generator)[0]
|
||||
|
||||
# call the callback, if provided
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, sample)
|
||||
|
||||
# 6. Post-process image sample
|
||||
image = self.postprocess_image(sample, output_type=output_type)
|
||||
|
||||
# Offload last model to CPU
|
||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||
self.final_offload_hook.offload()
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return ImagePipelineOutput(images=image)
|
||||
@@ -15,5 +15,5 @@ else:
|
||||
from .pipeline_kandinsky import KandinskyPipeline
|
||||
from .pipeline_kandinsky_img2img import KandinskyImg2ImgPipeline
|
||||
from .pipeline_kandinsky_inpaint import KandinskyInpaintPipeline
|
||||
from .pipeline_kandinsky_prior import KandinskyPriorPipeline
|
||||
from .pipeline_kandinsky_prior import KandinskyPriorPipeline, KandinskyPriorPipelineOutput
|
||||
from .text_encoder import MultilingualCLIP
|
||||
|
||||
@@ -115,6 +115,7 @@ class KandinskyPipeline(DiffusionPipeline):
|
||||
)
|
||||
self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1)
|
||||
|
||||
# Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents
|
||||
def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
|
||||
@@ -275,6 +275,7 @@ class KandinskyInpaintPipeline(DiffusionPipeline):
|
||||
)
|
||||
self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1)
|
||||
|
||||
# Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents
|
||||
def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
|
||||
@@ -274,6 +274,7 @@ class KandinskyPriorPipeline(DiffusionPipeline):
|
||||
|
||||
return KandinskyPriorPipelineOutput(image_embeds=image_emb, negative_image_embeds=zero_image_emb)
|
||||
|
||||
# Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents
|
||||
def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
|
||||
@@ -0,0 +1,7 @@
|
||||
from .pipeline_kandinsky2_2 import KandinskyV22Pipeline
|
||||
from .pipeline_kandinsky2_2_controlnet import KandinskyV22ControlnetPipeline
|
||||
from .pipeline_kandinsky2_2_controlnet_img2img import KandinskyV22ControlnetImg2ImgPipeline
|
||||
from .pipeline_kandinsky2_2_img2img import KandinskyV22Img2ImgPipeline
|
||||
from .pipeline_kandinsky2_2_inpainting import KandinskyV22InpaintPipeline
|
||||
from .pipeline_kandinsky2_2_prior import KandinskyV22PriorPipeline
|
||||
from .pipeline_kandinsky2_2_prior_emb2emb import KandinskyV22PriorEmb2EmbPipeline
|
||||
@@ -0,0 +1,317 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ...models import UNet2DConditionModel, VQModel
|
||||
from ...pipelines import DiffusionPipeline
|
||||
from ...pipelines.pipeline_utils import ImagePipelineOutput
|
||||
from ...schedulers import DDPMScheduler
|
||||
from ...utils import (
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
logging,
|
||||
randn_tensor,
|
||||
replace_example_docstring,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> from diffusers import KandinskyV22Pipeline, KandinskyV22PriorPipeline
|
||||
>>> import torch
|
||||
|
||||
>>> pipe_prior = KandinskyV22PriorPipeline.from_pretrained("kandinsky-community/kandinsky-2-2-prior")
|
||||
>>> pipe_prior.to("cuda")
|
||||
>>> prompt = "red cat, 4k photo"
|
||||
>>> out = pipe_prior(prompt)
|
||||
>>> image_emb = out.image_embeds
|
||||
>>> zero_image_emb = out.negative_image_embeds
|
||||
>>> pipe = KandinskyV22Pipeline.from_pretrained("kandinsky-community/kandinsky-2-2-decoder")
|
||||
>>> pipe.to("cuda")
|
||||
>>> image = pipe(
|
||||
... image_embeds=image_emb,
|
||||
... negative_image_embeds=zero_image_emb,
|
||||
... height=768,
|
||||
... width=768,
|
||||
... num_inference_steps=50,
|
||||
... ).images
|
||||
>>> image[0].save("cat.png")
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
def downscale_height_and_width(height, width, scale_factor=8):
|
||||
new_height = height // scale_factor**2
|
||||
if height % scale_factor**2 != 0:
|
||||
new_height += 1
|
||||
new_width = width // scale_factor**2
|
||||
if width % scale_factor**2 != 0:
|
||||
new_width += 1
|
||||
return new_height * scale_factor, new_width * scale_factor
|
||||
|
||||
|
||||
class KandinskyV22Pipeline(DiffusionPipeline):
|
||||
"""
|
||||
Pipeline for text-to-image generation using Kandinsky
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
Args:
|
||||
scheduler (Union[`DDIMScheduler`,`DDPMScheduler`]):
|
||||
A scheduler to be used in combination with `unet` to generate image latents.
|
||||
unet ([`UNet2DConditionModel`]):
|
||||
Conditional U-Net architecture to denoise the image embedding.
|
||||
movq ([`VQModel`]):
|
||||
MoVQ Decoder to generate the image from the latents.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: DDPMScheduler,
|
||||
movq: VQModel,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
movq=movq,
|
||||
)
|
||||
self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1)
|
||||
|
||||
# Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents
|
||||
def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
if latents.shape != shape:
|
||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
||||
latents = latents.to(device)
|
||||
|
||||
latents = latents * scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
def enable_sequential_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's
|
||||
models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only
|
||||
when their specific submodule has its `forward` method called.
|
||||
"""
|
||||
if is_accelerate_available():
|
||||
from accelerate import cpu_offload
|
||||
else:
|
||||
raise ImportError("Please install accelerate via `pip install accelerate`")
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
models = [
|
||||
self.unet,
|
||||
self.movq,
|
||||
]
|
||||
for cpu_offloaded_model in models:
|
||||
if cpu_offloaded_model is not None:
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
|
||||
def enable_model_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
|
||||
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
|
||||
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
|
||||
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
|
||||
"""
|
||||
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
|
||||
from accelerate import cpu_offload_with_hook
|
||||
else:
|
||||
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
if self.device.type != "cpu":
|
||||
self.to("cpu", silence_dtype_warnings=True)
|
||||
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
||||
|
||||
hook = None
|
||||
for cpu_offloaded_model in [self.unet, self.movq]:
|
||||
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
|
||||
|
||||
# We'll offload the last model manually.
|
||||
self.final_offload_hook = hook
|
||||
|
||||
@property
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
|
||||
def _execution_device(self):
|
||||
r"""
|
||||
Returns the device on which the pipeline's models will be executed. After calling
|
||||
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
|
||||
hooks.
|
||||
"""
|
||||
if not hasattr(self.unet, "_hf_hook"):
|
||||
return self.device
|
||||
for module in self.unet.modules():
|
||||
if (
|
||||
hasattr(module, "_hf_hook")
|
||||
and hasattr(module._hf_hook, "execution_device")
|
||||
and module._hf_hook.execution_device is not None
|
||||
):
|
||||
return torch.device(module._hf_hook.execution_device)
|
||||
return self.device
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
image_embeds: Union[torch.FloatTensor, List[torch.FloatTensor]],
|
||||
negative_image_embeds: Union[torch.FloatTensor, List[torch.FloatTensor]],
|
||||
height: int = 512,
|
||||
width: int = 512,
|
||||
num_inference_steps: int = 100,
|
||||
guidance_scale: float = 4.0,
|
||||
num_images_per_prompt: int = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
Function invoked when calling the pipeline for generation.
|
||||
image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`):
|
||||
The clip image embeddings for text prompt, that will be used to condition the image generation.
|
||||
negative_image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`):
|
||||
The clip image embeddings for negative text prompt, will be used to condition the image generation.
|
||||
height (`int`, *optional*, defaults to 512):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to 512):
|
||||
The width in pixels of the generated image.
|
||||
num_inference_steps (`int`, *optional*, defaults to 100):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, *optional*, defaults to 4.0):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
|
||||
(`np.array`) or `"pt"` (`torch.Tensor`).
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.ImagePipelineOutput`] or `tuple`
|
||||
"""
|
||||
device = self._execution_device
|
||||
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
if isinstance(image_embeds, list):
|
||||
image_embeds = torch.cat(image_embeds, dim=0)
|
||||
batch_size = image_embeds.shape[0] * num_images_per_prompt
|
||||
if isinstance(negative_image_embeds, list):
|
||||
negative_image_embeds = torch.cat(negative_image_embeds, dim=0)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
|
||||
image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to(dtype=self.unet.dtype, device=device)
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps_tensor = self.scheduler.timesteps
|
||||
|
||||
num_channels_latents = self.unet.config.in_channels
|
||||
|
||||
height, width = downscale_height_and_width(height, width, self.movq_scale_factor)
|
||||
|
||||
# create initial latent
|
||||
latents = self.prepare_latents(
|
||||
(batch_size, num_channels_latents, height, width),
|
||||
image_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
self.scheduler,
|
||||
)
|
||||
|
||||
for i, t in enumerate(self.progress_bar(timesteps_tensor)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
|
||||
added_cond_kwargs = {"image_embeds": image_embeds}
|
||||
noise_pred = self.unet(
|
||||
sample=latent_model_input,
|
||||
timestep=t,
|
||||
encoder_hidden_states=None,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred, variance_pred = noise_pred.split(latents.shape[1], dim=1)
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
_, variance_pred_text = variance_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
noise_pred = torch.cat([noise_pred, variance_pred_text], dim=1)
|
||||
|
||||
if not (
|
||||
hasattr(self.scheduler.config, "variance_type")
|
||||
and self.scheduler.config.variance_type in ["learned", "learned_range"]
|
||||
):
|
||||
noise_pred, _ = noise_pred.split(latents.shape[1], dim=1)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(
|
||||
noise_pred,
|
||||
t,
|
||||
latents,
|
||||
generator=generator,
|
||||
)[0]
|
||||
# post-processing
|
||||
image = self.movq.decode(latents, force_not_quantize=True)["sample"]
|
||||
|
||||
if output_type not in ["pt", "np", "pil"]:
|
||||
raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}")
|
||||
|
||||
if output_type in ["np", "pil"]:
|
||||
image = image * 0.5 + 0.5
|
||||
image = image.clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return ImagePipelineOutput(images=image)
|
||||
@@ -0,0 +1,372 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ...models import UNet2DConditionModel, VQModel
|
||||
from ...pipelines import DiffusionPipeline
|
||||
from ...pipelines.pipeline_utils import ImagePipelineOutput
|
||||
from ...schedulers import DDPMScheduler
|
||||
from ...utils import (
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
logging,
|
||||
randn_tensor,
|
||||
replace_example_docstring,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> import numpy as np
|
||||
|
||||
>>> from diffusers import KandinskyV22PriorPipeline, KandinskyV22ControlnetPipeline
|
||||
>>> from transformers import pipeline
|
||||
>>> from diffusers.utils import load_image
|
||||
|
||||
|
||||
>>> def make_hint(image, depth_estimator):
|
||||
... image = depth_estimator(image)["depth"]
|
||||
... image = np.array(image)
|
||||
... image = image[:, :, None]
|
||||
... image = np.concatenate([image, image, image], axis=2)
|
||||
... detected_map = torch.from_numpy(image).float() / 255.0
|
||||
... hint = detected_map.permute(2, 0, 1)
|
||||
... return hint
|
||||
|
||||
|
||||
>>> depth_estimator = pipeline("depth-estimation")
|
||||
|
||||
>>> pipe_prior = KandinskyV22PriorPipeline.from_pretrained(
|
||||
... "kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float16
|
||||
... )
|
||||
>>> pipe_prior = pipe_prior.to("cuda")
|
||||
|
||||
>>> pipe = KandinskyV22ControlnetPipeline.from_pretrained(
|
||||
... "kandinsky-community/kandinsky-2-2-controlnet-depth", torch_dtype=torch.float16
|
||||
... )
|
||||
>>> pipe = pipe.to("cuda")
|
||||
|
||||
|
||||
>>> img = load_image(
|
||||
... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
... "/kandinsky/cat.png"
|
||||
... ).resize((768, 768))
|
||||
|
||||
>>> hint = make_hint(img, depth_estimator).unsqueeze(0).half().to("cuda")
|
||||
|
||||
>>> prompt = "A robot, 4k photo"
|
||||
>>> negative_prior_prompt = "lowres, text, error, cropped, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, out of frame, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck, username, watermark, signature"
|
||||
|
||||
>>> generator = torch.Generator(device="cuda").manual_seed(43)
|
||||
|
||||
>>> image_emb, zero_image_emb = pipe_prior(
|
||||
... prompt=prompt, negative_prompt=negative_prior_prompt, generator=generator
|
||||
... ).to_tuple()
|
||||
|
||||
>>> images = pipe(
|
||||
... image_embeds=image_emb,
|
||||
... negative_image_embeds=zero_image_emb,
|
||||
... hint=hint,
|
||||
... num_inference_steps=50,
|
||||
... generator=generator,
|
||||
... height=768,
|
||||
... width=768,
|
||||
... ).images
|
||||
|
||||
>>> images[0].save("robot_cat.png")
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2.downscale_height_and_width
|
||||
def downscale_height_and_width(height, width, scale_factor=8):
|
||||
new_height = height // scale_factor**2
|
||||
if height % scale_factor**2 != 0:
|
||||
new_height += 1
|
||||
new_width = width // scale_factor**2
|
||||
if width % scale_factor**2 != 0:
|
||||
new_width += 1
|
||||
return new_height * scale_factor, new_width * scale_factor
|
||||
|
||||
|
||||
class KandinskyV22ControlnetPipeline(DiffusionPipeline):
|
||||
"""
|
||||
Pipeline for text-to-image generation using Kandinsky
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
Args:
|
||||
scheduler ([`DDIMScheduler`]):
|
||||
A scheduler to be used in combination with `unet` to generate image latents.
|
||||
unet ([`UNet2DConditionModel`]):
|
||||
Conditional U-Net architecture to denoise the image embedding.
|
||||
movq ([`VQModel`]):
|
||||
MoVQ Decoder to generate the image from the latents.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: DDPMScheduler,
|
||||
movq: VQModel,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
movq=movq,
|
||||
)
|
||||
self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1)
|
||||
|
||||
# Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents
|
||||
def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
if latents.shape != shape:
|
||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
||||
latents = latents.to(device)
|
||||
|
||||
latents = latents * scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
# Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2.KandinskyV22Pipeline.enable_sequential_cpu_offload
|
||||
def enable_sequential_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's
|
||||
models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only
|
||||
when their specific submodule has its `forward` method called.
|
||||
"""
|
||||
if is_accelerate_available():
|
||||
from accelerate import cpu_offload
|
||||
else:
|
||||
raise ImportError("Please install accelerate via `pip install accelerate`")
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
models = [
|
||||
self.unet,
|
||||
self.movq,
|
||||
]
|
||||
for cpu_offloaded_model in models:
|
||||
if cpu_offloaded_model is not None:
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
|
||||
# Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2.KandinskyV22Pipeline.enable_model_cpu_offload
|
||||
def enable_model_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
|
||||
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
|
||||
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
|
||||
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
|
||||
"""
|
||||
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
|
||||
from accelerate import cpu_offload_with_hook
|
||||
else:
|
||||
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
if self.device.type != "cpu":
|
||||
self.to("cpu", silence_dtype_warnings=True)
|
||||
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
||||
|
||||
hook = None
|
||||
for cpu_offloaded_model in [self.unet, self.movq]:
|
||||
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
|
||||
|
||||
# We'll offload the last model manually.
|
||||
self.final_offload_hook = hook
|
||||
|
||||
@property
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
|
||||
def _execution_device(self):
|
||||
r"""
|
||||
Returns the device on which the pipeline's models will be executed. After calling
|
||||
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
|
||||
hooks.
|
||||
"""
|
||||
if not hasattr(self.unet, "_hf_hook"):
|
||||
return self.device
|
||||
for module in self.unet.modules():
|
||||
if (
|
||||
hasattr(module, "_hf_hook")
|
||||
and hasattr(module._hf_hook, "execution_device")
|
||||
and module._hf_hook.execution_device is not None
|
||||
):
|
||||
return torch.device(module._hf_hook.execution_device)
|
||||
return self.device
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
image_embeds: Union[torch.FloatTensor, List[torch.FloatTensor]],
|
||||
negative_image_embeds: Union[torch.FloatTensor, List[torch.FloatTensor]],
|
||||
hint: torch.FloatTensor,
|
||||
height: int = 512,
|
||||
width: int = 512,
|
||||
num_inference_steps: int = 100,
|
||||
guidance_scale: float = 4.0,
|
||||
num_images_per_prompt: int = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
):
|
||||
"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`):
|
||||
The prompt or prompts to guide the image generation.
|
||||
hint (`torch.FloatTensor`):
|
||||
The controlnet condition.
|
||||
image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`):
|
||||
The clip image embeddings for text prompt, that will be used to condition the image generation.
|
||||
negative_image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`):
|
||||
The clip image embeddings for negative text prompt, will be used to condition the image generation.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
||||
if `guidance_scale` is less than `1`).
|
||||
height (`int`, *optional*, defaults to 512):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to 512):
|
||||
The width in pixels of the generated image.
|
||||
num_inference_steps (`int`, *optional*, defaults to 100):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, *optional*, defaults to 4.0):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
|
||||
(`np.array`) or `"pt"` (`torch.Tensor`).
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.ImagePipelineOutput`] or `tuple`
|
||||
"""
|
||||
device = self._execution_device
|
||||
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
if isinstance(image_embeds, list):
|
||||
image_embeds = torch.cat(image_embeds, dim=0)
|
||||
if isinstance(negative_image_embeds, list):
|
||||
negative_image_embeds = torch.cat(negative_image_embeds, dim=0)
|
||||
if isinstance(hint, list):
|
||||
hint = torch.cat(hint, dim=0)
|
||||
|
||||
batch_size = image_embeds.shape[0] * num_images_per_prompt
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
hint = hint.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
|
||||
image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to(dtype=self.unet.dtype, device=device)
|
||||
hint = torch.cat([hint, hint], dim=0).to(dtype=self.unet.dtype, device=device)
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps_tensor = self.scheduler.timesteps
|
||||
|
||||
num_channels_latents = self.movq.config.latent_channels
|
||||
|
||||
height, width = downscale_height_and_width(height, width, self.movq_scale_factor)
|
||||
|
||||
# create initial latent
|
||||
latents = self.prepare_latents(
|
||||
(batch_size, num_channels_latents, height, width),
|
||||
image_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
self.scheduler,
|
||||
)
|
||||
|
||||
for i, t in enumerate(self.progress_bar(timesteps_tensor)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
|
||||
added_cond_kwargs = {"image_embeds": image_embeds, "hint": hint}
|
||||
noise_pred = self.unet(
|
||||
sample=latent_model_input,
|
||||
timestep=t,
|
||||
encoder_hidden_states=None,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred, variance_pred = noise_pred.split(latents.shape[1], dim=1)
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
_, variance_pred_text = variance_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
noise_pred = torch.cat([noise_pred, variance_pred_text], dim=1)
|
||||
|
||||
if not (
|
||||
hasattr(self.scheduler.config, "variance_type")
|
||||
and self.scheduler.config.variance_type in ["learned", "learned_range"]
|
||||
):
|
||||
noise_pred, _ = noise_pred.split(latents.shape[1], dim=1)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(
|
||||
noise_pred,
|
||||
t,
|
||||
latents,
|
||||
generator=generator,
|
||||
)[0]
|
||||
# post-processing
|
||||
image = self.movq.decode(latents, force_not_quantize=True)["sample"]
|
||||
|
||||
if output_type not in ["pt", "np", "pil"]:
|
||||
raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}")
|
||||
|
||||
if output_type in ["np", "pil"]:
|
||||
image = image * 0.5 + 0.5
|
||||
image = image.clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return ImagePipelineOutput(images=image)
|
||||
@@ -0,0 +1,434 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from ...models import UNet2DConditionModel, VQModel
|
||||
from ...pipelines import DiffusionPipeline
|
||||
from ...pipelines.pipeline_utils import ImagePipelineOutput
|
||||
from ...schedulers import DDPMScheduler
|
||||
from ...utils import (
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
logging,
|
||||
randn_tensor,
|
||||
replace_example_docstring,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> import numpy as np
|
||||
|
||||
>>> from diffusers import KandinskyV22PriorEmb2EmbPipeline, KandinskyV22ControlnetImg2ImgPipeline
|
||||
>>> from transformers import pipeline
|
||||
>>> from diffusers.utils import load_image
|
||||
|
||||
|
||||
>>> def make_hint(image, depth_estimator):
|
||||
... image = depth_estimator(image)["depth"]
|
||||
... image = np.array(image)
|
||||
... image = image[:, :, None]
|
||||
... image = np.concatenate([image, image, image], axis=2)
|
||||
... detected_map = torch.from_numpy(image).float() / 255.0
|
||||
... hint = detected_map.permute(2, 0, 1)
|
||||
... return hint
|
||||
|
||||
|
||||
>>> depth_estimator = pipeline("depth-estimation")
|
||||
|
||||
>>> pipe_prior = KandinskyV22PriorEmb2EmbPipeline.from_pretrained(
|
||||
... "kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float16
|
||||
... )
|
||||
>>> pipe_prior = pipe_prior.to("cuda")
|
||||
|
||||
>>> pipe = KandinskyV22ControlnetImg2ImgPipeline.from_pretrained(
|
||||
... "kandinsky-community/kandinsky-2-2-controlnet-depth", torch_dtype=torch.float16
|
||||
... )
|
||||
>>> pipe = pipe.to("cuda")
|
||||
|
||||
>>> img = load_image(
|
||||
... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
... "/kandinsky/cat.png"
|
||||
... ).resize((768, 768))
|
||||
|
||||
|
||||
>>> hint = make_hint(img, depth_estimator).unsqueeze(0).half().to("cuda")
|
||||
|
||||
>>> prompt = "A robot, 4k photo"
|
||||
>>> negative_prior_prompt = "lowres, text, error, cropped, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, out of frame, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck, username, watermark, signature"
|
||||
|
||||
>>> generator = torch.Generator(device="cuda").manual_seed(43)
|
||||
|
||||
>>> img_emb = pipe_prior(prompt=prompt, image=img, strength=0.85, generator=generator)
|
||||
>>> negative_emb = pipe_prior(prompt=negative_prior_prompt, image=img, strength=1, generator=generator)
|
||||
|
||||
>>> images = pipe(
|
||||
... image=img,
|
||||
... strength=0.5,
|
||||
... image_embeds=img_emb.image_embeds,
|
||||
... negative_image_embeds=negative_emb.image_embeds,
|
||||
... hint=hint,
|
||||
... num_inference_steps=50,
|
||||
... generator=generator,
|
||||
... height=768,
|
||||
... width=768,
|
||||
... ).images
|
||||
|
||||
>>> images[0].save("robot_cat.png")
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2.downscale_height_and_width
|
||||
def downscale_height_and_width(height, width, scale_factor=8):
|
||||
new_height = height // scale_factor**2
|
||||
if height % scale_factor**2 != 0:
|
||||
new_height += 1
|
||||
new_width = width // scale_factor**2
|
||||
if width % scale_factor**2 != 0:
|
||||
new_width += 1
|
||||
return new_height * scale_factor, new_width * scale_factor
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_img2img.prepare_image
|
||||
def prepare_image(pil_image, w=512, h=512):
|
||||
pil_image = pil_image.resize((w, h), resample=Image.BICUBIC, reducing_gap=1)
|
||||
arr = np.array(pil_image.convert("RGB"))
|
||||
arr = arr.astype(np.float32) / 127.5 - 1
|
||||
arr = np.transpose(arr, [2, 0, 1])
|
||||
image = torch.from_numpy(arr).unsqueeze(0)
|
||||
return image
|
||||
|
||||
|
||||
class KandinskyV22ControlnetImg2ImgPipeline(DiffusionPipeline):
|
||||
"""
|
||||
Pipeline for image-to-image generation using Kandinsky
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
Args:
|
||||
scheduler ([`DDIMScheduler`]):
|
||||
A scheduler to be used in combination with `unet` to generate image latents.
|
||||
unet ([`UNet2DConditionModel`]):
|
||||
Conditional U-Net architecture to denoise the image embedding.
|
||||
movq ([`VQModel`]):
|
||||
MoVQ Decoder to generate the image from the latents.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: DDPMScheduler,
|
||||
movq: VQModel,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
movq=movq,
|
||||
)
|
||||
self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1)
|
||||
|
||||
# Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_img2img.KandinskyImg2ImgPipeline.get_timesteps
|
||||
def get_timesteps(self, num_inference_steps, strength, device):
|
||||
# get the original timestep using init_timestep
|
||||
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
|
||||
|
||||
t_start = max(num_inference_steps - init_timestep, 0)
|
||||
timesteps = self.scheduler.timesteps[t_start:]
|
||||
|
||||
return timesteps, num_inference_steps - t_start
|
||||
|
||||
# Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2_img2img.KandinskyV22Img2ImgPipeline.prepare_latents
|
||||
def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
|
||||
if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
|
||||
raise ValueError(
|
||||
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
|
||||
)
|
||||
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
|
||||
batch_size = batch_size * num_images_per_prompt
|
||||
|
||||
if image.shape[1] == 4:
|
||||
init_latents = image
|
||||
|
||||
else:
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
elif isinstance(generator, list):
|
||||
init_latents = [
|
||||
self.movq.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
|
||||
]
|
||||
init_latents = torch.cat(init_latents, dim=0)
|
||||
else:
|
||||
init_latents = self.movq.encode(image).latent_dist.sample(generator)
|
||||
|
||||
init_latents = self.movq.config.scaling_factor * init_latents
|
||||
|
||||
init_latents = torch.cat([init_latents], dim=0)
|
||||
|
||||
shape = init_latents.shape
|
||||
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
|
||||
# get latents
|
||||
init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
|
||||
|
||||
latents = init_latents
|
||||
|
||||
return latents
|
||||
|
||||
# Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2.KandinskyV22Pipeline.enable_sequential_cpu_offload
|
||||
def enable_sequential_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's
|
||||
models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only
|
||||
when their specific submodule has its `forward` method called.
|
||||
"""
|
||||
if is_accelerate_available():
|
||||
from accelerate import cpu_offload
|
||||
else:
|
||||
raise ImportError("Please install accelerate via `pip install accelerate`")
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
models = [
|
||||
self.unet,
|
||||
self.movq,
|
||||
]
|
||||
for cpu_offloaded_model in models:
|
||||
if cpu_offloaded_model is not None:
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
|
||||
# Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2.KandinskyV22Pipeline.enable_model_cpu_offload
|
||||
def enable_model_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
|
||||
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
|
||||
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
|
||||
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
|
||||
"""
|
||||
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
|
||||
from accelerate import cpu_offload_with_hook
|
||||
else:
|
||||
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
if self.device.type != "cpu":
|
||||
self.to("cpu", silence_dtype_warnings=True)
|
||||
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
||||
|
||||
hook = None
|
||||
for cpu_offloaded_model in [self.unet, self.movq]:
|
||||
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
|
||||
|
||||
# We'll offload the last model manually.
|
||||
self.final_offload_hook = hook
|
||||
|
||||
@property
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
|
||||
def _execution_device(self):
|
||||
r"""
|
||||
Returns the device on which the pipeline's models will be executed. After calling
|
||||
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
|
||||
hooks.
|
||||
"""
|
||||
if not hasattr(self.unet, "_hf_hook"):
|
||||
return self.device
|
||||
for module in self.unet.modules():
|
||||
if (
|
||||
hasattr(module, "_hf_hook")
|
||||
and hasattr(module._hf_hook, "execution_device")
|
||||
and module._hf_hook.execution_device is not None
|
||||
):
|
||||
return torch.device(module._hf_hook.execution_device)
|
||||
return self.device
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
image_embeds: Union[torch.FloatTensor, List[torch.FloatTensor]],
|
||||
image: Union[torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image]],
|
||||
negative_image_embeds: Union[torch.FloatTensor, List[torch.FloatTensor]],
|
||||
hint: torch.FloatTensor,
|
||||
height: int = 512,
|
||||
width: int = 512,
|
||||
num_inference_steps: int = 100,
|
||||
guidance_scale: float = 4.0,
|
||||
strength: float = 0.3,
|
||||
num_images_per_prompt: int = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
):
|
||||
"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`):
|
||||
The clip image embeddings for text prompt, that will be used to condition the image generation.
|
||||
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
|
||||
`Image`, or tensor representing an image batch, that will be used as the starting point for the
|
||||
process. Can also accpet image latents as `image`, if passing latents directly, it will not be encoded
|
||||
again.
|
||||
strength (`float`, *optional*, defaults to 0.8):
|
||||
Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
|
||||
will be used as a starting point, adding more noise to it the larger the `strength`. The number of
|
||||
denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will
|
||||
be maximum and the denoising process will run for the full number of iterations specified in
|
||||
`num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
|
||||
hint (`torch.FloatTensor`):
|
||||
The controlnet condition.
|
||||
negative_image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`):
|
||||
The clip image embeddings for negative text prompt, will be used to condition the image generation.
|
||||
height (`int`, *optional*, defaults to 512):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to 512):
|
||||
The width in pixels of the generated image.
|
||||
num_inference_steps (`int`, *optional*, defaults to 100):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, *optional*, defaults to 4.0):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
|
||||
(`np.array`) or `"pt"` (`torch.Tensor`).
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.ImagePipelineOutput`] or `tuple`
|
||||
"""
|
||||
device = self._execution_device
|
||||
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
if isinstance(image_embeds, list):
|
||||
image_embeds = torch.cat(image_embeds, dim=0)
|
||||
if isinstance(negative_image_embeds, list):
|
||||
negative_image_embeds = torch.cat(negative_image_embeds, dim=0)
|
||||
if isinstance(hint, list):
|
||||
hint = torch.cat(hint, dim=0)
|
||||
|
||||
batch_size = image_embeds.shape[0]
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
hint = hint.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
|
||||
image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to(dtype=self.unet.dtype, device=device)
|
||||
hint = torch.cat([hint, hint], dim=0).to(dtype=self.unet.dtype, device=device)
|
||||
|
||||
if not isinstance(image, list):
|
||||
image = [image]
|
||||
if not all(isinstance(i, (PIL.Image.Image, torch.Tensor)) for i in image):
|
||||
raise ValueError(
|
||||
f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support PIL image and pytorch tensor"
|
||||
)
|
||||
|
||||
image = torch.cat([prepare_image(i, width, height) for i in image], dim=0)
|
||||
image = image.to(dtype=image_embeds.dtype, device=device)
|
||||
|
||||
latents = self.movq.encode(image)["latents"]
|
||||
latents = latents.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
|
||||
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
||||
height, width = downscale_height_and_width(height, width, self.movq_scale_factor)
|
||||
latents = self.prepare_latents(
|
||||
latents, latent_timestep, batch_size, num_images_per_prompt, image_embeds.dtype, device, generator
|
||||
)
|
||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
|
||||
added_cond_kwargs = {"image_embeds": image_embeds, "hint": hint}
|
||||
noise_pred = self.unet(
|
||||
sample=latent_model_input,
|
||||
timestep=t,
|
||||
encoder_hidden_states=None,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred, variance_pred = noise_pred.split(latents.shape[1], dim=1)
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
_, variance_pred_text = variance_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
noise_pred = torch.cat([noise_pred, variance_pred_text], dim=1)
|
||||
|
||||
if not (
|
||||
hasattr(self.scheduler.config, "variance_type")
|
||||
and self.scheduler.config.variance_type in ["learned", "learned_range"]
|
||||
):
|
||||
noise_pred, _ = noise_pred.split(latents.shape[1], dim=1)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
|
||||
latents = self.scheduler.step(
|
||||
noise_pred,
|
||||
t,
|
||||
latents,
|
||||
generator=generator,
|
||||
)[0]
|
||||
|
||||
# post-processing
|
||||
image = self.movq.decode(latents, force_not_quantize=True)["sample"]
|
||||
|
||||
if output_type not in ["pt", "np", "pil"]:
|
||||
raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}")
|
||||
|
||||
if output_type in ["np", "pil"]:
|
||||
image = image * 0.5 + 0.5
|
||||
image = image.clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return ImagePipelineOutput(images=image)
|
||||
@@ -0,0 +1,398 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from ...models import UNet2DConditionModel, VQModel
|
||||
from ...pipelines import DiffusionPipeline
|
||||
from ...pipelines.pipeline_utils import ImagePipelineOutput
|
||||
from ...schedulers import DDPMScheduler
|
||||
from ...utils import (
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
logging,
|
||||
randn_tensor,
|
||||
replace_example_docstring,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> from diffusers import KandinskyV22Img2ImgPipeline, KandinskyV22PriorPipeline
|
||||
>>> from diffusers.utils import load_image
|
||||
>>> import torch
|
||||
|
||||
>>> pipe_prior = KandinskyV22PriorPipeline.from_pretrained(
|
||||
... "kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float16
|
||||
... )
|
||||
>>> pipe_prior.to("cuda")
|
||||
|
||||
>>> prompt = "A red cartoon frog, 4k"
|
||||
>>> image_emb, zero_image_emb = pipe_prior(prompt, return_dict=False)
|
||||
|
||||
>>> pipe = KandinskyV22Img2ImgPipeline.from_pretrained(
|
||||
... "kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16
|
||||
... )
|
||||
>>> pipe.to("cuda")
|
||||
|
||||
>>> init_image = load_image(
|
||||
... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
... "/kandinsky/frog.png"
|
||||
... )
|
||||
|
||||
>>> image = pipe(
|
||||
... image=init_image,
|
||||
... image_embeds=image_emb,
|
||||
... negative_image_embeds=zero_image_emb,
|
||||
... height=768,
|
||||
... width=768,
|
||||
... num_inference_steps=100,
|
||||
... strength=0.2,
|
||||
... ).images
|
||||
|
||||
>>> image[0].save("red_frog.png")
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2.downscale_height_and_width
|
||||
def downscale_height_and_width(height, width, scale_factor=8):
|
||||
new_height = height // scale_factor**2
|
||||
if height % scale_factor**2 != 0:
|
||||
new_height += 1
|
||||
new_width = width // scale_factor**2
|
||||
if width % scale_factor**2 != 0:
|
||||
new_width += 1
|
||||
return new_height * scale_factor, new_width * scale_factor
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_img2img.prepare_image
|
||||
def prepare_image(pil_image, w=512, h=512):
|
||||
pil_image = pil_image.resize((w, h), resample=Image.BICUBIC, reducing_gap=1)
|
||||
arr = np.array(pil_image.convert("RGB"))
|
||||
arr = arr.astype(np.float32) / 127.5 - 1
|
||||
arr = np.transpose(arr, [2, 0, 1])
|
||||
image = torch.from_numpy(arr).unsqueeze(0)
|
||||
return image
|
||||
|
||||
|
||||
class KandinskyV22Img2ImgPipeline(DiffusionPipeline):
|
||||
"""
|
||||
Pipeline for image-to-image generation using Kandinsky
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
Args:
|
||||
scheduler ([`DDIMScheduler`]):
|
||||
A scheduler to be used in combination with `unet` to generate image latents.
|
||||
unet ([`UNet2DConditionModel`]):
|
||||
Conditional U-Net architecture to denoise the image embedding.
|
||||
movq ([`VQModel`]):
|
||||
MoVQ Decoder to generate the image from the latents.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: DDPMScheduler,
|
||||
movq: VQModel,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
movq=movq,
|
||||
)
|
||||
self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1)
|
||||
|
||||
# Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_img2img.KandinskyImg2ImgPipeline.get_timesteps
|
||||
def get_timesteps(self, num_inference_steps, strength, device):
|
||||
# get the original timestep using init_timestep
|
||||
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
|
||||
|
||||
t_start = max(num_inference_steps - init_timestep, 0)
|
||||
timesteps = self.scheduler.timesteps[t_start:]
|
||||
|
||||
return timesteps, num_inference_steps - t_start
|
||||
|
||||
def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
|
||||
if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
|
||||
raise ValueError(
|
||||
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
|
||||
)
|
||||
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
|
||||
batch_size = batch_size * num_images_per_prompt
|
||||
|
||||
if image.shape[1] == 4:
|
||||
init_latents = image
|
||||
|
||||
else:
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
elif isinstance(generator, list):
|
||||
init_latents = [
|
||||
self.movq.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
|
||||
]
|
||||
init_latents = torch.cat(init_latents, dim=0)
|
||||
else:
|
||||
init_latents = self.movq.encode(image).latent_dist.sample(generator)
|
||||
|
||||
init_latents = self.movq.config.scaling_factor * init_latents
|
||||
|
||||
init_latents = torch.cat([init_latents], dim=0)
|
||||
|
||||
shape = init_latents.shape
|
||||
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
|
||||
# get latents
|
||||
init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
|
||||
|
||||
latents = init_latents
|
||||
|
||||
return latents
|
||||
|
||||
# Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2.KandinskyV22Pipeline.enable_sequential_cpu_offload
|
||||
def enable_sequential_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's
|
||||
models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only
|
||||
when their specific submodule has its `forward` method called.
|
||||
"""
|
||||
if is_accelerate_available():
|
||||
from accelerate import cpu_offload
|
||||
else:
|
||||
raise ImportError("Please install accelerate via `pip install accelerate`")
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
models = [
|
||||
self.unet,
|
||||
self.movq,
|
||||
]
|
||||
for cpu_offloaded_model in models:
|
||||
if cpu_offloaded_model is not None:
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
|
||||
# Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2.KandinskyV22Pipeline.enable_model_cpu_offload
|
||||
def enable_model_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
|
||||
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
|
||||
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
|
||||
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
|
||||
"""
|
||||
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
|
||||
from accelerate import cpu_offload_with_hook
|
||||
else:
|
||||
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
if self.device.type != "cpu":
|
||||
self.to("cpu", silence_dtype_warnings=True)
|
||||
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
||||
|
||||
hook = None
|
||||
for cpu_offloaded_model in [self.unet, self.movq]:
|
||||
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
|
||||
|
||||
# We'll offload the last model manually.
|
||||
self.final_offload_hook = hook
|
||||
|
||||
@property
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
|
||||
def _execution_device(self):
|
||||
r"""
|
||||
Returns the device on which the pipeline's models will be executed. After calling
|
||||
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
|
||||
hooks.
|
||||
"""
|
||||
if not hasattr(self.unet, "_hf_hook"):
|
||||
return self.device
|
||||
for module in self.unet.modules():
|
||||
if (
|
||||
hasattr(module, "_hf_hook")
|
||||
and hasattr(module._hf_hook, "execution_device")
|
||||
and module._hf_hook.execution_device is not None
|
||||
):
|
||||
return torch.device(module._hf_hook.execution_device)
|
||||
return self.device
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
image_embeds: Union[torch.FloatTensor, List[torch.FloatTensor]],
|
||||
image: Union[torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image]],
|
||||
negative_image_embeds: Union[torch.FloatTensor, List[torch.FloatTensor]],
|
||||
height: int = 512,
|
||||
width: int = 512,
|
||||
num_inference_steps: int = 100,
|
||||
guidance_scale: float = 4.0,
|
||||
strength: float = 0.3,
|
||||
num_images_per_prompt: int = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
):
|
||||
"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`):
|
||||
The clip image embeddings for text prompt, that will be used to condition the image generation.
|
||||
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
|
||||
`Image`, or tensor representing an image batch, that will be used as the starting point for the
|
||||
process. Can also accpet image latents as `image`, if passing latents directly, it will not be encoded
|
||||
again.
|
||||
strength (`float`, *optional*, defaults to 0.8):
|
||||
Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
|
||||
will be used as a starting point, adding more noise to it the larger the `strength`. The number of
|
||||
denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will
|
||||
be maximum and the denoising process will run for the full number of iterations specified in
|
||||
`num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
|
||||
negative_image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`):
|
||||
The clip image embeddings for negative text prompt, will be used to condition the image generation.
|
||||
height (`int`, *optional*, defaults to 512):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to 512):
|
||||
The width in pixels of the generated image.
|
||||
num_inference_steps (`int`, *optional*, defaults to 100):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, *optional*, defaults to 4.0):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
|
||||
(`np.array`) or `"pt"` (`torch.Tensor`).
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.ImagePipelineOutput`] or `tuple`
|
||||
"""
|
||||
device = self._execution_device
|
||||
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
if isinstance(image_embeds, list):
|
||||
image_embeds = torch.cat(image_embeds, dim=0)
|
||||
batch_size = image_embeds.shape[0]
|
||||
if isinstance(negative_image_embeds, list):
|
||||
negative_image_embeds = torch.cat(negative_image_embeds, dim=0)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
|
||||
image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to(dtype=self.unet.dtype, device=device)
|
||||
|
||||
if not isinstance(image, list):
|
||||
image = [image]
|
||||
if not all(isinstance(i, (PIL.Image.Image, torch.Tensor)) for i in image):
|
||||
raise ValueError(
|
||||
f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support PIL image and pytorch tensor"
|
||||
)
|
||||
|
||||
image = torch.cat([prepare_image(i, width, height) for i in image], dim=0)
|
||||
image = image.to(dtype=image_embeds.dtype, device=device)
|
||||
|
||||
latents = self.movq.encode(image)["latents"]
|
||||
latents = latents.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
|
||||
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
||||
height, width = downscale_height_and_width(height, width, self.movq_scale_factor)
|
||||
latents = self.prepare_latents(
|
||||
latents, latent_timestep, batch_size, num_images_per_prompt, image_embeds.dtype, device, generator
|
||||
)
|
||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
|
||||
added_cond_kwargs = {"image_embeds": image_embeds}
|
||||
noise_pred = self.unet(
|
||||
sample=latent_model_input,
|
||||
timestep=t,
|
||||
encoder_hidden_states=None,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred, variance_pred = noise_pred.split(latents.shape[1], dim=1)
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
_, variance_pred_text = variance_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
noise_pred = torch.cat([noise_pred, variance_pred_text], dim=1)
|
||||
|
||||
if not (
|
||||
hasattr(self.scheduler.config, "variance_type")
|
||||
and self.scheduler.config.variance_type in ["learned", "learned_range"]
|
||||
):
|
||||
noise_pred, _ = noise_pred.split(latents.shape[1], dim=1)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(
|
||||
noise_pred,
|
||||
t,
|
||||
latents,
|
||||
generator=generator,
|
||||
)[0]
|
||||
|
||||
# post-processing
|
||||
image = self.movq.decode(latents, force_not_quantize=True)["sample"]
|
||||
|
||||
if output_type not in ["pt", "np", "pil"]:
|
||||
raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}")
|
||||
|
||||
if output_type in ["np", "pil"]:
|
||||
image = image * 0.5 + 0.5
|
||||
image = image.clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return ImagePipelineOutput(images=image)
|
||||
@@ -0,0 +1,531 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from PIL import Image
|
||||
|
||||
from ...models import UNet2DConditionModel, VQModel
|
||||
from ...pipelines import DiffusionPipeline
|
||||
from ...pipelines.pipeline_utils import ImagePipelineOutput
|
||||
from ...schedulers import DDPMScheduler
|
||||
from ...utils import (
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
logging,
|
||||
randn_tensor,
|
||||
replace_example_docstring,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> from diffusers import KandinskyV22InpaintPipeline, KandinskyV22PriorPipeline
|
||||
>>> from diffusers.utils import load_image
|
||||
>>> import torch
|
||||
>>> import numpy as np
|
||||
|
||||
>>> pipe_prior = KandinskyV22PriorPipeline.from_pretrained(
|
||||
... "kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float16
|
||||
... )
|
||||
>>> pipe_prior.to("cuda")
|
||||
|
||||
>>> prompt = "a hat"
|
||||
>>> image_emb, zero_image_emb = pipe_prior(prompt, return_dict=False)
|
||||
|
||||
>>> pipe = KandinskyV22InpaintPipeline.from_pretrained(
|
||||
... "kandinsky-community/kandinsky-2-2-decoder-inpaint", torch_dtype=torch.float16
|
||||
... )
|
||||
>>> pipe.to("cuda")
|
||||
|
||||
>>> init_image = load_image(
|
||||
... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
... "/kandinsky/cat.png"
|
||||
... )
|
||||
|
||||
>>> mask = np.ones((768, 768), dtype=np.float32)
|
||||
>>> mask[:250, 250:-250] = 0
|
||||
|
||||
>>> out = pipe(
|
||||
... image=init_image,
|
||||
... mask_image=mask,
|
||||
... image_embeds=image_emb,
|
||||
... negative_image_embeds=zero_image_emb,
|
||||
... height=768,
|
||||
... width=768,
|
||||
... num_inference_steps=50,
|
||||
... )
|
||||
|
||||
>>> image = out.images[0]
|
||||
>>> image.save("cat_with_hat.png")
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2.downscale_height_and_width
|
||||
def downscale_height_and_width(height, width, scale_factor=8):
|
||||
new_height = height // scale_factor**2
|
||||
if height % scale_factor**2 != 0:
|
||||
new_height += 1
|
||||
new_width = width // scale_factor**2
|
||||
if width % scale_factor**2 != 0:
|
||||
new_width += 1
|
||||
return new_height * scale_factor, new_width * scale_factor
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_inpaint.prepare_mask
|
||||
def prepare_mask(masks):
|
||||
prepared_masks = []
|
||||
for mask in masks:
|
||||
old_mask = deepcopy(mask)
|
||||
for i in range(mask.shape[1]):
|
||||
for j in range(mask.shape[2]):
|
||||
if old_mask[0][i][j] == 1:
|
||||
continue
|
||||
if i != 0:
|
||||
mask[:, i - 1, j] = 0
|
||||
if j != 0:
|
||||
mask[:, i, j - 1] = 0
|
||||
if i != 0 and j != 0:
|
||||
mask[:, i - 1, j - 1] = 0
|
||||
if i != mask.shape[1] - 1:
|
||||
mask[:, i + 1, j] = 0
|
||||
if j != mask.shape[2] - 1:
|
||||
mask[:, i, j + 1] = 0
|
||||
if i != mask.shape[1] - 1 and j != mask.shape[2] - 1:
|
||||
mask[:, i + 1, j + 1] = 0
|
||||
prepared_masks.append(mask)
|
||||
return torch.stack(prepared_masks, dim=0)
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_inpaint.prepare_mask_and_masked_image
|
||||
def prepare_mask_and_masked_image(image, mask, height, width):
|
||||
r"""
|
||||
Prepares a pair (mask, image) to be consumed by the Kandinsky inpaint pipeline. This means that those inputs will
|
||||
be converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for
|
||||
the ``image`` and ``1`` for the ``mask``.
|
||||
|
||||
The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be
|
||||
binarized (``mask > 0.5``) and cast to ``torch.float32`` too.
|
||||
|
||||
Args:
|
||||
image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint.
|
||||
It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width``
|
||||
``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``.
|
||||
mask (_type_): The mask to apply to the image, i.e. regions to inpaint.
|
||||
It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width``
|
||||
``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``.
|
||||
height (`int`, *optional*, defaults to 512):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to 512):
|
||||
The width in pixels of the generated image.
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask
|
||||
should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions.
|
||||
TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not
|
||||
(ot the other way around).
|
||||
|
||||
Returns:
|
||||
tuple[torch.Tensor]: The pair (mask, image) as ``torch.Tensor`` with 4
|
||||
dimensions: ``batch x channels x height x width``.
|
||||
"""
|
||||
|
||||
if image is None:
|
||||
raise ValueError("`image` input cannot be undefined.")
|
||||
|
||||
if mask is None:
|
||||
raise ValueError("`mask_image` input cannot be undefined.")
|
||||
|
||||
if isinstance(image, torch.Tensor):
|
||||
if not isinstance(mask, torch.Tensor):
|
||||
raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not")
|
||||
|
||||
# Batch single image
|
||||
if image.ndim == 3:
|
||||
assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)"
|
||||
image = image.unsqueeze(0)
|
||||
|
||||
# Batch and add channel dim for single mask
|
||||
if mask.ndim == 2:
|
||||
mask = mask.unsqueeze(0).unsqueeze(0)
|
||||
|
||||
# Batch single mask or add channel dim
|
||||
if mask.ndim == 3:
|
||||
# Single batched mask, no channel dim or single mask not batched but channel dim
|
||||
if mask.shape[0] == 1:
|
||||
mask = mask.unsqueeze(0)
|
||||
|
||||
# Batched masks no channel dim
|
||||
else:
|
||||
mask = mask.unsqueeze(1)
|
||||
|
||||
assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions"
|
||||
assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions"
|
||||
assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size"
|
||||
|
||||
# Check image is in [-1, 1]
|
||||
if image.min() < -1 or image.max() > 1:
|
||||
raise ValueError("Image should be in [-1, 1] range")
|
||||
|
||||
# Check mask is in [0, 1]
|
||||
if mask.min() < 0 or mask.max() > 1:
|
||||
raise ValueError("Mask should be in [0, 1] range")
|
||||
|
||||
# Binarize mask
|
||||
mask[mask < 0.5] = 0
|
||||
mask[mask >= 0.5] = 1
|
||||
|
||||
# Image as float32
|
||||
image = image.to(dtype=torch.float32)
|
||||
elif isinstance(mask, torch.Tensor):
|
||||
raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not")
|
||||
else:
|
||||
# preprocess image
|
||||
if isinstance(image, (PIL.Image.Image, np.ndarray)):
|
||||
image = [image]
|
||||
|
||||
if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
|
||||
# resize all images w.r.t passed height an width
|
||||
image = [i.resize((width, height), resample=Image.BICUBIC, reducing_gap=1) for i in image]
|
||||
image = [np.array(i.convert("RGB"))[None, :] for i in image]
|
||||
image = np.concatenate(image, axis=0)
|
||||
elif isinstance(image, list) and isinstance(image[0], np.ndarray):
|
||||
image = np.concatenate([i[None, :] for i in image], axis=0)
|
||||
|
||||
image = image.transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
|
||||
|
||||
# preprocess mask
|
||||
if isinstance(mask, (PIL.Image.Image, np.ndarray)):
|
||||
mask = [mask]
|
||||
|
||||
if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image):
|
||||
mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask]
|
||||
mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0)
|
||||
mask = mask.astype(np.float32) / 255.0
|
||||
elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
|
||||
mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
|
||||
|
||||
mask[mask < 0.5] = 0
|
||||
mask[mask >= 0.5] = 1
|
||||
mask = torch.from_numpy(mask)
|
||||
|
||||
return mask, image
|
||||
|
||||
|
||||
class KandinskyV22InpaintPipeline(DiffusionPipeline):
|
||||
"""
|
||||
Pipeline for text-guided image inpainting using Kandinsky2.1
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
Args:
|
||||
scheduler ([`DDIMScheduler`]):
|
||||
A scheduler to be used in combination with `unet` to generate image latents.
|
||||
unet ([`UNet2DConditionModel`]):
|
||||
Conditional U-Net architecture to denoise the image embedding.
|
||||
movq ([`VQModel`]):
|
||||
MoVQ Decoder to generate the image from the latents.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: DDPMScheduler,
|
||||
movq: VQModel,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
movq=movq,
|
||||
)
|
||||
self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1)
|
||||
|
||||
# Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents
|
||||
def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
if latents.shape != shape:
|
||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
||||
latents = latents.to(device)
|
||||
|
||||
latents = latents * scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
# Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2.KandinskyV22Pipeline.enable_sequential_cpu_offload
|
||||
def enable_sequential_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's
|
||||
models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only
|
||||
when their specific submodule has its `forward` method called.
|
||||
"""
|
||||
if is_accelerate_available():
|
||||
from accelerate import cpu_offload
|
||||
else:
|
||||
raise ImportError("Please install accelerate via `pip install accelerate`")
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
models = [
|
||||
self.unet,
|
||||
self.movq,
|
||||
]
|
||||
for cpu_offloaded_model in models:
|
||||
if cpu_offloaded_model is not None:
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
|
||||
# Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2.KandinskyV22Pipeline.enable_model_cpu_offload
|
||||
def enable_model_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
|
||||
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
|
||||
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
|
||||
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
|
||||
"""
|
||||
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
|
||||
from accelerate import cpu_offload_with_hook
|
||||
else:
|
||||
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
if self.device.type != "cpu":
|
||||
self.to("cpu", silence_dtype_warnings=True)
|
||||
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
||||
|
||||
hook = None
|
||||
for cpu_offloaded_model in [self.unet, self.movq]:
|
||||
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
|
||||
|
||||
# We'll offload the last model manually.
|
||||
self.final_offload_hook = hook
|
||||
|
||||
@property
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
|
||||
def _execution_device(self):
|
||||
r"""
|
||||
Returns the device on which the pipeline's models will be executed. After calling
|
||||
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
|
||||
hooks.
|
||||
"""
|
||||
if not hasattr(self.unet, "_hf_hook"):
|
||||
return self.device
|
||||
for module in self.unet.modules():
|
||||
if (
|
||||
hasattr(module, "_hf_hook")
|
||||
and hasattr(module._hf_hook, "execution_device")
|
||||
and module._hf_hook.execution_device is not None
|
||||
):
|
||||
return torch.device(module._hf_hook.execution_device)
|
||||
return self.device
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
image_embeds: Union[torch.FloatTensor, List[torch.FloatTensor]],
|
||||
image: Union[torch.FloatTensor, PIL.Image.Image],
|
||||
mask_image: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray],
|
||||
negative_image_embeds: Union[torch.FloatTensor, List[torch.FloatTensor]],
|
||||
height: int = 512,
|
||||
width: int = 512,
|
||||
num_inference_steps: int = 100,
|
||||
guidance_scale: float = 4.0,
|
||||
num_images_per_prompt: int = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
Function invoked when calling the pipeline for generation.
|
||||
image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`):
|
||||
The clip image embeddings for text prompt, that will be used to condition the image generation.
|
||||
image (`PIL.Image.Image`):
|
||||
`Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will
|
||||
be masked out with `mask_image` and repainted according to `prompt`.
|
||||
mask_image (`np.array`):
|
||||
Tensor representing an image batch, to mask `image`. Black pixels in the mask will be repainted, while
|
||||
white pixels will be preserved. If `mask_image` is a PIL image, it will be converted to a single
|
||||
channel (luminance) before use. If it's a tensor, it should contain one color channel (L) instead of 3,
|
||||
so the expected shape would be `(B, H, W, 1)`.
|
||||
negative_image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`):
|
||||
The clip image embeddings for negative text prompt, will be used to condition the image generation.
|
||||
height (`int`, *optional*, defaults to 512):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to 512):
|
||||
The width in pixels of the generated image.
|
||||
num_inference_steps (`int`, *optional*, defaults to 100):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, *optional*, defaults to 4.0):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
|
||||
(`np.array`) or `"pt"` (`torch.Tensor`).
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.ImagePipelineOutput`] or `tuple`
|
||||
"""
|
||||
device = self._execution_device
|
||||
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
if isinstance(image_embeds, list):
|
||||
image_embeds = torch.cat(image_embeds, dim=0)
|
||||
batch_size = image_embeds.shape[0] * num_images_per_prompt
|
||||
if isinstance(negative_image_embeds, list):
|
||||
negative_image_embeds = torch.cat(negative_image_embeds, dim=0)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
|
||||
image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to(dtype=self.unet.dtype, device=device)
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps_tensor = self.scheduler.timesteps
|
||||
|
||||
# preprocess image and mask
|
||||
mask_image, image = prepare_mask_and_masked_image(image, mask_image, height, width)
|
||||
|
||||
image = image.to(dtype=image_embeds.dtype, device=device)
|
||||
image = self.movq.encode(image)["latents"]
|
||||
|
||||
mask_image = mask_image.to(dtype=image_embeds.dtype, device=device)
|
||||
|
||||
image_shape = tuple(image.shape[-2:])
|
||||
mask_image = F.interpolate(
|
||||
mask_image,
|
||||
image_shape,
|
||||
mode="nearest",
|
||||
)
|
||||
mask_image = prepare_mask(mask_image)
|
||||
masked_image = image * mask_image
|
||||
|
||||
mask_image = mask_image.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
masked_image = masked_image.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
if do_classifier_free_guidance:
|
||||
mask_image = mask_image.repeat(2, 1, 1, 1)
|
||||
masked_image = masked_image.repeat(2, 1, 1, 1)
|
||||
|
||||
num_channels_latents = self.movq.config.latent_channels
|
||||
|
||||
height, width = downscale_height_and_width(height, width, self.movq_scale_factor)
|
||||
|
||||
# create initial latent
|
||||
latents = self.prepare_latents(
|
||||
(batch_size, num_channels_latents, height, width),
|
||||
image_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
self.scheduler,
|
||||
)
|
||||
noise = torch.clone(latents)
|
||||
for i, t in enumerate(self.progress_bar(timesteps_tensor)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = torch.cat([latent_model_input, masked_image, mask_image], dim=1)
|
||||
|
||||
added_cond_kwargs = {"image_embeds": image_embeds}
|
||||
noise_pred = self.unet(
|
||||
sample=latent_model_input,
|
||||
timestep=t,
|
||||
encoder_hidden_states=None,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred, variance_pred = noise_pred.split(latents.shape[1], dim=1)
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
_, variance_pred_text = variance_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
noise_pred = torch.cat([noise_pred, variance_pred_text], dim=1)
|
||||
|
||||
if not (
|
||||
hasattr(self.scheduler.config, "variance_type")
|
||||
and self.scheduler.config.variance_type in ["learned", "learned_range"]
|
||||
):
|
||||
noise_pred, _ = noise_pred.split(latents.shape[1], dim=1)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(
|
||||
noise_pred,
|
||||
t,
|
||||
latents,
|
||||
generator=generator,
|
||||
)[0]
|
||||
init_latents_proper = image[:1]
|
||||
init_mask = mask_image[:1]
|
||||
|
||||
if i < len(timesteps_tensor) - 1:
|
||||
noise_timestep = timesteps_tensor[i + 1]
|
||||
init_latents_proper = self.scheduler.add_noise(
|
||||
init_latents_proper, noise, torch.tensor([noise_timestep])
|
||||
)
|
||||
|
||||
latents = init_mask * init_latents_proper + (1 - init_mask) * latents
|
||||
# post-processing
|
||||
latents = mask_image[:1] * image[:1] + (1 - mask_image[:1]) * latents
|
||||
image = self.movq.decode(latents, force_not_quantize=True)["sample"]
|
||||
|
||||
if output_type not in ["pt", "np", "pil"]:
|
||||
raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}")
|
||||
|
||||
if output_type in ["np", "pil"]:
|
||||
image = image * 0.5 + 0.5
|
||||
image = image.clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return ImagePipelineOutput(images=image)
|
||||
@@ -0,0 +1,541 @@
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import PIL
|
||||
import torch
|
||||
from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection
|
||||
|
||||
from ...models import PriorTransformer
|
||||
from ...pipelines import DiffusionPipeline
|
||||
from ...schedulers import UnCLIPScheduler
|
||||
from ...utils import (
|
||||
is_accelerate_available,
|
||||
logging,
|
||||
randn_tensor,
|
||||
replace_example_docstring,
|
||||
)
|
||||
from ..kandinsky import KandinskyPriorPipelineOutput
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> from diffusers import KandinskyV22Pipeline, KandinskyV22PriorPipeline
|
||||
>>> import torch
|
||||
|
||||
>>> pipe_prior = KandinskyV22PriorPipeline.from_pretrained("kandinsky-community/kandinsky-2-2-prior")
|
||||
>>> pipe_prior.to("cuda")
|
||||
>>> prompt = "red cat, 4k photo"
|
||||
>>> image_emb, negative_image_emb = pipe_prior(prompt).to_tuple()
|
||||
|
||||
>>> pipe = KandinskyV22Pipeline.from_pretrained("kandinsky-community/kandinsky-2-2-decoder")
|
||||
>>> pipe.to("cuda")
|
||||
>>> image = pipe(
|
||||
... image_embeds=image_emb,
|
||||
... negative_image_embeds=negative_image_emb,
|
||||
... height=768,
|
||||
... width=768,
|
||||
... num_inference_steps=50,
|
||||
... ).images
|
||||
>>> image[0].save("cat.png")
|
||||
```
|
||||
"""
|
||||
|
||||
EXAMPLE_INTERPOLATE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> from diffusers import KandinskyV22PriorPipeline, KandinskyV22Pipeline
|
||||
>>> from diffusers.utils import load_image
|
||||
>>> import PIL
|
||||
>>> import torch
|
||||
>>> from torchvision import transforms
|
||||
|
||||
>>> pipe_prior = KandinskyV22PriorPipeline.from_pretrained(
|
||||
... "kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float16
|
||||
... )
|
||||
>>> pipe_prior.to("cuda")
|
||||
>>> img1 = load_image(
|
||||
... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
... "/kandinsky/cat.png"
|
||||
... )
|
||||
>>> img2 = load_image(
|
||||
... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
... "/kandinsky/starry_night.jpeg"
|
||||
... )
|
||||
>>> images_texts = ["a cat", img1, img2]
|
||||
>>> weights = [0.3, 0.3, 0.4]
|
||||
>>> out = pipe_prior.interpolate(images_texts, weights)
|
||||
>>> pipe = KandinskyV22Pipeline.from_pretrained(
|
||||
... "kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16
|
||||
... )
|
||||
>>> pipe.to("cuda")
|
||||
>>> image = pipe(
|
||||
... image_embeds=out.image_embeds,
|
||||
... negative_image_embeds=out.negative_image_embeds,
|
||||
... height=768,
|
||||
... width=768,
|
||||
... num_inference_steps=50,
|
||||
... ).images[0]
|
||||
>>> image.save("starry_cat.png")
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
class KandinskyV22PriorPipeline(DiffusionPipeline):
|
||||
"""
|
||||
Pipeline for generating image prior for Kandinsky
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
Args:
|
||||
prior ([`PriorTransformer`]):
|
||||
The canonincal unCLIP prior to approximate the image embedding from the text embedding.
|
||||
image_encoder ([`CLIPVisionModelWithProjection`]):
|
||||
Frozen image-encoder.
|
||||
text_encoder ([`CLIPTextModelWithProjection`]):
|
||||
Frozen text-encoder.
|
||||
tokenizer (`CLIPTokenizer`):
|
||||
Tokenizer of class
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
scheduler ([`UnCLIPScheduler`]):
|
||||
A scheduler to be used in combination with `prior` to generate image embedding.
|
||||
image_processor ([`CLIPImageProcessor`]):
|
||||
A image_processor to be used to preprocess image from clip.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prior: PriorTransformer,
|
||||
image_encoder: CLIPVisionModelWithProjection,
|
||||
text_encoder: CLIPTextModelWithProjection,
|
||||
tokenizer: CLIPTokenizer,
|
||||
scheduler: UnCLIPScheduler,
|
||||
image_processor: CLIPImageProcessor,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
prior=prior,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
scheduler=scheduler,
|
||||
image_encoder=image_encoder,
|
||||
image_processor=image_processor,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_INTERPOLATE_DOC_STRING)
|
||||
def interpolate(
|
||||
self,
|
||||
images_and_prompts: List[Union[str, PIL.Image.Image, torch.FloatTensor]],
|
||||
weights: List[float],
|
||||
num_images_per_prompt: int = 1,
|
||||
num_inference_steps: int = 25,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
negative_prior_prompt: Optional[str] = None,
|
||||
negative_prompt: Union[str] = "",
|
||||
guidance_scale: float = 4.0,
|
||||
device=None,
|
||||
):
|
||||
"""
|
||||
Function invoked when using the prior pipeline for interpolation.
|
||||
|
||||
Args:
|
||||
images_and_prompts (`List[Union[str, PIL.Image.Image, torch.FloatTensor]]`):
|
||||
list of prompts and images to guide the image generation.
|
||||
weights: (`List[float]`):
|
||||
list of weights for each condition in `images_and_prompts`
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
num_inference_steps (`int`, *optional*, defaults to 100):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
negative_prior_prompt (`str`, *optional*):
|
||||
The prompt not to guide the prior diffusion process. Ignored when not using guidance (i.e., ignored if
|
||||
`guidance_scale` is less than `1`).
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt not to guide the image generation. Ignored when not using guidance (i.e., ignored if
|
||||
`guidance_scale` is less than `1`).
|
||||
guidance_scale (`float`, *optional*, defaults to 4.0):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`KandinskyPriorPipelineOutput`] or `tuple`
|
||||
"""
|
||||
|
||||
device = device or self.device
|
||||
|
||||
if len(images_and_prompts) != len(weights):
|
||||
raise ValueError(
|
||||
f"`images_and_prompts` contains {len(images_and_prompts)} items and `weights` contains {len(weights)} items - they should be lists of same length"
|
||||
)
|
||||
|
||||
image_embeddings = []
|
||||
for cond, weight in zip(images_and_prompts, weights):
|
||||
if isinstance(cond, str):
|
||||
image_emb = self(
|
||||
cond,
|
||||
num_inference_steps=num_inference_steps,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
generator=generator,
|
||||
latents=latents,
|
||||
negative_prompt=negative_prior_prompt,
|
||||
guidance_scale=guidance_scale,
|
||||
).image_embeds.unsqueeze(0)
|
||||
|
||||
elif isinstance(cond, (PIL.Image.Image, torch.Tensor)):
|
||||
if isinstance(cond, PIL.Image.Image):
|
||||
cond = (
|
||||
self.image_processor(cond, return_tensors="pt")
|
||||
.pixel_values[0]
|
||||
.unsqueeze(0)
|
||||
.to(dtype=self.image_encoder.dtype, device=device)
|
||||
)
|
||||
|
||||
image_emb = self.image_encoder(cond)["image_embeds"].repeat(num_images_per_prompt, 1).unsqueeze(0)
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"`images_and_prompts` can only contains elements to be of type `str`, `PIL.Image.Image` or `torch.Tensor` but is {type(cond)}"
|
||||
)
|
||||
|
||||
image_embeddings.append(image_emb * weight)
|
||||
|
||||
image_emb = torch.cat(image_embeddings).sum(dim=0)
|
||||
|
||||
out_zero = self(
|
||||
negative_prompt,
|
||||
num_inference_steps=num_inference_steps,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
generator=generator,
|
||||
latents=latents,
|
||||
negative_prompt=negative_prior_prompt,
|
||||
guidance_scale=guidance_scale,
|
||||
)
|
||||
zero_image_emb = out_zero.negative_image_embeds if negative_prompt == "" else out_zero.image_embeds
|
||||
|
||||
return KandinskyPriorPipelineOutput(image_embeds=image_emb, negative_image_embeds=zero_image_emb)
|
||||
|
||||
# Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents
|
||||
def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
if latents.shape != shape:
|
||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
||||
latents = latents.to(device)
|
||||
|
||||
latents = latents * scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
# Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_prior.KandinskyPriorPipeline.get_zero_embed
|
||||
def get_zero_embed(self, batch_size=1, device=None):
|
||||
device = device or self.device
|
||||
zero_img = torch.zeros(1, 3, self.image_encoder.config.image_size, self.image_encoder.config.image_size).to(
|
||||
device=device, dtype=self.image_encoder.dtype
|
||||
)
|
||||
zero_image_emb = self.image_encoder(zero_img)["image_embeds"]
|
||||
zero_image_emb = zero_image_emb.repeat(batch_size, 1)
|
||||
return zero_image_emb
|
||||
|
||||
# Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_prior.KandinskyPriorPipeline.enable_sequential_cpu_offload
|
||||
def enable_sequential_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's
|
||||
models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only
|
||||
when their specific submodule has its `forward` method called.
|
||||
"""
|
||||
if is_accelerate_available():
|
||||
from accelerate import cpu_offload
|
||||
else:
|
||||
raise ImportError("Please install accelerate via `pip install accelerate`")
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
models = [
|
||||
self.image_encoder,
|
||||
self.text_encoder,
|
||||
]
|
||||
for cpu_offloaded_model in models:
|
||||
if cpu_offloaded_model is not None:
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
|
||||
@property
|
||||
# Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_prior.KandinskyPriorPipeline._execution_device
|
||||
def _execution_device(self):
|
||||
r"""
|
||||
Returns the device on which the pipeline's models will be executed. After calling
|
||||
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
|
||||
hooks.
|
||||
"""
|
||||
if self.device != torch.device("meta") or not hasattr(self.text_encoder, "_hf_hook"):
|
||||
return self.device
|
||||
for module in self.text_encoder.modules():
|
||||
if (
|
||||
hasattr(module, "_hf_hook")
|
||||
and hasattr(module._hf_hook, "execution_device")
|
||||
and module._hf_hook.execution_device is not None
|
||||
):
|
||||
return torch.device(module._hf_hook.execution_device)
|
||||
return self.device
|
||||
|
||||
# Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_prior.KandinskyPriorPipeline._encode_prompt
|
||||
def _encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt=None,
|
||||
):
|
||||
batch_size = len(prompt) if isinstance(prompt, list) else 1
|
||||
# get prompt text embeddings
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
text_mask = text_inputs.attention_mask.bool().to(device)
|
||||
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
||||
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
||||
)
|
||||
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
|
||||
|
||||
text_encoder_output = self.text_encoder(text_input_ids.to(device))
|
||||
|
||||
prompt_embeds = text_encoder_output.text_embeds
|
||||
text_encoder_hidden_states = text_encoder_output.last_hidden_state
|
||||
|
||||
prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
uncond_tokens: List[str]
|
||||
if negative_prompt is None:
|
||||
uncond_tokens = [""] * batch_size
|
||||
elif type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif isinstance(negative_prompt, str):
|
||||
uncond_tokens = [negative_prompt]
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
else:
|
||||
uncond_tokens = negative_prompt
|
||||
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_tokens,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
uncond_text_mask = uncond_input.attention_mask.bool().to(device)
|
||||
negative_prompt_embeds_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device))
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.text_embeds
|
||||
uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state
|
||||
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
|
||||
seq_len = negative_prompt_embeds.shape[1]
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len)
|
||||
|
||||
seq_len = uncond_text_encoder_hidden_states.shape[1]
|
||||
uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1)
|
||||
uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view(
|
||||
batch_size * num_images_per_prompt, seq_len, -1
|
||||
)
|
||||
uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
|
||||
# done duplicates
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states])
|
||||
|
||||
text_mask = torch.cat([uncond_text_mask, text_mask])
|
||||
|
||||
return prompt_embeds, text_encoder_hidden_states, text_mask
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
num_inference_steps: int = 25,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
guidance_scale: float = 4.0,
|
||||
output_type: Optional[str] = "pt", # pt only
|
||||
return_dict: bool = True,
|
||||
):
|
||||
"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`):
|
||||
The prompt or prompts to guide the image generation.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
||||
if `guidance_scale` is less than `1`).
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
num_inference_steps (`int`, *optional*, defaults to 100):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
guidance_scale (`float`, *optional*, defaults to 4.0):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
output_type (`str`, *optional*, defaults to `"pt"`):
|
||||
The output format of the generate image. Choose between: `"np"` (`np.array`) or `"pt"`
|
||||
(`torch.Tensor`).
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`KandinskyPriorPipelineOutput`] or `tuple`
|
||||
"""
|
||||
|
||||
if isinstance(prompt, str):
|
||||
prompt = [prompt]
|
||||
elif not isinstance(prompt, list):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if isinstance(negative_prompt, str):
|
||||
negative_prompt = [negative_prompt]
|
||||
elif not isinstance(negative_prompt, list) and negative_prompt is not None:
|
||||
raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
|
||||
|
||||
# if the negative prompt is defined we double the batch size to
|
||||
# directly retrieve the negative prompt embedding
|
||||
if negative_prompt is not None:
|
||||
prompt = prompt + negative_prompt
|
||||
negative_prompt = 2 * negative_prompt
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
batch_size = len(prompt)
|
||||
batch_size = batch_size * num_images_per_prompt
|
||||
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt(
|
||||
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
|
||||
)
|
||||
|
||||
# prior
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
prior_timesteps_tensor = self.scheduler.timesteps
|
||||
|
||||
embedding_dim = self.prior.config.embedding_dim
|
||||
|
||||
latents = self.prepare_latents(
|
||||
(batch_size, embedding_dim),
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
self.scheduler,
|
||||
)
|
||||
|
||||
for i, t in enumerate(self.progress_bar(prior_timesteps_tensor)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
|
||||
predicted_image_embedding = self.prior(
|
||||
latent_model_input,
|
||||
timestep=t,
|
||||
proj_embedding=prompt_embeds,
|
||||
encoder_hidden_states=text_encoder_hidden_states,
|
||||
attention_mask=text_mask,
|
||||
).predicted_image_embedding
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
predicted_image_embedding_uncond, predicted_image_embedding_text = predicted_image_embedding.chunk(2)
|
||||
predicted_image_embedding = predicted_image_embedding_uncond + guidance_scale * (
|
||||
predicted_image_embedding_text - predicted_image_embedding_uncond
|
||||
)
|
||||
|
||||
if i + 1 == prior_timesteps_tensor.shape[0]:
|
||||
prev_timestep = None
|
||||
else:
|
||||
prev_timestep = prior_timesteps_tensor[i + 1]
|
||||
|
||||
latents = self.scheduler.step(
|
||||
predicted_image_embedding,
|
||||
timestep=t,
|
||||
sample=latents,
|
||||
generator=generator,
|
||||
prev_timestep=prev_timestep,
|
||||
).prev_sample
|
||||
|
||||
latents = self.prior.post_process_latents(latents)
|
||||
|
||||
image_embeddings = latents
|
||||
|
||||
# if negative prompt has been defined, we retrieve split the image embedding into two
|
||||
if negative_prompt is None:
|
||||
zero_embeds = self.get_zero_embed(latents.shape[0], device=latents.device)
|
||||
else:
|
||||
image_embeddings, zero_embeds = image_embeddings.chunk(2)
|
||||
|
||||
if output_type not in ["pt", "np"]:
|
||||
raise ValueError(f"Only the output types `pt` and `np` are supported not output_type={output_type}")
|
||||
|
||||
if output_type == "np":
|
||||
image_embeddings = image_embeddings.cpu().numpy()
|
||||
zero_embeds = zero_embeds.cpu().numpy()
|
||||
|
||||
if not return_dict:
|
||||
return (image_embeddings, zero_embeds)
|
||||
|
||||
return KandinskyPriorPipelineOutput(image_embeds=image_embeddings, negative_image_embeds=zero_embeds)
|
||||
@@ -0,0 +1,605 @@
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import PIL
|
||||
import torch
|
||||
from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection
|
||||
|
||||
from ...models import PriorTransformer
|
||||
from ...pipelines import DiffusionPipeline
|
||||
from ...schedulers import UnCLIPScheduler
|
||||
from ...utils import (
|
||||
is_accelerate_available,
|
||||
logging,
|
||||
randn_tensor,
|
||||
replace_example_docstring,
|
||||
)
|
||||
from ..kandinsky import KandinskyPriorPipelineOutput
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> from diffusers import KandinskyV22Pipeline, KandinskyV22PriorEmb2EmbPipeline
|
||||
>>> import torch
|
||||
|
||||
>>> pipe_prior = KandinskyPriorPipeline.from_pretrained(
|
||||
... "kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float16
|
||||
... )
|
||||
>>> pipe_prior.to("cuda")
|
||||
|
||||
>>> prompt = "red cat, 4k photo"
|
||||
>>> img = load_image(
|
||||
... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
... "/kandinsky/cat.png"
|
||||
... )
|
||||
>>> image_emb, nagative_image_emb = pipe_prior(prompt, image=img, strength=0.2).to_tuple()
|
||||
|
||||
>>> pipe = KandinskyPipeline.from_pretrained(
|
||||
... "kandinsky-community/kandinsky-2-2-decoder, torch_dtype=torch.float16"
|
||||
... )
|
||||
>>> pipe.to("cuda")
|
||||
|
||||
>>> image = pipe(
|
||||
... image_embeds=image_emb,
|
||||
... negative_image_embeds=negative_image_emb,
|
||||
... height=768,
|
||||
... width=768,
|
||||
... num_inference_steps=100,
|
||||
... ).images
|
||||
|
||||
>>> image[0].save("cat.png")
|
||||
```
|
||||
"""
|
||||
|
||||
EXAMPLE_INTERPOLATE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> from diffusers import KandinskyV22PriorEmb2EmbPipeline, KandinskyV22Pipeline
|
||||
>>> from diffusers.utils import load_image
|
||||
>>> import PIL
|
||||
|
||||
>>> import torch
|
||||
>>> from torchvision import transforms
|
||||
|
||||
>>> pipe_prior = KandinskyV22PriorPipeline.from_pretrained(
|
||||
... "kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float16
|
||||
... )
|
||||
>>> pipe_prior.to("cuda")
|
||||
|
||||
>>> img1 = load_image(
|
||||
... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
... "/kandinsky/cat.png"
|
||||
... )
|
||||
|
||||
>>> img2 = load_image(
|
||||
... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
... "/kandinsky/starry_night.jpeg"
|
||||
... )
|
||||
|
||||
>>> images_texts = ["a cat", img1, img2]
|
||||
>>> weights = [0.3, 0.3, 0.4]
|
||||
>>> image_emb, zero_image_emb = pipe_prior.interpolate(images_texts, weights)
|
||||
|
||||
>>> pipe = KandinskyV22Pipeline.from_pretrained(
|
||||
... "kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16
|
||||
... )
|
||||
>>> pipe.to("cuda")
|
||||
|
||||
>>> image = pipe(
|
||||
... image_embeds=image_emb,
|
||||
... negative_image_embeds=zero_image_emb,
|
||||
... height=768,
|
||||
... width=768,
|
||||
... num_inference_steps=150,
|
||||
... ).images[0]
|
||||
|
||||
>>> image.save("starry_cat.png")
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
class KandinskyV22PriorEmb2EmbPipeline(DiffusionPipeline):
|
||||
"""
|
||||
Pipeline for generating image prior for Kandinsky
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
Args:
|
||||
prior ([`PriorTransformer`]):
|
||||
The canonincal unCLIP prior to approximate the image embedding from the text embedding.
|
||||
image_encoder ([`CLIPVisionModelWithProjection`]):
|
||||
Frozen image-encoder.
|
||||
text_encoder ([`CLIPTextModelWithProjection`]):
|
||||
Frozen text-encoder.
|
||||
tokenizer (`CLIPTokenizer`):
|
||||
Tokenizer of class
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
scheduler ([`UnCLIPScheduler`]):
|
||||
A scheduler to be used in combination with `prior` to generate image embedding.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prior: PriorTransformer,
|
||||
image_encoder: CLIPVisionModelWithProjection,
|
||||
text_encoder: CLIPTextModelWithProjection,
|
||||
tokenizer: CLIPTokenizer,
|
||||
scheduler: UnCLIPScheduler,
|
||||
image_processor: CLIPImageProcessor,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
prior=prior,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
scheduler=scheduler,
|
||||
image_encoder=image_encoder,
|
||||
image_processor=image_processor,
|
||||
)
|
||||
|
||||
def get_timesteps(self, num_inference_steps, strength, device):
|
||||
# get the original timestep using init_timestep
|
||||
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
|
||||
|
||||
t_start = max(num_inference_steps - init_timestep, 0)
|
||||
timesteps = self.scheduler.timesteps[t_start:]
|
||||
|
||||
return timesteps, num_inference_steps - t_start
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_INTERPOLATE_DOC_STRING)
|
||||
def interpolate(
|
||||
self,
|
||||
images_and_prompts: List[Union[str, PIL.Image.Image, torch.FloatTensor]],
|
||||
weights: List[float],
|
||||
num_images_per_prompt: int = 1,
|
||||
num_inference_steps: int = 25,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
negative_prior_prompt: Optional[str] = None,
|
||||
negative_prompt: Union[str] = "",
|
||||
guidance_scale: float = 4.0,
|
||||
device=None,
|
||||
):
|
||||
"""
|
||||
Function invoked when using the prior pipeline for interpolation.
|
||||
|
||||
Args:
|
||||
images_and_prompts (`List[Union[str, PIL.Image.Image, torch.FloatTensor]]`):
|
||||
list of prompts and images to guide the image generation.
|
||||
weights: (`List[float]`):
|
||||
list of weights for each condition in `images_and_prompts`
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
num_inference_steps (`int`, *optional*, defaults to 100):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
negative_prior_prompt (`str`, *optional*):
|
||||
The prompt not to guide the prior diffusion process. Ignored when not using guidance (i.e., ignored if
|
||||
`guidance_scale` is less than `1`).
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt not to guide the image generation. Ignored when not using guidance (i.e., ignored if
|
||||
`guidance_scale` is less than `1`).
|
||||
guidance_scale (`float`, *optional*, defaults to 4.0):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`KandinskyPriorPipelineOutput`] or `tuple`
|
||||
"""
|
||||
|
||||
device = device or self.device
|
||||
|
||||
if len(images_and_prompts) != len(weights):
|
||||
raise ValueError(
|
||||
f"`images_and_prompts` contains {len(images_and_prompts)} items and `weights` contains {len(weights)} items - they should be lists of same length"
|
||||
)
|
||||
|
||||
image_embeddings = []
|
||||
for cond, weight in zip(images_and_prompts, weights):
|
||||
if isinstance(cond, str):
|
||||
image_emb = self(
|
||||
cond,
|
||||
num_inference_steps=num_inference_steps,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
generator=generator,
|
||||
latents=latents,
|
||||
negative_prompt=negative_prior_prompt,
|
||||
guidance_scale=guidance_scale,
|
||||
).image_embeds.unsqueeze(0)
|
||||
|
||||
elif isinstance(cond, (PIL.Image.Image, torch.Tensor)):
|
||||
image_emb = self._encode_image(
|
||||
cond, device=device, num_images_per_prompt=num_images_per_prompt
|
||||
).unsqueeze(0)
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"`images_and_prompts` can only contains elements to be of type `str`, `PIL.Image.Image` or `torch.Tensor` but is {type(cond)}"
|
||||
)
|
||||
|
||||
image_embeddings.append(image_emb * weight)
|
||||
|
||||
image_emb = torch.cat(image_embeddings).sum(dim=0)
|
||||
|
||||
return KandinskyPriorPipelineOutput(image_embeds=image_emb, negative_image_embeds=torch.randn_like(image_emb))
|
||||
|
||||
def _encode_image(
|
||||
self,
|
||||
image: Union[torch.Tensor, List[PIL.Image.Image]],
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
):
|
||||
if not isinstance(image, torch.Tensor):
|
||||
image = self.image_processor(image, return_tensors="pt").pixel_values.to(
|
||||
dtype=self.image_encoder.dtype, device=device
|
||||
)
|
||||
|
||||
image_emb = self.image_encoder(image)["image_embeds"] # B, D
|
||||
image_emb = image_emb.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
image_emb.to(device=device)
|
||||
|
||||
return image_emb
|
||||
|
||||
def prepare_latents(self, emb, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
|
||||
emb = emb.to(device=device, dtype=dtype)
|
||||
|
||||
batch_size = batch_size * num_images_per_prompt
|
||||
|
||||
init_latents = emb
|
||||
|
||||
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
|
||||
additional_image_per_prompt = batch_size // init_latents.shape[0]
|
||||
init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)
|
||||
elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
|
||||
raise ValueError(
|
||||
f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
|
||||
)
|
||||
else:
|
||||
init_latents = torch.cat([init_latents], dim=0)
|
||||
|
||||
shape = init_latents.shape
|
||||
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
|
||||
# get latents
|
||||
init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
|
||||
latents = init_latents
|
||||
|
||||
return latents
|
||||
|
||||
# Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_prior.KandinskyPriorPipeline.get_zero_embed
|
||||
def get_zero_embed(self, batch_size=1, device=None):
|
||||
device = device or self.device
|
||||
zero_img = torch.zeros(1, 3, self.image_encoder.config.image_size, self.image_encoder.config.image_size).to(
|
||||
device=device, dtype=self.image_encoder.dtype
|
||||
)
|
||||
zero_image_emb = self.image_encoder(zero_img)["image_embeds"]
|
||||
zero_image_emb = zero_image_emb.repeat(batch_size, 1)
|
||||
return zero_image_emb
|
||||
|
||||
# Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_prior.KandinskyPriorPipeline.enable_sequential_cpu_offload
|
||||
def enable_sequential_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's
|
||||
models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only
|
||||
when their specific submodule has its `forward` method called.
|
||||
"""
|
||||
if is_accelerate_available():
|
||||
from accelerate import cpu_offload
|
||||
else:
|
||||
raise ImportError("Please install accelerate via `pip install accelerate`")
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
models = [
|
||||
self.image_encoder,
|
||||
self.text_encoder,
|
||||
]
|
||||
for cpu_offloaded_model in models:
|
||||
if cpu_offloaded_model is not None:
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
|
||||
@property
|
||||
# Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_prior.KandinskyPriorPipeline._execution_device
|
||||
def _execution_device(self):
|
||||
r"""
|
||||
Returns the device on which the pipeline's models will be executed. After calling
|
||||
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
|
||||
hooks.
|
||||
"""
|
||||
if self.device != torch.device("meta") or not hasattr(self.text_encoder, "_hf_hook"):
|
||||
return self.device
|
||||
for module in self.text_encoder.modules():
|
||||
if (
|
||||
hasattr(module, "_hf_hook")
|
||||
and hasattr(module._hf_hook, "execution_device")
|
||||
and module._hf_hook.execution_device is not None
|
||||
):
|
||||
return torch.device(module._hf_hook.execution_device)
|
||||
return self.device
|
||||
|
||||
# Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_prior.KandinskyPriorPipeline._encode_prompt
|
||||
def _encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt=None,
|
||||
):
|
||||
batch_size = len(prompt) if isinstance(prompt, list) else 1
|
||||
# get prompt text embeddings
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
text_mask = text_inputs.attention_mask.bool().to(device)
|
||||
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
||||
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
||||
)
|
||||
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
|
||||
|
||||
text_encoder_output = self.text_encoder(text_input_ids.to(device))
|
||||
|
||||
prompt_embeds = text_encoder_output.text_embeds
|
||||
text_encoder_hidden_states = text_encoder_output.last_hidden_state
|
||||
|
||||
prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
uncond_tokens: List[str]
|
||||
if negative_prompt is None:
|
||||
uncond_tokens = [""] * batch_size
|
||||
elif type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif isinstance(negative_prompt, str):
|
||||
uncond_tokens = [negative_prompt]
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
else:
|
||||
uncond_tokens = negative_prompt
|
||||
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_tokens,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
uncond_text_mask = uncond_input.attention_mask.bool().to(device)
|
||||
negative_prompt_embeds_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device))
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.text_embeds
|
||||
uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state
|
||||
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
|
||||
seq_len = negative_prompt_embeds.shape[1]
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len)
|
||||
|
||||
seq_len = uncond_text_encoder_hidden_states.shape[1]
|
||||
uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1)
|
||||
uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view(
|
||||
batch_size * num_images_per_prompt, seq_len, -1
|
||||
)
|
||||
uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
|
||||
# done duplicates
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states])
|
||||
|
||||
text_mask = torch.cat([uncond_text_mask, text_mask])
|
||||
|
||||
return prompt_embeds, text_encoder_hidden_states, text_mask
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
image: Union[torch.Tensor, List[torch.Tensor], PIL.Image.Image, List[PIL.Image.Image]],
|
||||
strength: float = 0.3,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
num_inference_steps: int = 25,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
guidance_scale: float = 4.0,
|
||||
output_type: Optional[str] = "pt", # pt only
|
||||
return_dict: bool = True,
|
||||
):
|
||||
"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`):
|
||||
The prompt or prompts to guide the image generation.
|
||||
strength (`float`, *optional*, defaults to 0.8):
|
||||
Conceptually, indicates how much to transform the reference `emb`. Must be between 0 and 1. `image`
|
||||
will be used as a starting point, adding more noise to it the larger the `strength`. The number of
|
||||
denoising steps depends on the amount of noise initially added.
|
||||
emb (`torch.FloatTensor`):
|
||||
The image embedding.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
||||
if `guidance_scale` is less than `1`).
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
num_inference_steps (`int`, *optional*, defaults to 100):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
guidance_scale (`float`, *optional*, defaults to 4.0):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
output_type (`str`, *optional*, defaults to `"pt"`):
|
||||
The output format of the generate image. Choose between: `"np"` (`np.array`) or `"pt"`
|
||||
(`torch.Tensor`).
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`KandinskyPriorPipelineOutput`] or `tuple`
|
||||
"""
|
||||
|
||||
if isinstance(prompt, str):
|
||||
prompt = [prompt]
|
||||
elif not isinstance(prompt, list):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if isinstance(negative_prompt, str):
|
||||
negative_prompt = [negative_prompt]
|
||||
elif not isinstance(negative_prompt, list) and negative_prompt is not None:
|
||||
raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
|
||||
|
||||
# if the negative prompt is defined we double the batch size to
|
||||
# directly retrieve the negative prompt embedding
|
||||
if negative_prompt is not None:
|
||||
prompt = prompt + negative_prompt
|
||||
negative_prompt = 2 * negative_prompt
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
batch_size = len(prompt)
|
||||
batch_size = batch_size * num_images_per_prompt
|
||||
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt(
|
||||
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
|
||||
)
|
||||
|
||||
if not isinstance(image, List):
|
||||
image = [image]
|
||||
|
||||
if isinstance(image[0], torch.Tensor):
|
||||
image = torch.cat(image, dim=0)
|
||||
|
||||
if isinstance(image, torch.Tensor) and image.ndim == 2:
|
||||
# allow user to pass image_embeds directly
|
||||
image_embeds = image.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
elif isinstance(image, torch.Tensor) and image.ndim != 4:
|
||||
raise ValueError(
|
||||
f" if pass `image` as pytorch tensor, or a list of pytorch tensor, please make sure each tensor has shape [batch_size, channels, height, width], currently {image[0].unsqueeze(0).shape}"
|
||||
)
|
||||
else:
|
||||
image_embeds = self._encode_image(image, device, num_images_per_prompt)
|
||||
|
||||
# prior
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
|
||||
latents = image_embeds
|
||||
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
|
||||
latent_timestep = timesteps[:1].repeat(batch_size)
|
||||
latents = self.prepare_latents(
|
||||
latents,
|
||||
latent_timestep,
|
||||
batch_size // num_images_per_prompt,
|
||||
num_images_per_prompt,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
)
|
||||
|
||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
|
||||
predicted_image_embedding = self.prior(
|
||||
latent_model_input,
|
||||
timestep=t,
|
||||
proj_embedding=prompt_embeds,
|
||||
encoder_hidden_states=text_encoder_hidden_states,
|
||||
attention_mask=text_mask,
|
||||
).predicted_image_embedding
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
predicted_image_embedding_uncond, predicted_image_embedding_text = predicted_image_embedding.chunk(2)
|
||||
predicted_image_embedding = predicted_image_embedding_uncond + guidance_scale * (
|
||||
predicted_image_embedding_text - predicted_image_embedding_uncond
|
||||
)
|
||||
|
||||
if i + 1 == timesteps.shape[0]:
|
||||
prev_timestep = None
|
||||
else:
|
||||
prev_timestep = timesteps[i + 1]
|
||||
|
||||
latents = self.scheduler.step(
|
||||
predicted_image_embedding,
|
||||
timestep=t,
|
||||
sample=latents,
|
||||
generator=generator,
|
||||
prev_timestep=prev_timestep,
|
||||
).prev_sample
|
||||
|
||||
latents = self.prior.post_process_latents(latents)
|
||||
|
||||
image_embeddings = latents
|
||||
|
||||
# if negative prompt has been defined, we retrieve split the image embedding into two
|
||||
if negative_prompt is None:
|
||||
zero_embeds = self.get_zero_embed(latents.shape[0], device=latents.device)
|
||||
else:
|
||||
image_embeddings, zero_embeds = image_embeddings.chunk(2)
|
||||
|
||||
if output_type not in ["pt", "np"]:
|
||||
raise ValueError(f"Only the output types `pt` and `np` are supported not output_type={output_type}")
|
||||
|
||||
if output_type == "np":
|
||||
image_embeddings = image_embeddings.cpu().numpy()
|
||||
zero_embeds = zero_embeds.cpu().numpy()
|
||||
|
||||
if not return_dict:
|
||||
return (image_embeddings, zero_embeds)
|
||||
|
||||
return KandinskyPriorPipelineOutput(image_embeds=image_embeddings, negative_image_embeds=zero_embeds)
|
||||
@@ -204,7 +204,7 @@ def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLi
|
||||
transformers_index_format = r"\d{5}-of-\d{5}"
|
||||
|
||||
if variant is not None:
|
||||
# `diffusion_pytorch_model.fp16.bin` as well as `model.fp16-00001-of-00002.safetenstors`
|
||||
# `diffusion_pytorch_model.fp16.bin` as well as `model.fp16-00001-of-00002.safetensors`
|
||||
variant_file_re = re.compile(
|
||||
rf"({'|'.join(weight_prefixes)})\.({variant}|{variant}-{transformers_index_format})\.({'|'.join(weight_suffixs)})$"
|
||||
)
|
||||
@@ -213,7 +213,7 @@ def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLi
|
||||
rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.{variant}\.json$"
|
||||
)
|
||||
|
||||
# `diffusion_pytorch_model.bin` as well as `model-00001-of-00002.safetenstors`
|
||||
# `diffusion_pytorch_model.bin` as well as `model-00001-of-00002.safetensors`
|
||||
non_variant_file_re = re.compile(
|
||||
rf"({'|'.join(weight_prefixes)})(-{transformers_index_format})?\.({'|'.join(weight_suffixs)})$"
|
||||
)
|
||||
@@ -1168,7 +1168,7 @@ class DiffusionPipeline(ConfigMixin):
|
||||
|
||||
if use_safetensors and not is_safetensors_available():
|
||||
raise ValueError(
|
||||
"`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetenstors"
|
||||
"`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors"
|
||||
)
|
||||
|
||||
allow_pickle = False
|
||||
@@ -1213,6 +1213,15 @@ class DiffusionPipeline(ConfigMixin):
|
||||
filenames = {sibling.rfilename for sibling in info.siblings}
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
|
||||
|
||||
if len(variant_filenames) == 0 and variant is not None:
|
||||
deprecation_message = (
|
||||
f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available."
|
||||
f"The default model files: {model_filenames} will be loaded instead. Make sure to not load from `variant={variant}`"
|
||||
"if such variant modeling files are not available. Doing so will lead to an error in v0.22.0 as defaulting to non-variant"
|
||||
"modeling files is deprecated."
|
||||
)
|
||||
deprecate("no variant default", "0.22.0", deprecation_message, standard_warn=False)
|
||||
|
||||
# remove ignored filenames
|
||||
model_filenames = set(model_filenames) - set(ignore_filenames)
|
||||
variant_filenames = set(variant_filenames) - set(ignore_filenames)
|
||||
@@ -1302,7 +1311,7 @@ class DiffusionPipeline(ConfigMixin):
|
||||
snapshot_folder = Path(config_file).parent
|
||||
pipeline_is_cached = all((snapshot_folder / f).is_file() for f in expected_files)
|
||||
|
||||
if pipeline_is_cached:
|
||||
if pipeline_is_cached and not force_download:
|
||||
# if the pipeline is cached, we can directly return it
|
||||
# else call snapshot_download
|
||||
return snapshot_folder
|
||||
|
||||
@@ -0,0 +1,27 @@
|
||||
from ...utils import (
|
||||
OptionalDependencyNotAvailable,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
is_transformers_version,
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import ShapEPipeline
|
||||
else:
|
||||
from .camera import create_pan_cameras
|
||||
from .pipeline_shap_e import ShapEPipeline
|
||||
from .pipeline_shap_e_img2img import ShapEImg2ImgPipeline
|
||||
from .renderer import (
|
||||
BoundingBoxVolume,
|
||||
ImportanceRaySampler,
|
||||
MLPNeRFModelOutput,
|
||||
MLPNeRSTFModel,
|
||||
ShapEParamsProjModel,
|
||||
ShapERenderer,
|
||||
StratifiedRaySampler,
|
||||
VoidNeRFModel,
|
||||
)
|
||||
@@ -0,0 +1,147 @@
|
||||
# Copyright 2023 Open AI and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
@dataclass
|
||||
class DifferentiableProjectiveCamera:
|
||||
"""
|
||||
Implements a batch, differentiable, standard pinhole camera
|
||||
"""
|
||||
|
||||
origin: torch.Tensor # [batch_size x 3]
|
||||
x: torch.Tensor # [batch_size x 3]
|
||||
y: torch.Tensor # [batch_size x 3]
|
||||
z: torch.Tensor # [batch_size x 3]
|
||||
width: int
|
||||
height: int
|
||||
x_fov: float
|
||||
y_fov: float
|
||||
shape: Tuple[int]
|
||||
|
||||
def __post_init__(self):
|
||||
assert self.x.shape[0] == self.y.shape[0] == self.z.shape[0] == self.origin.shape[0]
|
||||
assert self.x.shape[1] == self.y.shape[1] == self.z.shape[1] == self.origin.shape[1] == 3
|
||||
assert len(self.x.shape) == len(self.y.shape) == len(self.z.shape) == len(self.origin.shape) == 2
|
||||
|
||||
def resolution(self):
|
||||
return torch.from_numpy(np.array([self.width, self.height], dtype=np.float32))
|
||||
|
||||
def fov(self):
|
||||
return torch.from_numpy(np.array([self.x_fov, self.y_fov], dtype=np.float32))
|
||||
|
||||
def get_image_coords(self) -> torch.Tensor:
|
||||
"""
|
||||
:return: coords of shape (width * height, 2)
|
||||
"""
|
||||
pixel_indices = torch.arange(self.height * self.width)
|
||||
coords = torch.stack(
|
||||
[
|
||||
pixel_indices % self.width,
|
||||
torch.div(pixel_indices, self.width, rounding_mode="trunc"),
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
return coords
|
||||
|
||||
@property
|
||||
def camera_rays(self):
|
||||
batch_size, *inner_shape = self.shape
|
||||
inner_batch_size = int(np.prod(inner_shape))
|
||||
|
||||
coords = self.get_image_coords()
|
||||
coords = torch.broadcast_to(coords.unsqueeze(0), [batch_size * inner_batch_size, *coords.shape])
|
||||
rays = self.get_camera_rays(coords)
|
||||
|
||||
rays = rays.view(batch_size, inner_batch_size * self.height * self.width, 2, 3)
|
||||
|
||||
return rays
|
||||
|
||||
def get_camera_rays(self, coords: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, *shape, n_coords = coords.shape
|
||||
assert n_coords == 2
|
||||
assert batch_size == self.origin.shape[0]
|
||||
|
||||
flat = coords.view(batch_size, -1, 2)
|
||||
|
||||
res = self.resolution()
|
||||
fov = self.fov()
|
||||
|
||||
fracs = (flat.float() / (res - 1)) * 2 - 1
|
||||
fracs = fracs * torch.tan(fov / 2)
|
||||
|
||||
fracs = fracs.view(batch_size, -1, 2)
|
||||
directions = (
|
||||
self.z.view(batch_size, 1, 3)
|
||||
+ self.x.view(batch_size, 1, 3) * fracs[:, :, :1]
|
||||
+ self.y.view(batch_size, 1, 3) * fracs[:, :, 1:]
|
||||
)
|
||||
directions = directions / directions.norm(dim=-1, keepdim=True)
|
||||
rays = torch.stack(
|
||||
[
|
||||
torch.broadcast_to(self.origin.view(batch_size, 1, 3), [batch_size, directions.shape[1], 3]),
|
||||
directions,
|
||||
],
|
||||
dim=2,
|
||||
)
|
||||
return rays.view(batch_size, *shape, 2, 3)
|
||||
|
||||
def resize_image(self, width: int, height: int) -> "DifferentiableProjectiveCamera":
|
||||
"""
|
||||
Creates a new camera for the resized view assuming the aspect ratio does not change.
|
||||
"""
|
||||
assert width * self.height == height * self.width, "The aspect ratio should not change."
|
||||
return DifferentiableProjectiveCamera(
|
||||
origin=self.origin,
|
||||
x=self.x,
|
||||
y=self.y,
|
||||
z=self.z,
|
||||
width=width,
|
||||
height=height,
|
||||
x_fov=self.x_fov,
|
||||
y_fov=self.y_fov,
|
||||
)
|
||||
|
||||
|
||||
def create_pan_cameras(size: int) -> DifferentiableProjectiveCamera:
|
||||
origins = []
|
||||
xs = []
|
||||
ys = []
|
||||
zs = []
|
||||
for theta in np.linspace(0, 2 * np.pi, num=20):
|
||||
z = np.array([np.sin(theta), np.cos(theta), -0.5])
|
||||
z /= np.sqrt(np.sum(z**2))
|
||||
origin = -z * 4
|
||||
x = np.array([np.cos(theta), -np.sin(theta), 0.0])
|
||||
y = np.cross(z, x)
|
||||
origins.append(origin)
|
||||
xs.append(x)
|
||||
ys.append(y)
|
||||
zs.append(z)
|
||||
return DifferentiableProjectiveCamera(
|
||||
origin=torch.from_numpy(np.stack(origins, axis=0)).float(),
|
||||
x=torch.from_numpy(np.stack(xs, axis=0)).float(),
|
||||
y=torch.from_numpy(np.stack(ys, axis=0)).float(),
|
||||
z=torch.from_numpy(np.stack(zs, axis=0)).float(),
|
||||
width=size,
|
||||
height=size,
|
||||
x_fov=0.7,
|
||||
y_fov=0.7,
|
||||
shape=(1, len(xs)),
|
||||
)
|
||||
@@ -0,0 +1,390 @@
|
||||
# Copyright 2023 Open AI and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
from transformers import CLIPTextModelWithProjection, CLIPTokenizer
|
||||
|
||||
from ...models import PriorTransformer
|
||||
from ...pipelines import DiffusionPipeline
|
||||
from ...schedulers import HeunDiscreteScheduler
|
||||
from ...utils import (
|
||||
BaseOutput,
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
logging,
|
||||
randn_tensor,
|
||||
replace_example_docstring,
|
||||
)
|
||||
from .renderer import ShapERenderer
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from diffusers import DiffusionPipeline
|
||||
>>> from diffusers.utils import export_to_gif
|
||||
|
||||
>>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
>>> repo = "openai/shap-e"
|
||||
>>> pipe = DiffusionPipeline.from_pretrained(repo, torch_dtype=torch.float16)
|
||||
>>> pipe = pipe.to(device)
|
||||
|
||||
>>> guidance_scale = 15.0
|
||||
>>> prompt = "a shark"
|
||||
|
||||
>>> images = pipe(
|
||||
... prompt,
|
||||
... guidance_scale=guidance_scale,
|
||||
... num_inference_steps=64,
|
||||
... frame_size=256,
|
||||
... ).images
|
||||
|
||||
>>> gif_path = export_to_gif(images[0], "shark_3d.gif")
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ShapEPipelineOutput(BaseOutput):
|
||||
"""
|
||||
Output class for ShapEPipeline.
|
||||
|
||||
Args:
|
||||
images (`torch.FloatTensor`)
|
||||
a list of images for 3D rendering
|
||||
"""
|
||||
|
||||
images: Union[List[List[PIL.Image.Image]], List[List[np.ndarray]]]
|
||||
|
||||
|
||||
class ShapEPipeline(DiffusionPipeline):
|
||||
"""
|
||||
Pipeline for generating latent representation of a 3D asset and rendering with NeRF method with Shap-E
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
Args:
|
||||
prior ([`PriorTransformer`]):
|
||||
The canonincal unCLIP prior to approximate the image embedding from the text embedding.
|
||||
text_encoder ([`CLIPTextModelWithProjection`]):
|
||||
Frozen text-encoder.
|
||||
tokenizer (`CLIPTokenizer`):
|
||||
Tokenizer of class
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
scheduler ([`HeunDiscreteScheduler`]):
|
||||
A scheduler to be used in combination with `prior` to generate image embedding.
|
||||
renderer ([`ShapERenderer`]):
|
||||
Shap-E renderer projects the generated latents into parameters of a MLP that's used to create 3D objects
|
||||
with the NeRF rendering method
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prior: PriorTransformer,
|
||||
text_encoder: CLIPTextModelWithProjection,
|
||||
tokenizer: CLIPTokenizer,
|
||||
scheduler: HeunDiscreteScheduler,
|
||||
renderer: ShapERenderer,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
prior=prior,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
scheduler=scheduler,
|
||||
renderer=renderer,
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents
|
||||
def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
if latents.shape != shape:
|
||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
||||
latents = latents.to(device)
|
||||
|
||||
latents = latents * scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
def enable_sequential_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's
|
||||
models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only
|
||||
when their specific submodule has its `forward` method called.
|
||||
"""
|
||||
if is_accelerate_available():
|
||||
from accelerate import cpu_offload
|
||||
else:
|
||||
raise ImportError("Please install accelerate via `pip install accelerate`")
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
models = [self.text_encoder, self.prior]
|
||||
for cpu_offloaded_model in models:
|
||||
if cpu_offloaded_model is not None:
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
|
||||
def enable_model_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
|
||||
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
|
||||
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
|
||||
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
|
||||
"""
|
||||
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
|
||||
from accelerate import cpu_offload_with_hook
|
||||
else:
|
||||
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
if self.device.type != "cpu":
|
||||
self.to("cpu", silence_dtype_warnings=True)
|
||||
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
||||
|
||||
hook = None
|
||||
for cpu_offloaded_model in [self.text_encoder, self.prior, self.renderer]:
|
||||
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
|
||||
|
||||
if self.safety_checker is not None:
|
||||
_, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
|
||||
|
||||
# We'll offload the last model manually.
|
||||
self.final_offload_hook = hook
|
||||
|
||||
@property
|
||||
def _execution_device(self):
|
||||
r"""
|
||||
Returns the device on which the pipeline's models will be executed. After calling
|
||||
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
|
||||
hooks.
|
||||
"""
|
||||
if self.device != torch.device("meta") or not hasattr(self.text_encoder, "_hf_hook"):
|
||||
return self.device
|
||||
for module in self.text_encoder.modules():
|
||||
if (
|
||||
hasattr(module, "_hf_hook")
|
||||
and hasattr(module._hf_hook, "execution_device")
|
||||
and module._hf_hook.execution_device is not None
|
||||
):
|
||||
return torch.device(module._hf_hook.execution_device)
|
||||
return self.device
|
||||
|
||||
def _encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
):
|
||||
len(prompt) if isinstance(prompt, list) else 1
|
||||
|
||||
# YiYi Notes: set pad_token_id to be 0, not sure why I can't set in the config file
|
||||
self.tokenizer.pad_token_id = 0
|
||||
# get prompt text embeddings
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
||||
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
text_encoder_output = self.text_encoder(text_input_ids.to(device))
|
||||
prompt_embeds = text_encoder_output.text_embeds
|
||||
|
||||
prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
# in Shap-E it normalize the prompt_embeds and then later rescale it
|
||||
prompt_embeds = prompt_embeds / torch.linalg.norm(prompt_embeds, dim=-1, keepdim=True)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
# Rescale the features to have unit variance
|
||||
prompt_embeds = math.sqrt(prompt_embeds.shape[1]) * prompt_embeds
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: str,
|
||||
num_images_per_prompt: int = 1,
|
||||
num_inference_steps: int = 25,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
guidance_scale: float = 4.0,
|
||||
frame_size: int = 64,
|
||||
output_type: Optional[str] = "pil", # pil, np, latent
|
||||
return_dict: bool = True,
|
||||
):
|
||||
"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`):
|
||||
The prompt or prompts to guide the image generation.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
num_inference_steps (`int`, *optional*, defaults to 25):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
guidance_scale (`float`, *optional*, defaults to 4.0):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
frame_size (`int`, *optional*, default to 64):
|
||||
the width and height of each image frame of the generated 3d output
|
||||
output_type (`str`, *optional*, defaults to `"pt"`):
|
||||
The output format of the generate image. Choose between: `"np"` (`np.array`) or `"pt"`
|
||||
(`torch.Tensor`).
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`ShapEPipelineOutput`] or `tuple`
|
||||
"""
|
||||
|
||||
if isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
batch_size = batch_size * num_images_per_prompt
|
||||
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
prompt_embeds = self._encode_prompt(prompt, device, num_images_per_prompt, do_classifier_free_guidance)
|
||||
|
||||
# prior
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
num_embeddings = self.prior.config.num_embeddings
|
||||
embedding_dim = self.prior.config.embedding_dim
|
||||
|
||||
latents = self.prepare_latents(
|
||||
(batch_size, num_embeddings * embedding_dim),
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
self.scheduler,
|
||||
)
|
||||
|
||||
# YiYi notes: for testing only to match ldm, we can directly create a latents with desired shape: batch_size, num_embeddings, embedding_dim
|
||||
latents = latents.reshape(latents.shape[0], num_embeddings, embedding_dim)
|
||||
|
||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
scaled_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
noise_pred = self.prior(
|
||||
scaled_model_input,
|
||||
timestep=t,
|
||||
proj_embedding=prompt_embeds,
|
||||
).predicted_image_embedding
|
||||
|
||||
# remove the variance
|
||||
noise_pred, _ = noise_pred.split(
|
||||
scaled_model_input.shape[2], dim=2
|
||||
) # batch_size, num_embeddings, embedding_dim
|
||||
|
||||
if do_classifier_free_guidance is not None:
|
||||
noise_pred_uncond, noise_pred = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond)
|
||||
|
||||
latents = self.scheduler.step(
|
||||
noise_pred,
|
||||
timestep=t,
|
||||
sample=latents,
|
||||
).prev_sample
|
||||
|
||||
if output_type == "latent":
|
||||
return ShapEPipelineOutput(images=latents)
|
||||
|
||||
images = []
|
||||
for i, latent in enumerate(latents):
|
||||
image = self.renderer.decode(
|
||||
latent[None, :],
|
||||
device,
|
||||
size=frame_size,
|
||||
ray_batch_size=4096,
|
||||
n_coarse_samples=64,
|
||||
n_fine_samples=128,
|
||||
)
|
||||
images.append(image)
|
||||
|
||||
images = torch.stack(images)
|
||||
|
||||
if output_type not in ["np", "pil"]:
|
||||
raise ValueError(f"Only the output types `pil` and `np` are supported not output_type={output_type}")
|
||||
|
||||
images = images.cpu().numpy()
|
||||
|
||||
if output_type == "pil":
|
||||
images = [self.numpy_to_pil(image) for image in images]
|
||||
|
||||
# Offload last model to CPU
|
||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||
self.final_offload_hook.offload()
|
||||
|
||||
if not return_dict:
|
||||
return (images,)
|
||||
|
||||
return ShapEPipelineOutput(images=images)
|
||||
@@ -0,0 +1,349 @@
|
||||
# Copyright 2023 Open AI and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
from transformers import CLIPImageProcessor, CLIPVisionModel
|
||||
|
||||
from ...models import PriorTransformer
|
||||
from ...pipelines import DiffusionPipeline
|
||||
from ...schedulers import HeunDiscreteScheduler
|
||||
from ...utils import (
|
||||
BaseOutput,
|
||||
is_accelerate_available,
|
||||
logging,
|
||||
randn_tensor,
|
||||
replace_example_docstring,
|
||||
)
|
||||
from .renderer import ShapERenderer
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> from PIL import Image
|
||||
>>> import torch
|
||||
>>> from diffusers import DiffusionPipeline
|
||||
>>> from diffusers.utils import export_to_gif, load_image
|
||||
|
||||
>>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
>>> repo = "openai/shap-e-img2img"
|
||||
>>> pipe = DiffusionPipeline.from_pretrained(repo, torch_dtype=torch.float16)
|
||||
>>> pipe = pipe.to(device)
|
||||
|
||||
>>> guidance_scale = 3.0
|
||||
>>> image_url = "https://hf.co/datasets/diffusers/docs-images/resolve/main/shap-e/corgi.png"
|
||||
>>> image = load_image(image_url).convert("RGB")
|
||||
|
||||
>>> images = pipe(
|
||||
... image,
|
||||
... guidance_scale=guidance_scale,
|
||||
... num_inference_steps=64,
|
||||
... frame_size=256,
|
||||
... ).images
|
||||
|
||||
>>> gif_path = export_to_gif(images[0], "corgi_3d.gif")
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ShapEPipelineOutput(BaseOutput):
|
||||
"""
|
||||
Output class for ShapEPipeline.
|
||||
|
||||
Args:
|
||||
images (`torch.FloatTensor`)
|
||||
a list of images for 3D rendering
|
||||
"""
|
||||
|
||||
images: Union[PIL.Image.Image, np.ndarray]
|
||||
|
||||
|
||||
class ShapEImg2ImgPipeline(DiffusionPipeline):
|
||||
"""
|
||||
Pipeline for generating latent representation of a 3D asset and rendering with NeRF method with Shap-E
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
Args:
|
||||
prior ([`PriorTransformer`]):
|
||||
The canonincal unCLIP prior to approximate the image embedding from the text embedding.
|
||||
text_encoder ([`CLIPTextModelWithProjection`]):
|
||||
Frozen text-encoder.
|
||||
tokenizer (`CLIPTokenizer`):
|
||||
Tokenizer of class
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
scheduler ([`HeunDiscreteScheduler`]):
|
||||
A scheduler to be used in combination with `prior` to generate image embedding.
|
||||
renderer ([`ShapERenderer`]):
|
||||
Shap-E renderer projects the generated latents into parameters of a MLP that's used to create 3D objects
|
||||
with the NeRF rendering method
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prior: PriorTransformer,
|
||||
image_encoder: CLIPVisionModel,
|
||||
image_processor: CLIPImageProcessor,
|
||||
scheduler: HeunDiscreteScheduler,
|
||||
renderer: ShapERenderer,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
prior=prior,
|
||||
image_encoder=image_encoder,
|
||||
image_processor=image_processor,
|
||||
scheduler=scheduler,
|
||||
renderer=renderer,
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents
|
||||
def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
if latents.shape != shape:
|
||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
||||
latents = latents.to(device)
|
||||
|
||||
latents = latents * scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
def enable_sequential_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's
|
||||
models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only
|
||||
when their specific submodule has its `forward` method called.
|
||||
"""
|
||||
if is_accelerate_available():
|
||||
from accelerate import cpu_offload
|
||||
else:
|
||||
raise ImportError("Please install accelerate via `pip install accelerate`")
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
models = [self.image_encoder, self.prior]
|
||||
for cpu_offloaded_model in models:
|
||||
if cpu_offloaded_model is not None:
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
|
||||
@property
|
||||
def _execution_device(self):
|
||||
r"""
|
||||
Returns the device on which the pipeline's models will be executed. After calling
|
||||
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
|
||||
hooks.
|
||||
"""
|
||||
if self.device != torch.device("meta") or not hasattr(self.image_encoder, "_hf_hook"):
|
||||
return self.device
|
||||
for module in self.image_encoder.modules():
|
||||
if (
|
||||
hasattr(module, "_hf_hook")
|
||||
and hasattr(module._hf_hook, "execution_device")
|
||||
and module._hf_hook.execution_device is not None
|
||||
):
|
||||
return torch.device(module._hf_hook.execution_device)
|
||||
return self.device
|
||||
|
||||
def _encode_image(
|
||||
self,
|
||||
image,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
):
|
||||
if isinstance(image, List) and isinstance(image[0], torch.Tensor):
|
||||
image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0)
|
||||
|
||||
if not isinstance(image, torch.Tensor):
|
||||
image = self.image_processor(image, return_tensors="pt").pixel_values[0].unsqueeze(0)
|
||||
|
||||
image = image.to(dtype=self.image_encoder.dtype, device=device)
|
||||
|
||||
image_embeds = self.image_encoder(image)["last_hidden_state"]
|
||||
image_embeds = image_embeds[:, 1:, :].contiguous() # batch_size, dim, 256
|
||||
|
||||
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
negative_image_embeds = torch.zeros_like(image_embeds)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
image_embeds = torch.cat([negative_image_embeds, image_embeds])
|
||||
|
||||
return image_embeds
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
image: Union[PIL.Image.Image, List[PIL.Image.Image]],
|
||||
num_images_per_prompt: int = 1,
|
||||
num_inference_steps: int = 25,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
guidance_scale: float = 4.0,
|
||||
frame_size: int = 64,
|
||||
output_type: Optional[str] = "pil", # pil, np, latent
|
||||
return_dict: bool = True,
|
||||
):
|
||||
"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`):
|
||||
The prompt or prompts to guide the image generation.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
num_inference_steps (`int`, *optional*, defaults to 100):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
guidance_scale (`float`, *optional*, defaults to 4.0):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
frame_size (`int`, *optional*, default to 64):
|
||||
the width and height of each image frame of the generated 3d output
|
||||
output_type (`str`, *optional*, defaults to `"pt"`):
|
||||
The output format of the generate image. Choose between: `"np"` (`np.array`) or `"pt"`
|
||||
(`torch.Tensor`).
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`ShapEPipelineOutput`] or `tuple`
|
||||
"""
|
||||
|
||||
if isinstance(image, PIL.Image.Image):
|
||||
batch_size = 1
|
||||
elif isinstance(image, torch.Tensor):
|
||||
batch_size = image.shape[0]
|
||||
elif isinstance(image, list) and isinstance(image[0], (torch.Tensor, PIL.Image.Image)):
|
||||
batch_size = len(image)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"`image` has to be of type `PIL.Image.Image`, `torch.Tensor`, `List[PIL.Image.Image]` or `List[torch.Tensor]` but is {type(image)}"
|
||||
)
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
batch_size = batch_size * num_images_per_prompt
|
||||
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
image_embeds = self._encode_image(image, device, num_images_per_prompt, do_classifier_free_guidance)
|
||||
|
||||
# prior
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
num_embeddings = self.prior.config.num_embeddings
|
||||
embedding_dim = self.prior.config.embedding_dim
|
||||
|
||||
latents = self.prepare_latents(
|
||||
(batch_size, num_embeddings * embedding_dim),
|
||||
image_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
self.scheduler,
|
||||
)
|
||||
|
||||
# YiYi notes: for testing only to match ldm, we can directly create a latents with desired shape: batch_size, num_embeddings, embedding_dim
|
||||
latents = latents.reshape(latents.shape[0], num_embeddings, embedding_dim)
|
||||
|
||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
scaled_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
noise_pred = self.prior(
|
||||
scaled_model_input,
|
||||
timestep=t,
|
||||
proj_embedding=image_embeds,
|
||||
).predicted_image_embedding
|
||||
|
||||
# remove the variance
|
||||
noise_pred, _ = noise_pred.split(
|
||||
scaled_model_input.shape[2], dim=2
|
||||
) # batch_size, num_embeddings, embedding_dim
|
||||
|
||||
if do_classifier_free_guidance is not None:
|
||||
noise_pred_uncond, noise_pred = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond)
|
||||
|
||||
latents = self.scheduler.step(
|
||||
noise_pred,
|
||||
timestep=t,
|
||||
sample=latents,
|
||||
).prev_sample
|
||||
|
||||
if output_type == "latent":
|
||||
return ShapEPipelineOutput(images=latents)
|
||||
|
||||
images = []
|
||||
for i, latent in enumerate(latents):
|
||||
print()
|
||||
image = self.renderer.decode(
|
||||
latent[None, :],
|
||||
device,
|
||||
size=frame_size,
|
||||
ray_batch_size=4096,
|
||||
n_coarse_samples=64,
|
||||
n_fine_samples=128,
|
||||
)
|
||||
|
||||
images.append(image)
|
||||
|
||||
images = torch.stack(images)
|
||||
|
||||
if output_type not in ["np", "pil"]:
|
||||
raise ValueError(f"Only the output types `pil` and `np` are supported not output_type={output_type}")
|
||||
|
||||
images = images.cpu().numpy()
|
||||
|
||||
if output_type == "pil":
|
||||
images = [self.numpy_to_pil(image) for image in images]
|
||||
|
||||
# Offload last model to CPU
|
||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||
self.final_offload_hook.offload()
|
||||
|
||||
if not return_dict:
|
||||
return (images,)
|
||||
|
||||
return ShapEPipelineOutput(images=images)
|
||||
@@ -0,0 +1,709 @@
|
||||
# Copyright 2023 Open AI and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...models import ModelMixin
|
||||
from ...utils import BaseOutput
|
||||
from .camera import create_pan_cameras
|
||||
|
||||
|
||||
def sample_pmf(pmf: torch.Tensor, n_samples: int) -> torch.Tensor:
|
||||
r"""
|
||||
Sample from the given discrete probability distribution with replacement.
|
||||
|
||||
The i-th bin is assumed to have mass pmf[i].
|
||||
|
||||
Args:
|
||||
pmf: [batch_size, *shape, n_samples, 1] where (pmf.sum(dim=-2) == 1).all()
|
||||
n_samples: number of samples
|
||||
|
||||
Return:
|
||||
indices sampled with replacement
|
||||
"""
|
||||
|
||||
*shape, support_size, last_dim = pmf.shape
|
||||
assert last_dim == 1
|
||||
|
||||
cdf = torch.cumsum(pmf.view(-1, support_size), dim=1)
|
||||
inds = torch.searchsorted(cdf, torch.rand(cdf.shape[0], n_samples, device=cdf.device))
|
||||
|
||||
return inds.view(*shape, n_samples, 1).clamp(0, support_size - 1)
|
||||
|
||||
|
||||
def posenc_nerf(x: torch.Tensor, min_deg: int = 0, max_deg: int = 15) -> torch.Tensor:
|
||||
"""
|
||||
Concatenate x and its positional encodings, following NeRF.
|
||||
|
||||
Reference: https://arxiv.org/pdf/2210.04628.pdf
|
||||
"""
|
||||
if min_deg == max_deg:
|
||||
return x
|
||||
|
||||
scales = 2.0 ** torch.arange(min_deg, max_deg, dtype=x.dtype, device=x.device)
|
||||
*shape, dim = x.shape
|
||||
xb = (x.reshape(-1, 1, dim) * scales.view(1, -1, 1)).reshape(*shape, -1)
|
||||
assert xb.shape[-1] == dim * (max_deg - min_deg)
|
||||
emb = torch.cat([xb, xb + math.pi / 2.0], axis=-1).sin()
|
||||
return torch.cat([x, emb], dim=-1)
|
||||
|
||||
|
||||
def encode_position(position):
|
||||
return posenc_nerf(position, min_deg=0, max_deg=15)
|
||||
|
||||
|
||||
def encode_direction(position, direction=None):
|
||||
if direction is None:
|
||||
return torch.zeros_like(posenc_nerf(position, min_deg=0, max_deg=8))
|
||||
else:
|
||||
return posenc_nerf(direction, min_deg=0, max_deg=8)
|
||||
|
||||
|
||||
def _sanitize_name(x: str) -> str:
|
||||
return x.replace(".", "__")
|
||||
|
||||
|
||||
def integrate_samples(volume_range, ts, density, channels):
|
||||
r"""
|
||||
Function integrating the model output.
|
||||
|
||||
Args:
|
||||
volume_range: Specifies the integral range [t0, t1]
|
||||
ts: timesteps
|
||||
density: torch.Tensor [batch_size, *shape, n_samples, 1]
|
||||
channels: torch.Tensor [batch_size, *shape, n_samples, n_channels]
|
||||
returns:
|
||||
channels: integrated rgb output weights: torch.Tensor [batch_size, *shape, n_samples, 1] (density
|
||||
*transmittance)[i] weight for each rgb output at [..., i, :]. transmittance: transmittance of this volume
|
||||
)
|
||||
"""
|
||||
|
||||
# 1. Calculate the weights
|
||||
_, _, dt = volume_range.partition(ts)
|
||||
ddensity = density * dt
|
||||
|
||||
mass = torch.cumsum(ddensity, dim=-2)
|
||||
transmittance = torch.exp(-mass[..., -1, :])
|
||||
|
||||
alphas = 1.0 - torch.exp(-ddensity)
|
||||
Ts = torch.exp(torch.cat([torch.zeros_like(mass[..., :1, :]), -mass[..., :-1, :]], dim=-2))
|
||||
# This is the probability of light hitting and reflecting off of
|
||||
# something at depth [..., i, :].
|
||||
weights = alphas * Ts
|
||||
|
||||
# 2. Integrate channels
|
||||
channels = torch.sum(channels * weights, dim=-2)
|
||||
|
||||
return channels, weights, transmittance
|
||||
|
||||
|
||||
class VoidNeRFModel(nn.Module):
|
||||
"""
|
||||
Implements the default empty space model where all queries are rendered as background.
|
||||
"""
|
||||
|
||||
def __init__(self, background, channel_scale=255.0):
|
||||
super().__init__()
|
||||
background = nn.Parameter(torch.from_numpy(np.array(background)).to(dtype=torch.float32) / channel_scale)
|
||||
|
||||
self.register_buffer("background", background)
|
||||
|
||||
def forward(self, position):
|
||||
background = self.background[None].to(position.device)
|
||||
|
||||
shape = position.shape[:-1]
|
||||
ones = [1] * (len(shape) - 1)
|
||||
n_channels = background.shape[-1]
|
||||
background = torch.broadcast_to(background.view(background.shape[0], *ones, n_channels), [*shape, n_channels])
|
||||
|
||||
return background
|
||||
|
||||
|
||||
@dataclass
|
||||
class VolumeRange:
|
||||
t0: torch.Tensor
|
||||
t1: torch.Tensor
|
||||
intersected: torch.Tensor
|
||||
|
||||
def __post_init__(self):
|
||||
assert self.t0.shape == self.t1.shape == self.intersected.shape
|
||||
|
||||
def partition(self, ts):
|
||||
"""
|
||||
Partitions t0 and t1 into n_samples intervals.
|
||||
|
||||
Args:
|
||||
ts: [batch_size, *shape, n_samples, 1]
|
||||
|
||||
Return:
|
||||
|
||||
lower: [batch_size, *shape, n_samples, 1] upper: [batch_size, *shape, n_samples, 1] delta: [batch_size,
|
||||
*shape, n_samples, 1]
|
||||
|
||||
where
|
||||
ts \\in [lower, upper] deltas = upper - lower
|
||||
"""
|
||||
|
||||
mids = (ts[..., 1:, :] + ts[..., :-1, :]) * 0.5
|
||||
lower = torch.cat([self.t0[..., None, :], mids], dim=-2)
|
||||
upper = torch.cat([mids, self.t1[..., None, :]], dim=-2)
|
||||
delta = upper - lower
|
||||
assert lower.shape == upper.shape == delta.shape == ts.shape
|
||||
return lower, upper, delta
|
||||
|
||||
|
||||
class BoundingBoxVolume(nn.Module):
|
||||
"""
|
||||
Axis-aligned bounding box defined by the two opposite corners.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
bbox_min,
|
||||
bbox_max,
|
||||
min_dist: float = 0.0,
|
||||
min_t_range: float = 1e-3,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
bbox_min: the left/bottommost corner of the bounding box
|
||||
bbox_max: the other corner of the bounding box
|
||||
min_dist: all rays should start at least this distance away from the origin.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.min_dist = min_dist
|
||||
self.min_t_range = min_t_range
|
||||
|
||||
self.bbox_min = torch.tensor(bbox_min)
|
||||
self.bbox_max = torch.tensor(bbox_max)
|
||||
self.bbox = torch.stack([self.bbox_min, self.bbox_max])
|
||||
assert self.bbox.shape == (2, 3)
|
||||
assert min_dist >= 0.0
|
||||
assert min_t_range > 0.0
|
||||
|
||||
def intersect(
|
||||
self,
|
||||
origin: torch.Tensor,
|
||||
direction: torch.Tensor,
|
||||
t0_lower: Optional[torch.Tensor] = None,
|
||||
epsilon=1e-6,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
origin: [batch_size, *shape, 3]
|
||||
direction: [batch_size, *shape, 3]
|
||||
t0_lower: Optional [batch_size, *shape, 1] lower bound of t0 when intersecting this volume.
|
||||
params: Optional meta parameters in case Volume is parametric
|
||||
epsilon: to stabilize calculations
|
||||
|
||||
Return:
|
||||
A tuple of (t0, t1, intersected) where each has a shape [batch_size, *shape, 1]. If a ray intersects with
|
||||
the volume, `o + td` is in the volume for all t in [t0, t1]. If the volume is bounded, t1 is guaranteed to
|
||||
be on the boundary of the volume.
|
||||
"""
|
||||
|
||||
batch_size, *shape, _ = origin.shape
|
||||
ones = [1] * len(shape)
|
||||
bbox = self.bbox.view(1, *ones, 2, 3).to(origin.device)
|
||||
|
||||
def _safe_divide(a, b, epsilon=1e-6):
|
||||
return a / torch.where(b < 0, b - epsilon, b + epsilon)
|
||||
|
||||
ts = _safe_divide(bbox - origin[..., None, :], direction[..., None, :], epsilon=epsilon)
|
||||
|
||||
# Cases to think about:
|
||||
#
|
||||
# 1. t1 <= t0: the ray does not pass through the AABB.
|
||||
# 2. t0 < t1 <= 0: the ray intersects but the BB is behind the origin.
|
||||
# 3. t0 <= 0 <= t1: the ray starts from inside the BB
|
||||
# 4. 0 <= t0 < t1: the ray is not inside and intersects with the BB twice.
|
||||
#
|
||||
# 1 and 4 are clearly handled from t0 < t1 below.
|
||||
# Making t0 at least min_dist (>= 0) takes care of 2 and 3.
|
||||
t0 = ts.min(dim=-2).values.max(dim=-1, keepdim=True).values.clamp(self.min_dist)
|
||||
t1 = ts.max(dim=-2).values.min(dim=-1, keepdim=True).values
|
||||
assert t0.shape == t1.shape == (batch_size, *shape, 1)
|
||||
if t0_lower is not None:
|
||||
assert t0.shape == t0_lower.shape
|
||||
t0 = torch.maximum(t0, t0_lower)
|
||||
|
||||
intersected = t0 + self.min_t_range < t1
|
||||
t0 = torch.where(intersected, t0, torch.zeros_like(t0))
|
||||
t1 = torch.where(intersected, t1, torch.ones_like(t1))
|
||||
|
||||
return VolumeRange(t0=t0, t1=t1, intersected=intersected)
|
||||
|
||||
|
||||
class StratifiedRaySampler(nn.Module):
|
||||
"""
|
||||
Instead of fixed intervals, a sample is drawn uniformly at random from each interval.
|
||||
"""
|
||||
|
||||
def __init__(self, depth_mode: str = "linear"):
|
||||
"""
|
||||
:param depth_mode: linear samples ts linearly in depth. harmonic ensures
|
||||
closer points are sampled more densely.
|
||||
"""
|
||||
self.depth_mode = depth_mode
|
||||
assert self.depth_mode in ("linear", "geometric", "harmonic")
|
||||
|
||||
def sample(
|
||||
self,
|
||||
t0: torch.Tensor,
|
||||
t1: torch.Tensor,
|
||||
n_samples: int,
|
||||
epsilon: float = 1e-3,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
t0: start time has shape [batch_size, *shape, 1]
|
||||
t1: finish time has shape [batch_size, *shape, 1]
|
||||
n_samples: number of ts to sample
|
||||
Return:
|
||||
sampled ts of shape [batch_size, *shape, n_samples, 1]
|
||||
"""
|
||||
ones = [1] * (len(t0.shape) - 1)
|
||||
ts = torch.linspace(0, 1, n_samples).view(*ones, n_samples).to(t0.dtype).to(t0.device)
|
||||
|
||||
if self.depth_mode == "linear":
|
||||
ts = t0 * (1.0 - ts) + t1 * ts
|
||||
elif self.depth_mode == "geometric":
|
||||
ts = (t0.clamp(epsilon).log() * (1.0 - ts) + t1.clamp(epsilon).log() * ts).exp()
|
||||
elif self.depth_mode == "harmonic":
|
||||
# The original NeRF recommends this interpolation scheme for
|
||||
# spherical scenes, but there could be some weird edge cases when
|
||||
# the observer crosses from the inner to outer volume.
|
||||
ts = 1.0 / (1.0 / t0.clamp(epsilon) * (1.0 - ts) + 1.0 / t1.clamp(epsilon) * ts)
|
||||
|
||||
mids = 0.5 * (ts[..., 1:] + ts[..., :-1])
|
||||
upper = torch.cat([mids, t1], dim=-1)
|
||||
lower = torch.cat([t0, mids], dim=-1)
|
||||
# yiyi notes: add a random seed here for testing, don't forget to remove
|
||||
torch.manual_seed(0)
|
||||
t_rand = torch.rand_like(ts)
|
||||
|
||||
ts = lower + (upper - lower) * t_rand
|
||||
return ts.unsqueeze(-1)
|
||||
|
||||
|
||||
class ImportanceRaySampler(nn.Module):
|
||||
"""
|
||||
Given the initial estimate of densities, this samples more from regions/bins expected to have objects.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
volume_range: VolumeRange,
|
||||
ts: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
blur_pool: bool = False,
|
||||
alpha: float = 1e-5,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
volume_range: the range in which a ray intersects the given volume.
|
||||
ts: earlier samples from the coarse rendering step
|
||||
weights: discretized version of density * transmittance
|
||||
blur_pool: if true, use 2-tap max + 2-tap blur filter from mip-NeRF.
|
||||
alpha: small value to add to weights.
|
||||
"""
|
||||
self.volume_range = volume_range
|
||||
self.ts = ts.clone().detach()
|
||||
self.weights = weights.clone().detach()
|
||||
self.blur_pool = blur_pool
|
||||
self.alpha = alpha
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(self, t0: torch.Tensor, t1: torch.Tensor, n_samples: int) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
t0: start time has shape [batch_size, *shape, 1]
|
||||
t1: finish time has shape [batch_size, *shape, 1]
|
||||
n_samples: number of ts to sample
|
||||
Return:
|
||||
sampled ts of shape [batch_size, *shape, n_samples, 1]
|
||||
"""
|
||||
lower, upper, _ = self.volume_range.partition(self.ts)
|
||||
|
||||
batch_size, *shape, n_coarse_samples, _ = self.ts.shape
|
||||
|
||||
weights = self.weights
|
||||
if self.blur_pool:
|
||||
padded = torch.cat([weights[..., :1, :], weights, weights[..., -1:, :]], dim=-2)
|
||||
maxes = torch.maximum(padded[..., :-1, :], padded[..., 1:, :])
|
||||
weights = 0.5 * (maxes[..., :-1, :] + maxes[..., 1:, :])
|
||||
weights = weights + self.alpha
|
||||
pmf = weights / weights.sum(dim=-2, keepdim=True)
|
||||
inds = sample_pmf(pmf, n_samples)
|
||||
assert inds.shape == (batch_size, *shape, n_samples, 1)
|
||||
assert (inds >= 0).all() and (inds < n_coarse_samples).all()
|
||||
|
||||
t_rand = torch.rand(inds.shape, device=inds.device)
|
||||
lower_ = torch.gather(lower, -2, inds)
|
||||
upper_ = torch.gather(upper, -2, inds)
|
||||
|
||||
ts = lower_ + (upper_ - lower_) * t_rand
|
||||
ts = torch.sort(ts, dim=-2).values
|
||||
return ts
|
||||
|
||||
|
||||
@dataclass
|
||||
class MLPNeRFModelOutput(BaseOutput):
|
||||
density: torch.Tensor
|
||||
signed_distance: torch.Tensor
|
||||
channels: torch.Tensor
|
||||
ts: torch.Tensor
|
||||
|
||||
|
||||
class MLPNeRSTFModel(ModelMixin, ConfigMixin):
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
d_hidden: int = 256,
|
||||
n_output: int = 12,
|
||||
n_hidden_layers: int = 6,
|
||||
act_fn: str = "swish",
|
||||
insert_direction_at: int = 4,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Instantiate the MLP
|
||||
|
||||
# Find out the dimension of encoded position and direction
|
||||
dummy = torch.eye(1, 3)
|
||||
d_posenc_pos = encode_position(position=dummy).shape[-1]
|
||||
d_posenc_dir = encode_direction(position=dummy).shape[-1]
|
||||
|
||||
mlp_widths = [d_hidden] * n_hidden_layers
|
||||
input_widths = [d_posenc_pos] + mlp_widths
|
||||
output_widths = mlp_widths + [n_output]
|
||||
|
||||
if insert_direction_at is not None:
|
||||
input_widths[insert_direction_at] += d_posenc_dir
|
||||
|
||||
self.mlp = nn.ModuleList([nn.Linear(d_in, d_out) for d_in, d_out in zip(input_widths, output_widths)])
|
||||
|
||||
if act_fn == "swish":
|
||||
# self.activation = swish
|
||||
# yiyi testing:
|
||||
self.activation = lambda x: F.silu(x)
|
||||
else:
|
||||
raise ValueError(f"Unsupported activation function {act_fn}")
|
||||
|
||||
self.sdf_activation = torch.tanh
|
||||
self.density_activation = torch.nn.functional.relu
|
||||
self.channel_activation = torch.sigmoid
|
||||
|
||||
def map_indices_to_keys(self, output):
|
||||
h_map = {
|
||||
"sdf": (0, 1),
|
||||
"density_coarse": (1, 2),
|
||||
"density_fine": (2, 3),
|
||||
"stf": (3, 6),
|
||||
"nerf_coarse": (6, 9),
|
||||
"nerf_fine": (9, 12),
|
||||
}
|
||||
|
||||
mapped_output = {k: output[..., start:end] for k, (start, end) in h_map.items()}
|
||||
|
||||
return mapped_output
|
||||
|
||||
def forward(self, *, position, direction, ts, nerf_level="coarse"):
|
||||
h = encode_position(position)
|
||||
|
||||
h_preact = h
|
||||
h_directionless = None
|
||||
for i, layer in enumerate(self.mlp):
|
||||
if i == self.config.insert_direction_at: # 4 in the config
|
||||
h_directionless = h_preact
|
||||
h_direction = encode_direction(position, direction=direction)
|
||||
h = torch.cat([h, h_direction], dim=-1)
|
||||
|
||||
h = layer(h)
|
||||
|
||||
h_preact = h
|
||||
|
||||
if i < len(self.mlp) - 1:
|
||||
h = self.activation(h)
|
||||
|
||||
h_final = h
|
||||
if h_directionless is None:
|
||||
h_directionless = h_preact
|
||||
|
||||
activation = self.map_indices_to_keys(h_final)
|
||||
|
||||
if nerf_level == "coarse":
|
||||
h_density = activation["density_coarse"]
|
||||
h_channels = activation["nerf_coarse"]
|
||||
else:
|
||||
h_density = activation["density_fine"]
|
||||
h_channels = activation["nerf_fine"]
|
||||
|
||||
density = self.density_activation(h_density)
|
||||
signed_distance = self.sdf_activation(activation["sdf"])
|
||||
channels = self.channel_activation(h_channels)
|
||||
|
||||
# yiyi notes: I think signed_distance is not used
|
||||
return MLPNeRFModelOutput(density=density, signed_distance=signed_distance, channels=channels, ts=ts)
|
||||
|
||||
|
||||
class ChannelsProj(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
vectors: int,
|
||||
channels: int,
|
||||
d_latent: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.proj = nn.Linear(d_latent, vectors * channels)
|
||||
self.norm = nn.LayerNorm(channels)
|
||||
self.d_latent = d_latent
|
||||
self.vectors = vectors
|
||||
self.channels = channels
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x_bvd = x
|
||||
w_vcd = self.proj.weight.view(self.vectors, self.channels, self.d_latent)
|
||||
b_vc = self.proj.bias.view(1, self.vectors, self.channels)
|
||||
h = torch.einsum("bvd,vcd->bvc", x_bvd, w_vcd)
|
||||
h = self.norm(h)
|
||||
|
||||
h = h + b_vc
|
||||
return h
|
||||
|
||||
|
||||
class ShapEParamsProjModel(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
project the latent representation of a 3D asset to obtain weights of a multi-layer perceptron (MLP).
|
||||
|
||||
For more details, see the original paper:
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
param_names: Tuple[str] = (
|
||||
"nerstf.mlp.0.weight",
|
||||
"nerstf.mlp.1.weight",
|
||||
"nerstf.mlp.2.weight",
|
||||
"nerstf.mlp.3.weight",
|
||||
),
|
||||
param_shapes: Tuple[Tuple[int]] = (
|
||||
(256, 93),
|
||||
(256, 256),
|
||||
(256, 256),
|
||||
(256, 256),
|
||||
),
|
||||
d_latent: int = 1024,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# check inputs
|
||||
if len(param_names) != len(param_shapes):
|
||||
raise ValueError("Must provide same number of `param_names` as `param_shapes`")
|
||||
self.projections = nn.ModuleDict({})
|
||||
for k, (vectors, channels) in zip(param_names, param_shapes):
|
||||
self.projections[_sanitize_name(k)] = ChannelsProj(
|
||||
vectors=vectors,
|
||||
channels=channels,
|
||||
d_latent=d_latent,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
out = {}
|
||||
start = 0
|
||||
for k, shape in zip(self.config.param_names, self.config.param_shapes):
|
||||
vectors, _ = shape
|
||||
end = start + vectors
|
||||
x_bvd = x[:, start:end]
|
||||
out[k] = self.projections[_sanitize_name(k)](x_bvd).reshape(len(x), *shape)
|
||||
start = end
|
||||
return out
|
||||
|
||||
|
||||
class ShapERenderer(ModelMixin, ConfigMixin):
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
param_names: Tuple[str] = (
|
||||
"nerstf.mlp.0.weight",
|
||||
"nerstf.mlp.1.weight",
|
||||
"nerstf.mlp.2.weight",
|
||||
"nerstf.mlp.3.weight",
|
||||
),
|
||||
param_shapes: Tuple[Tuple[int]] = (
|
||||
(256, 93),
|
||||
(256, 256),
|
||||
(256, 256),
|
||||
(256, 256),
|
||||
),
|
||||
d_latent: int = 1024,
|
||||
d_hidden: int = 256,
|
||||
n_output: int = 12,
|
||||
n_hidden_layers: int = 6,
|
||||
act_fn: str = "swish",
|
||||
insert_direction_at: int = 4,
|
||||
background: Tuple[float] = (
|
||||
255.0,
|
||||
255.0,
|
||||
255.0,
|
||||
),
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.params_proj = ShapEParamsProjModel(
|
||||
param_names=param_names,
|
||||
param_shapes=param_shapes,
|
||||
d_latent=d_latent,
|
||||
)
|
||||
self.mlp = MLPNeRSTFModel(d_hidden, n_output, n_hidden_layers, act_fn, insert_direction_at)
|
||||
self.void = VoidNeRFModel(background=background, channel_scale=255.0)
|
||||
self.volume = BoundingBoxVolume(bbox_max=[1.0, 1.0, 1.0], bbox_min=[-1.0, -1.0, -1.0])
|
||||
|
||||
@torch.no_grad()
|
||||
def render_rays(self, rays, sampler, n_samples, prev_model_out=None, render_with_direction=False):
|
||||
"""
|
||||
Perform volumetric rendering over a partition of possible t's in the union of rendering volumes (written below
|
||||
with some abuse of notations)
|
||||
|
||||
C(r) := sum(
|
||||
transmittance(t[i]) * integrate(
|
||||
lambda t: density(t) * channels(t) * transmittance(t), [t[i], t[i + 1]],
|
||||
) for i in range(len(parts))
|
||||
) + transmittance(t[-1]) * void_model(t[-1]).channels
|
||||
|
||||
where
|
||||
|
||||
1) transmittance(s) := exp(-integrate(density, [t[0], s])) calculates the probability of light passing through
|
||||
the volume specified by [t[0], s]. (transmittance of 1 means light can pass freely) 2) density and channels are
|
||||
obtained by evaluating the appropriate part.model at time t. 3) [t[i], t[i + 1]] is defined as the range of t
|
||||
where the ray intersects (parts[i].volume \\ union(part.volume for part in parts[:i])) at the surface of the
|
||||
shell (if bounded). If the ray does not intersect, the integral over this segment is evaluated as 0 and
|
||||
transmittance(t[i + 1]) := transmittance(t[i]). 4) The last term is integration to infinity (e.g. [t[-1],
|
||||
math.inf]) that is evaluated by the void_model (i.e. we consider this space to be empty).
|
||||
|
||||
args:
|
||||
rays: [batch_size x ... x 2 x 3] origin and direction. sampler: disjoint volume integrals. n_samples:
|
||||
number of ts to sample. prev_model_outputs: model outputs from the previous rendering step, including
|
||||
|
||||
:return: A tuple of
|
||||
- `channels`
|
||||
- A importance samplers for additional fine-grained rendering
|
||||
- raw model output
|
||||
"""
|
||||
origin, direction = rays[..., 0, :], rays[..., 1, :]
|
||||
|
||||
# Integrate over [t[i], t[i + 1]]
|
||||
|
||||
# 1 Intersect the rays with the current volume and sample ts to integrate along.
|
||||
vrange = self.volume.intersect(origin, direction, t0_lower=None)
|
||||
ts = sampler.sample(vrange.t0, vrange.t1, n_samples)
|
||||
ts = ts.to(rays.dtype)
|
||||
|
||||
if prev_model_out is not None:
|
||||
# Append the previous ts now before fprop because previous
|
||||
# rendering used a different model and we can't reuse the output.
|
||||
ts = torch.sort(torch.cat([ts, prev_model_out.ts], dim=-2), dim=-2).values
|
||||
|
||||
batch_size, *_shape, _t0_dim = vrange.t0.shape
|
||||
_, *ts_shape, _ts_dim = ts.shape
|
||||
|
||||
# 2. Get the points along the ray and query the model
|
||||
directions = torch.broadcast_to(direction.unsqueeze(-2), [batch_size, *ts_shape, 3])
|
||||
positions = origin.unsqueeze(-2) + ts * directions
|
||||
|
||||
directions = directions.to(self.mlp.dtype)
|
||||
positions = positions.to(self.mlp.dtype)
|
||||
|
||||
optional_directions = directions if render_with_direction else None
|
||||
|
||||
model_out = self.mlp(
|
||||
position=positions,
|
||||
direction=optional_directions,
|
||||
ts=ts,
|
||||
nerf_level="coarse" if prev_model_out is None else "fine",
|
||||
)
|
||||
|
||||
# 3. Integrate the model results
|
||||
channels, weights, transmittance = integrate_samples(
|
||||
vrange, model_out.ts, model_out.density, model_out.channels
|
||||
)
|
||||
|
||||
# 4. Clean up results that do not intersect with the volume.
|
||||
transmittance = torch.where(vrange.intersected, transmittance, torch.ones_like(transmittance))
|
||||
channels = torch.where(vrange.intersected, channels, torch.zeros_like(channels))
|
||||
# 5. integration to infinity (e.g. [t[-1], math.inf]) that is evaluated by the void_model (i.e. we consider this space to be empty).
|
||||
channels = channels + transmittance * self.void(origin)
|
||||
|
||||
weighted_sampler = ImportanceRaySampler(vrange, ts=model_out.ts, weights=weights)
|
||||
|
||||
return channels, weighted_sampler, model_out
|
||||
|
||||
@torch.no_grad()
|
||||
def decode(
|
||||
self,
|
||||
latents,
|
||||
device,
|
||||
size: int = 64,
|
||||
ray_batch_size: int = 4096,
|
||||
n_coarse_samples=64,
|
||||
n_fine_samples=128,
|
||||
):
|
||||
# project the the paramters from the generated latents
|
||||
projected_params = self.params_proj(latents)
|
||||
|
||||
# update the mlp layers of the renderer
|
||||
for name, param in self.mlp.state_dict().items():
|
||||
if f"nerstf.{name}" in projected_params.keys():
|
||||
param.copy_(projected_params[f"nerstf.{name}"].squeeze(0))
|
||||
|
||||
# create cameras object
|
||||
camera = create_pan_cameras(size)
|
||||
rays = camera.camera_rays
|
||||
rays = rays.to(device)
|
||||
n_batches = rays.shape[1] // ray_batch_size
|
||||
|
||||
coarse_sampler = StratifiedRaySampler()
|
||||
|
||||
images = []
|
||||
|
||||
for idx in range(n_batches):
|
||||
rays_batch = rays[:, idx * ray_batch_size : (idx + 1) * ray_batch_size]
|
||||
|
||||
# render rays with coarse, stratified samples.
|
||||
_, fine_sampler, coarse_model_out = self.render_rays(rays_batch, coarse_sampler, n_coarse_samples)
|
||||
# Then, render with additional importance-weighted ray samples.
|
||||
channels, _, _ = self.render_rays(
|
||||
rays_batch, fine_sampler, n_fine_samples, prev_model_out=coarse_model_out
|
||||
)
|
||||
|
||||
images.append(channels)
|
||||
|
||||
images = torch.cat(images, dim=1)
|
||||
images = images.view(*camera.shape, camera.height, camera.width, -1).squeeze(0)
|
||||
|
||||
return images
|
||||
@@ -24,6 +24,7 @@ from transformers import (
|
||||
AutoFeatureExtractor,
|
||||
BertTokenizerFast,
|
||||
CLIPImageProcessor,
|
||||
CLIPTextConfig,
|
||||
CLIPTextModel,
|
||||
CLIPTextModelWithProjection,
|
||||
CLIPTokenizer,
|
||||
@@ -48,7 +49,7 @@ from ...schedulers import (
|
||||
PNDMScheduler,
|
||||
UnCLIPScheduler,
|
||||
)
|
||||
from ...utils import is_omegaconf_available, is_safetensors_available, logging
|
||||
from ...utils import is_accelerate_available, is_omegaconf_available, is_safetensors_available, logging
|
||||
from ...utils.import_utils import BACKENDS_MAPPING
|
||||
from ..latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
|
||||
from ..paint_by_example import PaintByExampleImageEncoder
|
||||
@@ -57,6 +58,10 @@ from .safety_checker import StableDiffusionSafetyChecker
|
||||
from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate import init_empty_weights
|
||||
from accelerate.utils import set_module_tensor_to_device
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
@@ -233,7 +238,10 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
|
||||
if controlnet:
|
||||
unet_params = original_config.model.params.control_stage_config.params
|
||||
else:
|
||||
unet_params = original_config.model.params.unet_config.params
|
||||
if "unet_config" in original_config.model.params and original_config.model.params.unet_config is not None:
|
||||
unet_params = original_config.model.params.unet_config.params
|
||||
else:
|
||||
unet_params = original_config.model.params.network_config.params
|
||||
|
||||
vae_params = original_config.model.params.first_stage_config.params.ddconfig
|
||||
|
||||
@@ -253,6 +261,15 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
|
||||
up_block_types.append(block_type)
|
||||
resolution //= 2
|
||||
|
||||
if unet_params.transformer_depth is not None:
|
||||
transformer_layers_per_block = (
|
||||
unet_params.transformer_depth
|
||||
if isinstance(unet_params.transformer_depth, int)
|
||||
else list(unet_params.transformer_depth)
|
||||
)
|
||||
else:
|
||||
transformer_layers_per_block = 1
|
||||
|
||||
vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1)
|
||||
|
||||
head_dim = unet_params.num_heads if "num_heads" in unet_params else None
|
||||
@@ -262,14 +279,28 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
|
||||
if use_linear_projection:
|
||||
# stable diffusion 2-base-512 and 2-768
|
||||
if head_dim is None:
|
||||
head_dim = [5, 10, 20, 20]
|
||||
head_dim_mult = unet_params.model_channels // unet_params.num_head_channels
|
||||
head_dim = [head_dim_mult * c for c in list(unet_params.channel_mult)]
|
||||
|
||||
class_embed_type = None
|
||||
addition_embed_type = None
|
||||
addition_time_embed_dim = None
|
||||
projection_class_embeddings_input_dim = None
|
||||
context_dim = None
|
||||
|
||||
if unet_params.context_dim is not None:
|
||||
context_dim = (
|
||||
unet_params.context_dim if isinstance(unet_params.context_dim, int) else unet_params.context_dim[0]
|
||||
)
|
||||
|
||||
if "num_classes" in unet_params:
|
||||
if unet_params.num_classes == "sequential":
|
||||
class_embed_type = "projection"
|
||||
if context_dim in [2048, 1280]:
|
||||
# SDXL
|
||||
addition_embed_type = "text_time"
|
||||
addition_time_embed_dim = 256
|
||||
else:
|
||||
class_embed_type = "projection"
|
||||
assert "adm_in_channels" in unet_params
|
||||
projection_class_embeddings_input_dim = unet_params.adm_in_channels
|
||||
else:
|
||||
@@ -281,11 +312,14 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
|
||||
"down_block_types": tuple(down_block_types),
|
||||
"block_out_channels": tuple(block_out_channels),
|
||||
"layers_per_block": unet_params.num_res_blocks,
|
||||
"cross_attention_dim": unet_params.context_dim,
|
||||
"cross_attention_dim": context_dim,
|
||||
"attention_head_dim": head_dim,
|
||||
"use_linear_projection": use_linear_projection,
|
||||
"class_embed_type": class_embed_type,
|
||||
"addition_embed_type": addition_embed_type,
|
||||
"addition_time_embed_dim": addition_time_embed_dim,
|
||||
"projection_class_embeddings_input_dim": projection_class_embeddings_input_dim,
|
||||
"transformer_layers_per_block": transformer_layers_per_block,
|
||||
}
|
||||
|
||||
if controlnet:
|
||||
@@ -362,8 +396,8 @@ def convert_ldm_unet_checkpoint(
|
||||
|
||||
# at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
|
||||
if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema:
|
||||
print(f"Checkpoint {path} has both EMA and non-EMA weights.")
|
||||
print(
|
||||
logger.warning(f"Checkpoint {path} has both EMA and non-EMA weights.")
|
||||
logger.warning(
|
||||
"In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
|
||||
" weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
|
||||
)
|
||||
@@ -373,7 +407,7 @@ def convert_ldm_unet_checkpoint(
|
||||
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
|
||||
else:
|
||||
if sum(k.startswith("model_ema") for k in keys) > 100:
|
||||
print(
|
||||
logger.warning(
|
||||
"In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
|
||||
" weights (usually better for inference), please make sure to add the `--extract_ema` flag."
|
||||
)
|
||||
@@ -400,6 +434,12 @@ def convert_ldm_unet_checkpoint(
|
||||
else:
|
||||
raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}")
|
||||
|
||||
if config["addition_embed_type"] == "text_time":
|
||||
new_checkpoint["add_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"]
|
||||
new_checkpoint["add_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"]
|
||||
new_checkpoint["add_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"]
|
||||
new_checkpoint["add_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"]
|
||||
|
||||
new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
|
||||
new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
|
||||
|
||||
@@ -735,30 +775,36 @@ def convert_ldm_bert_checkpoint(checkpoint, config):
|
||||
|
||||
|
||||
def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False, text_encoder=None):
|
||||
text_model = (
|
||||
CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", local_files_only=local_files_only)
|
||||
if text_encoder is None
|
||||
else text_encoder
|
||||
)
|
||||
if text_encoder is None:
|
||||
config_name = "openai/clip-vit-large-patch14"
|
||||
config = CLIPTextConfig.from_pretrained(config_name)
|
||||
|
||||
with init_empty_weights():
|
||||
text_model = CLIPTextModel(config)
|
||||
|
||||
keys = list(checkpoint.keys())
|
||||
|
||||
text_model_dict = {}
|
||||
|
||||
for key in keys:
|
||||
if key.startswith("cond_stage_model.transformer"):
|
||||
text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
|
||||
remove_prefixes = ["cond_stage_model.transformer", "conditioner.embedders.0.transformer"]
|
||||
|
||||
text_model.load_state_dict(text_model_dict)
|
||||
for key in keys:
|
||||
for prefix in remove_prefixes:
|
||||
if key.startswith(prefix):
|
||||
text_model_dict[key[len(prefix + ".") :]] = checkpoint[key]
|
||||
|
||||
for param_name, param in text_model_dict.items():
|
||||
set_module_tensor_to_device(text_model, param_name, "cpu", value=param)
|
||||
|
||||
return text_model
|
||||
|
||||
|
||||
textenc_conversion_lst = [
|
||||
("cond_stage_model.model.positional_embedding", "text_model.embeddings.position_embedding.weight"),
|
||||
("cond_stage_model.model.token_embedding.weight", "text_model.embeddings.token_embedding.weight"),
|
||||
("cond_stage_model.model.ln_final.weight", "text_model.final_layer_norm.weight"),
|
||||
("cond_stage_model.model.ln_final.bias", "text_model.final_layer_norm.bias"),
|
||||
("positional_embedding", "text_model.embeddings.position_embedding.weight"),
|
||||
("token_embedding.weight", "text_model.embeddings.token_embedding.weight"),
|
||||
("ln_final.weight", "text_model.final_layer_norm.weight"),
|
||||
("ln_final.bias", "text_model.final_layer_norm.bias"),
|
||||
("text_projection", "text_projection.weight"),
|
||||
]
|
||||
textenc_conversion_map = {x[0]: x[1] for x in textenc_conversion_lst}
|
||||
|
||||
@@ -845,27 +891,48 @@ def convert_paint_by_example_checkpoint(checkpoint):
|
||||
return model
|
||||
|
||||
|
||||
def convert_open_clip_checkpoint(checkpoint):
|
||||
text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder")
|
||||
def convert_open_clip_checkpoint(
|
||||
checkpoint, config_name, prefix="cond_stage_model.model.", has_projection=False, **config_kwargs
|
||||
):
|
||||
# text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder")
|
||||
# text_model = CLIPTextModelWithProjection.from_pretrained(
|
||||
# "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", projection_dim=1280
|
||||
# )
|
||||
config = CLIPTextConfig.from_pretrained(config_name, **config_kwargs)
|
||||
|
||||
with init_empty_weights():
|
||||
text_model = CLIPTextModelWithProjection(config) if has_projection else CLIPTextModel(config)
|
||||
|
||||
keys = list(checkpoint.keys())
|
||||
|
||||
keys_to_ignore = []
|
||||
if config_name == "stabilityai/stable-diffusion-2" and config.num_hidden_layers == 23:
|
||||
# make sure to remove all keys > 22
|
||||
keys_to_ignore += [k for k in keys if k.startswith("cond_stage_model.model.transformer.resblocks.23")]
|
||||
keys_to_ignore += ["cond_stage_model.model.text_projection"]
|
||||
|
||||
text_model_dict = {}
|
||||
|
||||
if "cond_stage_model.model.text_projection" in checkpoint:
|
||||
d_model = int(checkpoint["cond_stage_model.model.text_projection"].shape[0])
|
||||
if prefix + "text_projection" in checkpoint:
|
||||
d_model = int(checkpoint[prefix + "text_projection"].shape[0])
|
||||
else:
|
||||
d_model = 1024
|
||||
|
||||
text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids")
|
||||
|
||||
for key in keys:
|
||||
if "resblocks.23" in key: # Diffusers drops the final layer and only uses the penultimate layer
|
||||
if key in keys_to_ignore:
|
||||
continue
|
||||
if key in textenc_conversion_map:
|
||||
text_model_dict[textenc_conversion_map[key]] = checkpoint[key]
|
||||
if key.startswith("cond_stage_model.model.transformer."):
|
||||
new_key = key[len("cond_stage_model.model.transformer.") :]
|
||||
if key[len(prefix) :] in textenc_conversion_map:
|
||||
if key.endswith("text_projection"):
|
||||
value = checkpoint[key].T
|
||||
else:
|
||||
value = checkpoint[key]
|
||||
|
||||
text_model_dict[textenc_conversion_map[key[len(prefix) :]]] = value
|
||||
|
||||
if key.startswith(prefix + "transformer."):
|
||||
new_key = key[len(prefix + "transformer.") :]
|
||||
if new_key.endswith(".in_proj_weight"):
|
||||
new_key = new_key[: -len(".in_proj_weight")]
|
||||
new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
|
||||
@@ -883,7 +950,8 @@ def convert_open_clip_checkpoint(checkpoint):
|
||||
|
||||
text_model_dict[new_key] = checkpoint[key]
|
||||
|
||||
text_model.load_state_dict(text_model_dict)
|
||||
for param_name, param in text_model_dict.items():
|
||||
set_module_tensor_to_device(text_model, param_name, "cpu", value=param)
|
||||
|
||||
return text_model
|
||||
|
||||
@@ -1013,7 +1081,7 @@ def convert_controlnet_checkpoint(
|
||||
def download_from_original_stable_diffusion_ckpt(
|
||||
checkpoint_path: str,
|
||||
original_config_file: str = None,
|
||||
image_size: int = 512,
|
||||
image_size: Optional[int] = None,
|
||||
prediction_type: str = None,
|
||||
model_type: str = None,
|
||||
extract_ema: bool = False,
|
||||
@@ -1029,6 +1097,7 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
load_safety_checker: bool = True,
|
||||
pipeline_class: DiffusionPipeline = None,
|
||||
local_files_only=False,
|
||||
vae_path=None,
|
||||
text_encoder=None,
|
||||
tokenizer=None,
|
||||
) -> DiffusionPipeline:
|
||||
@@ -1090,12 +1159,15 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
return: A StableDiffusionPipeline object representing the passed-in `.ckpt`/`.safetensors` file.
|
||||
"""
|
||||
|
||||
# import pipelines here to avoid circular import error when using from_ckpt method
|
||||
# import pipelines here to avoid circular import error when using from_single_file method
|
||||
from diffusers import (
|
||||
LDMTextToImagePipeline,
|
||||
PaintByExamplePipeline,
|
||||
StableDiffusionControlNetPipeline,
|
||||
StableDiffusionInpaintPipeline,
|
||||
StableDiffusionPipeline,
|
||||
StableDiffusionXLImg2ImgPipeline,
|
||||
StableDiffusionXLPipeline,
|
||||
StableUnCLIPImg2ImgPipeline,
|
||||
StableUnCLIPPipeline,
|
||||
)
|
||||
@@ -1115,12 +1187,9 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
if not is_safetensors_available():
|
||||
raise ValueError(BACKENDS_MAPPING["safetensors"][1])
|
||||
|
||||
from safetensors import safe_open
|
||||
from safetensors.torch import load_file as safe_load
|
||||
|
||||
checkpoint = {}
|
||||
with safe_open(checkpoint_path, framework="pt", device="cpu") as f:
|
||||
for key in f.keys():
|
||||
checkpoint[key] = f.get_tensor(key)
|
||||
checkpoint = safe_load(checkpoint_path, device="cpu")
|
||||
else:
|
||||
if device is None:
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
@@ -1132,7 +1201,7 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
if "global_step" in checkpoint:
|
||||
global_step = checkpoint["global_step"]
|
||||
else:
|
||||
print("global_step key not found in model")
|
||||
logger.debug("global_step key not found in model")
|
||||
global_step = None
|
||||
|
||||
# NOTE: this while loop isn't great but this controlnet checkpoint has one additional
|
||||
@@ -1141,24 +1210,53 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
checkpoint = checkpoint["state_dict"]
|
||||
|
||||
if original_config_file is None:
|
||||
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
||||
key_name_v2_1 = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
||||
key_name_sd_xl_base = "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias"
|
||||
key_name_sd_xl_refiner = "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias"
|
||||
|
||||
# model_type = "v1"
|
||||
config_url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
|
||||
|
||||
if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024:
|
||||
if key_name_v2_1 in checkpoint and checkpoint[key_name_v2_1].shape[-1] == 1024:
|
||||
# model_type = "v2"
|
||||
config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml"
|
||||
|
||||
if global_step == 110000:
|
||||
# v2.1 needs to upcast attention
|
||||
upcast_attention = True
|
||||
elif key_name_sd_xl_base in checkpoint:
|
||||
# only base xl has two text embedders
|
||||
config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml"
|
||||
elif key_name_sd_xl_refiner in checkpoint:
|
||||
# only refiner xl has embedder and one text embedders
|
||||
config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_refiner.yaml"
|
||||
|
||||
original_config_file = BytesIO(requests.get(config_url).content)
|
||||
|
||||
original_config = OmegaConf.load(original_config_file)
|
||||
|
||||
if num_in_channels is not None:
|
||||
# Convert the text model.
|
||||
if (
|
||||
model_type is None
|
||||
and "cond_stage_config" in original_config.model.params
|
||||
and original_config.model.params.cond_stage_config is not None
|
||||
):
|
||||
model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
|
||||
logger.debug(f"no `model_type` given, `model_type` inferred as: {model_type}")
|
||||
elif model_type is None and original_config.model.params.network_config is not None:
|
||||
if original_config.model.params.network_config.params.context_dim == 2048:
|
||||
model_type = "SDXL"
|
||||
else:
|
||||
model_type = "SDXL-Refiner"
|
||||
if image_size is None:
|
||||
image_size = 1024
|
||||
|
||||
if num_in_channels is None and pipeline_class == StableDiffusionInpaintPipeline:
|
||||
num_in_channels = 9
|
||||
elif num_in_channels is None:
|
||||
num_in_channels = 4
|
||||
|
||||
if "unet_config" in original_config.model.params:
|
||||
original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels
|
||||
|
||||
if (
|
||||
@@ -1187,20 +1285,37 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema
|
||||
)
|
||||
|
||||
num_train_timesteps = original_config.model.params.timesteps
|
||||
beta_start = original_config.model.params.linear_start
|
||||
beta_end = original_config.model.params.linear_end
|
||||
num_train_timesteps = getattr(original_config.model.params, "timesteps", None) or 1000
|
||||
|
||||
scheduler = DDIMScheduler(
|
||||
beta_end=beta_end,
|
||||
beta_schedule="scaled_linear",
|
||||
beta_start=beta_start,
|
||||
num_train_timesteps=num_train_timesteps,
|
||||
steps_offset=1,
|
||||
clip_sample=False,
|
||||
set_alpha_to_one=False,
|
||||
prediction_type=prediction_type,
|
||||
)
|
||||
if model_type in ["SDXL", "SDXL-Refiner"]:
|
||||
scheduler_dict = {
|
||||
"beta_schedule": "scaled_linear",
|
||||
"beta_start": 0.00085,
|
||||
"beta_end": 0.012,
|
||||
"interpolation_type": "linear",
|
||||
"num_train_timesteps": num_train_timesteps,
|
||||
"prediction_type": "epsilon",
|
||||
"sample_max_value": 1.0,
|
||||
"set_alpha_to_one": False,
|
||||
"skip_prk_steps": True,
|
||||
"steps_offset": 1,
|
||||
"timestep_spacing": "leading",
|
||||
}
|
||||
scheduler = EulerDiscreteScheduler.from_config(scheduler_dict)
|
||||
scheduler_type = "euler"
|
||||
else:
|
||||
beta_start = getattr(original_config.model.params, "linear_start", None) or 0.02
|
||||
beta_end = getattr(original_config.model.params, "linear_end", None) or 0.085
|
||||
scheduler = DDIMScheduler(
|
||||
beta_end=beta_end,
|
||||
beta_schedule="scaled_linear",
|
||||
beta_start=beta_start,
|
||||
num_train_timesteps=num_train_timesteps,
|
||||
steps_offset=1,
|
||||
clip_sample=False,
|
||||
set_alpha_to_one=False,
|
||||
prediction_type=prediction_type,
|
||||
)
|
||||
# make sure scheduler works correctly with DDIM
|
||||
scheduler.register_to_config(clip_sample=False)
|
||||
|
||||
@@ -1226,28 +1341,45 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
# Convert the UNet2DConditionModel model.
|
||||
unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
|
||||
unet_config["upcast_attention"] = upcast_attention
|
||||
unet = UNet2DConditionModel(**unet_config)
|
||||
with init_empty_weights():
|
||||
unet = UNet2DConditionModel(**unet_config)
|
||||
|
||||
converted_unet_checkpoint = convert_ldm_unet_checkpoint(
|
||||
checkpoint, unet_config, path=checkpoint_path, extract_ema=extract_ema
|
||||
)
|
||||
|
||||
unet.load_state_dict(converted_unet_checkpoint)
|
||||
for param_name, param in converted_unet_checkpoint.items():
|
||||
set_module_tensor_to_device(unet, param_name, "cpu", value=param)
|
||||
|
||||
# Convert the VAE model.
|
||||
vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
|
||||
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
|
||||
if vae_path is None:
|
||||
vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
|
||||
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
|
||||
|
||||
vae = AutoencoderKL(**vae_config)
|
||||
vae.load_state_dict(converted_vae_checkpoint)
|
||||
if (
|
||||
"model" in original_config
|
||||
and "params" in original_config.model
|
||||
and "scale_factor" in original_config.model.params
|
||||
):
|
||||
vae_scaling_factor = original_config.model.params.scale_factor
|
||||
else:
|
||||
vae_scaling_factor = 0.18215 # default SD scaling factor
|
||||
|
||||
# Convert the text model.
|
||||
if model_type is None:
|
||||
model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
|
||||
logger.debug(f"no `model_type` given, `model_type` inferred as: {model_type}")
|
||||
vae_config["scaling_factor"] = vae_scaling_factor
|
||||
|
||||
with init_empty_weights():
|
||||
vae = AutoencoderKL(**vae_config)
|
||||
|
||||
for param_name, param in converted_vae_checkpoint.items():
|
||||
set_module_tensor_to_device(vae, param_name, "cpu", value=param)
|
||||
else:
|
||||
vae = AutoencoderKL.from_pretrained(vae_path)
|
||||
|
||||
if model_type == "FrozenOpenCLIPEmbedder":
|
||||
text_model = convert_open_clip_checkpoint(checkpoint)
|
||||
config_name = "stabilityai/stable-diffusion-2"
|
||||
config_kwargs = {"subfolder": "text_encoder"}
|
||||
|
||||
text_model = convert_open_clip_checkpoint(checkpoint, config_name, **config_kwargs)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2", subfolder="tokenizer")
|
||||
|
||||
if stable_unclip is None:
|
||||
@@ -1375,6 +1507,50 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
elif model_type in ["SDXL", "SDXL-Refiner"]:
|
||||
if model_type == "SDXL":
|
||||
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
||||
text_encoder = convert_ldm_clip_checkpoint(checkpoint, local_files_only=local_files_only)
|
||||
tokenizer_2 = CLIPTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", pad_token="!")
|
||||
|
||||
config_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
|
||||
config_kwargs = {"projection_dim": 1280}
|
||||
text_encoder_2 = convert_open_clip_checkpoint(
|
||||
checkpoint, config_name, prefix="conditioner.embedders.1.model.", has_projection=True, **config_kwargs
|
||||
)
|
||||
|
||||
pipe = StableDiffusionXLPipeline(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
text_encoder_2=text_encoder_2,
|
||||
tokenizer_2=tokenizer_2,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
force_zeros_for_empty_prompt=True,
|
||||
)
|
||||
else:
|
||||
tokenizer = None
|
||||
text_encoder = None
|
||||
tokenizer_2 = CLIPTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", pad_token="!")
|
||||
|
||||
config_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
|
||||
config_kwargs = {"projection_dim": 1280}
|
||||
text_encoder_2 = convert_open_clip_checkpoint(
|
||||
checkpoint, config_name, prefix="conditioner.embedders.0.model.", has_projection=True, **config_kwargs
|
||||
)
|
||||
|
||||
pipe = StableDiffusionXLImg2ImgPipeline(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
text_encoder_2=text_encoder_2,
|
||||
tokenizer_2=tokenizer_2,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
requires_aesthetics_score=True,
|
||||
force_zeros_for_empty_prompt=False,
|
||||
)
|
||||
else:
|
||||
text_config = create_ldm_bert_config(original_config)
|
||||
text_model = convert_ldm_bert_checkpoint(checkpoint, text_config)
|
||||
|
||||
@@ -22,7 +22,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import FromCkptMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import (
|
||||
@@ -69,7 +69,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
||||
return noise_cfg
|
||||
|
||||
|
||||
class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromCkptMixin):
|
||||
class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using Stable Diffusion.
|
||||
|
||||
@@ -79,7 +79,7 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
|
||||
In addition the pipeline inherits the following loading methods:
|
||||
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
|
||||
- *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
|
||||
- *Ckpt*: [`loaders.FromCkptMixin.from_ckpt`]
|
||||
- *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
|
||||
|
||||
as well as the following saving methods:
|
||||
- *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
|
||||
|
||||
@@ -24,7 +24,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import FromCkptMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import (
|
||||
@@ -98,7 +98,9 @@ def preprocess(image):
|
||||
return image
|
||||
|
||||
|
||||
class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromCkptMixin):
|
||||
class StableDiffusionImg2ImgPipeline(
|
||||
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
|
||||
):
|
||||
r"""
|
||||
Pipeline for text-guided image to image generation using Stable Diffusion.
|
||||
|
||||
@@ -108,7 +110,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMi
|
||||
In addition the pipeline inherits the following loading methods:
|
||||
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
|
||||
- *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
|
||||
- *Ckpt*: [`loaders.FromCkptMixin.from_ckpt`]
|
||||
- *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
|
||||
|
||||
as well as the following saving methods:
|
||||
- *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
|
||||
|
||||
@@ -24,7 +24,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import deprecate, is_accelerate_available, is_accelerate_version, logging, randn_tensor
|
||||
@@ -153,7 +153,9 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool
|
||||
return mask, masked_image
|
||||
|
||||
|
||||
class StableDiffusionInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
|
||||
class StableDiffusionInpaintPipeline(
|
||||
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
|
||||
):
|
||||
r"""
|
||||
Pipeline for text-guided image inpainting using Stable Diffusion.
|
||||
|
||||
|
||||
+3
-3
@@ -24,7 +24,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import FromCkptMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import (
|
||||
@@ -85,7 +85,7 @@ def preprocess_mask(mask, batch_size, scale_factor=8):
|
||||
|
||||
|
||||
class StableDiffusionInpaintPipelineLegacy(
|
||||
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromCkptMixin
|
||||
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
|
||||
):
|
||||
r"""
|
||||
Pipeline for text-guided image inpainting using Stable Diffusion. *This is an experimental feature*.
|
||||
@@ -96,7 +96,7 @@ class StableDiffusionInpaintPipelineLegacy(
|
||||
In addition the pipeline inherits the following loading methods:
|
||||
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
|
||||
- *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
|
||||
- *Ckpt*: [`loaders.FromCkptMixin.from_ckpt`]
|
||||
- *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
|
||||
|
||||
as well as the following saving methods:
|
||||
- *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
|
||||
|
||||
@@ -22,7 +22,7 @@ import torch
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...image_processor import VaeImageProcessorLDM3D
|
||||
from ...loaders import FromCkptMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import (
|
||||
@@ -74,7 +74,9 @@ class LDM3DPipelineOutput(BaseOutput):
|
||||
nsfw_content_detected: Optional[List[bool]]
|
||||
|
||||
|
||||
class StableDiffusionLDM3DPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromCkptMixin):
|
||||
class StableDiffusionLDM3DPipeline(
|
||||
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
|
||||
):
|
||||
r"""
|
||||
Pipeline for text-to-image and 3d generation using LDM3D. LDM3D: Latent Diffusion Model for 3D:
|
||||
https://arxiv.org/abs/2305.10853
|
||||
@@ -85,7 +87,7 @@ class StableDiffusionLDM3DPipeline(DiffusionPipeline, TextualInversionLoaderMixi
|
||||
In addition the pipeline inherits the following loading methods:
|
||||
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
|
||||
- *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
|
||||
- *Ckpt*: [`loaders.FromCkptMixin.from_ckpt`]
|
||||
- *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
|
||||
|
||||
as well as the following saving methods:
|
||||
- *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
|
||||
|
||||
@@ -19,7 +19,7 @@ import torch
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import FromCkptMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import (
|
||||
@@ -59,7 +59,9 @@ EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
class StableDiffusionParadigmsPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromCkptMixin):
|
||||
class StableDiffusionParadigmsPipeline(
|
||||
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
|
||||
):
|
||||
r"""
|
||||
Parallelized version of StableDiffusionPipeline, based on the paper https://arxiv.org/abs/2305.16317 This pipeline
|
||||
parallelizes the denoising steps to generate a single image faster (more akin to model parallelism).
|
||||
@@ -72,7 +74,7 @@ class StableDiffusionParadigmsPipeline(DiffusionPipeline, TextualInversionLoader
|
||||
In addition the pipeline inherits the following loading methods:
|
||||
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
|
||||
- *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
|
||||
- *Ckpt*: [`loaders.FromCkptMixin.from_ckpt`]
|
||||
- *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
|
||||
|
||||
as well as the following saving methods:
|
||||
- *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
|
||||
|
||||
@@ -24,7 +24,12 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...models.attention_processor import AttnProcessor2_0, LoRAXFormersAttnProcessor, XFormersAttnProcessor
|
||||
from ...models.attention_processor import (
|
||||
AttnProcessor2_0,
|
||||
LoRAAttnProcessor2_0,
|
||||
LoRAXFormersAttnProcessor,
|
||||
XFormersAttnProcessor,
|
||||
)
|
||||
from ...schedulers import DDPMScheduler, KarrasDiffusionSchedulers
|
||||
from ...utils import deprecate, is_accelerate_available, is_accelerate_version, logging, randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
@@ -743,14 +748,19 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi
|
||||
# make sure the VAE is in float32 mode, as it overflows in float16
|
||||
self.vae.to(dtype=torch.float32)
|
||||
|
||||
use_torch_2_0_or_xformers = self.vae.decoder.mid_block.attentions[0].processor in [
|
||||
AttnProcessor2_0,
|
||||
XFormersAttnProcessor,
|
||||
LoRAXFormersAttnProcessor,
|
||||
]
|
||||
use_torch_2_0_or_xformers = isinstance(
|
||||
self.vae.decoder.mid_block.attentions[0].processor,
|
||||
(
|
||||
AttnProcessor2_0,
|
||||
XFormersAttnProcessor,
|
||||
LoRAXFormersAttnProcessor,
|
||||
LoRAAttnProcessor2_0,
|
||||
),
|
||||
)
|
||||
|
||||
# if xformers or torch_2_0 is used attention block does not need
|
||||
# to be in float32 which can save lots of memory
|
||||
if not use_torch_2_0_or_xformers:
|
||||
if use_torch_2_0_or_xformers:
|
||||
self.vae.post_quant_conv.to(latents.dtype)
|
||||
self.vae.decoder.conv_in.to(latents.dtype)
|
||||
self.vae.decoder.mid_block.to(latents.dtype)
|
||||
|
||||
@@ -0,0 +1,26 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
|
||||
from ...utils import BaseOutput, is_invisible_watermark_available, is_torch_available, is_transformers_available
|
||||
|
||||
|
||||
@dataclass
|
||||
class StableDiffusionXLPipelineOutput(BaseOutput):
|
||||
"""
|
||||
Output class for Stable Diffusion pipelines.
|
||||
|
||||
Args:
|
||||
images (`List[PIL.Image.Image]` or `np.ndarray`)
|
||||
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
|
||||
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
|
||||
"""
|
||||
|
||||
images: Union[List[PIL.Image.Image], np.ndarray]
|
||||
|
||||
|
||||
if is_transformers_available() and is_torch_available() and is_invisible_watermark_available():
|
||||
from .pipeline_stable_diffusion_xl import StableDiffusionXLPipeline
|
||||
from .pipeline_stable_diffusion_xl_img2img import StableDiffusionXLImg2ImgPipeline
|
||||
@@ -0,0 +1,823 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
|
||||
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...models.attention_processor import (
|
||||
AttnProcessor2_0,
|
||||
LoRAAttnProcessor2_0,
|
||||
LoRAXFormersAttnProcessor,
|
||||
XFormersAttnProcessor,
|
||||
)
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import (
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
logging,
|
||||
randn_tensor,
|
||||
replace_example_docstring,
|
||||
)
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from . import StableDiffusionXLPipelineOutput
|
||||
from .watermark import StableDiffusionXLWatermarker
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from diffusers import StableDiffusionXLPipeline
|
||||
|
||||
>>> pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||
... "stabilityai/stable-diffusion-xl-base-0.9", torch_dtype=torch.float16
|
||||
... )
|
||||
>>> pipe = pipe.to("cuda")
|
||||
|
||||
>>> prompt = "a photo of an astronaut riding a horse on mars"
|
||||
>>> image = pipe(prompt).images[0]
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
||||
"""
|
||||
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
|
||||
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
|
||||
"""
|
||||
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
|
||||
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
||||
# rescale the results from guidance (fixes overexposure)
|
||||
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
|
||||
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
|
||||
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
|
||||
return noise_cfg
|
||||
|
||||
|
||||
class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using Stable Diffusion.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
In addition the pipeline inherits the following loading methods:
|
||||
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
|
||||
- *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
|
||||
- *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
|
||||
|
||||
as well as the following saving methods:
|
||||
- *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`CLIPTextModel`]):
|
||||
Frozen text-encoder. Stable Diffusion uses the text portion of
|
||||
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
||||
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
||||
tokenizer (`CLIPTokenizer`):
|
||||
Tokenizer of class
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: CLIPTextModel,
|
||||
text_encoder_2: CLIPTextModelWithProjection,
|
||||
tokenizer: CLIPTokenizer,
|
||||
tokenizer_2: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
force_zeros_for_empty_prompt: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
text_encoder_2=text_encoder_2,
|
||||
tokenizer=tokenizer,
|
||||
tokenizer_2=tokenizer_2,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
self.default_sample_size = self.unet.config.sample_size
|
||||
|
||||
self.watermark = StableDiffusionXLWatermarker()
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
|
||||
def enable_vae_slicing(self):
|
||||
r"""
|
||||
Enable sliced VAE decoding.
|
||||
|
||||
When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
|
||||
steps. This is useful to save some memory and allow larger batch sizes.
|
||||
"""
|
||||
self.vae.enable_slicing()
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
|
||||
def disable_vae_slicing(self):
|
||||
r"""
|
||||
Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
self.vae.disable_slicing()
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
|
||||
def enable_vae_tiling(self):
|
||||
r"""
|
||||
Enable tiled VAE decoding.
|
||||
|
||||
When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in
|
||||
several steps. This is useful to save a large amount of memory and to allow the processing of larger images.
|
||||
"""
|
||||
self.vae.enable_tiling()
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
|
||||
def disable_vae_tiling(self):
|
||||
r"""
|
||||
Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
self.vae.disable_tiling()
|
||||
|
||||
def enable_sequential_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
|
||||
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
|
||||
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
|
||||
Note that offloading happens on a submodule basis. Memory savings are higher than with
|
||||
`enable_model_cpu_offload`, but performance is lower.
|
||||
"""
|
||||
if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"):
|
||||
from accelerate import cpu_offload
|
||||
else:
|
||||
raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher")
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
if self.device.type != "cpu":
|
||||
self.to("cpu", silence_dtype_warnings=True)
|
||||
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
||||
|
||||
for cpu_offloaded_model in [self.unet, self.text_encoder, self.text_encoder_2, self.vae]:
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
|
||||
def enable_model_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
|
||||
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
|
||||
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
|
||||
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
|
||||
"""
|
||||
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
|
||||
from accelerate import cpu_offload_with_hook
|
||||
else:
|
||||
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
if self.device.type != "cpu":
|
||||
self.to("cpu", silence_dtype_warnings=True)
|
||||
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
||||
|
||||
model_sequence = (
|
||||
[self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
|
||||
)
|
||||
model_sequence.extend([self.unet, self.vae])
|
||||
|
||||
hook = None
|
||||
for cpu_offloaded_model in model_sequence:
|
||||
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
|
||||
|
||||
# We'll offload the last model manually.
|
||||
self.final_offload_hook = hook
|
||||
|
||||
@property
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
|
||||
def _execution_device(self):
|
||||
r"""
|
||||
Returns the device on which the pipeline's models will be executed. After calling
|
||||
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
|
||||
hooks.
|
||||
"""
|
||||
if not hasattr(self.unet, "_hf_hook"):
|
||||
return self.device
|
||||
for module in self.unet.modules():
|
||||
if (
|
||||
hasattr(module, "_hf_hook")
|
||||
and hasattr(module._hf_hook, "execution_device")
|
||||
and module._hf_hook.execution_device is not None
|
||||
):
|
||||
return torch.device(module._hf_hook.execution_device)
|
||||
return self.device
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
device: Optional[torch.device] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
negative_prompt=None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
lora_scale: Optional[float] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
device: (`torch.device`):
|
||||
torch device
|
||||
num_images_per_prompt (`int`):
|
||||
number of images that should be generated per prompt
|
||||
do_classifier_free_guidance (`bool`):
|
||||
whether to use classifier free guidance or not
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
||||
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
||||
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
|
||||
input argument.
|
||||
lora_scale (`float`, *optional*):
|
||||
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
|
||||
# set lora scale so that monkey patched LoRA
|
||||
# function of text encoder can correctly access it
|
||||
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
# Define tokenizers and text encoders
|
||||
tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
|
||||
text_encoders = (
|
||||
[self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
|
||||
)
|
||||
|
||||
if prompt_embeds is None:
|
||||
# textual inversion: procecss multi-vector tokens if necessary
|
||||
prompt_embeds_list = []
|
||||
for tokenizer, text_encoder in zip(tokenizers, text_encoders):
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
prompt = self.maybe_convert_prompt(prompt, tokenizer)
|
||||
|
||||
text_inputs = tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
||||
text_input_ids, untruncated_ids
|
||||
):
|
||||
removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {tokenizer.model_max_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
prompt_embeds = text_encoder(
|
||||
text_input_ids.to(device),
|
||||
output_hidden_states=True,
|
||||
)
|
||||
|
||||
# We are only ALWAYS interested in the pooled output of the final text encoder
|
||||
pooled_prompt_embeds = prompt_embeds[0]
|
||||
prompt_embeds = prompt_embeds.hidden_states[-2]
|
||||
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
prompt_embeds_list.append(prompt_embeds)
|
||||
|
||||
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
|
||||
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
|
||||
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
|
||||
negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
|
||||
elif do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
negative_prompt = negative_prompt or ""
|
||||
uncond_tokens: List[str]
|
||||
if prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif isinstance(negative_prompt, str):
|
||||
uncond_tokens = [negative_prompt]
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
else:
|
||||
uncond_tokens = negative_prompt
|
||||
|
||||
negative_prompt_embeds_list = []
|
||||
for tokenizer, text_encoder in zip(tokenizers, text_encoders):
|
||||
# textual inversion: procecss multi-vector tokens if necessary
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, tokenizer)
|
||||
|
||||
max_length = prompt_embeds.shape[1]
|
||||
uncond_input = tokenizer(
|
||||
uncond_tokens,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
negative_prompt_embeds = text_encoder(
|
||||
uncond_input.input_ids.to(device),
|
||||
output_hidden_states=True,
|
||||
)
|
||||
# We are only ALWAYS interested in the pooled output of the final text encoder
|
||||
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
|
||||
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = negative_prompt_embeds.shape[1]
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder.dtype, device=device)
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(
|
||||
batch_size * num_images_per_prompt, seq_len, -1
|
||||
)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
|
||||
negative_prompt_embeds_list.append(negative_prompt_embeds)
|
||||
|
||||
negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
|
||||
|
||||
bs_embed = pooled_prompt_embeds.shape[0]
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
|
||||
bs_embed * num_images_per_prompt, -1
|
||||
)
|
||||
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
|
||||
bs_embed * num_images_per_prompt, -1
|
||||
)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
|
||||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
# check if the scheduler accepts generator
|
||||
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
if accepts_generator:
|
||||
extra_step_kwargs["generator"] = generator
|
||||
return extra_step_kwargs
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
callback_steps,
|
||||
negative_prompt=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
pooled_prompt_embeds=None,
|
||||
negative_pooled_prompt_embeds=None,
|
||||
):
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
|
||||
if (callback_steps is None) or (
|
||||
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
||||
f" {type(callback_steps)}."
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
||||
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
||||
raise ValueError(
|
||||
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and pooled_prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
|
||||
)
|
||||
|
||||
if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
||||
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
||||
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
latents = latents.to(device)
|
||||
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype):
|
||||
add_time_ids = list(original_size + crops_coords_top_left + target_size)
|
||||
|
||||
passed_add_embed_dim = (
|
||||
self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim
|
||||
)
|
||||
expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
|
||||
|
||||
if expected_add_embed_dim != passed_add_embed_dim:
|
||||
raise ValueError(
|
||||
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
|
||||
)
|
||||
|
||||
add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
|
||||
return add_time_ids
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 5.0,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: int = 1,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
guidance_rescale: float = 0.0,
|
||||
original_size: Optional[Tuple[int, int]] = None,
|
||||
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
||||
target_size: Optional[Tuple[int, int]] = None,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The width in pixels of the generated image.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, *optional*, defaults to 7.5):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
eta (`float`, *optional*, defaults to 0.0):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
||||
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
||||
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
|
||||
input argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] instead of a
|
||||
plain tuple.
|
||||
callback (`Callable`, *optional*):
|
||||
A function that will be called every `callback_steps` steps during inference. The function will be
|
||||
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
||||
callback_steps (`int`, *optional*, defaults to 1):
|
||||
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
||||
called at every step.
|
||||
cross_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
|
||||
guidance_rescale (`float`, *optional*, defaults to 0.7):
|
||||
Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
|
||||
[Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
|
||||
Guidance rescale factor should fix overexposure when using zero terminal SNR.
|
||||
original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
||||
TODO
|
||||
crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
|
||||
TODO
|
||||
target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
||||
TODO
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`:
|
||||
[`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
|
||||
`tuple. When returning a tuple, the first element is a list with the generated images, and the second
|
||||
element is a list of `bool`s denoting whether the corresponding generated image likely represents
|
||||
"not-safe-for-work" (nsfw) content, according to the `safety_checker`.
|
||||
"""
|
||||
# 0. Default height and width to unet
|
||||
height = height or self.default_sample_size * self.vae_scale_factor
|
||||
width = width or self.default_sample_size * self.vae_scale_factor
|
||||
|
||||
original_size = original_size or (height, width)
|
||||
target_size = target_size or (height, width)
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
callback_steps,
|
||||
negative_prompt,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
)
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
# 3. Encode input prompt
|
||||
text_encoder_lora_scale = (
|
||||
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
|
||||
)
|
||||
(
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
) = self.encode_prompt(
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
||||
lora_scale=text_encoder_lora_scale,
|
||||
)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = self.unet.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 7. Prepare added time ids & embeddings
|
||||
add_text_embeds = pooled_prompt_embeds
|
||||
add_time_ids = self._get_add_time_ids(
|
||||
original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
|
||||
)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
|
||||
add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(device)
|
||||
add_text_embeds = add_text_embeds.to(device)
|
||||
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
|
||||
|
||||
# 8. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# predict the noise residual
|
||||
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
|
||||
noise_pred = self.unet(
|
||||
latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
if do_classifier_free_guidance and guidance_rescale > 0.0:
|
||||
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
||||
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
# make sure the VAE is in float32 mode, as it overflows in float16
|
||||
self.vae.to(dtype=torch.float32)
|
||||
|
||||
use_torch_2_0_or_xformers = isinstance(
|
||||
self.vae.decoder.mid_block.attentions[0].processor,
|
||||
(
|
||||
AttnProcessor2_0,
|
||||
XFormersAttnProcessor,
|
||||
LoRAXFormersAttnProcessor,
|
||||
LoRAAttnProcessor2_0,
|
||||
),
|
||||
)
|
||||
# if xformers or torch_2_0 is used attention block does not need
|
||||
# to be in float32 which can save lots of memory
|
||||
if use_torch_2_0_or_xformers:
|
||||
self.vae.post_quant_conv.to(latents.dtype)
|
||||
self.vae.decoder.conv_in.to(latents.dtype)
|
||||
self.vae.decoder.mid_block.to(latents.dtype)
|
||||
else:
|
||||
latents = latents.float()
|
||||
|
||||
if not output_type == "latent":
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
||||
else:
|
||||
image = latents
|
||||
return StableDiffusionXLPipelineOutput(images=image)
|
||||
|
||||
image = self.watermark.apply_watermark(image)
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
# Offload last model to CPU
|
||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||
self.final_offload_hook.offload()
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return StableDiffusionXLPipelineOutput(images=image)
|
||||
@@ -0,0 +1,896 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
|
||||
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...models.attention_processor import (
|
||||
AttnProcessor2_0,
|
||||
LoRAAttnProcessor2_0,
|
||||
LoRAXFormersAttnProcessor,
|
||||
XFormersAttnProcessor,
|
||||
)
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import (
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
logging,
|
||||
randn_tensor,
|
||||
replace_example_docstring,
|
||||
)
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from . import StableDiffusionXLPipelineOutput
|
||||
from .watermark import StableDiffusionXLWatermarker
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from diffusers import StableDiffusionXLImg2ImgPipeline
|
||||
>>> from diffusers.utils import load_image
|
||||
|
||||
>>> pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
|
||||
... "stabilityai/stable-diffusion-xl-refiner-0.9", torch_dtype=torch.float16
|
||||
... )
|
||||
>>> pipe = pipe.to("cuda")
|
||||
>>> url = "https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/aa_xl/000000009.png"
|
||||
|
||||
>>> init_image = load_image(url).convert("RGB")
|
||||
>>> prompt = "a photo of an astronaut riding a horse on mars"
|
||||
>>> image = pipe(prompt, image=init_image).images[0]
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
||||
"""
|
||||
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
|
||||
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
|
||||
"""
|
||||
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
|
||||
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
||||
# rescale the results from guidance (fixes overexposure)
|
||||
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
|
||||
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
|
||||
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
|
||||
return noise_cfg
|
||||
|
||||
|
||||
class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using Stable Diffusion.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
In addition the pipeline inherits the following loading methods:
|
||||
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
|
||||
- *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
|
||||
- *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
|
||||
|
||||
as well as the following saving methods:
|
||||
- *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`CLIPTextModel`]):
|
||||
Frozen text-encoder. Stable Diffusion uses the text portion of
|
||||
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
||||
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
||||
tokenizer (`CLIPTokenizer`):
|
||||
Tokenizer of class
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
"""
|
||||
_optional_components = ["tokenizer", "text_encoder"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: CLIPTextModel,
|
||||
text_encoder_2: CLIPTextModelWithProjection,
|
||||
tokenizer: CLIPTokenizer,
|
||||
tokenizer_2: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
requires_aesthetics_score: bool = False,
|
||||
force_zeros_for_empty_prompt: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
text_encoder_2=text_encoder_2,
|
||||
tokenizer=tokenizer,
|
||||
tokenizer_2=tokenizer_2,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
|
||||
self.register_to_config(requires_aesthetics_score=requires_aesthetics_score)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
|
||||
self.watermark = StableDiffusionXLWatermarker()
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
|
||||
def enable_vae_slicing(self):
|
||||
r"""
|
||||
Enable sliced VAE decoding.
|
||||
|
||||
When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
|
||||
steps. This is useful to save some memory and allow larger batch sizes.
|
||||
"""
|
||||
self.vae.enable_slicing()
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
|
||||
def disable_vae_slicing(self):
|
||||
r"""
|
||||
Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
self.vae.disable_slicing()
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
|
||||
def enable_vae_tiling(self):
|
||||
r"""
|
||||
Enable tiled VAE decoding.
|
||||
|
||||
When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in
|
||||
several steps. This is useful to save a large amount of memory and to allow the processing of larger images.
|
||||
"""
|
||||
self.vae.enable_tiling()
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
|
||||
def disable_vae_tiling(self):
|
||||
r"""
|
||||
Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
self.vae.disable_tiling()
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.enable_sequential_cpu_offload
|
||||
def enable_sequential_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
|
||||
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
|
||||
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
|
||||
Note that offloading happens on a submodule basis. Memory savings are higher than with
|
||||
`enable_model_cpu_offload`, but performance is lower.
|
||||
"""
|
||||
if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"):
|
||||
from accelerate import cpu_offload
|
||||
else:
|
||||
raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher")
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
if self.device.type != "cpu":
|
||||
self.to("cpu", silence_dtype_warnings=True)
|
||||
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
||||
|
||||
for cpu_offloaded_model in [self.unet, self.text_encoder, self.text_encoder_2, self.vae]:
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.enable_model_cpu_offload
|
||||
def enable_model_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
|
||||
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
|
||||
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
|
||||
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
|
||||
"""
|
||||
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
|
||||
from accelerate import cpu_offload_with_hook
|
||||
else:
|
||||
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
if self.device.type != "cpu":
|
||||
self.to("cpu", silence_dtype_warnings=True)
|
||||
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
||||
|
||||
model_sequence = (
|
||||
[self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
|
||||
)
|
||||
model_sequence.extend([self.unet, self.vae])
|
||||
|
||||
hook = None
|
||||
for cpu_offloaded_model in model_sequence:
|
||||
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
|
||||
|
||||
# We'll offload the last model manually.
|
||||
self.final_offload_hook = hook
|
||||
|
||||
@property
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
|
||||
def _execution_device(self):
|
||||
r"""
|
||||
Returns the device on which the pipeline's models will be executed. After calling
|
||||
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
|
||||
hooks.
|
||||
"""
|
||||
if not hasattr(self.unet, "_hf_hook"):
|
||||
return self.device
|
||||
for module in self.unet.modules():
|
||||
if (
|
||||
hasattr(module, "_hf_hook")
|
||||
and hasattr(module._hf_hook, "execution_device")
|
||||
and module._hf_hook.execution_device is not None
|
||||
):
|
||||
return torch.device(module._hf_hook.execution_device)
|
||||
return self.device
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
device: Optional[torch.device] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
negative_prompt=None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
lora_scale: Optional[float] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
device: (`torch.device`):
|
||||
torch device
|
||||
num_images_per_prompt (`int`):
|
||||
number of images that should be generated per prompt
|
||||
do_classifier_free_guidance (`bool`):
|
||||
whether to use classifier free guidance or not
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
||||
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
||||
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
|
||||
input argument.
|
||||
lora_scale (`float`, *optional*):
|
||||
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
|
||||
# set lora scale so that monkey patched LoRA
|
||||
# function of text encoder can correctly access it
|
||||
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
# Define tokenizers and text encoders
|
||||
tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
|
||||
text_encoders = (
|
||||
[self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
|
||||
)
|
||||
|
||||
if prompt_embeds is None:
|
||||
# textual inversion: procecss multi-vector tokens if necessary
|
||||
prompt_embeds_list = []
|
||||
for tokenizer, text_encoder in zip(tokenizers, text_encoders):
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
prompt = self.maybe_convert_prompt(prompt, tokenizer)
|
||||
|
||||
text_inputs = tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
||||
text_input_ids, untruncated_ids
|
||||
):
|
||||
removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {tokenizer.model_max_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
prompt_embeds = text_encoder(
|
||||
text_input_ids.to(device),
|
||||
output_hidden_states=True,
|
||||
)
|
||||
|
||||
# We are only ALWAYS interested in the pooled output of the final text encoder
|
||||
pooled_prompt_embeds = prompt_embeds[0]
|
||||
prompt_embeds = prompt_embeds.hidden_states[-2]
|
||||
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
prompt_embeds_list.append(prompt_embeds)
|
||||
|
||||
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
|
||||
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
|
||||
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
|
||||
negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
|
||||
elif do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
negative_prompt = negative_prompt or ""
|
||||
uncond_tokens: List[str]
|
||||
if prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif isinstance(negative_prompt, str):
|
||||
uncond_tokens = [negative_prompt]
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
else:
|
||||
uncond_tokens = negative_prompt
|
||||
|
||||
negative_prompt_embeds_list = []
|
||||
for tokenizer, text_encoder in zip(tokenizers, text_encoders):
|
||||
# textual inversion: procecss multi-vector tokens if necessary
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, tokenizer)
|
||||
|
||||
max_length = prompt_embeds.shape[1]
|
||||
uncond_input = tokenizer(
|
||||
uncond_tokens,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
negative_prompt_embeds = text_encoder(
|
||||
uncond_input.input_ids.to(device),
|
||||
output_hidden_states=True,
|
||||
)
|
||||
# We are only ALWAYS interested in the pooled output of the final text encoder
|
||||
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
|
||||
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = negative_prompt_embeds.shape[1]
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder.dtype, device=device)
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(
|
||||
batch_size * num_images_per_prompt, seq_len, -1
|
||||
)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
|
||||
negative_prompt_embeds_list.append(negative_prompt_embeds)
|
||||
|
||||
negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
|
||||
|
||||
bs_embed = pooled_prompt_embeds.shape[0]
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
|
||||
bs_embed * num_images_per_prompt, -1
|
||||
)
|
||||
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
|
||||
bs_embed * num_images_per_prompt, -1
|
||||
)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
|
||||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
# check if the scheduler accepts generator
|
||||
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
if accepts_generator:
|
||||
extra_step_kwargs["generator"] = generator
|
||||
return extra_step_kwargs
|
||||
|
||||
def check_inputs(
|
||||
self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None
|
||||
):
|
||||
if strength < 0 or strength > 1:
|
||||
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
|
||||
|
||||
if (callback_steps is None) or (
|
||||
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
||||
f" {type(callback_steps)}."
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
||||
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
||||
raise ValueError(
|
||||
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
|
||||
def get_timesteps(self, num_inference_steps, strength, device):
|
||||
# get the original timestep using init_timestep
|
||||
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
|
||||
|
||||
t_start = max(num_inference_steps - init_timestep, 0)
|
||||
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
||||
|
||||
return timesteps, num_inference_steps - t_start
|
||||
|
||||
def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
|
||||
if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
|
||||
raise ValueError(
|
||||
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
|
||||
)
|
||||
|
||||
# Offload text encoder if `enable_model_cpu_offload` was enabled
|
||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||
self.text_encoder_2.to("cpu")
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
|
||||
batch_size = batch_size * num_images_per_prompt
|
||||
|
||||
if image.shape[1] == 4:
|
||||
init_latents = image
|
||||
|
||||
else:
|
||||
# make sure the VAE is in float32 mode, as it overflows in float16
|
||||
image = image.float()
|
||||
self.vae.to(dtype=torch.float32)
|
||||
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
elif isinstance(generator, list):
|
||||
init_latents = [
|
||||
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
|
||||
]
|
||||
init_latents = torch.cat(init_latents, dim=0)
|
||||
else:
|
||||
init_latents = self.vae.encode(image).latent_dist.sample(generator)
|
||||
|
||||
self.vae.to(dtype)
|
||||
init_latents = init_latents.to(dtype)
|
||||
|
||||
init_latents = self.vae.config.scaling_factor * init_latents
|
||||
|
||||
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
|
||||
# expand init_latents for batch_size
|
||||
additional_image_per_prompt = batch_size // init_latents.shape[0]
|
||||
init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)
|
||||
elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
|
||||
raise ValueError(
|
||||
f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
|
||||
)
|
||||
else:
|
||||
init_latents = torch.cat([init_latents], dim=0)
|
||||
|
||||
shape = init_latents.shape
|
||||
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
|
||||
# get latents
|
||||
init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
|
||||
latents = init_latents
|
||||
|
||||
return latents
|
||||
|
||||
def _get_add_time_ids(
|
||||
self, original_size, crops_coords_top_left, target_size, aesthetic_score, negative_aesthetic_score, dtype
|
||||
):
|
||||
if self.config.requires_aesthetics_score:
|
||||
add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,))
|
||||
add_neg_time_ids = list(original_size + crops_coords_top_left + (negative_aesthetic_score,))
|
||||
else:
|
||||
add_time_ids = list(original_size + crops_coords_top_left + target_size)
|
||||
add_neg_time_ids = list(original_size + crops_coords_top_left + target_size)
|
||||
|
||||
passed_add_embed_dim = (
|
||||
self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim
|
||||
)
|
||||
expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
|
||||
|
||||
if (
|
||||
expected_add_embed_dim > passed_add_embed_dim
|
||||
and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim
|
||||
):
|
||||
raise ValueError(
|
||||
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model."
|
||||
)
|
||||
elif (
|
||||
expected_add_embed_dim < passed_add_embed_dim
|
||||
and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim
|
||||
):
|
||||
raise ValueError(
|
||||
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model."
|
||||
)
|
||||
elif expected_add_embed_dim != passed_add_embed_dim:
|
||||
raise ValueError(
|
||||
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
|
||||
)
|
||||
|
||||
add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
|
||||
add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype)
|
||||
|
||||
return add_time_ids, add_neg_time_ids
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
image: Union[
|
||||
torch.FloatTensor,
|
||||
PIL.Image.Image,
|
||||
np.ndarray,
|
||||
List[torch.FloatTensor],
|
||||
List[PIL.Image.Image],
|
||||
List[np.ndarray],
|
||||
] = None,
|
||||
strength: float = 0.3,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 5.0,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: int = 1,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
guidance_rescale: float = 0.0,
|
||||
original_size: Tuple[int, int] = None,
|
||||
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
||||
target_size: Tuple[int, int] = None,
|
||||
aesthetic_score: float = 6.0,
|
||||
negative_aesthetic_score: float = 2.5,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
image (`torch.FloatTensor` or `PIL.Image.Image` or `np.ndarray` or `List[torch.FloatTensor]` or `List[PIL.Image.Image]` or `List[np.ndarray]`):
|
||||
The image(s) to modify with the pipeline.
|
||||
strength (`float`, *optional*, defaults to 0.8):
|
||||
Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
|
||||
will be used as a starting point, adding more noise to it the larger the `strength`. The number of
|
||||
denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will
|
||||
be maximum and the denoising process will run for the full number of iterations specified in
|
||||
`num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, *optional*, defaults to 7.5):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
eta (`float`, *optional*, defaults to 0.0):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
||||
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
||||
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
|
||||
input argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] instead of a
|
||||
plain tuple.
|
||||
callback (`Callable`, *optional*):
|
||||
A function that will be called every `callback_steps` steps during inference. The function will be
|
||||
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
||||
callback_steps (`int`, *optional*, defaults to 1):
|
||||
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
||||
called at every step.
|
||||
cross_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
|
||||
guidance_rescale (`float`, *optional*, defaults to 0.7):
|
||||
Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
|
||||
[Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
|
||||
Guidance rescale factor should fix overexposure when using zero terminal SNR.
|
||||
original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
||||
TODO
|
||||
crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
|
||||
TODO
|
||||
target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
||||
TODO
|
||||
aesthetic_score (`float`, *optional*, defaults to 6.0):
|
||||
TODO
|
||||
negative_aesthetic_score (`float`, *optional*, defaults to 2.5):
|
||||
TDOO
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`:
|
||||
[`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
|
||||
`tuple. When returning a tuple, the first element is a list with the generated images, and the second
|
||||
element is a list of `bool`s denoting whether the corresponding generated image likely represents
|
||||
"not-safe-for-work" (nsfw) content, according to the `safety_checker`.
|
||||
"""
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(prompt, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds)
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
# 3. Encode input prompt
|
||||
text_encoder_lora_scale = (
|
||||
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
|
||||
)
|
||||
(
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
) = self.encode_prompt(
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
||||
lora_scale=text_encoder_lora_scale,
|
||||
)
|
||||
|
||||
# 4. Preprocess image
|
||||
image = self.image_processor.preprocess(image)
|
||||
|
||||
# 5. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
|
||||
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
||||
|
||||
# 6. Prepare latent variables
|
||||
latents = self.prepare_latents(
|
||||
image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator
|
||||
)
|
||||
# 7. Prepare extra step kwargs.
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
height, width = latents.shape[-2:]
|
||||
height = height * self.vae_scale_factor
|
||||
width = width * self.vae_scale_factor
|
||||
|
||||
original_size = original_size or (height, width)
|
||||
target_size = target_size or (height, width)
|
||||
|
||||
# 8. Prepare added time ids & embeddings
|
||||
add_text_embeds = pooled_prompt_embeds
|
||||
add_time_ids, add_neg_time_ids = self._get_add_time_ids(
|
||||
original_size,
|
||||
crops_coords_top_left,
|
||||
target_size,
|
||||
aesthetic_score,
|
||||
negative_aesthetic_score,
|
||||
dtype=prompt_embeds.dtype,
|
||||
)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
|
||||
add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(device)
|
||||
add_text_embeds = add_text_embeds.to(device)
|
||||
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
|
||||
|
||||
# 9. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# predict the noise residual
|
||||
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
|
||||
noise_pred = self.unet(
|
||||
latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
if do_classifier_free_guidance and guidance_rescale > 0.0:
|
||||
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
||||
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
# make sure the VAE is in float32 mode, as it overflows in float16
|
||||
self.vae.to(dtype=torch.float32)
|
||||
|
||||
use_torch_2_0_or_xformers = isinstance(
|
||||
self.vae.decoder.mid_block.attentions[0].processor,
|
||||
(
|
||||
AttnProcessor2_0,
|
||||
XFormersAttnProcessor,
|
||||
LoRAXFormersAttnProcessor,
|
||||
LoRAAttnProcessor2_0,
|
||||
),
|
||||
)
|
||||
# if xformers or torch_2_0 is used attention block does not need
|
||||
# to be in float32 which can save lots of memory
|
||||
if use_torch_2_0_or_xformers:
|
||||
self.vae.post_quant_conv.to(latents.dtype)
|
||||
self.vae.decoder.conv_in.to(latents.dtype)
|
||||
self.vae.decoder.mid_block.to(latents.dtype)
|
||||
else:
|
||||
latents = latents.float()
|
||||
|
||||
if not output_type == "latent":
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
||||
else:
|
||||
image = latents
|
||||
return StableDiffusionXLPipelineOutput(images=image)
|
||||
|
||||
image = self.watermark.apply_watermark(image)
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
# Offload last model to CPU
|
||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||
self.final_offload_hook.offload()
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return StableDiffusionXLPipelineOutput(images=image)
|
||||
@@ -0,0 +1,31 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from imwatermark import WatermarkEncoder
|
||||
|
||||
|
||||
# Copied from https://github.com/Stability-AI/generative-models/blob/613af104c6b85184091d42d374fef420eddb356d/scripts/demo/streamlit_helpers.py#L66
|
||||
WATERMARK_MESSAGE = 0b101100111110110010010000011110111011000110011110
|
||||
# bin(x)[2:] gives bits of x as str, use int to convert them to 0/1
|
||||
WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]
|
||||
|
||||
|
||||
class StableDiffusionXLWatermarker:
|
||||
def __init__(self):
|
||||
self.watermark = WATERMARK_BITS
|
||||
self.encoder = WatermarkEncoder()
|
||||
|
||||
self.encoder.set_watermark("bits", self.watermark)
|
||||
|
||||
def apply_watermark(self, images: torch.FloatTensor):
|
||||
# can't encode images that are smaller than 256
|
||||
if images.shape[-1] < 256:
|
||||
return images
|
||||
|
||||
images = (255 * (images / 2 + 0.5)).cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
|
||||
images = [self.encoder.encode(image, "dwtDct") for image in images]
|
||||
|
||||
images = torch.from_numpy(np.array(images)).permute(0, 3, 1, 2)
|
||||
|
||||
images = torch.clamp(2 * (images / 255 - 0.5), min=-1.0, max=1.0)
|
||||
return images
|
||||
@@ -648,7 +648,8 @@ class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lora
|
||||
t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
).sample
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
|
||||
+2
-1
@@ -723,7 +723,8 @@ class VideoToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
|
||||
t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
).sample
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
|
||||
@@ -18,6 +18,9 @@ from ...models.attention_processor import (
|
||||
from ...models.dual_transformer_2d import DualTransformer2DModel
|
||||
from ...models.embeddings import (
|
||||
GaussianFourierProjection,
|
||||
ImageHintTimeEmbedding,
|
||||
ImageProjection,
|
||||
ImageTimeEmbedding,
|
||||
TextImageProjection,
|
||||
TextImageTimeEmbedding,
|
||||
TextTimeEmbedding,
|
||||
@@ -189,7 +192,11 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
|
||||
cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
|
||||
The dimension of the cross attention features.
|
||||
encoder_hid_dim (`int`, *optional*, defaults to `None`):
|
||||
transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
|
||||
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
|
||||
[`~models.unet_2d_blocks.CrossAttnDownBlockFlat`], [`~models.unet_2d_blocks.CrossAttnUpBlockFlat`],
|
||||
[`~models.unet_2d_blocks.UNetMidBlockFlatCrossAttn`].
|
||||
encoder_hid_dim (`int`, *optional*, defaults to None):
|
||||
If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
|
||||
dimension to `cross_attention_dim`.
|
||||
encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
|
||||
@@ -206,6 +213,8 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
addition_embed_type (`str`, *optional*, defaults to `None`):
|
||||
Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
|
||||
"text". "text" will use the `TextTimeEmbedding` layer.
|
||||
addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
|
||||
Dimension for the timestep embeddings.
|
||||
num_class_embeds (`int`, *optional*, defaults to `None`):
|
||||
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
|
||||
class conditioning with `class_embed_type` equal to `None`.
|
||||
@@ -266,6 +275,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
norm_num_groups: Optional[int] = 32,
|
||||
norm_eps: float = 1e-5,
|
||||
cross_attention_dim: Union[int, Tuple[int]] = 1280,
|
||||
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
|
||||
encoder_hid_dim: Optional[int] = None,
|
||||
encoder_hid_dim_type: Optional[str] = None,
|
||||
attention_head_dim: Union[int, Tuple[int]] = 8,
|
||||
@@ -274,6 +284,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
use_linear_projection: bool = False,
|
||||
class_embed_type: Optional[str] = None,
|
||||
addition_embed_type: Optional[str] = None,
|
||||
addition_time_embed_dim: Optional[int] = None,
|
||||
num_class_embeds: Optional[int] = None,
|
||||
upcast_attention: bool = False,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
@@ -296,6 +307,14 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
|
||||
self.sample_size = sample_size
|
||||
|
||||
if num_attention_heads is not None:
|
||||
raise ValueError(
|
||||
"At the moment it is not possible to define the number of attention heads via `num_attention_heads`"
|
||||
" because of a naming issue as described in"
|
||||
" https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing"
|
||||
" `num_attention_heads` will only be supported in diffusers v0.19."
|
||||
)
|
||||
|
||||
# If `num_attention_heads` is not defined (which is the case for most models)
|
||||
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
|
||||
# The reason for this behavior is to correct for incorrectly named variables that were introduced
|
||||
@@ -401,7 +420,12 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
image_embed_dim=cross_attention_dim,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
)
|
||||
|
||||
elif encoder_hid_dim_type == "image_proj":
|
||||
# Kandinsky 2.2
|
||||
self.encoder_hid_proj = ImageProjection(
|
||||
image_embed_dim=encoder_hid_dim,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
)
|
||||
elif encoder_hid_dim_type is not None:
|
||||
raise ValueError(
|
||||
f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
|
||||
@@ -454,6 +478,15 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
self.add_embedding = TextImageTimeEmbedding(
|
||||
text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
|
||||
)
|
||||
elif addition_embed_type == "text_time":
|
||||
self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
|
||||
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
||||
elif addition_embed_type == "image":
|
||||
# Kandinsky 2.2
|
||||
self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
|
||||
elif addition_embed_type == "image_hint":
|
||||
# Kandinsky 2.2 ControlNet
|
||||
self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
|
||||
elif addition_embed_type is not None:
|
||||
raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
|
||||
|
||||
@@ -486,6 +519,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
if isinstance(layers_per_block, int):
|
||||
layers_per_block = [layers_per_block] * len(down_block_types)
|
||||
|
||||
if isinstance(transformer_layers_per_block, int):
|
||||
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
|
||||
|
||||
if class_embeddings_concat:
|
||||
# The time embeddings are concatenated with the class embeddings. The dimension of the
|
||||
# time embeddings passed to the down, middle, and up blocks is twice the dimension of the
|
||||
@@ -504,6 +540,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
down_block = get_down_block(
|
||||
down_block_type,
|
||||
num_layers=layers_per_block[i],
|
||||
transformer_layers_per_block=transformer_layers_per_block[i],
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
temb_channels=blocks_time_embed_dim,
|
||||
@@ -529,6 +566,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
# mid
|
||||
if mid_block_type == "UNetMidBlockFlatCrossAttn":
|
||||
self.mid_block = UNetMidBlockFlatCrossAttn(
|
||||
transformer_layers_per_block=transformer_layers_per_block[-1],
|
||||
in_channels=block_out_channels[-1],
|
||||
temb_channels=blocks_time_embed_dim,
|
||||
resnet_eps=norm_eps,
|
||||
@@ -570,6 +608,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
reversed_num_attention_heads = list(reversed(num_attention_heads))
|
||||
reversed_layers_per_block = list(reversed(layers_per_block))
|
||||
reversed_cross_attention_dim = list(reversed(cross_attention_dim))
|
||||
reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
|
||||
only_cross_attention = list(reversed(only_cross_attention))
|
||||
|
||||
output_channel = reversed_block_out_channels[0]
|
||||
@@ -590,6 +629,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
up_block = get_up_block(
|
||||
up_block_type,
|
||||
num_layers=reversed_layers_per_block[i] + 1,
|
||||
transformer_layers_per_block=reversed_transformer_layers_per_block[i],
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
prev_output_channel=prev_output_channel,
|
||||
@@ -796,6 +836,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
tuple.
|
||||
cross_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
|
||||
added_cond_kwargs: (`dict`, *optional*):
|
||||
A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
|
||||
are passed along to the UNet blocks.
|
||||
|
||||
Returns:
|
||||
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
|
||||
@@ -866,6 +909,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
t_emb = t_emb.to(dtype=sample.dtype)
|
||||
|
||||
emb = self.time_embedding(t_emb, timestep_cond)
|
||||
aug_emb = None
|
||||
|
||||
if self.class_embedding is not None:
|
||||
if class_labels is None:
|
||||
@@ -887,9 +931,8 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
|
||||
if self.config.addition_embed_type == "text":
|
||||
aug_emb = self.add_embedding(encoder_hidden_states)
|
||||
emb = emb + aug_emb
|
||||
elif self.config.addition_embed_type == "text_image":
|
||||
# Kadinsky 2.1 - style
|
||||
# Kandinsky 2.1 - style
|
||||
if "image_embeds" not in added_cond_kwargs:
|
||||
raise ValueError(
|
||||
f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires"
|
||||
@@ -898,9 +941,48 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
|
||||
image_embs = added_cond_kwargs.get("image_embeds")
|
||||
text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
|
||||
|
||||
aug_emb = self.add_embedding(text_embs, image_embs)
|
||||
emb = emb + aug_emb
|
||||
elif self.config.addition_embed_type == "text_time":
|
||||
if "text_embeds" not in added_cond_kwargs:
|
||||
raise ValueError(
|
||||
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires"
|
||||
" the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
|
||||
)
|
||||
text_embeds = added_cond_kwargs.get("text_embeds")
|
||||
if "time_ids" not in added_cond_kwargs:
|
||||
raise ValueError(
|
||||
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires"
|
||||
" the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
|
||||
)
|
||||
time_ids = added_cond_kwargs.get("time_ids")
|
||||
time_embeds = self.add_time_proj(time_ids.flatten())
|
||||
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
|
||||
|
||||
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
|
||||
add_embeds = add_embeds.to(emb.dtype)
|
||||
aug_emb = self.add_embedding(add_embeds)
|
||||
elif self.config.addition_embed_type == "image":
|
||||
# Kandinsky 2.2 - style
|
||||
if "image_embeds" not in added_cond_kwargs:
|
||||
raise ValueError(
|
||||
f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the"
|
||||
" keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
|
||||
)
|
||||
image_embs = added_cond_kwargs.get("image_embeds")
|
||||
aug_emb = self.add_embedding(image_embs)
|
||||
elif self.config.addition_embed_type == "image_hint":
|
||||
# Kandinsky 2.2 - style
|
||||
if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
|
||||
raise ValueError(
|
||||
f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires"
|
||||
" the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
|
||||
)
|
||||
image_embs = added_cond_kwargs.get("image_embeds")
|
||||
hint = added_cond_kwargs.get("hint")
|
||||
aug_emb, hint = self.add_embedding(image_embs, hint)
|
||||
sample = torch.cat([sample, hint], dim=1)
|
||||
|
||||
emb = emb + aug_emb if aug_emb is not None else emb
|
||||
|
||||
if self.time_embed_act is not None:
|
||||
emb = self.time_embed_act(emb)
|
||||
@@ -917,7 +999,15 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
|
||||
image_embeds = added_cond_kwargs.get("image_embeds")
|
||||
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
|
||||
|
||||
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
|
||||
# Kandinsky 2.2 - style
|
||||
if "image_embeds" not in added_cond_kwargs:
|
||||
raise ValueError(
|
||||
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires"
|
||||
" the keyword argument `image_embeds` to be passed in `added_conditions`"
|
||||
)
|
||||
image_embeds = added_cond_kwargs.get("image_embeds")
|
||||
encoder_hidden_states = self.encoder_hid_proj(image_embeds)
|
||||
# 2. pre-process
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
@@ -1212,6 +1302,7 @@ class CrossAttnDownBlockFlat(nn.Module):
|
||||
temb_channels: int,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 1,
|
||||
transformer_layers_per_block: int = 1,
|
||||
resnet_eps: float = 1e-6,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
resnet_act_fn: str = "swish",
|
||||
@@ -1256,7 +1347,7 @@ class CrossAttnDownBlockFlat(nn.Module):
|
||||
num_attention_heads,
|
||||
out_channels // num_attention_heads,
|
||||
in_channels=out_channels,
|
||||
num_layers=1,
|
||||
num_layers=transformer_layers_per_block,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
norm_num_groups=resnet_groups,
|
||||
use_linear_projection=use_linear_projection,
|
||||
@@ -1446,6 +1537,7 @@ class CrossAttnUpBlockFlat(nn.Module):
|
||||
temb_channels: int,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 1,
|
||||
transformer_layers_per_block: int = 1,
|
||||
resnet_eps: float = 1e-6,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
resnet_act_fn: str = "swish",
|
||||
@@ -1491,7 +1583,7 @@ class CrossAttnUpBlockFlat(nn.Module):
|
||||
num_attention_heads,
|
||||
out_channels // num_attention_heads,
|
||||
in_channels=out_channels,
|
||||
num_layers=1,
|
||||
num_layers=transformer_layers_per_block,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
norm_num_groups=resnet_groups,
|
||||
use_linear_projection=use_linear_projection,
|
||||
@@ -1592,6 +1684,7 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
|
||||
temb_channels: int,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 1,
|
||||
transformer_layers_per_block: int = 1,
|
||||
resnet_eps: float = 1e-6,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
resnet_act_fn: str = "swish",
|
||||
@@ -1634,7 +1727,7 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
|
||||
num_attention_heads,
|
||||
in_channels // num_attention_heads,
|
||||
in_channels=in_channels,
|
||||
num_layers=1,
|
||||
num_layers=transformer_layers_per_block,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
norm_num_groups=resnet_groups,
|
||||
use_linear_projection=use_linear_projection,
|
||||
|
||||
@@ -28,6 +28,7 @@ try:
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils.dummy_pt_objects import * # noqa F403
|
||||
else:
|
||||
from .scheduling_consistency_models import CMStochasticIterativeScheduler
|
||||
from .scheduling_ddim import DDIMScheduler
|
||||
from .scheduling_ddim_inverse import DDIMInverseScheduler
|
||||
from .scheduling_ddim_parallel import DDIMParallelScheduler
|
||||
|
||||
@@ -0,0 +1,380 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import BaseOutput, logging, randn_tensor
|
||||
from .scheduling_utils import SchedulerMixin
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
@dataclass
|
||||
class CMStochasticIterativeSchedulerOutput(BaseOutput):
|
||||
"""
|
||||
Output class for the scheduler's step function output.
|
||||
|
||||
Args:
|
||||
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
||||
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
|
||||
denoising loop.
|
||||
"""
|
||||
|
||||
prev_sample: torch.FloatTensor
|
||||
|
||||
|
||||
class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
Multistep and onestep sampling for consistency models from Song et al. 2023 [1]. This implements Algorithm 1 in the
|
||||
paper [1].
|
||||
|
||||
[1] Song, Yang and Dhariwal, Prafulla and Chen, Mark and Sutskever, Ilya. "Consistency Models"
|
||||
https://arxiv.org/pdf/2303.01469 [2] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based
|
||||
Generative Models." https://arxiv.org/abs/2206.00364
|
||||
|
||||
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
|
||||
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
|
||||
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
|
||||
[`~SchedulerMixin.from_pretrained`] functions.
|
||||
|
||||
Args:
|
||||
num_train_timesteps (`int`): number of diffusion steps used to train the model.
|
||||
sigma_min (`float`):
|
||||
Minimum noise magnitude in the sigma schedule. This was set to 0.002 in the original implementation.
|
||||
sigma_max (`float`):
|
||||
Maximum noise magnitude in the sigma schedule. This was set to 80.0 in the original implementation.
|
||||
sigma_data (`float`):
|
||||
The standard deviation of the data distribution, following the EDM paper [2]. This was set to 0.5 in the
|
||||
original implementation, which is also the original value suggested in the EDM paper.
|
||||
s_noise (`float`):
|
||||
The amount of additional noise to counteract loss of detail during sampling. A reasonable range is [1.000,
|
||||
1.011]. This was set to 1.0 in the original implementation.
|
||||
rho (`float`):
|
||||
The rho parameter used for calculating the Karras sigma schedule, introduced in the EDM paper [2]. This was
|
||||
set to 7.0 in the original implementation, which is also the original value suggested in the EDM paper.
|
||||
clip_denoised (`bool`):
|
||||
Whether to clip the denoised outputs to `(-1, 1)`. Defaults to `True`.
|
||||
timesteps (`List` or `np.ndarray` or `torch.Tensor`, *optional*):
|
||||
Optionally, an explicit timestep schedule can be specified. The timesteps are expected to be in increasing
|
||||
order.
|
||||
"""
|
||||
|
||||
order = 1
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_train_timesteps: int = 40,
|
||||
sigma_min: float = 0.002,
|
||||
sigma_max: float = 80.0,
|
||||
sigma_data: float = 0.5,
|
||||
s_noise: float = 1.0,
|
||||
rho: float = 7.0,
|
||||
clip_denoised: bool = True,
|
||||
):
|
||||
# standard deviation of the initial noise distribution
|
||||
self.init_noise_sigma = sigma_max
|
||||
|
||||
ramp = np.linspace(0, 1, num_train_timesteps)
|
||||
sigmas = self._convert_to_karras(ramp)
|
||||
timesteps = self.sigma_to_t(sigmas)
|
||||
|
||||
# setable values
|
||||
self.num_inference_steps = None
|
||||
self.sigmas = torch.from_numpy(sigmas)
|
||||
self.timesteps = torch.from_numpy(timesteps)
|
||||
self.custom_timesteps = False
|
||||
self.is_scale_input_called = False
|
||||
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
|
||||
indices = (schedule_timesteps == timestep).nonzero()
|
||||
return indices.item()
|
||||
|
||||
def scale_model_input(
|
||||
self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
Scales the consistency model input by `(sigma**2 + sigma_data**2) ** 0.5`, following the EDM model.
|
||||
|
||||
Args:
|
||||
sample (`torch.FloatTensor`): input sample
|
||||
timestep (`float` or `torch.FloatTensor`): the current timestep in the diffusion chain
|
||||
Returns:
|
||||
`torch.FloatTensor`: scaled input sample
|
||||
"""
|
||||
# Get sigma corresponding to timestep
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
step_idx = self.index_for_timestep(timestep)
|
||||
sigma = self.sigmas[step_idx]
|
||||
|
||||
sample = sample / ((sigma**2 + self.config.sigma_data**2) ** 0.5)
|
||||
|
||||
self.is_scale_input_called = True
|
||||
return sample
|
||||
|
||||
def sigma_to_t(self, sigmas: Union[float, np.ndarray]):
|
||||
"""
|
||||
Gets scaled timesteps from the Karras sigmas, for input to the consistency model.
|
||||
|
||||
Args:
|
||||
sigmas (`float` or `np.ndarray`): single Karras sigma or array of Karras sigmas
|
||||
Returns:
|
||||
`float` or `np.ndarray`: scaled input timestep or scaled input timestep array
|
||||
"""
|
||||
if not isinstance(sigmas, np.ndarray):
|
||||
sigmas = np.array(sigmas, dtype=np.float64)
|
||||
|
||||
timesteps = 1000 * 0.25 * np.log(sigmas + 1e-44)
|
||||
|
||||
return timesteps
|
||||
|
||||
def set_timesteps(
|
||||
self,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Union[str, torch.device] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
):
|
||||
"""
|
||||
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
|
||||
|
||||
Args:
|
||||
num_inference_steps (`int`):
|
||||
the number of diffusion steps used when generating samples with a pre-trained model.
|
||||
device (`str` or `torch.device`, optional):
|
||||
the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, optional):
|
||||
custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
|
||||
timestep spacing strategy of equal spacing between timesteps is used. If passed, `num_inference_steps`
|
||||
must be `None`.
|
||||
"""
|
||||
if num_inference_steps is None and timesteps is None:
|
||||
raise ValueError("Exactly one of `num_inference_steps` or `timesteps` must be supplied.")
|
||||
|
||||
if num_inference_steps is not None and timesteps is not None:
|
||||
raise ValueError("Can only pass one of `num_inference_steps` or `timesteps`.")
|
||||
|
||||
# Follow DDPMScheduler custom timesteps logic
|
||||
if timesteps is not None:
|
||||
for i in range(1, len(timesteps)):
|
||||
if timesteps[i] >= timesteps[i - 1]:
|
||||
raise ValueError("`timesteps` must be in descending order.")
|
||||
|
||||
if timesteps[0] >= self.config.num_train_timesteps:
|
||||
raise ValueError(
|
||||
f"`timesteps` must start before `self.config.train_timesteps`:"
|
||||
f" {self.config.num_train_timesteps}."
|
||||
)
|
||||
|
||||
timesteps = np.array(timesteps, dtype=np.int64)
|
||||
self.custom_timesteps = True
|
||||
else:
|
||||
if num_inference_steps > self.config.num_train_timesteps:
|
||||
raise ValueError(
|
||||
f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
|
||||
f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
|
||||
f" maximal {self.config.num_train_timesteps} timesteps."
|
||||
)
|
||||
|
||||
self.num_inference_steps = num_inference_steps
|
||||
|
||||
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
|
||||
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
|
||||
self.custom_timesteps = False
|
||||
|
||||
# Map timesteps to Karras sigmas directly for multistep sampling
|
||||
# See https://github.com/openai/consistency_models/blob/main/cm/karras_diffusion.py#L675
|
||||
num_train_timesteps = self.config.num_train_timesteps
|
||||
ramp = timesteps[::-1].copy()
|
||||
ramp = ramp / (num_train_timesteps - 1)
|
||||
sigmas = self._convert_to_karras(ramp)
|
||||
timesteps = self.sigma_to_t(sigmas)
|
||||
|
||||
sigmas = np.concatenate([sigmas, [self.sigma_min]]).astype(np.float32)
|
||||
self.sigmas = torch.from_numpy(sigmas).to(device=device)
|
||||
|
||||
if str(device).startswith("mps"):
|
||||
# mps does not support float64
|
||||
self.timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32)
|
||||
else:
|
||||
self.timesteps = torch.from_numpy(timesteps).to(device=device)
|
||||
|
||||
# Modified _convert_to_karras implementation that takes in ramp as argument
|
||||
def _convert_to_karras(self, ramp):
|
||||
"""Constructs the noise schedule of Karras et al. (2022)."""
|
||||
|
||||
sigma_min: float = self.config.sigma_min
|
||||
sigma_max: float = self.config.sigma_max
|
||||
|
||||
rho = self.config.rho
|
||||
min_inv_rho = sigma_min ** (1 / rho)
|
||||
max_inv_rho = sigma_max ** (1 / rho)
|
||||
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
||||
return sigmas
|
||||
|
||||
def get_scalings(self, sigma):
|
||||
sigma_data = self.config.sigma_data
|
||||
|
||||
c_skip = sigma_data**2 / (sigma**2 + sigma_data**2)
|
||||
c_out = sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5
|
||||
return c_skip, c_out
|
||||
|
||||
def get_scalings_for_boundary_condition(self, sigma):
|
||||
"""
|
||||
Gets the scalings used in the consistency model parameterization, following Appendix C of the original paper.
|
||||
This enforces the consistency model boundary condition.
|
||||
|
||||
Note that `epsilon` in the equations for c_skip and c_out is set to sigma_min.
|
||||
|
||||
Args:
|
||||
sigma (`torch.FloatTensor`):
|
||||
The current sigma in the Karras sigma schedule.
|
||||
Returns:
|
||||
`tuple`:
|
||||
A two-element tuple where c_skip (which weights the current sample) is the first element and c_out
|
||||
(which weights the consistency model output) is the second element.
|
||||
"""
|
||||
sigma_min = self.config.sigma_min
|
||||
sigma_data = self.config.sigma_data
|
||||
|
||||
c_skip = sigma_data**2 / ((sigma - sigma_min) ** 2 + sigma_data**2)
|
||||
c_out = (sigma - sigma_min) * sigma_data / (sigma**2 + sigma_data**2) ** 0.5
|
||||
return c_skip, c_out
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
timestep: Union[float, torch.FloatTensor],
|
||||
sample: torch.FloatTensor,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[CMStochasticIterativeSchedulerOutput, Tuple]:
|
||||
"""
|
||||
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
|
||||
process from the learned model outputs (most often the predicted noise).
|
||||
|
||||
Args:
|
||||
model_output (`torch.FloatTensor`): direct output from learned diffusion model.
|
||||
timestep (`float`): current timestep in the diffusion chain.
|
||||
sample (`torch.FloatTensor`):
|
||||
current instance of sample being created by diffusion process.
|
||||
generator (`torch.Generator`, *optional*): Random number generator.
|
||||
return_dict (`bool`): option for returning tuple rather than EulerDiscreteSchedulerOutput class
|
||||
Returns:
|
||||
[`~schedulers.scheduling_utils.CMStochasticIterativeSchedulerOutput`] or `tuple`:
|
||||
[`~schedulers.scheduling_utils.CMStochasticIterativeSchedulerOutput`] if `return_dict` is True, otherwise a
|
||||
`tuple`. When returning a tuple, the first element is the sample tensor.
|
||||
"""
|
||||
|
||||
if (
|
||||
isinstance(timestep, int)
|
||||
or isinstance(timestep, torch.IntTensor)
|
||||
or isinstance(timestep, torch.LongTensor)
|
||||
):
|
||||
raise ValueError(
|
||||
(
|
||||
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
|
||||
f" `{self.__class__}.step()` is not supported. Make sure to pass"
|
||||
" one of the `scheduler.timesteps` as a timestep."
|
||||
),
|
||||
)
|
||||
|
||||
if not self.is_scale_input_called:
|
||||
logger.warning(
|
||||
"The `scale_model_input` function should be called before `step` to ensure correct denoising. "
|
||||
"See `StableDiffusionPipeline` for a usage example."
|
||||
)
|
||||
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
|
||||
sigma_min = self.config.sigma_min
|
||||
sigma_max = self.config.sigma_max
|
||||
|
||||
step_index = self.index_for_timestep(timestep)
|
||||
|
||||
# sigma_next corresponds to next_t in original implementation
|
||||
sigma = self.sigmas[step_index]
|
||||
if step_index + 1 < self.config.num_train_timesteps:
|
||||
sigma_next = self.sigmas[step_index + 1]
|
||||
else:
|
||||
# Set sigma_next to sigma_min
|
||||
sigma_next = self.sigmas[-1]
|
||||
|
||||
# Get scalings for boundary conditions
|
||||
c_skip, c_out = self.get_scalings_for_boundary_condition(sigma)
|
||||
|
||||
# 1. Denoise model output using boundary conditions
|
||||
denoised = c_out * model_output + c_skip * sample
|
||||
if self.config.clip_denoised:
|
||||
denoised = denoised.clamp(-1, 1)
|
||||
|
||||
# 2. Sample z ~ N(0, s_noise^2 * I)
|
||||
# Noise is not used for onestep sampling.
|
||||
if len(self.timesteps) > 1:
|
||||
noise = randn_tensor(
|
||||
model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
|
||||
)
|
||||
else:
|
||||
noise = torch.zeros_like(model_output)
|
||||
z = noise * self.config.s_noise
|
||||
|
||||
sigma_hat = sigma_next.clamp(min=sigma_min, max=sigma_max)
|
||||
|
||||
# 3. Return noisy sample
|
||||
# tau = sigma_hat, eps = sigma_min
|
||||
prev_sample = denoised + z * (sigma_hat**2 - sigma_min**2) ** 0.5
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
|
||||
return CMStochasticIterativeSchedulerOutput(prev_sample=prev_sample)
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
|
||||
def add_noise(
|
||||
self,
|
||||
original_samples: torch.FloatTensor,
|
||||
noise: torch.FloatTensor,
|
||||
timesteps: torch.FloatTensor,
|
||||
) -> torch.FloatTensor:
|
||||
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
||||
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
||||
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
||||
# mps does not support float64
|
||||
schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
|
||||
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
|
||||
else:
|
||||
schedule_timesteps = self.timesteps.to(original_samples.device)
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
|
||||
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
||||
|
||||
sigma = sigmas[step_indices].flatten()
|
||||
while len(sigma.shape) < len(original_samples.shape):
|
||||
sigma = sigma.unsqueeze(-1)
|
||||
|
||||
noisy_samples = original_samples + noise * sigma
|
||||
return noisy_samples
|
||||
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user