Compare commits
19 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| b13dfac9dd | |||
| 451631be51 | |||
| 71d84a9ce1 | |||
| cfd84dfc14 | |||
| 0592773d90 | |||
| c1bad6e488 | |||
| b62104c737 | |||
| 697594f635 | |||
| 83d0aba6c0 | |||
| 47b3346422 | |||
| 07f1fbb18e | |||
| 2551b73670 | |||
| 930c8fdcb7 | |||
| 6b1abba18d | |||
| 470f51cd26 | |||
| b7e35dc782 | |||
| c77ac246c1 | |||
| ed2a3584ab | |||
| 3eb498e7b4 |
@@ -117,6 +117,8 @@
|
||||
title: Habana Gaudi
|
||||
- local: optimization/tome
|
||||
title: Token Merging
|
||||
- local: optimization/bentoml
|
||||
title: BentoML Integration
|
||||
title: Optimization/Special Hardware
|
||||
- sections:
|
||||
- local: conceptual/philosophy
|
||||
@@ -164,6 +166,8 @@
|
||||
title: VQModel
|
||||
- local: api/models/autoencoderkl
|
||||
title: AutoencoderKL
|
||||
- local: api/models/asymmetricautoencoderkl
|
||||
title: AsymmetricAutoencoderKL
|
||||
- local: api/models/transformer2d
|
||||
title: Transformer2D
|
||||
- local: api/models/transformer_temporal
|
||||
|
||||
@@ -35,3 +35,11 @@ Adapters (textual inversion, LoRA, hypernetworks) allow you to modify a diffusio
|
||||
## FromSingleFileMixin
|
||||
|
||||
[[autodoc]] loaders.FromSingleFileMixin
|
||||
|
||||
## FromOriginalControlnetMixin
|
||||
|
||||
[[autodoc]] loaders.FromOriginalControlnetMixin
|
||||
|
||||
## FromOriginalVAEMixin
|
||||
|
||||
[[autodoc]] loaders.FromOriginalVAEMixin
|
||||
|
||||
@@ -0,0 +1,55 @@
|
||||
# AsymmetricAutoencoderKL
|
||||
|
||||
Improved larger variational autoencoder (VAE) model with KL loss for inpainting task: [Designing a Better Asymmetric VQGAN for StableDiffusion](https://arxiv.org/abs/2306.04632) by Zixin Zhu, Xuelu Feng, Dongdong Chen, Jianmin Bao, Le Wang, Yinpeng Chen, Lu Yuan, Gang Hua.
|
||||
|
||||
The abstract from the paper is:
|
||||
|
||||
*StableDiffusion is a revolutionary text-to-image generator that is causing a stir in the world of image generation and editing. Unlike traditional methods that learn a diffusion model in pixel space, StableDiffusion learns a diffusion model in the latent space via a VQGAN, ensuring both efficiency and quality. It not only supports image generation tasks, but also enables image editing for real images, such as image inpainting and local editing. However, we have observed that the vanilla VQGAN used in StableDiffusion leads to significant information loss, causing distortion artifacts even in non-edited image regions. To this end, we propose a new asymmetric VQGAN with two simple designs. Firstly, in addition to the input from the encoder, the decoder contains a conditional branch that incorporates information from task-specific priors, such as the unmasked image region in inpainting. Secondly, the decoder is much heavier than the encoder, allowing for more detailed recovery while only slightly increasing the total inference cost. The training cost of our asymmetric VQGAN is cheap, and we only need to retrain a new asymmetric decoder while keeping the vanilla VQGAN encoder and StableDiffusion unchanged. Our asymmetric VQGAN can be widely used in StableDiffusion-based inpainting and local editing methods. Extensive experiments demonstrate that it can significantly improve the inpainting and editing performance, while maintaining the original text-to-image capability. The code is available at https://github.com/buxiangzhiren/Asymmetric_VQGAN*
|
||||
|
||||
Evaluation results can be found in section 4.1 of the original paper.
|
||||
|
||||
## Available checkpoints
|
||||
|
||||
* [https://huggingface.co/cross-attention/asymmetric-autoencoder-kl-x-1-5](https://huggingface.co/cross-attention/asymmetric-autoencoder-kl-x-1-5)
|
||||
* [https://huggingface.co/cross-attention/asymmetric-autoencoder-kl-x-2](https://huggingface.co/cross-attention/asymmetric-autoencoder-kl-x-2)
|
||||
|
||||
## Example Usage
|
||||
|
||||
```python
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
import requests
|
||||
from diffusers import AsymmetricAutoencoderKL, StableDiffusionInpaintPipeline
|
||||
|
||||
|
||||
def download_image(url: str) -> Image.Image:
|
||||
response = requests.get(url)
|
||||
return Image.open(BytesIO(response.content)).convert("RGB")
|
||||
|
||||
|
||||
prompt = "a photo of a person"
|
||||
img_url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/repaint/celeba_hq_256.png"
|
||||
mask_url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/repaint/mask_256.png"
|
||||
|
||||
image = download_image(img_url).resize((256, 256))
|
||||
mask_image = download_image(mask_url).resize((256, 256))
|
||||
|
||||
pipe = StableDiffusionInpaintPipeline.from_pretrained("runwayml/stable-diffusion-inpainting")
|
||||
pipe.vae = AsymmetricAutoencoderKL.from_pretrained("cross-attention/asymmetric-autoencoder-kl-x-1-5")
|
||||
pipe.to("cuda")
|
||||
|
||||
image = pipe(prompt=prompt, image=image, mask_image=mask_image).images[0]
|
||||
image.save("image.jpeg")
|
||||
```
|
||||
|
||||
## AsymmetricAutoencoderKL
|
||||
|
||||
[[autodoc]] models.autoencoder_asym_kl.AsymmetricAutoencoderKL
|
||||
|
||||
## AutoencoderKLOutput
|
||||
|
||||
[[autodoc]] models.autoencoder_kl.AutoencoderKLOutput
|
||||
|
||||
## DecoderOutput
|
||||
|
||||
[[autodoc]] models.vae.DecoderOutput
|
||||
@@ -6,6 +6,18 @@ The abstract from the paper is:
|
||||
|
||||
*How can we perform efficient inference and learning in directed probabilistic models, in the presence of continuous latent variables with intractable posterior distributions, and large datasets? We introduce a stochastic variational inference and learning algorithm that scales to large datasets and, under some mild differentiability conditions, even works in the intractable case. Our contributions are two-fold. First, we show that a reparameterization of the variational lower bound yields a lower bound estimator that can be straightforwardly optimized using standard stochastic gradient methods. Second, we show that for i.i.d. datasets with continuous latent variables per datapoint, posterior inference can be made especially efficient by fitting an approximate inference model (also called a recognition model) to the intractable posterior using the proposed lower bound estimator. Theoretical advantages are reflected in experimental results.*
|
||||
|
||||
## Loading from the original format
|
||||
|
||||
By default the [`AutoencoderKL`] should be loaded with [`~ModelMixin.from_pretrained`], but it can also be loaded
|
||||
from the original format using [`FromOriginalVAEMixin.from_single_file`] as follows:
|
||||
|
||||
```py
|
||||
from diffusers import AutoencoderKL
|
||||
|
||||
url = "https://huggingface.co/stabilityai/sd-vae-ft-mse-original/blob/main/vae-ft-mse-840000-ema-pruned.safetensors" # can also be local file
|
||||
model = AutoencoderKL.from_single_file(url)
|
||||
```
|
||||
|
||||
## AutoencoderKL
|
||||
|
||||
[[autodoc]] AutoencoderKL
|
||||
@@ -28,4 +40,4 @@ The abstract from the paper is:
|
||||
|
||||
## FlaxDecoderOutput
|
||||
|
||||
[[autodoc]] models.vae_flax.FlaxDecoderOutput
|
||||
[[autodoc]] models.vae_flax.FlaxDecoderOutput
|
||||
|
||||
@@ -6,6 +6,21 @@ The abstract from the paper is:
|
||||
|
||||
*We present a neural network structure, ControlNet, to control pretrained large diffusion models to support additional input conditions. The ControlNet learns task-specific conditions in an end-to-end way, and the learning is robust even when the training dataset is small (< 50k). Moreover, training a ControlNet is as fast as fine-tuning a diffusion model, and the model can be trained on a personal devices. Alternatively, if powerful computation clusters are available, the model can scale to large amounts (millions to billions) of data. We report that large diffusion models like Stable Diffusion can be augmented with ControlNets to enable conditional inputs like edge maps, segmentation maps, keypoints, etc. This may enrich the methods to control large diffusion models and further facilitate related applications.*
|
||||
|
||||
## Loading from the original format
|
||||
|
||||
By default the [`ControlNetModel`] should be loaded with [`~ModelMixin.from_pretrained`], but it can also be loaded
|
||||
from the original format using [`FromOriginalControlnetMixin.from_single_file`] as follows:
|
||||
|
||||
```py
|
||||
from diffusers import StableDiffusionControlnetPipeline, ControlNetModel
|
||||
|
||||
url = "https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth" # can also be a local path
|
||||
controlnet = ControlNetModel.from_single_file(url)
|
||||
|
||||
url = "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned.safetensors" # can also be a local path
|
||||
pipe = StableDiffusionControlnetPipeline.from_single_file(url, controlnet=controlnet)
|
||||
```
|
||||
|
||||
## ControlNetModel
|
||||
|
||||
[[autodoc]] ControlNetModel
|
||||
@@ -20,4 +35,4 @@ The abstract from the paper is:
|
||||
|
||||
## FlaxControlNetOutput
|
||||
|
||||
[[autodoc]] models.controlnet_flax.FlaxControlNetOutput
|
||||
[[autodoc]] models.controlnet_flax.FlaxControlNetOutput
|
||||
|
||||
@@ -128,6 +128,63 @@ gif_path = export_to_gif(images[0], "burger_3d.gif")
|
||||
```
|
||||

|
||||
|
||||
### Generate mesh
|
||||
|
||||
For both [`ShapEPipeline`] and [`ShapEImg2ImgPipeline`], you can generate mesh output by passing `output_type` as `mesh` to the pipeline, and then use the [`ShapEPipeline.export_to_ply`] utility function to save the output as a `ply` file. We also provide a [`ShapEPipeline.export_to_obj`] function that you can use to save mesh outputs as `obj` files.
|
||||
|
||||
```python
|
||||
import torch
|
||||
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers.utils import export_to_ply
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
repo = "openai/shap-e"
|
||||
pipe = DiffusionPipeline.from_pretrained(repo, torch_dtype=torch.float16, variant="fp16")
|
||||
pipe = pipe.to(device)
|
||||
|
||||
guidance_scale = 15.0
|
||||
prompt = "A birthday cupcake"
|
||||
|
||||
images = pipe(prompt, guidance_scale=guidance_scale, num_inference_steps=64, frame_size=256, output_type="mesh").images
|
||||
|
||||
ply_path = export_to_ply(images[0], "3d_cake.ply")
|
||||
print(f"saved to folder: {ply_path}")
|
||||
```
|
||||
|
||||
Huggingface Datasets supports mesh visualization for mesh files in `glb` format. Below we will show you how to convert your mesh file into `glb` format so that you can use the Dataset viewer to render 3D objects.
|
||||
|
||||
We need to install `trimesh` library.
|
||||
|
||||
```
|
||||
pip install trimesh
|
||||
```
|
||||
|
||||
To convert the mesh file into `glb` format,
|
||||
|
||||
```python
|
||||
import trimesh
|
||||
|
||||
mesh = trimesh.load("3d_cake.ply")
|
||||
mesh.export("3d_cake.glb", file_type="glb")
|
||||
```
|
||||
|
||||
By default, the mesh output of Shap-E is from the bottom viewpoint; you can change the default viewpoint by applying a rotation transformation
|
||||
|
||||
```python
|
||||
import trimesh
|
||||
import numpy as np
|
||||
|
||||
mesh = trimesh.load("3d_cake.ply")
|
||||
rot = trimesh.transformations.rotation_matrix(-np.pi / 2, [1, 0, 0])
|
||||
mesh = mesh.apply_transform(rot)
|
||||
mesh.export("3d_cake.glb", file_type="glb")
|
||||
```
|
||||
|
||||
Now you can upload your mesh file to your dataset and visualize it! Here is the link to the 3D cake we just generated
|
||||
https://huggingface.co/datasets/hf-internal-testing/diffusers-images/blob/main/shap_e/3d_cake.glb
|
||||
|
||||
## ShapEPipeline
|
||||
[[autodoc]] ShapEPipeline
|
||||
- all
|
||||
|
||||
@@ -0,0 +1,200 @@
|
||||
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# BentoML Integration Guide
|
||||
|
||||
[[open-in-colab]]
|
||||
|
||||
[BentoML](https://github.com/bentoml/BentoML/) is an open-source framework designed for building,
|
||||
shipping, and scaling AI applications. It allows users to easily package and serve diffusion models
|
||||
for production, ensuring reliable and efficient deployments. It features out-of-the-box operational
|
||||
management tools like monitoring and tracing, and facilitates the deployment to various cloud platforms
|
||||
with ease. BentoML's distributed architecture and the separation of API server logic from
|
||||
model inference logic enable efficient scaling of deployments, even with budget constraints.
|
||||
As a result, integrating it with Diffusers provides a valuable tool for real-world deployments.
|
||||
|
||||
This tutorial demonstrates how to integrate BentoML with Diffusers.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- Install [Diffusers](https://huggingface.co/docs/diffusers/installation).
|
||||
- Install BentoML by running `pip install bentoml`. For more information, see the [BentoML documentation](https://docs.bentoml.com).
|
||||
|
||||
## Import a diffusion model
|
||||
|
||||
First, you need to prepare the model. BentoML has its own [Model Store](https://docs.bentoml.com/en/latest/concepts/model.html)
|
||||
for model management. Create a `download_model.py` file as below to import a diffusion model into BentoML's Model
|
||||
Store:
|
||||
|
||||
```py
|
||||
import bentoml
|
||||
|
||||
bentoml.diffusers.import_model(
|
||||
"sd2.1", # Model tag in the BentoML Model Store
|
||||
"stabilityai/stable-diffusion-2-1", # Hugging Face model identifier
|
||||
)
|
||||
```
|
||||
|
||||
This code snippet downloads the Stable Diffusion 2.1 model (using it's repo id
|
||||
`stabilityai/stable-diffusion-2-1`) from the Hugging Face Hub (or use the cached download
|
||||
files if the model is already downloaded) and imports it into the BentoML Model
|
||||
Store with the name `sd2.1`.
|
||||
|
||||
For models already fine-tuned and stored on disk, you can provide the path instead of
|
||||
the repo id.
|
||||
|
||||
```py
|
||||
import bentoml
|
||||
|
||||
bentoml.diffusers.import_model(
|
||||
"sd2.1-local",
|
||||
"./local_stable_diffusion_2.1/",
|
||||
)
|
||||
```
|
||||
|
||||
You can view the model in the Model Store:
|
||||
|
||||
```
|
||||
bentoml models list
|
||||
|
||||
Tag Module Size Creation Time
|
||||
sd2.1:ysrlmubascajwnry bentoml.diffusers 33.85 GiB 2023-07-12 16:47:44
|
||||
```
|
||||
|
||||
## Turn a diffusion model into a RESTful service with BentoML
|
||||
|
||||
Once the diffusion model is in BentoML's Model Store, you can implement a text-to-image
|
||||
service with it. The Stable Diffusion model accepts various arguments
|
||||
in addition to the required prompt to guide the image generation process.
|
||||
To validate these input arguments, use BentoML's [pydantic](https://github.com/pydantic/pydantic) integration.
|
||||
Create a `sdargs.py` file with an example pydantic model:
|
||||
|
||||
```py
|
||||
import typing as t
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class SDArgs(BaseModel):
|
||||
prompt: str
|
||||
negative_prompt: t.Optional[str] = None
|
||||
height: t.Optional[int] = 512
|
||||
width: t.Optional[int] = 512
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
```
|
||||
|
||||
This pydantic model requires a string field `prompt` and three optional fields: `height`, `width`, and `negative_prompt`,
|
||||
each with corresponding types. The `extra = "allow"` line supports adding additional fields not defined in the `SDArgs` class.
|
||||
In a real-world scenario, you may define all the desired fields and not allow extra ones.
|
||||
|
||||
Next, create a BentoML Service file that defines a Stable Diffusion service:
|
||||
|
||||
```py
|
||||
import bentoml
|
||||
from bentoml.io import Image, JSON
|
||||
|
||||
from sdargs import SDArgs
|
||||
|
||||
bento_model = bentoml.diffusers.get("sd2.1:latest")
|
||||
sd21_runner = bento_model.to_runner(name="sd21-runner")
|
||||
|
||||
svc = bentoml.Service("stable-diffusion-21", runners=[sd21_runner])
|
||||
|
||||
|
||||
@svc.api(input=JSON(pydantic_model=SDArgs), output=Image())
|
||||
async def txt2img(input_data):
|
||||
kwargs = input_data.dict()
|
||||
res = await sd21_runner.async_run(**kwargs)
|
||||
images = res[0]
|
||||
return images[0]
|
||||
```
|
||||
|
||||
Save the file as `service.py`, and spin up a BentoML Service endpoint using:
|
||||
|
||||
```
|
||||
bentoml serve service:svc
|
||||
```
|
||||
|
||||
An HTTP server with `/txt2img` endpoint that accepts a JSON dictionary should be up at
|
||||
port 3000. Go to <http://127.0.0.1:3000> in your web browser to access the Swagger UI.
|
||||
|
||||
You can also test the text-to-image generation using `curl` and write the returned image to
|
||||
`output.jpg`.
|
||||
|
||||
```
|
||||
curl -X POST http://127.0.0.1:3000/txt2img \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d "{\"prompt\":\"a black cat\", \"height\":768, \"width\":768}" \
|
||||
--output output.jpg
|
||||
```
|
||||
|
||||
## Package a BentoML Service for cloud deployment
|
||||
|
||||
To deploy a BentoML Service, you need to pack it into a BentoML
|
||||
[Bento](https://docs.bentoml.com/en/latest/concepts/bento.html), a file archive with all the source code,
|
||||
models, data files, and dependencies. This can be done by providing a `bentofile.yaml` file as follows:
|
||||
|
||||
```yaml
|
||||
service: "service.py:svc"
|
||||
include:
|
||||
- "service.py"
|
||||
python:
|
||||
packages:
|
||||
- torch
|
||||
- transformers
|
||||
- accelerate
|
||||
- diffusers
|
||||
- triton
|
||||
- xformers
|
||||
- pydantic
|
||||
docker:
|
||||
distro: debian
|
||||
cuda_version: "11.6"
|
||||
```
|
||||
|
||||
The `bentofile.yaml` file contains [Bento build
|
||||
options](https://docs.bentoml.com/en/latest/concepts/bento.html#bento-build-options),
|
||||
such as package dependencies and Docker options.
|
||||
|
||||
Then you build a Bento using:
|
||||
|
||||
```
|
||||
bentoml build
|
||||
```
|
||||
|
||||
The output looks like:
|
||||
|
||||
```
|
||||
Successfully built Bento(tag="stable-diffusion-21:crkuh7a7rw5bcasc").
|
||||
|
||||
Possible next steps:
|
||||
|
||||
* Containerize your Bento with `bentoml containerize`:
|
||||
$ bentoml containerize stable-diffusion-21:crkuh7a7rw5bcasc
|
||||
|
||||
* Push to BentoCloud with `bentoml push`:
|
||||
$ bentoml push stable-diffusion-21:crkuh7a7rw5bcasc
|
||||
```
|
||||
|
||||
You can create a Docker image based on the Bento by running the following command and deploy it to a cloud provider.
|
||||
|
||||
```
|
||||
bentoml containerize stable-diffusion-21:crkuh7a7rw5bcasc
|
||||
```
|
||||
|
||||
If you want an end-to-end solution for deploying and managing models, you can push the Bento to [Yatai](https://github.com/bentoml/Yatai) or
|
||||
[BentoCloud](https://bentoml.com/cloud) for a distributed deployment.
|
||||
|
||||
For more information about BentoML's integration with Diffusers, see the [BentoML Diffusers
|
||||
Guide](https://docs.bentoml.com/en/latest/frameworks/diffusers.html).
|
||||
@@ -274,9 +274,9 @@ pipe = StableDiffusionControlNetPipeline.from_pretrained(
|
||||
|
||||
# speed up diffusion process with faster scheduler and memory optimization
|
||||
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
|
||||
# remove following line if xformers is not installed
|
||||
# remove following line if xformers is not installed or when using Torch 2.0.
|
||||
pipe.enable_xformers_memory_efficient_attention()
|
||||
|
||||
# memory optimization.
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
control_image = load_image("./conditioning_image_1.png")
|
||||
@@ -285,9 +285,8 @@ prompt = "pale golden rod circle with old lace background"
|
||||
# generate image
|
||||
generator = torch.manual_seed(0)
|
||||
image = pipe(
|
||||
prompt, num_inference_steps=20, generator=generator, image=control_image
|
||||
prompt, num_inference_steps=20, generator=generator, image=control_image
|
||||
).images[0]
|
||||
|
||||
image.save("./output.png")
|
||||
```
|
||||
|
||||
@@ -460,3 +459,7 @@ The profile can then be inspected at http://localhost:6006/#profile
|
||||
Sometimes you'll get version conflicts (error messages like `Duplicate plugins for name projector`), which means that you have to uninstall and reinstall all versions of Tensorflow/Tensorboard (e.g. with `pip uninstall tensorflow tf-nightly tensorboard tb-nightly tensorboard-plugin-profile && pip install tf-nightly tbp-nightly tensorboard-plugin-profile`).
|
||||
|
||||
Note that the debugging functionality of the Tensorboard `profile` plugin is still under active development. Not all views are fully functional, and for example the `trace_viewer` cuts off events after 1M (which can result in all your device traces getting lost if you for example profile the compilation step by accident).
|
||||
|
||||
## Support for Stable Diffusion XL
|
||||
|
||||
We provide a training script for training a ControlNet with [Stable Diffusion XL](https://huggingface.co/papers/2307.01952). Please refer to [README_sdxl.md](./README_sdxl.md) for more details.
|
||||
|
||||
@@ -0,0 +1,131 @@
|
||||
# DreamBooth training example for Stable Diffusion XL (SDXL)
|
||||
|
||||
The `train_controlnet_sdxl.py` script shows how to implement the training procedure and adapt it for [Stable Diffusion XL](https://huggingface.co/papers/2307.01952).
|
||||
|
||||
## Running locally with PyTorch
|
||||
|
||||
### Installing the dependencies
|
||||
|
||||
Before running the scripts, make sure to install the library's training dependencies:
|
||||
|
||||
**Important**
|
||||
|
||||
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/huggingface/diffusers
|
||||
cd diffusers
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
Then cd in the `examples/controlnet` folder and run
|
||||
```bash
|
||||
pip install -r requirements_sdxl.txt
|
||||
```
|
||||
|
||||
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
|
||||
|
||||
```bash
|
||||
accelerate config
|
||||
```
|
||||
|
||||
Or for a default accelerate configuration without answering questions about your environment
|
||||
|
||||
```bash
|
||||
accelerate config default
|
||||
```
|
||||
|
||||
Or if your environment doesn't support an interactive shell (e.g., a notebook)
|
||||
|
||||
```python
|
||||
from accelerate.utils import write_basic_config
|
||||
write_basic_config()
|
||||
```
|
||||
|
||||
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
|
||||
|
||||
## Circle filling dataset
|
||||
|
||||
The original dataset is hosted in the [ControlNet repo](https://huggingface.co/lllyasviel/ControlNet/blob/main/training/fill50k.zip). We re-uploaded it to be compatible with `datasets` [here](https://huggingface.co/datasets/fusing/fill50k). Note that `datasets` handles dataloading within the training script.
|
||||
|
||||
## Training
|
||||
|
||||
Our training examples use two test conditioning images. They can be downloaded by running
|
||||
|
||||
```sh
|
||||
wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_1.png
|
||||
|
||||
wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_2.png
|
||||
```
|
||||
|
||||
Then run `huggingface-cli login` to log into your Hugging Face account. This is needed to be able to push the trained ControlNet parameters to Hugging Face Hub.
|
||||
|
||||
```bash
|
||||
export MODEL_DIR="stabilityai/stable-diffusion-xl-base-0.9"
|
||||
export OUTPUT_DIR="path to save model"
|
||||
|
||||
accelerate launch train_controlnet_sdxl.py \
|
||||
--pretrained_model_name_or_path=$MODEL_DIR \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--dataset_name=fusing/fill50k \
|
||||
--mixed_precision="fp16" \
|
||||
--resolution=1024 \
|
||||
--learning_rate=1e-5 \
|
||||
--max_train_steps=15000 \
|
||||
--validation_image "./conditioning_image_1.png" "./conditioning_image_2.png" \
|
||||
--validation_prompt "red circle with blue background" "cyan circle with brown floral background" \
|
||||
--validation_steps=100 \
|
||||
--train_batch_size=1 \
|
||||
--gradient_accumulation_steps=4 \
|
||||
--report_to="wandb" \
|
||||
--seed=42 \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
To better track our training experiments, we're using the following flags in the command above:
|
||||
|
||||
* `report_to="wandb` will ensure the training runs are tracked on Weights and Biases. To use it, be sure to install `wandb` with `pip install wandb`.
|
||||
* `validation_image`, `validation_prompt`, and `validation_steps` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
|
||||
|
||||
Our experiments were conducted on a single 40GB A100 GPU.
|
||||
|
||||
### Inference
|
||||
|
||||
Once training is done, we can perform inference like so:
|
||||
|
||||
```python
|
||||
from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
|
||||
from diffusers.utils import load_image
|
||||
import torch
|
||||
|
||||
base_model_path = "stabilityai/stable-diffusion-xl-base-0.9"
|
||||
controlnet_path = "path to controlnet"
|
||||
|
||||
controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
|
||||
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
|
||||
base_model_path, controlnet=controlnet, torch_dtype=torch.float16
|
||||
)
|
||||
|
||||
# speed up diffusion process with faster scheduler and memory optimization
|
||||
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
|
||||
# remove following line if xformers is not installed or when using Torch 2.0.
|
||||
pipe.enable_xformers_memory_efficient_attention()
|
||||
# memory optimization.
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
control_image = load_image("./conditioning_image_1.png")
|
||||
prompt = "pale golden rod circle with old lace background"
|
||||
|
||||
# generate image
|
||||
generator = torch.manual_seed(0)
|
||||
image = pipe(
|
||||
prompt, num_inference_steps=20, generator=generator, image=control_image
|
||||
).images[0]
|
||||
image.save("./output.png")
|
||||
```
|
||||
|
||||
## Notes
|
||||
|
||||
### Specifying a better VAE
|
||||
|
||||
SDXL's VAE is known to suffer from numerical instability issues. This is why we also expose a CLI argument namely `--pretrained_vae_model_name_or_path` that lets you specify the location of a better VAE (such as [this one](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)).
|
||||
@@ -0,0 +1,235 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# This file is heavily inspired by https://github.com/mlfoundations/open_clip/blob/main/src/training/data.py
|
||||
|
||||
import itertools
|
||||
import json
|
||||
import math
|
||||
import random
|
||||
import re
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import webdataset as wds
|
||||
from braceexpand import braceexpand
|
||||
from torch.utils.data import default_collate
|
||||
from torchvision import transforms
|
||||
from transformers import PreTrainedTokenizer
|
||||
from webdataset.tariterators import (
|
||||
base_plus_ext,
|
||||
tar_file_expander,
|
||||
url_opener,
|
||||
valid_sample,
|
||||
)
|
||||
import numpy as np
|
||||
import cv2
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def filter_keys(key_set):
|
||||
def _f(dictionary):
|
||||
return {k: v for k, v in dictionary.items() if k in key_set}
|
||||
|
||||
return _f
|
||||
|
||||
|
||||
def group_by_keys_nothrow(data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None):
|
||||
"""Return function over iterator that groups key, value pairs into samples.
|
||||
|
||||
:param keys: function that splits the key into key and extension (base_plus_ext)
|
||||
:param lcase: convert suffixes to lower case (Default value = True)
|
||||
"""
|
||||
current_sample = None
|
||||
for filesample in data:
|
||||
assert isinstance(filesample, dict)
|
||||
fname, value = filesample["fname"], filesample["data"]
|
||||
prefix, suffix = keys(fname)
|
||||
if prefix is None:
|
||||
continue
|
||||
if lcase:
|
||||
suffix = suffix.lower()
|
||||
# FIXME webdataset version throws if suffix in current_sample, but we have a potential for
|
||||
# this happening in the current LAION400m dataset if a tar ends with same prefix as the next
|
||||
# begins, rare, but can happen since prefix aren't unique across tar files in that dataset
|
||||
if current_sample is None or prefix != current_sample["__key__"] or suffix in current_sample:
|
||||
if valid_sample(current_sample):
|
||||
yield current_sample
|
||||
current_sample = dict(__key__=prefix, __url__=filesample["__url__"])
|
||||
if suffixes is None or suffix in suffixes:
|
||||
current_sample[suffix] = value
|
||||
if valid_sample(current_sample):
|
||||
yield current_sample
|
||||
|
||||
|
||||
def tarfile_to_samples_nothrow(src, handler=wds.warn_and_continue):
|
||||
# NOTE this is a re-impl of the webdataset impl with group_by_keys that doesn't throw
|
||||
streams = url_opener(src, handler=handler)
|
||||
files = tar_file_expander(streams, handler=handler)
|
||||
samples = group_by_keys_nothrow(files, handler=handler)
|
||||
return samples
|
||||
|
||||
|
||||
def control_transform(image):
|
||||
image = np.array(image)
|
||||
|
||||
low_threshold = 100
|
||||
high_threshold = 200
|
||||
|
||||
image = cv2.Canny(image, low_threshold, high_threshold)
|
||||
image = image[:, :, None]
|
||||
image = np.concatenate([image, image, image], axis=2)
|
||||
control_image = Image.fromarray(image)
|
||||
return control_image
|
||||
|
||||
|
||||
class ImageNetTransform:
|
||||
def __init__(self, resolution, center_crop=True, random_flip=False):
|
||||
self.train_transform = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(resolution, interpolation=transforms.InterpolationMode.BILINEAR),
|
||||
transforms.CenterCrop(resolution),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5]),
|
||||
]
|
||||
)
|
||||
self.train_control_transform = transforms.Compose(
|
||||
[
|
||||
control_transform,
|
||||
transforms.Resize(resolution, interpolation=transforms.InterpolationMode.BILINEAR),
|
||||
transforms.CenterCrop(resolution),
|
||||
transforms.ToTensor(),
|
||||
]
|
||||
)
|
||||
self.eval_transform = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(resolution, interpolation=transforms.InterpolationMode.BILINEAR),
|
||||
transforms.CenterCrop(resolution),
|
||||
transforms.ToTensor(),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class Text2ImageDataset:
|
||||
def __init__(
|
||||
self,
|
||||
train_shards_path_or_url: Union[str, List[str]],
|
||||
eval_shards_path_or_url: Union[str, List[str]],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
max_seq_length: int,
|
||||
num_train_examples: int,
|
||||
per_gpu_batch_size: int,
|
||||
global_batch_size: int,
|
||||
num_workers: int,
|
||||
tokenizer_two: Optional[PreTrainedTokenizer] = None,
|
||||
resolution: int = 256,
|
||||
center_crop: bool = True,
|
||||
random_flip: bool = False,
|
||||
shuffle_buffer_size: int = 1000,
|
||||
pin_memory: bool = False,
|
||||
persistent_workers: bool = False,
|
||||
):
|
||||
transform = ImageNetTransform(resolution, center_crop, random_flip)
|
||||
|
||||
def tokenize(text):
|
||||
input_ids = tokenizer(
|
||||
text, max_length=max_seq_length, padding="max_length", truncation=True, return_tensors="pt"
|
||||
).input_ids
|
||||
return input_ids[0]
|
||||
|
||||
def tokenize_2(text):
|
||||
input_ids = tokenizer_two(
|
||||
text, max_length=max_seq_length, padding="max_length", truncation=True, return_tensors="pt"
|
||||
).input_ids
|
||||
return input_ids[0]
|
||||
|
||||
if not isinstance(train_shards_path_or_url, str):
|
||||
train_shards_path_or_url = [list(braceexpand(urls)) for urls in train_shards_path_or_url]
|
||||
# flatten list using itertools
|
||||
train_shards_path_or_url = list(itertools.chain.from_iterable(train_shards_path_or_url))
|
||||
|
||||
if not isinstance(eval_shards_path_or_url, str):
|
||||
eval_shards_path_or_url = [list(braceexpand(urls)) for urls in eval_shards_path_or_url]
|
||||
# flatten list using itertools
|
||||
eval_shards_path_or_url = list(itertools.chain.from_iterable(eval_shards_path_or_url))
|
||||
|
||||
processing_pipeline = [
|
||||
wds.decode("pil", handler=wds.ignore_and_continue),
|
||||
wds.rename(image="jpg;png;jpeg;webp", control_image="jpg;png;jpeg;webp", input_ids="text;txt;caption", input_ids_2="text;txt;caption", handler=wds.warn_and_continue),
|
||||
wds.map(filter_keys(set(["image", "control_image", "input_ids", "input_ids_2"]))),
|
||||
wds.map_dict(image=transform.train_transform, control_image=transform.train_control_transform, input_ids=tokenize, input_ids_2=tokenize_2),
|
||||
wds.to_tuple("image", "control_image", "input_ids", "input_ids_2"),
|
||||
]
|
||||
|
||||
# Create train dataset and loader
|
||||
pipeline = [
|
||||
wds.ResampledShards(train_shards_path_or_url),
|
||||
tarfile_to_samples_nothrow,
|
||||
wds.shuffle(shuffle_buffer_size),
|
||||
*processing_pipeline,
|
||||
wds.batched(per_gpu_batch_size, partial=False, collation_fn=default_collate),
|
||||
]
|
||||
|
||||
num_batches = math.ceil(num_train_examples / global_batch_size)
|
||||
num_worker_batches = math.ceil(num_train_examples / (global_batch_size * num_workers)) # per dataloader worker
|
||||
num_batches = num_worker_batches * num_workers
|
||||
num_samples = num_batches * global_batch_size
|
||||
|
||||
# each worker is iterating over this
|
||||
self._train_dataset = wds.DataPipeline(*pipeline).with_epoch(num_worker_batches)
|
||||
self._train_dataloader = wds.WebLoader(
|
||||
self._train_dataset,
|
||||
batch_size=None,
|
||||
shuffle=False,
|
||||
num_workers=num_workers,
|
||||
pin_memory=pin_memory,
|
||||
persistent_workers=persistent_workers,
|
||||
)
|
||||
# add meta-data to dataloader instance for convenience
|
||||
self._train_dataloader.num_batches = num_batches
|
||||
self._train_dataloader.num_samples = num_samples
|
||||
|
||||
# Create eval dataset and loader
|
||||
pipeline = [
|
||||
wds.SimpleShardList(eval_shards_path_or_url),
|
||||
wds.split_by_worker,
|
||||
wds.tarfile_to_samples(handler=wds.ignore_and_continue),
|
||||
*processing_pipeline,
|
||||
wds.batched(per_gpu_batch_size, partial=False, collation_fn=default_collate),
|
||||
]
|
||||
self._eval_dataset = wds.DataPipeline(*pipeline)
|
||||
self._eval_dataloader = wds.WebLoader(
|
||||
self._eval_dataset,
|
||||
batch_size=None,
|
||||
shuffle=False,
|
||||
num_workers=num_workers,
|
||||
pin_memory=pin_memory,
|
||||
persistent_workers=persistent_workers,
|
||||
)
|
||||
|
||||
@property
|
||||
def train_dataset(self):
|
||||
return self._train_dataset
|
||||
|
||||
@property
|
||||
def train_dataloader(self):
|
||||
return self._train_dataloader
|
||||
|
||||
@property
|
||||
def eval_dataset(self):
|
||||
return self._eval_dataset
|
||||
|
||||
@property
|
||||
def eval_dataloader(self):
|
||||
return self._eval_dataloader
|
||||
@@ -0,0 +1,9 @@
|
||||
accelerate>=0.16.0
|
||||
torchvision
|
||||
transformers>=4.25.1
|
||||
ftfy
|
||||
tensorboard
|
||||
Jinja2
|
||||
invisible-watermark>=0.2.0
|
||||
datasets
|
||||
wandb
|
||||
@@ -0,0 +1,28 @@
|
||||
export MODEL_DIR="stabilityai/stable-diffusion-xl-base-0.9"
|
||||
export OUTPUT_DIR="controlnet-0-9-canny"
|
||||
|
||||
# --max_train_steps=15000 \
|
||||
accelerate launch train_controlnet_webdatasets.py \
|
||||
--pretrained_model_name_or_path=$MODEL_DIR \
|
||||
--pretrained_vae_model_name_or_path="madebyollin/sdxl-vae-fp16-fix" \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--mixed_precision="fp16" \
|
||||
--resolution=1024 \
|
||||
--learning_rate=1e-5 \
|
||||
--max_train_steps=30000 \
|
||||
--max_train_samples=12000000 \
|
||||
--dataloader_num_workers=4 \
|
||||
--validation_image "./c_image_0.png" "./c_image_1.png" "./c_image_2.png" "./c_image_3.png" "./c_image_4.png" "./c_image_5.png" "./c_image_6.png" "./c_image_7.png" \
|
||||
--validation_prompt "beautiful room" "two paradise birds" "a snowy house behind a forest" "a couple watching a romantic sunset" "boats in the Amazonas" "a beautiful face of a woman" "a skater in Brooklyn" "a tornado in Iowa" \
|
||||
--train_shards_path_or_url "pipe:aws s3 cp s3://muse-datasets/laion-aesthetic6plus-data/{00000..01208}.tar -" \
|
||||
--eval_shards_path_or_url "pipe:aws s3 cp s3://muse-datasets/laion-aesthetic6plus-data/{01209..01210}.tar -" \
|
||||
--proportion_empty_prompts 0.5 \
|
||||
--validation_steps=1000 \
|
||||
--train_batch_size=12 \
|
||||
--gradient_checkpointing \
|
||||
--use_8bit_adam \
|
||||
--enable_xformers_memory_efficient_attention \
|
||||
--gradient_accumulation_steps=1 \
|
||||
--seed=42 \
|
||||
--report_to="wandb" \
|
||||
--push_to_hub
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,184 @@
|
||||
import argparse
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Literal
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import AsymmetricAutoencoderKL
|
||||
|
||||
|
||||
ASYMMETRIC_AUTOENCODER_KL_x_1_5_CONFIG = {
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"down_block_types": [
|
||||
"DownEncoderBlock2D",
|
||||
"DownEncoderBlock2D",
|
||||
"DownEncoderBlock2D",
|
||||
"DownEncoderBlock2D",
|
||||
],
|
||||
"down_block_out_channels": [128, 256, 512, 512],
|
||||
"layers_per_down_block": 2,
|
||||
"up_block_types": [
|
||||
"UpDecoderBlock2D",
|
||||
"UpDecoderBlock2D",
|
||||
"UpDecoderBlock2D",
|
||||
"UpDecoderBlock2D",
|
||||
],
|
||||
"up_block_out_channels": [192, 384, 768, 768],
|
||||
"layers_per_up_block": 3,
|
||||
"act_fn": "silu",
|
||||
"latent_channels": 4,
|
||||
"norm_num_groups": 32,
|
||||
"sample_size": 256,
|
||||
"scaling_factor": 0.18215,
|
||||
}
|
||||
|
||||
ASYMMETRIC_AUTOENCODER_KL_x_2_CONFIG = {
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"down_block_types": [
|
||||
"DownEncoderBlock2D",
|
||||
"DownEncoderBlock2D",
|
||||
"DownEncoderBlock2D",
|
||||
"DownEncoderBlock2D",
|
||||
],
|
||||
"down_block_out_channels": [128, 256, 512, 512],
|
||||
"layers_per_down_block": 2,
|
||||
"up_block_types": [
|
||||
"UpDecoderBlock2D",
|
||||
"UpDecoderBlock2D",
|
||||
"UpDecoderBlock2D",
|
||||
"UpDecoderBlock2D",
|
||||
],
|
||||
"up_block_out_channels": [256, 512, 1024, 1024],
|
||||
"layers_per_up_block": 5,
|
||||
"act_fn": "silu",
|
||||
"latent_channels": 4,
|
||||
"norm_num_groups": 32,
|
||||
"sample_size": 256,
|
||||
"scaling_factor": 0.18215,
|
||||
}
|
||||
|
||||
|
||||
def convert_asymmetric_autoencoder_kl_state_dict(original_state_dict: Dict[str, Any]) -> Dict[str, Any]:
|
||||
converted_state_dict = {}
|
||||
for k, v in original_state_dict.items():
|
||||
if k.startswith("encoder."):
|
||||
converted_state_dict[
|
||||
k.replace("encoder.down.", "encoder.down_blocks.")
|
||||
.replace("encoder.mid.", "encoder.mid_block.")
|
||||
.replace("encoder.norm_out.", "encoder.conv_norm_out.")
|
||||
.replace(".downsample.", ".downsamplers.0.")
|
||||
.replace(".nin_shortcut.", ".conv_shortcut.")
|
||||
.replace(".block.", ".resnets.")
|
||||
.replace(".block_1.", ".resnets.0.")
|
||||
.replace(".block_2.", ".resnets.1.")
|
||||
.replace(".attn_1.k.", ".attentions.0.to_k.")
|
||||
.replace(".attn_1.q.", ".attentions.0.to_q.")
|
||||
.replace(".attn_1.v.", ".attentions.0.to_v.")
|
||||
.replace(".attn_1.proj_out.", ".attentions.0.to_out.0.")
|
||||
.replace(".attn_1.norm.", ".attentions.0.group_norm.")
|
||||
] = v
|
||||
elif k.startswith("decoder.") and "up_layers" not in k:
|
||||
converted_state_dict[
|
||||
k.replace("decoder.encoder.", "decoder.condition_encoder.")
|
||||
.replace(".norm_out.", ".conv_norm_out.")
|
||||
.replace(".up.0.", ".up_blocks.3.")
|
||||
.replace(".up.1.", ".up_blocks.2.")
|
||||
.replace(".up.2.", ".up_blocks.1.")
|
||||
.replace(".up.3.", ".up_blocks.0.")
|
||||
.replace(".block.", ".resnets.")
|
||||
.replace("mid", "mid_block")
|
||||
.replace(".0.upsample.", ".0.upsamplers.0.")
|
||||
.replace(".1.upsample.", ".1.upsamplers.0.")
|
||||
.replace(".2.upsample.", ".2.upsamplers.0.")
|
||||
.replace(".nin_shortcut.", ".conv_shortcut.")
|
||||
.replace(".block_1.", ".resnets.0.")
|
||||
.replace(".block_2.", ".resnets.1.")
|
||||
.replace(".attn_1.k.", ".attentions.0.to_k.")
|
||||
.replace(".attn_1.q.", ".attentions.0.to_q.")
|
||||
.replace(".attn_1.v.", ".attentions.0.to_v.")
|
||||
.replace(".attn_1.proj_out.", ".attentions.0.to_out.0.")
|
||||
.replace(".attn_1.norm.", ".attentions.0.group_norm.")
|
||||
] = v
|
||||
elif k.startswith("quant_conv."):
|
||||
converted_state_dict[k] = v
|
||||
elif k.startswith("post_quant_conv."):
|
||||
converted_state_dict[k] = v
|
||||
else:
|
||||
print(f" skipping key `{k}`")
|
||||
# fix weights shape
|
||||
for k, v in converted_state_dict.items():
|
||||
if (
|
||||
(k.startswith("encoder.mid_block.attentions.0") or k.startswith("decoder.mid_block.attentions.0"))
|
||||
and k.endswith("weight")
|
||||
and ("to_q" in k or "to_k" in k or "to_v" in k or "to_out" in k)
|
||||
):
|
||||
converted_state_dict[k] = converted_state_dict[k][:, :, 0, 0]
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def get_asymmetric_autoencoder_kl_from_original_checkpoint(
|
||||
scale: Literal["1.5", "2"], original_checkpoint_path: str, map_location: torch.device
|
||||
) -> AsymmetricAutoencoderKL:
|
||||
print("Loading original state_dict")
|
||||
original_state_dict = torch.load(original_checkpoint_path, map_location=map_location)
|
||||
original_state_dict = original_state_dict["state_dict"]
|
||||
print("Converting state_dict")
|
||||
converted_state_dict = convert_asymmetric_autoencoder_kl_state_dict(original_state_dict)
|
||||
kwargs = ASYMMETRIC_AUTOENCODER_KL_x_1_5_CONFIG if scale == "1.5" else ASYMMETRIC_AUTOENCODER_KL_x_2_CONFIG
|
||||
print("Initializing AsymmetricAutoencoderKL model")
|
||||
asymmetric_autoencoder_kl = AsymmetricAutoencoderKL(**kwargs)
|
||||
print("Loading weight from converted state_dict")
|
||||
asymmetric_autoencoder_kl.load_state_dict(converted_state_dict)
|
||||
asymmetric_autoencoder_kl.eval()
|
||||
print("AsymmetricAutoencoderKL successfully initialized")
|
||||
return asymmetric_autoencoder_kl
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
start = time.time()
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--scale",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Asymmetric VQGAN scale: `1.5` or `2`",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--original_checkpoint_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the original Asymmetric VQGAN checkpoint",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to save pretrained AsymmetricAutoencoderKL model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--map_location",
|
||||
default="cpu",
|
||||
type=str,
|
||||
required=False,
|
||||
help="The device passed to `map_location` when loading the checkpoint",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
assert args.scale in ["1.5", "2"], f"{args.scale} should be `1.5` of `2`"
|
||||
assert Path(args.original_checkpoint_path).is_file()
|
||||
|
||||
asymmetric_autoencoder_kl = get_asymmetric_autoencoder_kl_from_original_checkpoint(
|
||||
scale=args.scale,
|
||||
original_checkpoint_path=args.original_checkpoint_path,
|
||||
map_location=torch.device(args.map_location),
|
||||
)
|
||||
print("Saving pretrained AsymmetricAutoencoderKL")
|
||||
asymmetric_autoencoder_kl.save_pretrained(args.output_path)
|
||||
print(f"Done in {time.time() - start:.2f} seconds")
|
||||
@@ -22,7 +22,7 @@ $ python scripts/convert_shap_e_to_diffusers.py \
|
||||
--prior_checkpoint_path /home/yiyi_huggingface_co/shap-e/shap_e_model_cache/text_cond.pt \
|
||||
--prior_image_checkpoint_path /home/yiyi_huggingface_co/shap-e/shap_e_model_cache/image_cond.pt \
|
||||
--transmitter_checkpoint_path /home/yiyi_huggingface_co/shap-e/shap_e_model_cache/transmitter.pt\
|
||||
--dump_path /home/yiyi_huggingface_co/model_repo/shap-e/renderer\
|
||||
--dump_path /home/yiyi_huggingface_co/model_repo/shap-e-img2img/shap_e_renderer\
|
||||
--debug renderer
|
||||
```
|
||||
"""
|
||||
@@ -373,6 +373,487 @@ def prior_image_original_checkpoint_to_diffusers_checkpoint(model, checkpoint):
|
||||
|
||||
# renderer
|
||||
|
||||
## create the lookup table for marching cubes method used in MeshDecoder
|
||||
|
||||
MC_TABLE = [
|
||||
[],
|
||||
[[0, 1, 0, 2, 0, 4]],
|
||||
[[1, 0, 1, 5, 1, 3]],
|
||||
[[0, 4, 1, 5, 0, 2], [1, 5, 1, 3, 0, 2]],
|
||||
[[2, 0, 2, 3, 2, 6]],
|
||||
[[0, 1, 2, 3, 0, 4], [2, 3, 2, 6, 0, 4]],
|
||||
[[1, 0, 1, 5, 1, 3], [2, 6, 0, 2, 3, 2]],
|
||||
[[3, 2, 2, 6, 3, 1], [3, 1, 2, 6, 1, 5], [1, 5, 2, 6, 0, 4]],
|
||||
[[3, 1, 3, 7, 3, 2]],
|
||||
[[0, 2, 0, 4, 0, 1], [3, 7, 2, 3, 1, 3]],
|
||||
[[1, 5, 3, 7, 1, 0], [3, 7, 3, 2, 1, 0]],
|
||||
[[2, 0, 0, 4, 2, 3], [2, 3, 0, 4, 3, 7], [3, 7, 0, 4, 1, 5]],
|
||||
[[2, 0, 3, 1, 2, 6], [3, 1, 3, 7, 2, 6]],
|
||||
[[1, 3, 3, 7, 1, 0], [1, 0, 3, 7, 0, 4], [0, 4, 3, 7, 2, 6]],
|
||||
[[0, 1, 1, 5, 0, 2], [0, 2, 1, 5, 2, 6], [2, 6, 1, 5, 3, 7]],
|
||||
[[0, 4, 1, 5, 3, 7], [0, 4, 3, 7, 2, 6]],
|
||||
[[4, 0, 4, 6, 4, 5]],
|
||||
[[0, 2, 4, 6, 0, 1], [4, 6, 4, 5, 0, 1]],
|
||||
[[1, 5, 1, 3, 1, 0], [4, 6, 5, 4, 0, 4]],
|
||||
[[5, 1, 1, 3, 5, 4], [5, 4, 1, 3, 4, 6], [4, 6, 1, 3, 0, 2]],
|
||||
[[2, 0, 2, 3, 2, 6], [4, 5, 0, 4, 6, 4]],
|
||||
[[6, 4, 4, 5, 6, 2], [6, 2, 4, 5, 2, 3], [2, 3, 4, 5, 0, 1]],
|
||||
[[2, 6, 2, 0, 3, 2], [1, 0, 1, 5, 3, 1], [6, 4, 5, 4, 0, 4]],
|
||||
[[1, 3, 5, 4, 1, 5], [1, 3, 4, 6, 5, 4], [1, 3, 3, 2, 4, 6], [3, 2, 2, 6, 4, 6]],
|
||||
[[3, 1, 3, 7, 3, 2], [6, 4, 5, 4, 0, 4]],
|
||||
[[4, 5, 0, 1, 4, 6], [0, 1, 0, 2, 4, 6], [7, 3, 2, 3, 1, 3]],
|
||||
[[3, 2, 1, 0, 3, 7], [1, 0, 1, 5, 3, 7], [6, 4, 5, 4, 0, 4]],
|
||||
[[3, 7, 3, 2, 1, 5], [3, 2, 6, 4, 1, 5], [1, 5, 6, 4, 5, 4], [3, 2, 2, 0, 6, 4]],
|
||||
[[3, 7, 2, 6, 3, 1], [2, 6, 2, 0, 3, 1], [5, 4, 0, 4, 6, 4]],
|
||||
[[1, 0, 1, 3, 5, 4], [1, 3, 2, 6, 5, 4], [1, 3, 3, 7, 2, 6], [5, 4, 2, 6, 4, 6]],
|
||||
[[0, 1, 1, 5, 0, 2], [0, 2, 1, 5, 2, 6], [2, 6, 1, 5, 3, 7], [4, 5, 0, 4, 4, 6]],
|
||||
[[6, 2, 4, 6, 4, 5], [4, 5, 5, 1, 6, 2], [6, 2, 5, 1, 7, 3]],
|
||||
[[5, 1, 5, 4, 5, 7]],
|
||||
[[0, 1, 0, 2, 0, 4], [5, 7, 1, 5, 4, 5]],
|
||||
[[1, 0, 5, 4, 1, 3], [5, 4, 5, 7, 1, 3]],
|
||||
[[4, 5, 5, 7, 4, 0], [4, 0, 5, 7, 0, 2], [0, 2, 5, 7, 1, 3]],
|
||||
[[2, 0, 2, 3, 2, 6], [7, 5, 1, 5, 4, 5]],
|
||||
[[2, 6, 0, 4, 2, 3], [0, 4, 0, 1, 2, 3], [7, 5, 1, 5, 4, 5]],
|
||||
[[5, 7, 1, 3, 5, 4], [1, 3, 1, 0, 5, 4], [6, 2, 0, 2, 3, 2]],
|
||||
[[3, 1, 3, 2, 7, 5], [3, 2, 0, 4, 7, 5], [3, 2, 2, 6, 0, 4], [7, 5, 0, 4, 5, 4]],
|
||||
[[3, 7, 3, 2, 3, 1], [5, 4, 7, 5, 1, 5]],
|
||||
[[0, 4, 0, 1, 2, 0], [3, 1, 3, 7, 2, 3], [4, 5, 7, 5, 1, 5]],
|
||||
[[7, 3, 3, 2, 7, 5], [7, 5, 3, 2, 5, 4], [5, 4, 3, 2, 1, 0]],
|
||||
[[0, 4, 2, 3, 0, 2], [0, 4, 3, 7, 2, 3], [0, 4, 4, 5, 3, 7], [4, 5, 5, 7, 3, 7]],
|
||||
[[2, 0, 3, 1, 2, 6], [3, 1, 3, 7, 2, 6], [4, 5, 7, 5, 1, 5]],
|
||||
[[1, 3, 3, 7, 1, 0], [1, 0, 3, 7, 0, 4], [0, 4, 3, 7, 2, 6], [5, 7, 1, 5, 5, 4]],
|
||||
[[2, 6, 2, 0, 3, 7], [2, 0, 4, 5, 3, 7], [3, 7, 4, 5, 7, 5], [2, 0, 0, 1, 4, 5]],
|
||||
[[4, 0, 5, 4, 5, 7], [5, 7, 7, 3, 4, 0], [4, 0, 7, 3, 6, 2]],
|
||||
[[4, 6, 5, 7, 4, 0], [5, 7, 5, 1, 4, 0]],
|
||||
[[1, 0, 0, 2, 1, 5], [1, 5, 0, 2, 5, 7], [5, 7, 0, 2, 4, 6]],
|
||||
[[0, 4, 4, 6, 0, 1], [0, 1, 4, 6, 1, 3], [1, 3, 4, 6, 5, 7]],
|
||||
[[0, 2, 4, 6, 5, 7], [0, 2, 5, 7, 1, 3]],
|
||||
[[5, 1, 4, 0, 5, 7], [4, 0, 4, 6, 5, 7], [3, 2, 6, 2, 0, 2]],
|
||||
[[2, 3, 2, 6, 0, 1], [2, 6, 7, 5, 0, 1], [0, 1, 7, 5, 1, 5], [2, 6, 6, 4, 7, 5]],
|
||||
[[0, 4, 4, 6, 0, 1], [0, 1, 4, 6, 1, 3], [1, 3, 4, 6, 5, 7], [2, 6, 0, 2, 2, 3]],
|
||||
[[3, 1, 2, 3, 2, 6], [2, 6, 6, 4, 3, 1], [3, 1, 6, 4, 7, 5]],
|
||||
[[4, 6, 5, 7, 4, 0], [5, 7, 5, 1, 4, 0], [2, 3, 1, 3, 7, 3]],
|
||||
[[1, 0, 0, 2, 1, 5], [1, 5, 0, 2, 5, 7], [5, 7, 0, 2, 4, 6], [3, 2, 1, 3, 3, 7]],
|
||||
[[0, 1, 0, 4, 2, 3], [0, 4, 5, 7, 2, 3], [0, 4, 4, 6, 5, 7], [2, 3, 5, 7, 3, 7]],
|
||||
[[7, 5, 3, 7, 3, 2], [3, 2, 2, 0, 7, 5], [7, 5, 2, 0, 6, 4]],
|
||||
[[0, 4, 4, 6, 5, 7], [0, 4, 5, 7, 1, 5], [0, 2, 1, 3, 3, 7], [3, 7, 2, 6, 0, 2]],
|
||||
[
|
||||
[3, 1, 7, 3, 6, 2],
|
||||
[6, 2, 0, 1, 3, 1],
|
||||
[6, 4, 0, 1, 6, 2],
|
||||
[6, 4, 5, 1, 0, 1],
|
||||
[6, 4, 7, 5, 5, 1],
|
||||
],
|
||||
[
|
||||
[4, 0, 6, 4, 7, 5],
|
||||
[7, 5, 1, 0, 4, 0],
|
||||
[7, 3, 1, 0, 7, 5],
|
||||
[7, 3, 2, 0, 1, 0],
|
||||
[7, 3, 6, 2, 2, 0],
|
||||
],
|
||||
[[7, 3, 6, 2, 6, 4], [7, 5, 7, 3, 6, 4]],
|
||||
[[6, 2, 6, 7, 6, 4]],
|
||||
[[0, 4, 0, 1, 0, 2], [6, 7, 4, 6, 2, 6]],
|
||||
[[1, 0, 1, 5, 1, 3], [7, 6, 4, 6, 2, 6]],
|
||||
[[1, 3, 0, 2, 1, 5], [0, 2, 0, 4, 1, 5], [7, 6, 4, 6, 2, 6]],
|
||||
[[2, 3, 6, 7, 2, 0], [6, 7, 6, 4, 2, 0]],
|
||||
[[4, 0, 0, 1, 4, 6], [4, 6, 0, 1, 6, 7], [6, 7, 0, 1, 2, 3]],
|
||||
[[6, 4, 2, 0, 6, 7], [2, 0, 2, 3, 6, 7], [5, 1, 3, 1, 0, 1]],
|
||||
[[1, 5, 1, 3, 0, 4], [1, 3, 7, 6, 0, 4], [0, 4, 7, 6, 4, 6], [1, 3, 3, 2, 7, 6]],
|
||||
[[3, 2, 3, 1, 3, 7], [6, 4, 2, 6, 7, 6]],
|
||||
[[3, 7, 3, 2, 1, 3], [0, 2, 0, 4, 1, 0], [7, 6, 4, 6, 2, 6]],
|
||||
[[1, 5, 3, 7, 1, 0], [3, 7, 3, 2, 1, 0], [4, 6, 2, 6, 7, 6]],
|
||||
[[2, 0, 0, 4, 2, 3], [2, 3, 0, 4, 3, 7], [3, 7, 0, 4, 1, 5], [6, 4, 2, 6, 6, 7]],
|
||||
[[7, 6, 6, 4, 7, 3], [7, 3, 6, 4, 3, 1], [3, 1, 6, 4, 2, 0]],
|
||||
[[0, 1, 4, 6, 0, 4], [0, 1, 6, 7, 4, 6], [0, 1, 1, 3, 6, 7], [1, 3, 3, 7, 6, 7]],
|
||||
[[0, 2, 0, 1, 4, 6], [0, 1, 3, 7, 4, 6], [0, 1, 1, 5, 3, 7], [4, 6, 3, 7, 6, 7]],
|
||||
[[7, 3, 6, 7, 6, 4], [6, 4, 4, 0, 7, 3], [7, 3, 4, 0, 5, 1]],
|
||||
[[4, 0, 6, 2, 4, 5], [6, 2, 6, 7, 4, 5]],
|
||||
[[2, 6, 6, 7, 2, 0], [2, 0, 6, 7, 0, 1], [0, 1, 6, 7, 4, 5]],
|
||||
[[6, 7, 4, 5, 6, 2], [4, 5, 4, 0, 6, 2], [3, 1, 0, 1, 5, 1]],
|
||||
[[2, 0, 2, 6, 3, 1], [2, 6, 4, 5, 3, 1], [2, 6, 6, 7, 4, 5], [3, 1, 4, 5, 1, 5]],
|
||||
[[0, 2, 2, 3, 0, 4], [0, 4, 2, 3, 4, 5], [4, 5, 2, 3, 6, 7]],
|
||||
[[0, 1, 2, 3, 6, 7], [0, 1, 6, 7, 4, 5]],
|
||||
[[0, 2, 2, 3, 0, 4], [0, 4, 2, 3, 4, 5], [4, 5, 2, 3, 6, 7], [1, 3, 0, 1, 1, 5]],
|
||||
[[5, 4, 1, 5, 1, 3], [1, 3, 3, 2, 5, 4], [5, 4, 3, 2, 7, 6]],
|
||||
[[4, 0, 6, 2, 4, 5], [6, 2, 6, 7, 4, 5], [1, 3, 7, 3, 2, 3]],
|
||||
[[2, 6, 6, 7, 2, 0], [2, 0, 6, 7, 0, 1], [0, 1, 6, 7, 4, 5], [3, 7, 2, 3, 3, 1]],
|
||||
[[0, 1, 1, 5, 3, 7], [0, 1, 3, 7, 2, 3], [0, 4, 2, 6, 6, 7], [6, 7, 4, 5, 0, 4]],
|
||||
[
|
||||
[6, 2, 7, 6, 5, 4],
|
||||
[5, 4, 0, 2, 6, 2],
|
||||
[5, 1, 0, 2, 5, 4],
|
||||
[5, 1, 3, 2, 0, 2],
|
||||
[5, 1, 7, 3, 3, 2],
|
||||
],
|
||||
[[3, 1, 3, 7, 2, 0], [3, 7, 5, 4, 2, 0], [2, 0, 5, 4, 0, 4], [3, 7, 7, 6, 5, 4]],
|
||||
[[1, 0, 3, 1, 3, 7], [3, 7, 7, 6, 1, 0], [1, 0, 7, 6, 5, 4]],
|
||||
[
|
||||
[1, 0, 5, 1, 7, 3],
|
||||
[7, 3, 2, 0, 1, 0],
|
||||
[7, 6, 2, 0, 7, 3],
|
||||
[7, 6, 4, 0, 2, 0],
|
||||
[7, 6, 5, 4, 4, 0],
|
||||
],
|
||||
[[7, 6, 5, 4, 5, 1], [7, 3, 7, 6, 5, 1]],
|
||||
[[5, 7, 5, 1, 5, 4], [6, 2, 7, 6, 4, 6]],
|
||||
[[0, 2, 0, 4, 1, 0], [5, 4, 5, 7, 1, 5], [2, 6, 7, 6, 4, 6]],
|
||||
[[1, 0, 5, 4, 1, 3], [5, 4, 5, 7, 1, 3], [2, 6, 7, 6, 4, 6]],
|
||||
[[4, 5, 5, 7, 4, 0], [4, 0, 5, 7, 0, 2], [0, 2, 5, 7, 1, 3], [6, 7, 4, 6, 6, 2]],
|
||||
[[2, 3, 6, 7, 2, 0], [6, 7, 6, 4, 2, 0], [1, 5, 4, 5, 7, 5]],
|
||||
[[4, 0, 0, 1, 4, 6], [4, 6, 0, 1, 6, 7], [6, 7, 0, 1, 2, 3], [5, 1, 4, 5, 5, 7]],
|
||||
[[0, 2, 2, 3, 6, 7], [0, 2, 6, 7, 4, 6], [0, 1, 4, 5, 5, 7], [5, 7, 1, 3, 0, 1]],
|
||||
[
|
||||
[5, 4, 7, 5, 3, 1],
|
||||
[3, 1, 0, 4, 5, 4],
|
||||
[3, 2, 0, 4, 3, 1],
|
||||
[3, 2, 6, 4, 0, 4],
|
||||
[3, 2, 7, 6, 6, 4],
|
||||
],
|
||||
[[5, 4, 5, 7, 1, 5], [3, 7, 3, 2, 1, 3], [4, 6, 2, 6, 7, 6]],
|
||||
[[1, 0, 0, 2, 0, 4], [1, 5, 5, 4, 5, 7], [3, 2, 1, 3, 3, 7], [2, 6, 7, 6, 4, 6]],
|
||||
[[7, 3, 3, 2, 7, 5], [7, 5, 3, 2, 5, 4], [5, 4, 3, 2, 1, 0], [6, 2, 7, 6, 6, 4]],
|
||||
[
|
||||
[0, 4, 2, 3, 0, 2],
|
||||
[0, 4, 3, 7, 2, 3],
|
||||
[0, 4, 4, 5, 3, 7],
|
||||
[4, 5, 5, 7, 3, 7],
|
||||
[6, 7, 4, 6, 2, 6],
|
||||
],
|
||||
[[7, 6, 6, 4, 7, 3], [7, 3, 6, 4, 3, 1], [3, 1, 6, 4, 2, 0], [5, 4, 7, 5, 5, 1]],
|
||||
[
|
||||
[0, 1, 4, 6, 0, 4],
|
||||
[0, 1, 6, 7, 4, 6],
|
||||
[0, 1, 1, 3, 6, 7],
|
||||
[1, 3, 3, 7, 6, 7],
|
||||
[5, 7, 1, 5, 4, 5],
|
||||
],
|
||||
[
|
||||
[6, 7, 4, 6, 0, 2],
|
||||
[0, 2, 3, 7, 6, 7],
|
||||
[0, 1, 3, 7, 0, 2],
|
||||
[0, 1, 5, 7, 3, 7],
|
||||
[0, 1, 4, 5, 5, 7],
|
||||
],
|
||||
[[4, 0, 6, 7, 4, 6], [4, 0, 7, 3, 6, 7], [4, 0, 5, 7, 7, 3], [4, 5, 5, 7, 4, 0]],
|
||||
[[7, 5, 5, 1, 7, 6], [7, 6, 5, 1, 6, 2], [6, 2, 5, 1, 4, 0]],
|
||||
[[0, 2, 1, 5, 0, 1], [0, 2, 5, 7, 1, 5], [0, 2, 2, 6, 5, 7], [2, 6, 6, 7, 5, 7]],
|
||||
[[1, 3, 1, 0, 5, 7], [1, 0, 2, 6, 5, 7], [5, 7, 2, 6, 7, 6], [1, 0, 0, 4, 2, 6]],
|
||||
[[2, 0, 6, 2, 6, 7], [6, 7, 7, 5, 2, 0], [2, 0, 7, 5, 3, 1]],
|
||||
[[0, 4, 0, 2, 1, 5], [0, 2, 6, 7, 1, 5], [0, 2, 2, 3, 6, 7], [1, 5, 6, 7, 5, 7]],
|
||||
[[7, 6, 5, 7, 5, 1], [5, 1, 1, 0, 7, 6], [7, 6, 1, 0, 3, 2]],
|
||||
[
|
||||
[2, 0, 3, 2, 7, 6],
|
||||
[7, 6, 4, 0, 2, 0],
|
||||
[7, 5, 4, 0, 7, 6],
|
||||
[7, 5, 1, 0, 4, 0],
|
||||
[7, 5, 3, 1, 1, 0],
|
||||
],
|
||||
[[7, 5, 3, 1, 3, 2], [7, 6, 7, 5, 3, 2]],
|
||||
[[7, 5, 5, 1, 7, 6], [7, 6, 5, 1, 6, 2], [6, 2, 5, 1, 4, 0], [3, 1, 7, 3, 3, 2]],
|
||||
[
|
||||
[0, 2, 1, 5, 0, 1],
|
||||
[0, 2, 5, 7, 1, 5],
|
||||
[0, 2, 2, 6, 5, 7],
|
||||
[2, 6, 6, 7, 5, 7],
|
||||
[3, 7, 2, 3, 1, 3],
|
||||
],
|
||||
[
|
||||
[3, 7, 2, 3, 0, 1],
|
||||
[0, 1, 5, 7, 3, 7],
|
||||
[0, 4, 5, 7, 0, 1],
|
||||
[0, 4, 6, 7, 5, 7],
|
||||
[0, 4, 2, 6, 6, 7],
|
||||
],
|
||||
[[2, 0, 3, 7, 2, 3], [2, 0, 7, 5, 3, 7], [2, 0, 6, 7, 7, 5], [2, 6, 6, 7, 2, 0]],
|
||||
[
|
||||
[5, 7, 1, 5, 0, 4],
|
||||
[0, 4, 6, 7, 5, 7],
|
||||
[0, 2, 6, 7, 0, 4],
|
||||
[0, 2, 3, 7, 6, 7],
|
||||
[0, 2, 1, 3, 3, 7],
|
||||
],
|
||||
[[1, 0, 5, 7, 1, 5], [1, 0, 7, 6, 5, 7], [1, 0, 3, 7, 7, 6], [1, 3, 3, 7, 1, 0]],
|
||||
[[0, 2, 0, 1, 0, 4], [3, 7, 6, 7, 5, 7]],
|
||||
[[7, 5, 7, 3, 7, 6]],
|
||||
[[7, 3, 7, 5, 7, 6]],
|
||||
[[0, 1, 0, 2, 0, 4], [6, 7, 3, 7, 5, 7]],
|
||||
[[1, 3, 1, 0, 1, 5], [7, 6, 3, 7, 5, 7]],
|
||||
[[0, 4, 1, 5, 0, 2], [1, 5, 1, 3, 0, 2], [6, 7, 3, 7, 5, 7]],
|
||||
[[2, 6, 2, 0, 2, 3], [7, 5, 6, 7, 3, 7]],
|
||||
[[0, 1, 2, 3, 0, 4], [2, 3, 2, 6, 0, 4], [5, 7, 6, 7, 3, 7]],
|
||||
[[1, 5, 1, 3, 0, 1], [2, 3, 2, 6, 0, 2], [5, 7, 6, 7, 3, 7]],
|
||||
[[3, 2, 2, 6, 3, 1], [3, 1, 2, 6, 1, 5], [1, 5, 2, 6, 0, 4], [7, 6, 3, 7, 7, 5]],
|
||||
[[3, 1, 7, 5, 3, 2], [7, 5, 7, 6, 3, 2]],
|
||||
[[7, 6, 3, 2, 7, 5], [3, 2, 3, 1, 7, 5], [4, 0, 1, 0, 2, 0]],
|
||||
[[5, 7, 7, 6, 5, 1], [5, 1, 7, 6, 1, 0], [1, 0, 7, 6, 3, 2]],
|
||||
[[2, 3, 2, 0, 6, 7], [2, 0, 1, 5, 6, 7], [2, 0, 0, 4, 1, 5], [6, 7, 1, 5, 7, 5]],
|
||||
[[6, 2, 2, 0, 6, 7], [6, 7, 2, 0, 7, 5], [7, 5, 2, 0, 3, 1]],
|
||||
[[0, 4, 0, 1, 2, 6], [0, 1, 5, 7, 2, 6], [2, 6, 5, 7, 6, 7], [0, 1, 1, 3, 5, 7]],
|
||||
[[1, 5, 0, 2, 1, 0], [1, 5, 2, 6, 0, 2], [1, 5, 5, 7, 2, 6], [5, 7, 7, 6, 2, 6]],
|
||||
[[5, 1, 7, 5, 7, 6], [7, 6, 6, 2, 5, 1], [5, 1, 6, 2, 4, 0]],
|
||||
[[4, 5, 4, 0, 4, 6], [7, 3, 5, 7, 6, 7]],
|
||||
[[0, 2, 4, 6, 0, 1], [4, 6, 4, 5, 0, 1], [3, 7, 5, 7, 6, 7]],
|
||||
[[4, 6, 4, 5, 0, 4], [1, 5, 1, 3, 0, 1], [6, 7, 3, 7, 5, 7]],
|
||||
[[5, 1, 1, 3, 5, 4], [5, 4, 1, 3, 4, 6], [4, 6, 1, 3, 0, 2], [7, 3, 5, 7, 7, 6]],
|
||||
[[2, 3, 2, 6, 0, 2], [4, 6, 4, 5, 0, 4], [3, 7, 5, 7, 6, 7]],
|
||||
[[6, 4, 4, 5, 6, 2], [6, 2, 4, 5, 2, 3], [2, 3, 4, 5, 0, 1], [7, 5, 6, 7, 7, 3]],
|
||||
[[0, 1, 1, 5, 1, 3], [0, 2, 2, 3, 2, 6], [4, 5, 0, 4, 4, 6], [5, 7, 6, 7, 3, 7]],
|
||||
[
|
||||
[1, 3, 5, 4, 1, 5],
|
||||
[1, 3, 4, 6, 5, 4],
|
||||
[1, 3, 3, 2, 4, 6],
|
||||
[3, 2, 2, 6, 4, 6],
|
||||
[7, 6, 3, 7, 5, 7],
|
||||
],
|
||||
[[3, 1, 7, 5, 3, 2], [7, 5, 7, 6, 3, 2], [0, 4, 6, 4, 5, 4]],
|
||||
[[1, 0, 0, 2, 4, 6], [1, 0, 4, 6, 5, 4], [1, 3, 5, 7, 7, 6], [7, 6, 3, 2, 1, 3]],
|
||||
[[5, 7, 7, 6, 5, 1], [5, 1, 7, 6, 1, 0], [1, 0, 7, 6, 3, 2], [4, 6, 5, 4, 4, 0]],
|
||||
[
|
||||
[7, 5, 6, 7, 2, 3],
|
||||
[2, 3, 1, 5, 7, 5],
|
||||
[2, 0, 1, 5, 2, 3],
|
||||
[2, 0, 4, 5, 1, 5],
|
||||
[2, 0, 6, 4, 4, 5],
|
||||
],
|
||||
[[6, 2, 2, 0, 6, 7], [6, 7, 2, 0, 7, 5], [7, 5, 2, 0, 3, 1], [4, 0, 6, 4, 4, 5]],
|
||||
[
|
||||
[4, 6, 5, 4, 1, 0],
|
||||
[1, 0, 2, 6, 4, 6],
|
||||
[1, 3, 2, 6, 1, 0],
|
||||
[1, 3, 7, 6, 2, 6],
|
||||
[1, 3, 5, 7, 7, 6],
|
||||
],
|
||||
[
|
||||
[1, 5, 0, 2, 1, 0],
|
||||
[1, 5, 2, 6, 0, 2],
|
||||
[1, 5, 5, 7, 2, 6],
|
||||
[5, 7, 7, 6, 2, 6],
|
||||
[4, 6, 5, 4, 0, 4],
|
||||
],
|
||||
[[5, 1, 4, 6, 5, 4], [5, 1, 6, 2, 4, 6], [5, 1, 7, 6, 6, 2], [5, 7, 7, 6, 5, 1]],
|
||||
[[5, 4, 7, 6, 5, 1], [7, 6, 7, 3, 5, 1]],
|
||||
[[7, 3, 5, 1, 7, 6], [5, 1, 5, 4, 7, 6], [2, 0, 4, 0, 1, 0]],
|
||||
[[3, 1, 1, 0, 3, 7], [3, 7, 1, 0, 7, 6], [7, 6, 1, 0, 5, 4]],
|
||||
[[0, 2, 0, 4, 1, 3], [0, 4, 6, 7, 1, 3], [1, 3, 6, 7, 3, 7], [0, 4, 4, 5, 6, 7]],
|
||||
[[5, 4, 7, 6, 5, 1], [7, 6, 7, 3, 5, 1], [0, 2, 3, 2, 6, 2]],
|
||||
[[1, 5, 5, 4, 7, 6], [1, 5, 7, 6, 3, 7], [1, 0, 3, 2, 2, 6], [2, 6, 0, 4, 1, 0]],
|
||||
[[3, 1, 1, 0, 3, 7], [3, 7, 1, 0, 7, 6], [7, 6, 1, 0, 5, 4], [2, 0, 3, 2, 2, 6]],
|
||||
[
|
||||
[2, 3, 6, 2, 4, 0],
|
||||
[4, 0, 1, 3, 2, 3],
|
||||
[4, 5, 1, 3, 4, 0],
|
||||
[4, 5, 7, 3, 1, 3],
|
||||
[4, 5, 6, 7, 7, 3],
|
||||
],
|
||||
[[1, 5, 5, 4, 1, 3], [1, 3, 5, 4, 3, 2], [3, 2, 5, 4, 7, 6]],
|
||||
[[1, 5, 5, 4, 1, 3], [1, 3, 5, 4, 3, 2], [3, 2, 5, 4, 7, 6], [0, 4, 1, 0, 0, 2]],
|
||||
[[1, 0, 5, 4, 7, 6], [1, 0, 7, 6, 3, 2]],
|
||||
[[2, 3, 0, 2, 0, 4], [0, 4, 4, 5, 2, 3], [2, 3, 4, 5, 6, 7]],
|
||||
[[1, 3, 1, 5, 0, 2], [1, 5, 7, 6, 0, 2], [1, 5, 5, 4, 7, 6], [0, 2, 7, 6, 2, 6]],
|
||||
[
|
||||
[5, 1, 4, 5, 6, 7],
|
||||
[6, 7, 3, 1, 5, 1],
|
||||
[6, 2, 3, 1, 6, 7],
|
||||
[6, 2, 0, 1, 3, 1],
|
||||
[6, 2, 4, 0, 0, 1],
|
||||
],
|
||||
[[6, 7, 2, 6, 2, 0], [2, 0, 0, 1, 6, 7], [6, 7, 0, 1, 4, 5]],
|
||||
[[6, 2, 4, 0, 4, 5], [6, 7, 6, 2, 4, 5]],
|
||||
[[6, 7, 7, 3, 6, 4], [6, 4, 7, 3, 4, 0], [4, 0, 7, 3, 5, 1]],
|
||||
[[1, 5, 1, 0, 3, 7], [1, 0, 4, 6, 3, 7], [1, 0, 0, 2, 4, 6], [3, 7, 4, 6, 7, 6]],
|
||||
[[1, 0, 3, 7, 1, 3], [1, 0, 7, 6, 3, 7], [1, 0, 0, 4, 7, 6], [0, 4, 4, 6, 7, 6]],
|
||||
[[6, 4, 7, 6, 7, 3], [7, 3, 3, 1, 6, 4], [6, 4, 3, 1, 2, 0]],
|
||||
[[6, 7, 7, 3, 6, 4], [6, 4, 7, 3, 4, 0], [4, 0, 7, 3, 5, 1], [2, 3, 6, 2, 2, 0]],
|
||||
[
|
||||
[7, 6, 3, 7, 1, 5],
|
||||
[1, 5, 4, 6, 7, 6],
|
||||
[1, 0, 4, 6, 1, 5],
|
||||
[1, 0, 2, 6, 4, 6],
|
||||
[1, 0, 3, 2, 2, 6],
|
||||
],
|
||||
[
|
||||
[1, 0, 3, 7, 1, 3],
|
||||
[1, 0, 7, 6, 3, 7],
|
||||
[1, 0, 0, 4, 7, 6],
|
||||
[0, 4, 4, 6, 7, 6],
|
||||
[2, 6, 0, 2, 3, 2],
|
||||
],
|
||||
[[3, 1, 7, 6, 3, 7], [3, 1, 6, 4, 7, 6], [3, 1, 2, 6, 6, 4], [3, 2, 2, 6, 3, 1]],
|
||||
[[3, 2, 3, 1, 7, 6], [3, 1, 0, 4, 7, 6], [7, 6, 0, 4, 6, 4], [3, 1, 1, 5, 0, 4]],
|
||||
[
|
||||
[0, 1, 2, 0, 6, 4],
|
||||
[6, 4, 5, 1, 0, 1],
|
||||
[6, 7, 5, 1, 6, 4],
|
||||
[6, 7, 3, 1, 5, 1],
|
||||
[6, 7, 2, 3, 3, 1],
|
||||
],
|
||||
[[0, 1, 4, 0, 4, 6], [4, 6, 6, 7, 0, 1], [0, 1, 6, 7, 2, 3]],
|
||||
[[6, 7, 2, 3, 2, 0], [6, 4, 6, 7, 2, 0]],
|
||||
[
|
||||
[2, 6, 0, 2, 1, 3],
|
||||
[1, 3, 7, 6, 2, 6],
|
||||
[1, 5, 7, 6, 1, 3],
|
||||
[1, 5, 4, 6, 7, 6],
|
||||
[1, 5, 0, 4, 4, 6],
|
||||
],
|
||||
[[1, 5, 1, 0, 1, 3], [4, 6, 7, 6, 2, 6]],
|
||||
[[0, 1, 2, 6, 0, 2], [0, 1, 6, 7, 2, 6], [0, 1, 4, 6, 6, 7], [0, 4, 4, 6, 0, 1]],
|
||||
[[6, 7, 6, 2, 6, 4]],
|
||||
[[6, 2, 7, 3, 6, 4], [7, 3, 7, 5, 6, 4]],
|
||||
[[7, 5, 6, 4, 7, 3], [6, 4, 6, 2, 7, 3], [1, 0, 2, 0, 4, 0]],
|
||||
[[6, 2, 7, 3, 6, 4], [7, 3, 7, 5, 6, 4], [0, 1, 5, 1, 3, 1]],
|
||||
[[2, 0, 0, 4, 1, 5], [2, 0, 1, 5, 3, 1], [2, 6, 3, 7, 7, 5], [7, 5, 6, 4, 2, 6]],
|
||||
[[3, 7, 7, 5, 3, 2], [3, 2, 7, 5, 2, 0], [2, 0, 7, 5, 6, 4]],
|
||||
[[3, 2, 3, 7, 1, 0], [3, 7, 6, 4, 1, 0], [3, 7, 7, 5, 6, 4], [1, 0, 6, 4, 0, 4]],
|
||||
[[3, 7, 7, 5, 3, 2], [3, 2, 7, 5, 2, 0], [2, 0, 7, 5, 6, 4], [1, 5, 3, 1, 1, 0]],
|
||||
[
|
||||
[7, 3, 5, 7, 4, 6],
|
||||
[4, 6, 2, 3, 7, 3],
|
||||
[4, 0, 2, 3, 4, 6],
|
||||
[4, 0, 1, 3, 2, 3],
|
||||
[4, 0, 5, 1, 1, 3],
|
||||
],
|
||||
[[2, 3, 3, 1, 2, 6], [2, 6, 3, 1, 6, 4], [6, 4, 3, 1, 7, 5]],
|
||||
[[2, 3, 3, 1, 2, 6], [2, 6, 3, 1, 6, 4], [6, 4, 3, 1, 7, 5], [0, 1, 2, 0, 0, 4]],
|
||||
[[1, 0, 1, 5, 3, 2], [1, 5, 4, 6, 3, 2], [3, 2, 4, 6, 2, 6], [1, 5, 5, 7, 4, 6]],
|
||||
[
|
||||
[0, 2, 4, 0, 5, 1],
|
||||
[5, 1, 3, 2, 0, 2],
|
||||
[5, 7, 3, 2, 5, 1],
|
||||
[5, 7, 6, 2, 3, 2],
|
||||
[5, 7, 4, 6, 6, 2],
|
||||
],
|
||||
[[2, 0, 3, 1, 7, 5], [2, 0, 7, 5, 6, 4]],
|
||||
[[4, 6, 0, 4, 0, 1], [0, 1, 1, 3, 4, 6], [4, 6, 1, 3, 5, 7]],
|
||||
[[0, 2, 1, 0, 1, 5], [1, 5, 5, 7, 0, 2], [0, 2, 5, 7, 4, 6]],
|
||||
[[5, 7, 4, 6, 4, 0], [5, 1, 5, 7, 4, 0]],
|
||||
[[5, 4, 4, 0, 5, 7], [5, 7, 4, 0, 7, 3], [7, 3, 4, 0, 6, 2]],
|
||||
[[0, 1, 0, 2, 4, 5], [0, 2, 3, 7, 4, 5], [4, 5, 3, 7, 5, 7], [0, 2, 2, 6, 3, 7]],
|
||||
[[5, 4, 4, 0, 5, 7], [5, 7, 4, 0, 7, 3], [7, 3, 4, 0, 6, 2], [1, 0, 5, 1, 1, 3]],
|
||||
[
|
||||
[1, 5, 3, 1, 2, 0],
|
||||
[2, 0, 4, 5, 1, 5],
|
||||
[2, 6, 4, 5, 2, 0],
|
||||
[2, 6, 7, 5, 4, 5],
|
||||
[2, 6, 3, 7, 7, 5],
|
||||
],
|
||||
[[2, 3, 0, 4, 2, 0], [2, 3, 4, 5, 0, 4], [2, 3, 3, 7, 4, 5], [3, 7, 7, 5, 4, 5]],
|
||||
[[3, 2, 7, 3, 7, 5], [7, 5, 5, 4, 3, 2], [3, 2, 5, 4, 1, 0]],
|
||||
[
|
||||
[2, 3, 0, 4, 2, 0],
|
||||
[2, 3, 4, 5, 0, 4],
|
||||
[2, 3, 3, 7, 4, 5],
|
||||
[3, 7, 7, 5, 4, 5],
|
||||
[1, 5, 3, 1, 0, 1],
|
||||
],
|
||||
[[3, 2, 1, 5, 3, 1], [3, 2, 5, 4, 1, 5], [3, 2, 7, 5, 5, 4], [3, 7, 7, 5, 3, 2]],
|
||||
[[2, 6, 2, 3, 0, 4], [2, 3, 7, 5, 0, 4], [2, 3, 3, 1, 7, 5], [0, 4, 7, 5, 4, 5]],
|
||||
[
|
||||
[3, 2, 1, 3, 5, 7],
|
||||
[5, 7, 6, 2, 3, 2],
|
||||
[5, 4, 6, 2, 5, 7],
|
||||
[5, 4, 0, 2, 6, 2],
|
||||
[5, 4, 1, 0, 0, 2],
|
||||
],
|
||||
[
|
||||
[4, 5, 0, 4, 2, 6],
|
||||
[2, 6, 7, 5, 4, 5],
|
||||
[2, 3, 7, 5, 2, 6],
|
||||
[2, 3, 1, 5, 7, 5],
|
||||
[2, 3, 0, 1, 1, 5],
|
||||
],
|
||||
[[2, 3, 2, 0, 2, 6], [1, 5, 7, 5, 4, 5]],
|
||||
[[5, 7, 4, 5, 4, 0], [4, 0, 0, 2, 5, 7], [5, 7, 0, 2, 1, 3]],
|
||||
[[5, 4, 1, 0, 1, 3], [5, 7, 5, 4, 1, 3]],
|
||||
[[0, 2, 4, 5, 0, 4], [0, 2, 5, 7, 4, 5], [0, 2, 1, 5, 5, 7], [0, 1, 1, 5, 0, 2]],
|
||||
[[5, 4, 5, 1, 5, 7]],
|
||||
[[4, 6, 6, 2, 4, 5], [4, 5, 6, 2, 5, 1], [5, 1, 6, 2, 7, 3]],
|
||||
[[4, 6, 6, 2, 4, 5], [4, 5, 6, 2, 5, 1], [5, 1, 6, 2, 7, 3], [0, 2, 4, 0, 0, 1]],
|
||||
[[3, 7, 3, 1, 2, 6], [3, 1, 5, 4, 2, 6], [3, 1, 1, 0, 5, 4], [2, 6, 5, 4, 6, 4]],
|
||||
[
|
||||
[6, 4, 2, 6, 3, 7],
|
||||
[3, 7, 5, 4, 6, 4],
|
||||
[3, 1, 5, 4, 3, 7],
|
||||
[3, 1, 0, 4, 5, 4],
|
||||
[3, 1, 2, 0, 0, 4],
|
||||
],
|
||||
[[2, 0, 2, 3, 6, 4], [2, 3, 1, 5, 6, 4], [6, 4, 1, 5, 4, 5], [2, 3, 3, 7, 1, 5]],
|
||||
[
|
||||
[0, 4, 1, 0, 3, 2],
|
||||
[3, 2, 6, 4, 0, 4],
|
||||
[3, 7, 6, 4, 3, 2],
|
||||
[3, 7, 5, 4, 6, 4],
|
||||
[3, 7, 1, 5, 5, 4],
|
||||
],
|
||||
[
|
||||
[1, 3, 0, 1, 4, 5],
|
||||
[4, 5, 7, 3, 1, 3],
|
||||
[4, 6, 7, 3, 4, 5],
|
||||
[4, 6, 2, 3, 7, 3],
|
||||
[4, 6, 0, 2, 2, 3],
|
||||
],
|
||||
[[3, 7, 3, 1, 3, 2], [5, 4, 6, 4, 0, 4]],
|
||||
[[3, 1, 2, 6, 3, 2], [3, 1, 6, 4, 2, 6], [3, 1, 1, 5, 6, 4], [1, 5, 5, 4, 6, 4]],
|
||||
[
|
||||
[3, 1, 2, 6, 3, 2],
|
||||
[3, 1, 6, 4, 2, 6],
|
||||
[3, 1, 1, 5, 6, 4],
|
||||
[1, 5, 5, 4, 6, 4],
|
||||
[0, 4, 1, 0, 2, 0],
|
||||
],
|
||||
[[4, 5, 6, 4, 6, 2], [6, 2, 2, 3, 4, 5], [4, 5, 2, 3, 0, 1]],
|
||||
[[2, 3, 6, 4, 2, 6], [2, 3, 4, 5, 6, 4], [2, 3, 0, 4, 4, 5], [2, 0, 0, 4, 2, 3]],
|
||||
[[1, 3, 5, 1, 5, 4], [5, 4, 4, 6, 1, 3], [1, 3, 4, 6, 0, 2]],
|
||||
[[1, 3, 0, 4, 1, 0], [1, 3, 4, 6, 0, 4], [1, 3, 5, 4, 4, 6], [1, 5, 5, 4, 1, 3]],
|
||||
[[4, 6, 0, 2, 0, 1], [4, 5, 4, 6, 0, 1]],
|
||||
[[4, 6, 4, 0, 4, 5]],
|
||||
[[4, 0, 6, 2, 7, 3], [4, 0, 7, 3, 5, 1]],
|
||||
[[1, 5, 0, 1, 0, 2], [0, 2, 2, 6, 1, 5], [1, 5, 2, 6, 3, 7]],
|
||||
[[3, 7, 1, 3, 1, 0], [1, 0, 0, 4, 3, 7], [3, 7, 0, 4, 2, 6]],
|
||||
[[3, 1, 2, 0, 2, 6], [3, 7, 3, 1, 2, 6]],
|
||||
[[0, 4, 2, 0, 2, 3], [2, 3, 3, 7, 0, 4], [0, 4, 3, 7, 1, 5]],
|
||||
[[3, 7, 1, 5, 1, 0], [3, 2, 3, 7, 1, 0]],
|
||||
[[0, 4, 1, 3, 0, 1], [0, 4, 3, 7, 1, 3], [0, 4, 2, 3, 3, 7], [0, 2, 2, 3, 0, 4]],
|
||||
[[3, 7, 3, 1, 3, 2]],
|
||||
[[2, 6, 3, 2, 3, 1], [3, 1, 1, 5, 2, 6], [2, 6, 1, 5, 0, 4]],
|
||||
[[1, 5, 3, 2, 1, 3], [1, 5, 2, 6, 3, 2], [1, 5, 0, 2, 2, 6], [1, 0, 0, 2, 1, 5]],
|
||||
[[2, 3, 0, 1, 0, 4], [2, 6, 2, 3, 0, 4]],
|
||||
[[2, 3, 2, 0, 2, 6]],
|
||||
[[1, 5, 0, 4, 0, 2], [1, 3, 1, 5, 0, 2]],
|
||||
[[1, 5, 1, 0, 1, 3]],
|
||||
[[0, 2, 0, 1, 0, 4]],
|
||||
[],
|
||||
]
|
||||
|
||||
|
||||
def create_mc_lookup_table():
|
||||
cases = torch.zeros(256, 5, 3, dtype=torch.long)
|
||||
masks = torch.zeros(256, 5, dtype=torch.bool)
|
||||
|
||||
edge_to_index = {
|
||||
(0, 1): 0,
|
||||
(2, 3): 1,
|
||||
(4, 5): 2,
|
||||
(6, 7): 3,
|
||||
(0, 2): 4,
|
||||
(1, 3): 5,
|
||||
(4, 6): 6,
|
||||
(5, 7): 7,
|
||||
(0, 4): 8,
|
||||
(1, 5): 9,
|
||||
(2, 6): 10,
|
||||
(3, 7): 11,
|
||||
}
|
||||
|
||||
for i, case in enumerate(MC_TABLE):
|
||||
for j, tri in enumerate(case):
|
||||
for k, (c1, c2) in enumerate(zip(tri[::2], tri[1::2])):
|
||||
cases[i, j, k] = edge_to_index[(c1, c2) if c1 < c2 else (c2, c1)]
|
||||
masks[i, j] = True
|
||||
return cases, masks
|
||||
|
||||
|
||||
RENDERER_CONFIG = {}
|
||||
|
||||
|
||||
@@ -400,7 +881,12 @@ def renderer_model_original_checkpoint_to_diffusers_checkpoint(model, checkpoint
|
||||
}
|
||||
)
|
||||
|
||||
diffusers_checkpoint.update({"void.background": torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32)})
|
||||
diffusers_checkpoint.update({"void.background": model.state_dict()["void.background"]})
|
||||
|
||||
cases, masks = create_mc_lookup_table()
|
||||
|
||||
diffusers_checkpoint.update({"mesh_decoder.cases": cases})
|
||||
diffusers_checkpoint.update({"mesh_decoder.masks": masks})
|
||||
|
||||
return diffusers_checkpoint
|
||||
|
||||
|
||||
@@ -36,6 +36,7 @@ except OptionalDependencyNotAvailable:
|
||||
from .utils.dummy_pt_objects import * # noqa F403
|
||||
else:
|
||||
from .models import (
|
||||
AsymmetricAutoencoderKL,
|
||||
AutoencoderKL,
|
||||
ControlNetModel,
|
||||
ModelMixin,
|
||||
@@ -199,6 +200,7 @@ except OptionalDependencyNotAvailable:
|
||||
from .utils.dummy_torch_and_transformers_and_invisible_watermark_objects import * # noqa F403
|
||||
else:
|
||||
from .pipelines import (
|
||||
StableDiffusionXLControlNetPipeline,
|
||||
StableDiffusionXLImg2ImgPipeline,
|
||||
StableDiffusionXLInpaintPipeline,
|
||||
StableDiffusionXLPipeline,
|
||||
|
||||
+358
-5
@@ -14,9 +14,12 @@
|
||||
import os
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from contextlib import nullcontext
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
|
||||
import requests
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from huggingface_hub import hf_hub_download
|
||||
@@ -42,10 +45,13 @@ from .utils import (
|
||||
HF_HUB_OFFLINE,
|
||||
_get_model_file,
|
||||
deprecate,
|
||||
is_accelerate_available,
|
||||
is_omegaconf_available,
|
||||
is_safetensors_available,
|
||||
is_transformers_available,
|
||||
logging,
|
||||
)
|
||||
from .utils.import_utils import BACKENDS_MAPPING
|
||||
|
||||
|
||||
if is_safetensors_available():
|
||||
@@ -54,6 +60,9 @@ if is_safetensors_available():
|
||||
if is_transformers_available():
|
||||
from transformers import CLIPTextModel, PreTrainedModel, PreTrainedTokenizer
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate import init_empty_weights
|
||||
from accelerate.utils import set_module_tensor_to_device
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
@@ -1319,8 +1328,8 @@ class FromSingleFileMixin:
|
||||
@classmethod
|
||||
def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
|
||||
r"""
|
||||
Instantiate a [`DiffusionPipeline`] from pretrained pipeline weights saved in the `.ckpt` format. The pipeline
|
||||
is set in evaluation mode (`model.eval()`) by default.
|
||||
Instantiate a [`DiffusionPipeline`] from pretrained pipeline weights saved in the `.ckpt` or `.safetensors`
|
||||
format. The pipeline is set in evaluation mode (`model.eval()`) by default.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*):
|
||||
@@ -1430,6 +1439,7 @@ class FromSingleFileMixin:
|
||||
load_safety_checker = kwargs.pop("load_safety_checker", True)
|
||||
prediction_type = kwargs.pop("prediction_type", None)
|
||||
text_encoder = kwargs.pop("text_encoder", None)
|
||||
controlnet = kwargs.pop("controlnet", None)
|
||||
tokenizer = kwargs.pop("tokenizer", None)
|
||||
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
@@ -1446,11 +1456,18 @@ class FromSingleFileMixin:
|
||||
# TODO: For now we only support stable diffusion
|
||||
stable_unclip = None
|
||||
model_type = None
|
||||
controlnet = False
|
||||
|
||||
if pipeline_name == "StableDiffusionControlNetPipeline":
|
||||
if pipeline_name in [
|
||||
"StableDiffusionControlNetPipeline",
|
||||
"StableDiffusionControlNetImg2ImgPipeline",
|
||||
"StableDiffusionControlNetInpaintPipeline",
|
||||
]:
|
||||
from .models.controlnet import ControlNetModel
|
||||
from .pipelines.controlnet.multicontrolnet import MultiControlNetModel
|
||||
|
||||
# Model type will be inferred from the checkpoint.
|
||||
controlnet = True
|
||||
if not isinstance(controlnet, (ControlNetModel, MultiControlNetModel)):
|
||||
raise ValueError("ControlNet needs to be passed if loading from ControlNet pipeline.")
|
||||
elif "StableDiffusion" in pipeline_name:
|
||||
# Model type will be inferred from the checkpoint.
|
||||
pass
|
||||
@@ -1519,3 +1536,339 @@ class FromSingleFileMixin:
|
||||
pipe.to(torch_dtype=torch_dtype)
|
||||
|
||||
return pipe
|
||||
|
||||
|
||||
class FromOriginalVAEMixin:
|
||||
@classmethod
|
||||
def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
|
||||
r"""
|
||||
Instantiate a [`AutoencoderKL`] from pretrained controlnet weights saved in the original `.ckpt` or
|
||||
`.safetensors` format. The pipeline is format. The pipeline is set in evaluation mode (`model.eval()`) by
|
||||
default.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*):
|
||||
Can be either:
|
||||
- A link to the `.ckpt` file (for example
|
||||
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.ckpt"`) on the Hub.
|
||||
- A path to a *file* containing all pipeline weights.
|
||||
torch_dtype (`str` or `torch.dtype`, *optional*):
|
||||
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
|
||||
dtype is automatically derived from the model's weights.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
resume_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
|
||||
incompletely downloaded files are deleted.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether to only load local model weights and configuration files or not. If set to True, the model
|
||||
won't be downloaded from the Hub.
|
||||
use_auth_token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
||||
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
||||
allowed by Git.
|
||||
image_size (`int`, *optional*, defaults to 512):
|
||||
The image size the model was trained on. Use 512 for all Stable Diffusion v1 models and the Stable
|
||||
Diffusion v2 base model. Use 768 for Stable Diffusion v2.
|
||||
use_safetensors (`bool`, *optional*, defaults to `None`):
|
||||
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
|
||||
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
|
||||
weights. If set to `False`, safetensors weights are not loaded.
|
||||
upcast_attention (`bool`, *optional*, defaults to `None`):
|
||||
Whether the attention computation should always be upcasted.
|
||||
scaling_factor (`float`, *optional*, defaults to 0.18215):
|
||||
The component-wise standard deviation of the trained latent space computed using the first batch of the
|
||||
training set. This is used to scale the latent space to have unit variance when training the diffusion
|
||||
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
|
||||
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z
|
||||
= 1 / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution
|
||||
Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
|
||||
kwargs (remaining dictionary of keyword arguments, *optional*):
|
||||
Can be used to overwrite load and saveable variables (for example the pipeline components of the
|
||||
specific pipeline class). The overwritten components are directly passed to the pipelines `__init__`
|
||||
method. See example below for more information.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Make sure to pass both `image_size` and `scaling_factor` to `from_single_file()` if you want to load
|
||||
a VAE that does accompany a stable diffusion model of v2 or higher or SDXL.
|
||||
|
||||
</Tip>
|
||||
|
||||
Examples:
|
||||
|
||||
```py
|
||||
from diffusers import AutoencoderKL
|
||||
|
||||
url = "https://huggingface.co/stabilityai/sd-vae-ft-mse-original/blob/main/vae-ft-mse-840000-ema-pruned.safetensors" # can also be local file
|
||||
model = AutoencoderKL.from_single_file(url)
|
||||
```
|
||||
"""
|
||||
if not is_omegaconf_available():
|
||||
raise ValueError(BACKENDS_MAPPING["omegaconf"][1])
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from .models import AutoencoderKL
|
||||
|
||||
# import here to avoid circular dependency
|
||||
from .pipelines.stable_diffusion.convert_from_ckpt import (
|
||||
convert_ldm_vae_checkpoint,
|
||||
create_vae_diffusers_config,
|
||||
)
|
||||
|
||||
config_file = kwargs.pop("config_file", None)
|
||||
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
|
||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
image_size = kwargs.pop("image_size", None)
|
||||
scaling_factor = kwargs.pop("scaling_factor", None)
|
||||
kwargs.pop("upcast_attention", None)
|
||||
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
|
||||
use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False)
|
||||
|
||||
file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1]
|
||||
from_safetensors = file_extension == "safetensors"
|
||||
|
||||
if from_safetensors and use_safetensors is False:
|
||||
raise ValueError("Make sure to install `safetensors` with `pip install safetensors`.")
|
||||
|
||||
# remove huggingface url
|
||||
for prefix in ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]:
|
||||
if pretrained_model_link_or_path.startswith(prefix):
|
||||
pretrained_model_link_or_path = pretrained_model_link_or_path[len(prefix) :]
|
||||
|
||||
# Code based on diffusers.pipelines.pipeline_utils.DiffusionPipeline.from_pretrained
|
||||
ckpt_path = Path(pretrained_model_link_or_path)
|
||||
if not ckpt_path.is_file():
|
||||
# get repo_id and (potentially nested) file path of ckpt in repo
|
||||
repo_id = "/".join(ckpt_path.parts[:2])
|
||||
file_path = "/".join(ckpt_path.parts[2:])
|
||||
|
||||
if file_path.startswith("blob/"):
|
||||
file_path = file_path[len("blob/") :]
|
||||
|
||||
if file_path.startswith("main/"):
|
||||
file_path = file_path[len("main/") :]
|
||||
|
||||
pretrained_model_link_or_path = hf_hub_download(
|
||||
repo_id,
|
||||
filename=file_path,
|
||||
cache_dir=cache_dir,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
force_download=force_download,
|
||||
)
|
||||
|
||||
if from_safetensors:
|
||||
from safetensors import safe_open
|
||||
|
||||
checkpoint = {}
|
||||
with safe_open(pretrained_model_link_or_path, framework="pt", device="cpu") as f:
|
||||
for key in f.keys():
|
||||
checkpoint[key] = f.get_tensor(key)
|
||||
else:
|
||||
checkpoint = torch.load(pretrained_model_link_or_path, map_location="cpu")
|
||||
|
||||
if "state_dict" in checkpoint:
|
||||
checkpoint = checkpoint["state_dict"]
|
||||
|
||||
if config_file is None:
|
||||
config_url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
|
||||
config_file = BytesIO(requests.get(config_url).content)
|
||||
|
||||
original_config = OmegaConf.load(config_file)
|
||||
|
||||
# default to sd-v1-5
|
||||
image_size = image_size or 512
|
||||
|
||||
vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
|
||||
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
|
||||
|
||||
if scaling_factor is None:
|
||||
if (
|
||||
"model" in original_config
|
||||
and "params" in original_config.model
|
||||
and "scale_factor" in original_config.model.params
|
||||
):
|
||||
vae_scaling_factor = original_config.model.params.scale_factor
|
||||
else:
|
||||
vae_scaling_factor = 0.18215 # default SD scaling factor
|
||||
|
||||
vae_config["scaling_factor"] = vae_scaling_factor
|
||||
|
||||
ctx = init_empty_weights if is_accelerate_available() else nullcontext
|
||||
with ctx():
|
||||
vae = AutoencoderKL(**vae_config)
|
||||
|
||||
if is_accelerate_available():
|
||||
for param_name, param in converted_vae_checkpoint.items():
|
||||
set_module_tensor_to_device(vae, param_name, "cpu", value=param)
|
||||
else:
|
||||
vae.load_state_dict(converted_vae_checkpoint)
|
||||
|
||||
if torch_dtype is not None:
|
||||
vae.to(torch_dtype=torch_dtype)
|
||||
|
||||
return vae
|
||||
|
||||
|
||||
class FromOriginalControlnetMixin:
|
||||
@classmethod
|
||||
def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
|
||||
r"""
|
||||
Instantiate a [`ControlNetModel`] from pretrained controlnet weights saved in the original `.ckpt` or
|
||||
`.safetensors` format. The pipeline is set in evaluation mode (`model.eval()`) by default.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*):
|
||||
Can be either:
|
||||
- A link to the `.ckpt` file (for example
|
||||
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.ckpt"`) on the Hub.
|
||||
- A path to a *file* containing all pipeline weights.
|
||||
torch_dtype (`str` or `torch.dtype`, *optional*):
|
||||
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
|
||||
dtype is automatically derived from the model's weights.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
resume_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
|
||||
incompletely downloaded files are deleted.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether to only load local model weights and configuration files or not. If set to True, the model
|
||||
won't be downloaded from the Hub.
|
||||
use_auth_token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
||||
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
||||
allowed by Git.
|
||||
use_safetensors (`bool`, *optional*, defaults to `None`):
|
||||
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
|
||||
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
|
||||
weights. If set to `False`, safetensors weights are not loaded.
|
||||
image_size (`int`, *optional*, defaults to 512):
|
||||
The image size the model was trained on. Use 512 for all Stable Diffusion v1 models and the Stable
|
||||
Diffusion v2 base model. Use 768 for Stable Diffusion v2.
|
||||
upcast_attention (`bool`, *optional*, defaults to `None`):
|
||||
Whether the attention computation should always be upcasted.
|
||||
kwargs (remaining dictionary of keyword arguments, *optional*):
|
||||
Can be used to overwrite load and saveable variables (for example the pipeline components of the
|
||||
specific pipeline class). The overwritten components are directly passed to the pipelines `__init__`
|
||||
method. See example below for more information.
|
||||
|
||||
Examples:
|
||||
|
||||
```py
|
||||
from diffusers import StableDiffusionControlnetPipeline, ControlNetModel
|
||||
|
||||
url = "https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth" # can also be a local path
|
||||
model = ControlNetModel.from_single_file(url)
|
||||
|
||||
url = "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned.safetensors" # can also be a local path
|
||||
pipe = StableDiffusionControlnetPipeline.from_single_file(url, controlnet=controlnet)
|
||||
```
|
||||
"""
|
||||
# import here to avoid circular dependency
|
||||
from .pipelines.stable_diffusion.convert_from_ckpt import download_controlnet_from_original_ckpt
|
||||
|
||||
config_file = kwargs.pop("config_file", None)
|
||||
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
|
||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||
num_in_channels = kwargs.pop("num_in_channels", None)
|
||||
use_linear_projection = kwargs.pop("use_linear_projection", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
extract_ema = kwargs.pop("extract_ema", False)
|
||||
image_size = kwargs.pop("image_size", None)
|
||||
upcast_attention = kwargs.pop("upcast_attention", None)
|
||||
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
|
||||
use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False)
|
||||
|
||||
file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1]
|
||||
from_safetensors = file_extension == "safetensors"
|
||||
|
||||
if from_safetensors and use_safetensors is False:
|
||||
raise ValueError("Make sure to install `safetensors` with `pip install safetensors`.")
|
||||
|
||||
# remove huggingface url
|
||||
for prefix in ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]:
|
||||
if pretrained_model_link_or_path.startswith(prefix):
|
||||
pretrained_model_link_or_path = pretrained_model_link_or_path[len(prefix) :]
|
||||
|
||||
# Code based on diffusers.pipelines.pipeline_utils.DiffusionPipeline.from_pretrained
|
||||
ckpt_path = Path(pretrained_model_link_or_path)
|
||||
if not ckpt_path.is_file():
|
||||
# get repo_id and (potentially nested) file path of ckpt in repo
|
||||
repo_id = "/".join(ckpt_path.parts[:2])
|
||||
file_path = "/".join(ckpt_path.parts[2:])
|
||||
|
||||
if file_path.startswith("blob/"):
|
||||
file_path = file_path[len("blob/") :]
|
||||
|
||||
if file_path.startswith("main/"):
|
||||
file_path = file_path[len("main/") :]
|
||||
|
||||
pretrained_model_link_or_path = hf_hub_download(
|
||||
repo_id,
|
||||
filename=file_path,
|
||||
cache_dir=cache_dir,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
force_download=force_download,
|
||||
)
|
||||
|
||||
if config_file is None:
|
||||
config_url = "https://raw.githubusercontent.com/lllyasviel/ControlNet/main/models/cldm_v15.yaml"
|
||||
config_file = BytesIO(requests.get(config_url).content)
|
||||
|
||||
image_size = image_size or 512
|
||||
|
||||
controlnet = download_controlnet_from_original_ckpt(
|
||||
pretrained_model_link_or_path,
|
||||
original_config_file=config_file,
|
||||
image_size=image_size,
|
||||
extract_ema=extract_ema,
|
||||
num_in_channels=num_in_channels,
|
||||
upcast_attention=upcast_attention,
|
||||
from_safetensors=from_safetensors,
|
||||
use_linear_projection=use_linear_projection,
|
||||
)
|
||||
|
||||
if torch_dtype is not None:
|
||||
controlnet.to(torch_dtype=torch_dtype)
|
||||
|
||||
return controlnet
|
||||
|
||||
@@ -17,6 +17,7 @@ from ..utils import is_flax_available, is_torch_available
|
||||
|
||||
if is_torch_available():
|
||||
from .adapter import MultiAdapter, T2IAdapter
|
||||
from .autoencoder_asym_kl import AsymmetricAutoencoderKL
|
||||
from .autoencoder_kl import AutoencoderKL
|
||||
from .controlnet import ControlNetModel
|
||||
from .dual_transformer_2d import DualTransformer2DModel
|
||||
|
||||
@@ -1096,7 +1096,6 @@ class AttnProcessor2_0:
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
inner_dim = hidden_states.shape[-1]
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
@@ -1117,6 +1116,7 @@ class AttnProcessor2_0:
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
@@ -0,0 +1,180 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import apply_forward_hook
|
||||
from .autoencoder_kl import AutoencoderKLOutput
|
||||
from .modeling_utils import ModelMixin
|
||||
from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder, MaskConditionDecoder
|
||||
|
||||
|
||||
class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
|
||||
r"""
|
||||
Designing a Better Asymmetric VQGAN for StableDiffusion https://arxiv.org/abs/2306.04632 . A VAE model with KL loss
|
||||
for encoding images into latents and decoding latent representations into images.
|
||||
|
||||
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
||||
for all models (such as downloading or saving).
|
||||
|
||||
Parameters:
|
||||
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
|
||||
out_channels (int, *optional*, defaults to 3): Number of channels in the output.
|
||||
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
|
||||
Tuple of downsample block types.
|
||||
down_block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
|
||||
Tuple of down block output channels.
|
||||
layers_per_down_block (`int`, *optional*, defaults to `1`):
|
||||
Number layers for down block.
|
||||
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
|
||||
Tuple of upsample block types.
|
||||
up_block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
|
||||
Tuple of up block output channels.
|
||||
layers_per_up_block (`int`, *optional*, defaults to `1`):
|
||||
Number layers for up block.
|
||||
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
||||
latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
|
||||
sample_size (`int`, *optional*, defaults to `32`): Sample input size.
|
||||
norm_num_groups (`int`, *optional*, defaults to `32`):
|
||||
Number of groups to use for the first normalization layer in ResNet blocks.
|
||||
scaling_factor (`float`, *optional*, defaults to 0.18215):
|
||||
The component-wise standard deviation of the trained latent space computed using the first batch of the
|
||||
training set. This is used to scale the latent space to have unit variance when training the diffusion
|
||||
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
|
||||
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
|
||||
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
|
||||
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
|
||||
down_block_out_channels: Tuple[int] = (64,),
|
||||
layers_per_down_block: int = 1,
|
||||
up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
|
||||
up_block_out_channels: Tuple[int] = (64,),
|
||||
layers_per_up_block: int = 1,
|
||||
act_fn: str = "silu",
|
||||
latent_channels: int = 4,
|
||||
norm_num_groups: int = 32,
|
||||
sample_size: int = 32,
|
||||
scaling_factor: float = 0.18215,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
# pass init params to Encoder
|
||||
self.encoder = Encoder(
|
||||
in_channels=in_channels,
|
||||
out_channels=latent_channels,
|
||||
down_block_types=down_block_types,
|
||||
block_out_channels=down_block_out_channels,
|
||||
layers_per_block=layers_per_down_block,
|
||||
act_fn=act_fn,
|
||||
norm_num_groups=norm_num_groups,
|
||||
double_z=True,
|
||||
)
|
||||
|
||||
# pass init params to Decoder
|
||||
self.decoder = MaskConditionDecoder(
|
||||
in_channels=latent_channels,
|
||||
out_channels=out_channels,
|
||||
up_block_types=up_block_types,
|
||||
block_out_channels=up_block_out_channels,
|
||||
layers_per_block=layers_per_up_block,
|
||||
act_fn=act_fn,
|
||||
norm_num_groups=norm_num_groups,
|
||||
)
|
||||
|
||||
self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
|
||||
self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1)
|
||||
|
||||
self.use_slicing = False
|
||||
self.use_tiling = False
|
||||
|
||||
@apply_forward_hook
|
||||
def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
|
||||
h = self.encoder(x)
|
||||
moments = self.quant_conv(h)
|
||||
posterior = DiagonalGaussianDistribution(moments)
|
||||
|
||||
if not return_dict:
|
||||
return (posterior,)
|
||||
|
||||
return AutoencoderKLOutput(latent_dist=posterior)
|
||||
|
||||
def _decode(
|
||||
self,
|
||||
z: torch.FloatTensor,
|
||||
image: Optional[torch.FloatTensor] = None,
|
||||
mask: Optional[torch.FloatTensor] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[DecoderOutput, torch.FloatTensor]:
|
||||
z = self.post_quant_conv(z)
|
||||
dec = self.decoder(z, image, mask)
|
||||
|
||||
if not return_dict:
|
||||
return (dec,)
|
||||
|
||||
return DecoderOutput(sample=dec)
|
||||
|
||||
@apply_forward_hook
|
||||
def decode(
|
||||
self,
|
||||
z: torch.FloatTensor,
|
||||
image: Optional[torch.FloatTensor] = None,
|
||||
mask: Optional[torch.FloatTensor] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[DecoderOutput, torch.FloatTensor]:
|
||||
decoded = self._decode(z, image, mask).sample
|
||||
|
||||
if not return_dict:
|
||||
return (decoded,)
|
||||
|
||||
return DecoderOutput(sample=decoded)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
mask: Optional[torch.FloatTensor] = None,
|
||||
sample_posterior: bool = False,
|
||||
return_dict: bool = True,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
) -> Union[DecoderOutput, torch.FloatTensor]:
|
||||
r"""
|
||||
Args:
|
||||
sample (`torch.FloatTensor`): Input sample.
|
||||
mask (`torch.FloatTensor`, *optional*, defaults to `None`): Optional inpainting mask.
|
||||
sample_posterior (`bool`, *optional*, defaults to `False`):
|
||||
Whether to sample from the posterior.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
|
||||
"""
|
||||
x = sample
|
||||
posterior = self.encode(x).latent_dist
|
||||
if sample_posterior:
|
||||
z = posterior.sample(generator=generator)
|
||||
else:
|
||||
z = posterior.mode()
|
||||
dec = self.decode(z, sample, mask).sample
|
||||
|
||||
if not return_dict:
|
||||
return (dec,)
|
||||
|
||||
return DecoderOutput(sample=dec)
|
||||
@@ -18,6 +18,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..loaders import FromOriginalVAEMixin
|
||||
from ..utils import BaseOutput, apply_forward_hook
|
||||
from .attention_processor import AttentionProcessor, AttnProcessor
|
||||
from .modeling_utils import ModelMixin
|
||||
@@ -38,7 +39,7 @@ class AutoencoderKLOutput(BaseOutput):
|
||||
latent_dist: "DiagonalGaussianDistribution"
|
||||
|
||||
|
||||
class AutoencoderKL(ModelMixin, ConfigMixin):
|
||||
class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
|
||||
r"""
|
||||
A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
|
||||
|
||||
|
||||
@@ -19,9 +19,10 @@ from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..loaders import FromOriginalControlnetMixin
|
||||
from ..utils import BaseOutput, logging
|
||||
from .attention_processor import AttentionProcessor, AttnProcessor
|
||||
from .embeddings import TimestepEmbedding, Timesteps
|
||||
from .embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps
|
||||
from .modeling_utils import ModelMixin
|
||||
from .unet_2d_blocks import (
|
||||
CrossAttnDownBlock2D,
|
||||
@@ -100,7 +101,7 @@ class ControlNetConditioningEmbedding(nn.Module):
|
||||
return embedding
|
||||
|
||||
|
||||
class ControlNetModel(ModelMixin, ConfigMixin):
|
||||
class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
|
||||
"""
|
||||
A ControlNet model.
|
||||
|
||||
@@ -131,12 +132,25 @@ class ControlNetModel(ModelMixin, ConfigMixin):
|
||||
The epsilon to use for the normalization.
|
||||
cross_attention_dim (`int`, defaults to 1280):
|
||||
The dimension of the cross attention features.
|
||||
transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
|
||||
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
|
||||
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
|
||||
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
|
||||
encoder_hid_dim (`int`, *optional*, defaults to None):
|
||||
If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
|
||||
dimension to `cross_attention_dim`.
|
||||
encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
|
||||
If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
|
||||
embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
|
||||
attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
|
||||
The dimension of the attention heads.
|
||||
use_linear_projection (`bool`, defaults to `False`):
|
||||
class_embed_type (`str`, *optional*, defaults to `None`):
|
||||
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
|
||||
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
|
||||
addition_embed_type (`str`, *optional*, defaults to `None`):
|
||||
Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
|
||||
"text". "text" will use the `TextTimeEmbedding` layer.
|
||||
num_class_embeds (`int`, *optional*, defaults to 0):
|
||||
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
|
||||
class conditioning with `class_embed_type` equal to `None`.
|
||||
@@ -177,10 +191,15 @@ class ControlNetModel(ModelMixin, ConfigMixin):
|
||||
norm_num_groups: Optional[int] = 32,
|
||||
norm_eps: float = 1e-5,
|
||||
cross_attention_dim: int = 1280,
|
||||
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
|
||||
encoder_hid_dim: Optional[int] = None,
|
||||
encoder_hid_dim_type: Optional[str] = None,
|
||||
attention_head_dim: Union[int, Tuple[int]] = 8,
|
||||
num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
|
||||
use_linear_projection: bool = False,
|
||||
class_embed_type: Optional[str] = None,
|
||||
addition_embed_type: Optional[str] = None,
|
||||
addition_time_embed_dim: Optional[int] = None,
|
||||
num_class_embeds: Optional[int] = None,
|
||||
upcast_attention: bool = False,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
@@ -188,6 +207,7 @@ class ControlNetModel(ModelMixin, ConfigMixin):
|
||||
controlnet_conditioning_channel_order: str = "rgb",
|
||||
conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
|
||||
global_pool_conditions: bool = False,
|
||||
addition_embed_type_num_heads=64,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -215,6 +235,9 @@ class ControlNetModel(ModelMixin, ConfigMixin):
|
||||
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
|
||||
)
|
||||
|
||||
if isinstance(transformer_layers_per_block, int):
|
||||
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
|
||||
|
||||
# input
|
||||
conv_in_kernel = 3
|
||||
conv_in_padding = (conv_in_kernel - 1) // 2
|
||||
@@ -224,16 +247,43 @@ class ControlNetModel(ModelMixin, ConfigMixin):
|
||||
|
||||
# time
|
||||
time_embed_dim = block_out_channels[0] * 4
|
||||
|
||||
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
||||
timestep_input_dim = block_out_channels[0]
|
||||
|
||||
self.time_embedding = TimestepEmbedding(
|
||||
timestep_input_dim,
|
||||
time_embed_dim,
|
||||
act_fn=act_fn,
|
||||
)
|
||||
|
||||
if encoder_hid_dim_type is None and encoder_hid_dim is not None:
|
||||
encoder_hid_dim_type = "text_proj"
|
||||
self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
|
||||
logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
|
||||
|
||||
if encoder_hid_dim is None and encoder_hid_dim_type is not None:
|
||||
raise ValueError(
|
||||
f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
|
||||
)
|
||||
|
||||
if encoder_hid_dim_type == "text_proj":
|
||||
self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
|
||||
elif encoder_hid_dim_type == "text_image_proj":
|
||||
# image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
||||
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
||||
# case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
|
||||
self.encoder_hid_proj = TextImageProjection(
|
||||
text_embed_dim=encoder_hid_dim,
|
||||
image_embed_dim=cross_attention_dim,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
)
|
||||
|
||||
elif encoder_hid_dim_type is not None:
|
||||
raise ValueError(
|
||||
f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
|
||||
)
|
||||
else:
|
||||
self.encoder_hid_proj = None
|
||||
|
||||
# class embedding
|
||||
if class_embed_type is None and num_class_embeds is not None:
|
||||
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
||||
@@ -257,6 +307,29 @@ class ControlNetModel(ModelMixin, ConfigMixin):
|
||||
else:
|
||||
self.class_embedding = None
|
||||
|
||||
if addition_embed_type == "text":
|
||||
if encoder_hid_dim is not None:
|
||||
text_time_embedding_from_dim = encoder_hid_dim
|
||||
else:
|
||||
text_time_embedding_from_dim = cross_attention_dim
|
||||
|
||||
self.add_embedding = TextTimeEmbedding(
|
||||
text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
|
||||
)
|
||||
elif addition_embed_type == "text_image":
|
||||
# text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
||||
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
||||
# case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
|
||||
self.add_embedding = TextImageTimeEmbedding(
|
||||
text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
|
||||
)
|
||||
elif addition_embed_type == "text_time":
|
||||
self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
|
||||
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
||||
|
||||
elif addition_embed_type is not None:
|
||||
raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
|
||||
|
||||
# control net conditioning embedding
|
||||
self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
|
||||
conditioning_embedding_channels=block_out_channels[0],
|
||||
@@ -291,6 +364,7 @@ class ControlNetModel(ModelMixin, ConfigMixin):
|
||||
down_block = get_down_block(
|
||||
down_block_type,
|
||||
num_layers=layers_per_block,
|
||||
transformer_layers_per_block=transformer_layers_per_block[i],
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
temb_channels=time_embed_dim,
|
||||
@@ -327,6 +401,7 @@ class ControlNetModel(ModelMixin, ConfigMixin):
|
||||
self.controlnet_mid_block = controlnet_block
|
||||
|
||||
self.mid_block = UNetMidBlock2DCrossAttn(
|
||||
transformer_layers_per_block=transformer_layers_per_block[-1],
|
||||
in_channels=mid_block_channel,
|
||||
temb_channels=time_embed_dim,
|
||||
resnet_eps=norm_eps,
|
||||
@@ -356,7 +431,22 @@ class ControlNetModel(ModelMixin, ConfigMixin):
|
||||
The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied
|
||||
where applicable.
|
||||
"""
|
||||
transformer_layers_per_block = (
|
||||
unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1
|
||||
)
|
||||
encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
|
||||
encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None
|
||||
addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None
|
||||
addition_time_embed_dim = (
|
||||
unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None
|
||||
)
|
||||
|
||||
controlnet = cls(
|
||||
encoder_hid_dim=encoder_hid_dim,
|
||||
encoder_hid_dim_type=encoder_hid_dim_type,
|
||||
addition_embed_type=addition_embed_type,
|
||||
addition_time_embed_dim=addition_time_embed_dim,
|
||||
transformer_layers_per_block=transformer_layers_per_block,
|
||||
in_channels=unet.config.in_channels,
|
||||
flip_sin_to_cos=unet.config.flip_sin_to_cos,
|
||||
freq_shift=unet.config.freq_shift,
|
||||
@@ -542,6 +632,7 @@ class ControlNetModel(ModelMixin, ConfigMixin):
|
||||
class_labels: Optional[torch.Tensor] = None,
|
||||
timestep_cond: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
guess_mode: bool = False,
|
||||
return_dict: bool = True,
|
||||
@@ -564,7 +655,9 @@ class ControlNetModel(ModelMixin, ConfigMixin):
|
||||
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
|
||||
timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
|
||||
attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
|
||||
cross_attention_kwargs(`dict[str]`, *optional*, defaults to `None`):
|
||||
added_cond_kwargs (`dict`):
|
||||
Additional conditions for the Stable Diffusion XL UNet.
|
||||
cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
|
||||
A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
|
||||
guess_mode (`bool`, defaults to `False`):
|
||||
In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
|
||||
@@ -618,6 +711,7 @@ class ControlNetModel(ModelMixin, ConfigMixin):
|
||||
t_emb = t_emb.to(dtype=sample.dtype)
|
||||
|
||||
emb = self.time_embedding(t_emb, timestep_cond)
|
||||
aug_emb = None
|
||||
|
||||
if self.class_embedding is not None:
|
||||
if class_labels is None:
|
||||
@@ -629,6 +723,30 @@ class ControlNetModel(ModelMixin, ConfigMixin):
|
||||
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
||||
emb = emb + class_emb
|
||||
|
||||
if "addition_embed_type" in self.config:
|
||||
if self.config.addition_embed_type == "text":
|
||||
aug_emb = self.add_embedding(encoder_hidden_states)
|
||||
|
||||
elif self.config.addition_embed_type == "text_time":
|
||||
if "text_embeds" not in added_cond_kwargs:
|
||||
raise ValueError(
|
||||
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
|
||||
)
|
||||
text_embeds = added_cond_kwargs.get("text_embeds")
|
||||
if "time_ids" not in added_cond_kwargs:
|
||||
raise ValueError(
|
||||
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
|
||||
)
|
||||
time_ids = added_cond_kwargs.get("time_ids")
|
||||
time_embeds = self.add_time_proj(time_ids.flatten())
|
||||
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
|
||||
|
||||
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
|
||||
add_embeds = add_embeds.to(emb.dtype)
|
||||
aug_emb = self.add_embedding(add_embeds)
|
||||
|
||||
emb = emb + aug_emb if aug_emb is not None else emb
|
||||
|
||||
# 2. pre-process
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
|
||||
@@ -235,12 +235,12 @@ class OutConv1DBlock(nn.Module):
|
||||
|
||||
|
||||
class OutValueFunctionBlock(nn.Module):
|
||||
def __init__(self, fc_dim, embed_dim):
|
||||
def __init__(self, fc_dim, embed_dim, act_fn="mish"):
|
||||
super().__init__()
|
||||
self.final_block = nn.ModuleList(
|
||||
[
|
||||
nn.Linear(fc_dim + embed_dim, fc_dim // 2),
|
||||
nn.Mish(),
|
||||
get_activation(act_fn),
|
||||
nn.Linear(fc_dim // 2, 1),
|
||||
]
|
||||
)
|
||||
@@ -652,5 +652,5 @@ def get_out_block(*, out_block_type, num_groups_out, embed_dim, out_channels, ac
|
||||
if out_block_type == "OutConv1DBlock":
|
||||
return OutConv1DBlock(num_groups_out, out_channels, embed_dim, act_fn)
|
||||
elif out_block_type == "ValueFunction":
|
||||
return OutValueFunctionBlock(fc_dim, embed_dim)
|
||||
return OutValueFunctionBlock(fc_dim, embed_dim, act_fn)
|
||||
return None
|
||||
|
||||
@@ -280,6 +280,253 @@ class Decoder(nn.Module):
|
||||
return sample
|
||||
|
||||
|
||||
class UpSample(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.deconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1)
|
||||
|
||||
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
|
||||
x = torch.relu(x)
|
||||
x = self.deconv(x)
|
||||
return x
|
||||
|
||||
|
||||
class MaskConditionEncoder(nn.Module):
|
||||
"""
|
||||
used in AsymmetricAutoencoderKL
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_ch: int,
|
||||
out_ch: int = 192,
|
||||
res_ch: int = 768,
|
||||
stride: int = 16,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
channels = []
|
||||
while stride > 1:
|
||||
stride = stride // 2
|
||||
in_ch_ = out_ch * 2
|
||||
if out_ch > res_ch:
|
||||
out_ch = res_ch
|
||||
if stride == 1:
|
||||
in_ch_ = res_ch
|
||||
channels.append((in_ch_, out_ch))
|
||||
out_ch *= 2
|
||||
|
||||
out_channels = []
|
||||
for _in_ch, _out_ch in channels:
|
||||
out_channels.append(_out_ch)
|
||||
out_channels.append(channels[-1][0])
|
||||
|
||||
layers = []
|
||||
in_ch_ = in_ch
|
||||
for l in range(len(out_channels)):
|
||||
out_ch_ = out_channels[l]
|
||||
if l == 0 or l == 1:
|
||||
layers.append(nn.Conv2d(in_ch_, out_ch_, kernel_size=3, stride=1, padding=1))
|
||||
else:
|
||||
layers.append(nn.Conv2d(in_ch_, out_ch_, kernel_size=4, stride=2, padding=1))
|
||||
in_ch_ = out_ch_
|
||||
|
||||
self.layers = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x: torch.FloatTensor, mask=None) -> torch.FloatTensor:
|
||||
out = {}
|
||||
for l in range(len(self.layers)):
|
||||
layer = self.layers[l]
|
||||
x = layer(x)
|
||||
out[str(tuple(x.shape))] = x
|
||||
x = torch.relu(x)
|
||||
return out
|
||||
|
||||
|
||||
class MaskConditionDecoder(nn.Module):
|
||||
"""The `MaskConditionDecoder` should be used in combination with [`AsymmetricAutoencoderKL`] to enhance the model's
|
||||
decoder with a conditioner on the mask and masked image."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
up_block_types=("UpDecoderBlock2D",),
|
||||
block_out_channels=(64,),
|
||||
layers_per_block=2,
|
||||
norm_num_groups=32,
|
||||
act_fn="silu",
|
||||
norm_type="group", # group, spatial
|
||||
):
|
||||
super().__init__()
|
||||
self.layers_per_block = layers_per_block
|
||||
|
||||
self.conv_in = nn.Conv2d(
|
||||
in_channels,
|
||||
block_out_channels[-1],
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
)
|
||||
|
||||
self.mid_block = None
|
||||
self.up_blocks = nn.ModuleList([])
|
||||
|
||||
temb_channels = in_channels if norm_type == "spatial" else None
|
||||
|
||||
# mid
|
||||
self.mid_block = UNetMidBlock2D(
|
||||
in_channels=block_out_channels[-1],
|
||||
resnet_eps=1e-6,
|
||||
resnet_act_fn=act_fn,
|
||||
output_scale_factor=1,
|
||||
resnet_time_scale_shift="default" if norm_type == "group" else norm_type,
|
||||
attention_head_dim=block_out_channels[-1],
|
||||
resnet_groups=norm_num_groups,
|
||||
temb_channels=temb_channels,
|
||||
)
|
||||
|
||||
# up
|
||||
reversed_block_out_channels = list(reversed(block_out_channels))
|
||||
output_channel = reversed_block_out_channels[0]
|
||||
for i, up_block_type in enumerate(up_block_types):
|
||||
prev_output_channel = output_channel
|
||||
output_channel = reversed_block_out_channels[i]
|
||||
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
|
||||
up_block = get_up_block(
|
||||
up_block_type,
|
||||
num_layers=self.layers_per_block + 1,
|
||||
in_channels=prev_output_channel,
|
||||
out_channels=output_channel,
|
||||
prev_output_channel=None,
|
||||
add_upsample=not is_final_block,
|
||||
resnet_eps=1e-6,
|
||||
resnet_act_fn=act_fn,
|
||||
resnet_groups=norm_num_groups,
|
||||
attention_head_dim=output_channel,
|
||||
temb_channels=temb_channels,
|
||||
resnet_time_scale_shift=norm_type,
|
||||
)
|
||||
self.up_blocks.append(up_block)
|
||||
prev_output_channel = output_channel
|
||||
|
||||
# condition encoder
|
||||
self.condition_encoder = MaskConditionEncoder(
|
||||
in_ch=out_channels,
|
||||
out_ch=block_out_channels[0],
|
||||
res_ch=block_out_channels[-1],
|
||||
)
|
||||
|
||||
# out
|
||||
if norm_type == "spatial":
|
||||
self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels)
|
||||
else:
|
||||
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
|
||||
self.conv_act = nn.SiLU()
|
||||
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, z, image=None, mask=None, latent_embeds=None):
|
||||
sample = z
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
if is_torch_version(">=", "1.11.0"):
|
||||
# middle
|
||||
sample = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(self.mid_block), sample, latent_embeds, use_reentrant=False
|
||||
)
|
||||
sample = sample.to(upscale_dtype)
|
||||
|
||||
# condition encoder
|
||||
if image is not None and mask is not None:
|
||||
masked_image = (1 - mask) * image
|
||||
im_x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(self.condition_encoder), masked_image, mask, use_reentrant=False
|
||||
)
|
||||
|
||||
# up
|
||||
for up_block in self.up_blocks:
|
||||
if image is not None and mask is not None:
|
||||
sample_ = im_x[str(tuple(sample.shape))]
|
||||
mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest")
|
||||
sample = sample * mask_ + sample_ * (1 - mask_)
|
||||
sample = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(up_block), sample, latent_embeds, use_reentrant=False
|
||||
)
|
||||
if image is not None and mask is not None:
|
||||
sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask)
|
||||
else:
|
||||
# middle
|
||||
sample = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(self.mid_block), sample, latent_embeds
|
||||
)
|
||||
sample = sample.to(upscale_dtype)
|
||||
|
||||
# condition encoder
|
||||
if image is not None and mask is not None:
|
||||
masked_image = (1 - mask) * image
|
||||
im_x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(self.condition_encoder), masked_image, mask
|
||||
)
|
||||
|
||||
# up
|
||||
for up_block in self.up_blocks:
|
||||
if image is not None and mask is not None:
|
||||
sample_ = im_x[str(tuple(sample.shape))]
|
||||
mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest")
|
||||
sample = sample * mask_ + sample_ * (1 - mask_)
|
||||
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds)
|
||||
if image is not None and mask is not None:
|
||||
sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask)
|
||||
else:
|
||||
# middle
|
||||
sample = self.mid_block(sample, latent_embeds)
|
||||
sample = sample.to(upscale_dtype)
|
||||
|
||||
# condition encoder
|
||||
if image is not None and mask is not None:
|
||||
masked_image = (1 - mask) * image
|
||||
im_x = self.condition_encoder(masked_image, mask)
|
||||
|
||||
# up
|
||||
for up_block in self.up_blocks:
|
||||
if image is not None and mask is not None:
|
||||
sample_ = im_x[str(tuple(sample.shape))]
|
||||
mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest")
|
||||
sample = sample * mask_ + sample_ * (1 - mask_)
|
||||
sample = up_block(sample, latent_embeds)
|
||||
if image is not None and mask is not None:
|
||||
sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask)
|
||||
|
||||
# post-process
|
||||
if latent_embeds is None:
|
||||
sample = self.conv_norm_out(sample)
|
||||
else:
|
||||
sample = self.conv_norm_out(sample, latent_embeds)
|
||||
sample = self.conv_act(sample)
|
||||
sample = self.conv_out(sample)
|
||||
|
||||
return sample
|
||||
|
||||
|
||||
class VectorQuantizer(nn.Module):
|
||||
"""
|
||||
Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly avoids costly matrix
|
||||
|
||||
@@ -120,6 +120,7 @@ try:
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils.dummy_torch_and_transformers_and_invisible_watermark_objects import * # noqa F403
|
||||
else:
|
||||
from .controlnet import StableDiffusionXLControlNetPipeline
|
||||
from .stable_diffusion_xl import (
|
||||
StableDiffusionXLImg2ImgPipeline,
|
||||
StableDiffusionXLInpaintPipeline,
|
||||
|
||||
@@ -1,11 +1,16 @@
|
||||
from ...utils import (
|
||||
OptionalDependencyNotAvailable,
|
||||
is_flax_available,
|
||||
is_invisible_watermark_available,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
if is_transformers_available() and is_torch_available() and is_invisible_watermark_available():
|
||||
from .pipeline_controlnet_sd_xl import StableDiffusionXLControlNetPipeline
|
||||
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
|
||||
@@ -45,17 +45,17 @@ class MultiControlNetModel(ModelMixin):
|
||||
) -> Union[ControlNetOutput, Tuple]:
|
||||
for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)):
|
||||
down_samples, mid_sample = controlnet(
|
||||
sample,
|
||||
timestep,
|
||||
encoder_hidden_states,
|
||||
image,
|
||||
scale,
|
||||
class_labels,
|
||||
timestep_cond,
|
||||
attention_mask,
|
||||
cross_attention_kwargs,
|
||||
guess_mode,
|
||||
return_dict,
|
||||
sample=sample,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
controlnet_cond=image,
|
||||
conditioning_scale=scale,
|
||||
class_labels=class_labels,
|
||||
timestep_cond=timestep_cond,
|
||||
attention_mask=attention_mask,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
guess_mode=guess_mode,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
# merge samples
|
||||
|
||||
@@ -24,7 +24,7 @@ import torch.nn.functional as F
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import (
|
||||
@@ -90,7 +90,9 @@ EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
|
||||
class StableDiffusionControlNetPipeline(
|
||||
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
|
||||
):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance.
|
||||
|
||||
@@ -912,7 +914,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
|
||||
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
|
||||
for s, e in zip(control_guidance_start, control_guidance_end)
|
||||
]
|
||||
controlnet_keep.append(keeps[0] if len(keeps) == 1 else keeps)
|
||||
controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
|
||||
|
||||
# 8. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
|
||||
@@ -24,7 +24,7 @@ import torch.nn.functional as F
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import (
|
||||
@@ -116,7 +116,9 @@ def prepare_image(image):
|
||||
return image
|
||||
|
||||
|
||||
class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
|
||||
class StableDiffusionControlNetImg2ImgPipeline(
|
||||
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
|
||||
):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance.
|
||||
|
||||
@@ -1005,7 +1007,7 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi
|
||||
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
|
||||
for s, e in zip(control_guidance_start, control_guidance_end)
|
||||
]
|
||||
controlnet_keep.append(keeps[0] if len(keeps) == 1 else keeps)
|
||||
controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
|
||||
|
||||
# 8. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
|
||||
@@ -25,7 +25,7 @@ import torch.nn.functional as F
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import (
|
||||
@@ -222,7 +222,9 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image=False
|
||||
return mask, masked_image
|
||||
|
||||
|
||||
class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
|
||||
class StableDiffusionControlNetInpaintPipeline(
|
||||
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
|
||||
):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance.
|
||||
|
||||
@@ -1240,7 +1242,7 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, TextualInversi
|
||||
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
|
||||
for s, e in zip(control_guidance_start, control_guidance_end)
|
||||
]
|
||||
controlnet_keep.append(keeps[0] if len(keeps) == 1 else keeps)
|
||||
controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
|
||||
|
||||
# 8. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
|
||||
@@ -0,0 +1,960 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import inspect
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
|
||||
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
|
||||
from ...models.attention_processor import (
|
||||
AttnProcessor2_0,
|
||||
LoRAAttnProcessor2_0,
|
||||
LoRAXFormersAttnProcessor,
|
||||
XFormersAttnProcessor,
|
||||
)
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import (
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
is_compiled_module,
|
||||
logging,
|
||||
randn_tensor,
|
||||
replace_example_docstring,
|
||||
)
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from ..stable_diffusion_xl import StableDiffusionXLPipelineOutput
|
||||
from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
|
||||
from .multicontrolnet import MultiControlNetModel
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> # To be updated when there's a useful ControlNet checkpoint
|
||||
>>> # compatible with SDXL.
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using Stable Diffusion XL with ControlNet guidance.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
In addition the pipeline inherits the following loading methods:
|
||||
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
|
||||
- *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`CLIPTextModel`]):
|
||||
Frozen text-encoder. Stable Diffusion uses the text portion of
|
||||
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
||||
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
||||
text_encoder_2 ([` CLIPTextModelWithProjection`]):
|
||||
Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of
|
||||
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
|
||||
specifically the
|
||||
[laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
|
||||
variant.
|
||||
tokenizer (`CLIPTokenizer`):
|
||||
Tokenizer of class
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
tokenizer_2 (`CLIPTokenizer`):
|
||||
Second Tokenizer of class
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
|
||||
controlnet ([`ControlNetModel`] or `List[ControlNetModel]`):
|
||||
Provides additional conditioning to the unet during the denoising process. If you set multiple ControlNets
|
||||
as a list, the outputs from each ControlNet are added together to create one combined additional
|
||||
conditioning.
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: CLIPTextModel,
|
||||
text_encoder_2: CLIPTextModelWithProjection,
|
||||
tokenizer: CLIPTokenizer,
|
||||
tokenizer_2: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
controlnet: ControlNetModel,
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
force_zeros_for_empty_prompt: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if isinstance(controlnet, (list, tuple)):
|
||||
raise ValueError("MultiControlNet is not yet supported.")
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
text_encoder_2=text_encoder_2,
|
||||
tokenizer=tokenizer,
|
||||
tokenizer_2=tokenizer_2,
|
||||
unet=unet,
|
||||
controlnet=controlnet,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
|
||||
self.control_image_processor = VaeImageProcessor(
|
||||
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
|
||||
)
|
||||
self.watermark = StableDiffusionXLWatermarker()
|
||||
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
|
||||
def enable_vae_slicing(self):
|
||||
r"""
|
||||
Enable sliced VAE decoding.
|
||||
|
||||
When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
|
||||
steps. This is useful to save some memory and allow larger batch sizes.
|
||||
"""
|
||||
self.vae.enable_slicing()
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
|
||||
def disable_vae_slicing(self):
|
||||
r"""
|
||||
Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
self.vae.disable_slicing()
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
|
||||
def enable_vae_tiling(self):
|
||||
r"""
|
||||
Enable tiled VAE decoding.
|
||||
|
||||
When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in
|
||||
several steps. This is useful to save a large amount of memory and to allow the processing of larger images.
|
||||
"""
|
||||
self.vae.enable_tiling()
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
|
||||
def disable_vae_tiling(self):
|
||||
r"""
|
||||
Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
self.vae.disable_tiling()
|
||||
|
||||
def enable_model_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
|
||||
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
|
||||
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
|
||||
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
|
||||
"""
|
||||
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
|
||||
from accelerate import cpu_offload_with_hook
|
||||
else:
|
||||
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
hook = None
|
||||
for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
|
||||
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
|
||||
|
||||
# control net hook has be manually offloaded as it alternates with unet
|
||||
cpu_offload_with_hook(self.controlnet, device)
|
||||
|
||||
# We'll offload the last model manually.
|
||||
self.final_offload_hook = hook
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
device: Optional[torch.device] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
negative_prompt=None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
lora_scale: Optional[float] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
device: (`torch.device`):
|
||||
torch device
|
||||
num_images_per_prompt (`int`):
|
||||
number of images that should be generated per prompt
|
||||
do_classifier_free_guidance (`bool`):
|
||||
whether to use classifier free guidance or not
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
||||
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
||||
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
|
||||
input argument.
|
||||
lora_scale (`float`, *optional*):
|
||||
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
|
||||
# set lora scale so that monkey patched LoRA
|
||||
# function of text encoder can correctly access it
|
||||
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
# Define tokenizers and text encoders
|
||||
tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
|
||||
text_encoders = (
|
||||
[self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
|
||||
)
|
||||
|
||||
if prompt_embeds is None:
|
||||
# textual inversion: procecss multi-vector tokens if necessary
|
||||
prompt_embeds_list = []
|
||||
for tokenizer, text_encoder in zip(tokenizers, text_encoders):
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
prompt = self.maybe_convert_prompt(prompt, tokenizer)
|
||||
|
||||
text_inputs = tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
||||
text_input_ids, untruncated_ids
|
||||
):
|
||||
removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {tokenizer.model_max_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
prompt_embeds = text_encoder(
|
||||
text_input_ids.to(device),
|
||||
output_hidden_states=True,
|
||||
)
|
||||
|
||||
# We are only ALWAYS interested in the pooled output of the final text encoder
|
||||
pooled_prompt_embeds = prompt_embeds[0]
|
||||
prompt_embeds = prompt_embeds.hidden_states[-2]
|
||||
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
prompt_embeds_list.append(prompt_embeds)
|
||||
|
||||
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
|
||||
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
|
||||
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
|
||||
negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
|
||||
elif do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
negative_prompt = negative_prompt or ""
|
||||
uncond_tokens: List[str]
|
||||
if prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif isinstance(negative_prompt, str):
|
||||
uncond_tokens = [negative_prompt]
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
else:
|
||||
uncond_tokens = negative_prompt
|
||||
|
||||
negative_prompt_embeds_list = []
|
||||
for tokenizer, text_encoder in zip(tokenizers, text_encoders):
|
||||
# textual inversion: procecss multi-vector tokens if necessary
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, tokenizer)
|
||||
|
||||
max_length = prompt_embeds.shape[1]
|
||||
uncond_input = tokenizer(
|
||||
uncond_tokens,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
negative_prompt_embeds = text_encoder(
|
||||
uncond_input.input_ids.to(device),
|
||||
output_hidden_states=True,
|
||||
)
|
||||
# We are only ALWAYS interested in the pooled output of the final text encoder
|
||||
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
|
||||
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = negative_prompt_embeds.shape[1]
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder.dtype, device=device)
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(
|
||||
batch_size * num_images_per_prompt, seq_len, -1
|
||||
)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
|
||||
negative_prompt_embeds_list.append(negative_prompt_embeds)
|
||||
|
||||
negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
|
||||
|
||||
bs_embed = pooled_prompt_embeds.shape[0]
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
|
||||
bs_embed * num_images_per_prompt, -1
|
||||
)
|
||||
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
|
||||
bs_embed * num_images_per_prompt, -1
|
||||
)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
|
||||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
# check if the scheduler accepts generator
|
||||
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
if accepts_generator:
|
||||
extra_step_kwargs["generator"] = generator
|
||||
return extra_step_kwargs
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
image,
|
||||
callback_steps,
|
||||
negative_prompt=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
controlnet_conditioning_scale=1.0,
|
||||
control_guidance_start=0.0,
|
||||
control_guidance_end=1.0,
|
||||
):
|
||||
if (callback_steps is None) or (
|
||||
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
||||
f" {type(callback_steps)}."
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
||||
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
||||
raise ValueError(
|
||||
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
|
||||
# Check `image`
|
||||
is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
|
||||
self.controlnet, torch._dynamo.eval_frame.OptimizedModule
|
||||
)
|
||||
if (
|
||||
isinstance(self.controlnet, ControlNetModel)
|
||||
or is_compiled
|
||||
and isinstance(self.controlnet._orig_mod, ControlNetModel)
|
||||
):
|
||||
self.check_image(image, prompt, prompt_embeds)
|
||||
else:
|
||||
assert False
|
||||
|
||||
# Check `controlnet_conditioning_scale`
|
||||
if (
|
||||
isinstance(self.controlnet, ControlNetModel)
|
||||
or is_compiled
|
||||
and isinstance(self.controlnet._orig_mod, ControlNetModel)
|
||||
):
|
||||
if not isinstance(controlnet_conditioning_scale, float):
|
||||
raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
|
||||
else:
|
||||
assert False
|
||||
|
||||
if len(control_guidance_start) != len(control_guidance_end):
|
||||
raise ValueError(
|
||||
f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list."
|
||||
)
|
||||
|
||||
for start, end in zip(control_guidance_start, control_guidance_end):
|
||||
if start >= end:
|
||||
raise ValueError(
|
||||
f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}."
|
||||
)
|
||||
if start < 0.0:
|
||||
raise ValueError(f"control guidance start: {start} can't be smaller than 0.")
|
||||
if end > 1.0:
|
||||
raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
|
||||
|
||||
def check_image(self, image, prompt, prompt_embeds):
|
||||
image_is_pil = isinstance(image, PIL.Image.Image)
|
||||
image_is_tensor = isinstance(image, torch.Tensor)
|
||||
image_is_np = isinstance(image, np.ndarray)
|
||||
image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
|
||||
image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
|
||||
image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)
|
||||
|
||||
if (
|
||||
not image_is_pil
|
||||
and not image_is_tensor
|
||||
and not image_is_np
|
||||
and not image_is_pil_list
|
||||
and not image_is_tensor_list
|
||||
and not image_is_np_list
|
||||
):
|
||||
raise TypeError(
|
||||
f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}"
|
||||
)
|
||||
|
||||
if image_is_pil:
|
||||
image_batch_size = 1
|
||||
else:
|
||||
image_batch_size = len(image)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
prompt_batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
prompt_batch_size = len(prompt)
|
||||
elif prompt_embeds is not None:
|
||||
prompt_batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if image_batch_size != 1 and image_batch_size != prompt_batch_size:
|
||||
raise ValueError(
|
||||
f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
|
||||
)
|
||||
|
||||
def prepare_image(
|
||||
self,
|
||||
image,
|
||||
width,
|
||||
height,
|
||||
batch_size,
|
||||
num_images_per_prompt,
|
||||
device,
|
||||
dtype,
|
||||
do_classifier_free_guidance=False,
|
||||
guess_mode=False,
|
||||
):
|
||||
image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
|
||||
image_batch_size = image.shape[0]
|
||||
|
||||
if image_batch_size == 1:
|
||||
repeat_by = batch_size
|
||||
else:
|
||||
# image batch size is the same as prompt batch size
|
||||
repeat_by = num_images_per_prompt
|
||||
|
||||
image = image.repeat_interleave(repeat_by, dim=0)
|
||||
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
|
||||
if do_classifier_free_guidance and not guess_mode:
|
||||
image = torch.cat([image] * 2)
|
||||
|
||||
return image
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
||||
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
||||
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
latents = latents.to(device)
|
||||
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids
|
||||
def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype):
|
||||
add_time_ids = list(original_size + crops_coords_top_left + target_size)
|
||||
|
||||
passed_add_embed_dim = (
|
||||
self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim
|
||||
)
|
||||
expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
|
||||
|
||||
if expected_add_embed_dim != passed_add_embed_dim:
|
||||
raise ValueError(
|
||||
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
|
||||
)
|
||||
|
||||
add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
|
||||
return add_time_ids
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
|
||||
def upcast_vae(self):
|
||||
dtype = self.vae.dtype
|
||||
self.vae.to(dtype=torch.float32)
|
||||
use_torch_2_0_or_xformers = isinstance(
|
||||
self.vae.decoder.mid_block.attentions[0].processor,
|
||||
(
|
||||
AttnProcessor2_0,
|
||||
XFormersAttnProcessor,
|
||||
LoRAXFormersAttnProcessor,
|
||||
LoRAAttnProcessor2_0,
|
||||
),
|
||||
)
|
||||
# if xformers or torch_2_0 is used attention block does not need
|
||||
# to be in float32 which can save lots of memory
|
||||
if use_torch_2_0_or_xformers:
|
||||
self.vae.post_quant_conv.to(dtype)
|
||||
self.vae.decoder.conv_in.to(dtype)
|
||||
self.vae.decoder.mid_block.to(dtype)
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
image: Union[
|
||||
torch.FloatTensor,
|
||||
PIL.Image.Image,
|
||||
np.ndarray,
|
||||
List[torch.FloatTensor],
|
||||
List[PIL.Image.Image],
|
||||
List[np.ndarray],
|
||||
] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 7.5,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: int = 1,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
|
||||
guess_mode: bool = False,
|
||||
control_guidance_start: Union[float, List[float]] = 0.0,
|
||||
control_guidance_end: Union[float, List[float]] = 1.0,
|
||||
original_size: Tuple[int, int] = (1024, 1024),
|
||||
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
||||
target_size: Tuple[int, int] = (1024, 1024),
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
|
||||
`List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
|
||||
The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
|
||||
the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can
|
||||
also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If
|
||||
height and/or width are passed, `image` is resized according to them. If multiple ControlNets are
|
||||
specified in init, images must be passed as a list such that each element of the list can be correctly
|
||||
batched for input to a single controlnet.
|
||||
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The width in pixels of the generated image.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, *optional*, defaults to 7.5):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
eta (`float`, *optional*, defaults to 0.0):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
||||
plain tuple.
|
||||
callback (`Callable`, *optional*):
|
||||
A function that will be called every `callback_steps` steps during inference. The function will be
|
||||
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
||||
callback_steps (`int`, *optional*, defaults to 1):
|
||||
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
||||
called at every step.
|
||||
cross_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
|
||||
controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
|
||||
The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
|
||||
to the residual in the original unet. If multiple ControlNets are specified in init, you can set the
|
||||
corresponding scale as a list.
|
||||
guess_mode (`bool`, *optional*, defaults to `False`):
|
||||
In this mode, the ControlNet encoder will try best to recognize the content of the input image even if
|
||||
you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended.
|
||||
control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
|
||||
The percentage of total steps at which the controlnet starts applying.
|
||||
control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
|
||||
The percentage of total steps at which the controlnet stops applying.
|
||||
original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
||||
TODO
|
||||
crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
|
||||
TODO
|
||||
target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
||||
TODO
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
||||
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple`
|
||||
containing the output images.
|
||||
"""
|
||||
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
|
||||
|
||||
# align format for control guidance
|
||||
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
|
||||
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
|
||||
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
|
||||
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
|
||||
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
|
||||
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
|
||||
control_guidance_start, control_guidance_end = mult * [control_guidance_start], mult * [
|
||||
control_guidance_end
|
||||
]
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
image,
|
||||
callback_steps,
|
||||
negative_prompt,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
controlnet_conditioning_scale,
|
||||
control_guidance_start,
|
||||
control_guidance_end,
|
||||
)
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
global_pool_conditions = (
|
||||
controlnet.config.global_pool_conditions
|
||||
if isinstance(controlnet, ControlNetModel)
|
||||
else controlnet.nets[0].config.global_pool_conditions
|
||||
)
|
||||
guess_mode = guess_mode or global_pool_conditions
|
||||
|
||||
# 3. Encode input prompt
|
||||
text_encoder_lora_scale = (
|
||||
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
|
||||
)
|
||||
(
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
) = self.encode_prompt(
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
lora_scale=text_encoder_lora_scale,
|
||||
)
|
||||
|
||||
# 4. Prepare image
|
||||
if isinstance(controlnet, ControlNetModel):
|
||||
image = self.prepare_image(
|
||||
image=image,
|
||||
width=width,
|
||||
height=height,
|
||||
batch_size=batch_size * num_images_per_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
dtype=controlnet.dtype,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
guess_mode=guess_mode,
|
||||
)
|
||||
height, width = image.shape[-2:]
|
||||
else:
|
||||
assert False
|
||||
|
||||
# 5. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 6. Prepare latent variables
|
||||
num_channels_latents = self.unet.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 7.1 Create tensor stating which controlnets to keep
|
||||
controlnet_keep = []
|
||||
for i in range(len(timesteps)):
|
||||
keeps = [
|
||||
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
|
||||
for s, e in zip(control_guidance_start, control_guidance_end)
|
||||
]
|
||||
controlnet_keep.append(keeps[0] if len(keeps) == 1 else keeps)
|
||||
|
||||
# 7.2 Prepare added time ids & embeddings
|
||||
add_text_embeds = pooled_prompt_embeds
|
||||
add_time_ids = self._get_add_time_ids(
|
||||
original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
|
||||
)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
|
||||
add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(device)
|
||||
add_text_embeds = add_text_embeds.to(device)
|
||||
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
|
||||
|
||||
# 8. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# controlnet(s) inference
|
||||
if guess_mode and do_classifier_free_guidance:
|
||||
# Infer ControlNet only for the conditional batch.
|
||||
control_model_input = latents
|
||||
control_model_input = self.scheduler.scale_model_input(control_model_input, t)
|
||||
controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
|
||||
else:
|
||||
control_model_input = latent_model_input
|
||||
controlnet_prompt_embeds = prompt_embeds
|
||||
|
||||
if isinstance(controlnet_keep[i], list):
|
||||
cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
|
||||
else:
|
||||
cond_scale = controlnet_conditioning_scale * controlnet_keep[i]
|
||||
|
||||
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
|
||||
down_block_res_samples, mid_block_res_sample = self.controlnet(
|
||||
control_model_input,
|
||||
t,
|
||||
encoder_hidden_states=controlnet_prompt_embeds,
|
||||
controlnet_cond=image,
|
||||
conditioning_scale=cond_scale,
|
||||
guess_mode=guess_mode,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
if guess_mode and do_classifier_free_guidance:
|
||||
# Infered ControlNet only for the conditional batch.
|
||||
# To apply the output of ControlNet to both the unconditional and conditional batches,
|
||||
# add 0 to the unconditional batch to keep it unchanged.
|
||||
down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
|
||||
mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(
|
||||
latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
down_block_additional_residuals=down_block_res_samples,
|
||||
mid_block_additional_residual=mid_block_res_sample,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
# If we do sequential model offloading, let's offload unet and controlnet
|
||||
# manually for max memory savings
|
||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||
self.unet.to("cpu")
|
||||
self.controlnet.to("cpu")
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# make sure the VAE is in float32 mode, as it overflows in float16
|
||||
if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
|
||||
self.upcast_vae()
|
||||
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
||||
|
||||
if not output_type == "latent":
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
||||
else:
|
||||
image = latents
|
||||
return StableDiffusionXLPipelineOutput(images=image)
|
||||
|
||||
image = self.watermark.apply_watermark(image)
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
# Offload last model to CPU
|
||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||
self.final_offload_hook.offload()
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return StableDiffusionXLPipelineOutput(images=image)
|
||||
@@ -95,7 +95,7 @@ class ShapEPipeline(DiffusionPipeline):
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
scheduler ([`HeunDiscreteScheduler`]):
|
||||
A scheduler to be used in combination with `prior` to generate image embedding.
|
||||
renderer ([`ShapERenderer`]):
|
||||
shap_e_renderer ([`ShapERenderer`]):
|
||||
Shap-E renderer projects the generated latents into parameters of a MLP that's used to create 3D objects
|
||||
with the NeRF rendering method
|
||||
"""
|
||||
@@ -106,7 +106,7 @@ class ShapEPipeline(DiffusionPipeline):
|
||||
text_encoder: CLIPTextModelWithProjection,
|
||||
tokenizer: CLIPTokenizer,
|
||||
scheduler: HeunDiscreteScheduler,
|
||||
renderer: ShapERenderer,
|
||||
shap_e_renderer: ShapERenderer,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -115,7 +115,7 @@ class ShapEPipeline(DiffusionPipeline):
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
scheduler=scheduler,
|
||||
renderer=renderer,
|
||||
shap_e_renderer=shap_e_renderer,
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents
|
||||
@@ -149,7 +149,7 @@ class ShapEPipeline(DiffusionPipeline):
|
||||
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
||||
|
||||
hook = None
|
||||
for cpu_offloaded_model in [self.text_encoder, self.prior, self.renderer]:
|
||||
for cpu_offloaded_model in [self.text_encoder, self.prior, self.shap_e_renderer]:
|
||||
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
|
||||
|
||||
if self.safety_checker is not None:
|
||||
@@ -218,7 +218,7 @@ class ShapEPipeline(DiffusionPipeline):
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
guidance_scale: float = 4.0,
|
||||
frame_size: int = 64,
|
||||
output_type: Optional[str] = "pil", # pil, np, latent
|
||||
output_type: Optional[str] = "pil", # pil, np, latent, mesh
|
||||
return_dict: bool = True,
|
||||
):
|
||||
"""
|
||||
@@ -248,8 +248,8 @@ class ShapEPipeline(DiffusionPipeline):
|
||||
frame_size (`int`, *optional*, default to 64):
|
||||
the width and height of each image frame of the generated 3d output
|
||||
output_type (`str`, *optional*, defaults to `"pt"`):
|
||||
The output format of the generate image. Choose between: `"np"` (`np.array`) or `"pt"`
|
||||
(`torch.Tensor`).
|
||||
The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
|
||||
(`np.array`),`"latent"` (`torch.Tensor`), mesh ([`MeshDecoderOutput`]).
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
|
||||
|
||||
@@ -319,30 +319,39 @@ class ShapEPipeline(DiffusionPipeline):
|
||||
sample=latents,
|
||||
).prev_sample
|
||||
|
||||
if output_type not in ["np", "pil", "latent", "mesh"]:
|
||||
raise ValueError(
|
||||
f"Only the output types `pil`, `np`, `latent` and `mesh` are supported not output_type={output_type}"
|
||||
)
|
||||
|
||||
if output_type == "latent":
|
||||
return ShapEPipelineOutput(images=latents)
|
||||
|
||||
images = []
|
||||
for i, latent in enumerate(latents):
|
||||
image = self.renderer.decode(
|
||||
latent[None, :],
|
||||
device,
|
||||
size=frame_size,
|
||||
ray_batch_size=4096,
|
||||
n_coarse_samples=64,
|
||||
n_fine_samples=128,
|
||||
)
|
||||
images.append(image)
|
||||
if output_type == "mesh":
|
||||
for i, latent in enumerate(latents):
|
||||
mesh = self.shap_e_renderer.decode_to_mesh(
|
||||
latent[None, :],
|
||||
device,
|
||||
)
|
||||
images.append(mesh)
|
||||
|
||||
images = torch.stack(images)
|
||||
else:
|
||||
# np, pil
|
||||
for i, latent in enumerate(latents):
|
||||
image = self.shap_e_renderer.decode_to_image(
|
||||
latent[None, :],
|
||||
device,
|
||||
size=frame_size,
|
||||
)
|
||||
images.append(image)
|
||||
|
||||
if output_type not in ["np", "pil"]:
|
||||
raise ValueError(f"Only the output types `pil` and `np` are supported not output_type={output_type}")
|
||||
images = torch.stack(images)
|
||||
|
||||
images = images.cpu().numpy()
|
||||
images = images.cpu().numpy()
|
||||
|
||||
if output_type == "pil":
|
||||
images = [self.numpy_to_pil(image) for image in images]
|
||||
if output_type == "pil":
|
||||
images = [self.numpy_to_pil(image) for image in images]
|
||||
|
||||
# Offload last model to CPU
|
||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||
|
||||
@@ -94,7 +94,7 @@ class ShapEImg2ImgPipeline(DiffusionPipeline):
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
scheduler ([`HeunDiscreteScheduler`]):
|
||||
A scheduler to be used in combination with `prior` to generate image embedding.
|
||||
renderer ([`ShapERenderer`]):
|
||||
shap_e_renderer ([`ShapERenderer`]):
|
||||
Shap-E renderer projects the generated latents into parameters of a MLP that's used to create 3D objects
|
||||
with the NeRF rendering method
|
||||
"""
|
||||
@@ -105,7 +105,7 @@ class ShapEImg2ImgPipeline(DiffusionPipeline):
|
||||
image_encoder: CLIPVisionModel,
|
||||
image_processor: CLIPImageProcessor,
|
||||
scheduler: HeunDiscreteScheduler,
|
||||
renderer: ShapERenderer,
|
||||
shap_e_renderer: ShapERenderer,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -114,7 +114,7 @@ class ShapEImg2ImgPipeline(DiffusionPipeline):
|
||||
image_encoder=image_encoder,
|
||||
image_processor=image_processor,
|
||||
scheduler=scheduler,
|
||||
renderer=renderer,
|
||||
shap_e_renderer=shap_e_renderer,
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents
|
||||
@@ -170,7 +170,7 @@ class ShapEImg2ImgPipeline(DiffusionPipeline):
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
guidance_scale: float = 4.0,
|
||||
frame_size: int = 64,
|
||||
output_type: Optional[str] = "pil", # pil, np, latent
|
||||
output_type: Optional[str] = "pil", # pil, np, latent, mesh
|
||||
return_dict: bool = True,
|
||||
):
|
||||
"""
|
||||
@@ -200,8 +200,7 @@ class ShapEImg2ImgPipeline(DiffusionPipeline):
|
||||
frame_size (`int`, *optional*, default to 64):
|
||||
the width and height of each image frame of the generated 3d output
|
||||
output_type (`str`, *optional*, defaults to `"pt"`):
|
||||
The output format of the generate image. Choose between: `"np"` (`np.array`) or `"pt"`
|
||||
(`torch.Tensor`).
|
||||
(`np.array`),`"latent"` (`torch.Tensor`), mesh ([`MeshDecoderOutput`]).
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
|
||||
|
||||
@@ -275,32 +274,39 @@ class ShapEImg2ImgPipeline(DiffusionPipeline):
|
||||
sample=latents,
|
||||
).prev_sample
|
||||
|
||||
if output_type not in ["np", "pil", "latent", "mesh"]:
|
||||
raise ValueError(
|
||||
f"Only the output types `pil`, `np`, `latent` and `mesh` are supported not output_type={output_type}"
|
||||
)
|
||||
|
||||
if output_type == "latent":
|
||||
return ShapEPipelineOutput(images=latents)
|
||||
|
||||
images = []
|
||||
for i, latent in enumerate(latents):
|
||||
print()
|
||||
image = self.renderer.decode(
|
||||
latent[None, :],
|
||||
device,
|
||||
size=frame_size,
|
||||
ray_batch_size=4096,
|
||||
n_coarse_samples=64,
|
||||
n_fine_samples=128,
|
||||
)
|
||||
if output_type == "mesh":
|
||||
for i, latent in enumerate(latents):
|
||||
mesh = self.shap_e_renderer.decode_to_mesh(
|
||||
latent[None, :],
|
||||
device,
|
||||
)
|
||||
images.append(mesh)
|
||||
|
||||
images.append(image)
|
||||
else:
|
||||
# np, pil
|
||||
for i, latent in enumerate(latents):
|
||||
image = self.shap_e_renderer.decode_to_image(
|
||||
latent[None, :],
|
||||
device,
|
||||
size=frame_size,
|
||||
)
|
||||
images.append(image)
|
||||
|
||||
images = torch.stack(images)
|
||||
images = torch.stack(images)
|
||||
|
||||
if output_type not in ["np", "pil"]:
|
||||
raise ValueError(f"Only the output types `pil` and `np` are supported not output_type={output_type}")
|
||||
images = images.cpu().numpy()
|
||||
|
||||
images = images.cpu().numpy()
|
||||
|
||||
if output_type == "pil":
|
||||
images = [self.numpy_to_pil(image) for image in images]
|
||||
if output_type == "pil":
|
||||
images = [self.numpy_to_pil(image) for image in images]
|
||||
|
||||
# Offload last model to CPU
|
||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -116,6 +116,101 @@ def integrate_samples(volume_range, ts, density, channels):
|
||||
return channels, weights, transmittance
|
||||
|
||||
|
||||
def volume_query_points(volume, grid_size):
|
||||
indices = torch.arange(grid_size**3, device=volume.bbox_min.device)
|
||||
zs = indices % grid_size
|
||||
ys = torch.div(indices, grid_size, rounding_mode="trunc") % grid_size
|
||||
xs = torch.div(indices, grid_size**2, rounding_mode="trunc") % grid_size
|
||||
combined = torch.stack([xs, ys, zs], dim=1)
|
||||
return (combined.float() / (grid_size - 1)) * (volume.bbox_max - volume.bbox_min) + volume.bbox_min
|
||||
|
||||
|
||||
def _convert_srgb_to_linear(u: torch.Tensor):
|
||||
return torch.where(u <= 0.04045, u / 12.92, ((u + 0.055) / 1.055) ** 2.4)
|
||||
|
||||
|
||||
def _create_flat_edge_indices(
|
||||
flat_cube_indices: torch.Tensor,
|
||||
grid_size: Tuple[int, int, int],
|
||||
):
|
||||
num_xs = (grid_size[0] - 1) * grid_size[1] * grid_size[2]
|
||||
y_offset = num_xs
|
||||
num_ys = grid_size[0] * (grid_size[1] - 1) * grid_size[2]
|
||||
z_offset = num_xs + num_ys
|
||||
return torch.stack(
|
||||
[
|
||||
# Edges spanning x-axis.
|
||||
flat_cube_indices[:, 0] * grid_size[1] * grid_size[2]
|
||||
+ flat_cube_indices[:, 1] * grid_size[2]
|
||||
+ flat_cube_indices[:, 2],
|
||||
flat_cube_indices[:, 0] * grid_size[1] * grid_size[2]
|
||||
+ (flat_cube_indices[:, 1] + 1) * grid_size[2]
|
||||
+ flat_cube_indices[:, 2],
|
||||
flat_cube_indices[:, 0] * grid_size[1] * grid_size[2]
|
||||
+ flat_cube_indices[:, 1] * grid_size[2]
|
||||
+ flat_cube_indices[:, 2]
|
||||
+ 1,
|
||||
flat_cube_indices[:, 0] * grid_size[1] * grid_size[2]
|
||||
+ (flat_cube_indices[:, 1] + 1) * grid_size[2]
|
||||
+ flat_cube_indices[:, 2]
|
||||
+ 1,
|
||||
# Edges spanning y-axis.
|
||||
(
|
||||
y_offset
|
||||
+ flat_cube_indices[:, 0] * (grid_size[1] - 1) * grid_size[2]
|
||||
+ flat_cube_indices[:, 1] * grid_size[2]
|
||||
+ flat_cube_indices[:, 2]
|
||||
),
|
||||
(
|
||||
y_offset
|
||||
+ (flat_cube_indices[:, 0] + 1) * (grid_size[1] - 1) * grid_size[2]
|
||||
+ flat_cube_indices[:, 1] * grid_size[2]
|
||||
+ flat_cube_indices[:, 2]
|
||||
),
|
||||
(
|
||||
y_offset
|
||||
+ flat_cube_indices[:, 0] * (grid_size[1] - 1) * grid_size[2]
|
||||
+ flat_cube_indices[:, 1] * grid_size[2]
|
||||
+ flat_cube_indices[:, 2]
|
||||
+ 1
|
||||
),
|
||||
(
|
||||
y_offset
|
||||
+ (flat_cube_indices[:, 0] + 1) * (grid_size[1] - 1) * grid_size[2]
|
||||
+ flat_cube_indices[:, 1] * grid_size[2]
|
||||
+ flat_cube_indices[:, 2]
|
||||
+ 1
|
||||
),
|
||||
# Edges spanning z-axis.
|
||||
(
|
||||
z_offset
|
||||
+ flat_cube_indices[:, 0] * grid_size[1] * (grid_size[2] - 1)
|
||||
+ flat_cube_indices[:, 1] * (grid_size[2] - 1)
|
||||
+ flat_cube_indices[:, 2]
|
||||
),
|
||||
(
|
||||
z_offset
|
||||
+ (flat_cube_indices[:, 0] + 1) * grid_size[1] * (grid_size[2] - 1)
|
||||
+ flat_cube_indices[:, 1] * (grid_size[2] - 1)
|
||||
+ flat_cube_indices[:, 2]
|
||||
),
|
||||
(
|
||||
z_offset
|
||||
+ flat_cube_indices[:, 0] * grid_size[1] * (grid_size[2] - 1)
|
||||
+ (flat_cube_indices[:, 1] + 1) * (grid_size[2] - 1)
|
||||
+ flat_cube_indices[:, 2]
|
||||
),
|
||||
(
|
||||
z_offset
|
||||
+ (flat_cube_indices[:, 0] + 1) * grid_size[1] * (grid_size[2] - 1)
|
||||
+ (flat_cube_indices[:, 1] + 1) * (grid_size[2] - 1)
|
||||
+ flat_cube_indices[:, 2]
|
||||
),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
|
||||
class VoidNeRFModel(nn.Module):
|
||||
"""
|
||||
Implements the default empty space model where all queries are rendered as background.
|
||||
@@ -368,6 +463,141 @@ class ImportanceRaySampler(nn.Module):
|
||||
return ts
|
||||
|
||||
|
||||
@dataclass
|
||||
class MeshDecoderOutput(BaseOutput):
|
||||
"""
|
||||
A 3D triangle mesh with optional data at the vertices and faces.
|
||||
|
||||
Args:
|
||||
verts (`torch.Tensor` of shape `(N, 3)`):
|
||||
array of vertext coordinates
|
||||
faces (`torch.Tensor` of shape `(N, 3)`):
|
||||
array of triangles, pointing to indices in verts.
|
||||
vertext_channels (Dict):
|
||||
vertext coordinates for each color channel
|
||||
"""
|
||||
|
||||
verts: torch.Tensor
|
||||
faces: torch.Tensor
|
||||
vertex_channels: Dict[str, torch.Tensor]
|
||||
|
||||
|
||||
class MeshDecoder(nn.Module):
|
||||
"""
|
||||
Construct meshes from Signed distance functions (SDFs) using marching cubes method
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
cases = torch.zeros(256, 5, 3, dtype=torch.long)
|
||||
masks = torch.zeros(256, 5, dtype=torch.bool)
|
||||
|
||||
self.register_buffer("cases", cases)
|
||||
self.register_buffer("masks", masks)
|
||||
|
||||
def forward(self, field: torch.Tensor, min_point: torch.Tensor, size: torch.Tensor):
|
||||
"""
|
||||
For a signed distance field, produce a mesh using marching cubes.
|
||||
|
||||
:param field: a 3D tensor of field values, where negative values correspond
|
||||
to the outside of the shape. The dimensions correspond to the x, y, and z directions, respectively.
|
||||
:param min_point: a tensor of shape [3] containing the point corresponding
|
||||
to (0, 0, 0) in the field.
|
||||
:param size: a tensor of shape [3] containing the per-axis distance from the
|
||||
(0, 0, 0) field corner and the (-1, -1, -1) field corner.
|
||||
"""
|
||||
assert len(field.shape) == 3, "input must be a 3D scalar field"
|
||||
dev = field.device
|
||||
|
||||
cases = self.cases.to(dev)
|
||||
masks = self.masks.to(dev)
|
||||
|
||||
min_point = min_point.to(dev)
|
||||
size = size.to(dev)
|
||||
|
||||
grid_size = field.shape
|
||||
grid_size_tensor = torch.tensor(grid_size).to(size)
|
||||
|
||||
# Create bitmasks between 0 and 255 (inclusive) indicating the state
|
||||
# of the eight corners of each cube.
|
||||
bitmasks = (field > 0).to(torch.uint8)
|
||||
bitmasks = bitmasks[:-1, :, :] | (bitmasks[1:, :, :] << 1)
|
||||
bitmasks = bitmasks[:, :-1, :] | (bitmasks[:, 1:, :] << 2)
|
||||
bitmasks = bitmasks[:, :, :-1] | (bitmasks[:, :, 1:] << 4)
|
||||
|
||||
# Compute corner coordinates across the entire grid.
|
||||
corner_coords = torch.empty(*grid_size, 3, device=dev, dtype=field.dtype)
|
||||
corner_coords[range(grid_size[0]), :, :, 0] = torch.arange(grid_size[0], device=dev, dtype=field.dtype)[
|
||||
:, None, None
|
||||
]
|
||||
corner_coords[:, range(grid_size[1]), :, 1] = torch.arange(grid_size[1], device=dev, dtype=field.dtype)[
|
||||
:, None
|
||||
]
|
||||
corner_coords[:, :, range(grid_size[2]), 2] = torch.arange(grid_size[2], device=dev, dtype=field.dtype)
|
||||
|
||||
# Compute all vertices across all edges in the grid, even though we will
|
||||
# throw some out later. We have (X-1)*Y*Z + X*(Y-1)*Z + X*Y*(Z-1) vertices.
|
||||
# These are all midpoints, and don't account for interpolation (which is
|
||||
# done later based on the used edge midpoints).
|
||||
edge_midpoints = torch.cat(
|
||||
[
|
||||
((corner_coords[:-1] + corner_coords[1:]) / 2).reshape(-1, 3),
|
||||
((corner_coords[:, :-1] + corner_coords[:, 1:]) / 2).reshape(-1, 3),
|
||||
((corner_coords[:, :, :-1] + corner_coords[:, :, 1:]) / 2).reshape(-1, 3),
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
|
||||
# Create a flat array of [X, Y, Z] indices for each cube.
|
||||
cube_indices = torch.zeros(
|
||||
grid_size[0] - 1, grid_size[1] - 1, grid_size[2] - 1, 3, device=dev, dtype=torch.long
|
||||
)
|
||||
cube_indices[range(grid_size[0] - 1), :, :, 0] = torch.arange(grid_size[0] - 1, device=dev)[:, None, None]
|
||||
cube_indices[:, range(grid_size[1] - 1), :, 1] = torch.arange(grid_size[1] - 1, device=dev)[:, None]
|
||||
cube_indices[:, :, range(grid_size[2] - 1), 2] = torch.arange(grid_size[2] - 1, device=dev)
|
||||
flat_cube_indices = cube_indices.reshape(-1, 3)
|
||||
|
||||
# Create a flat array mapping each cube to 12 global edge indices.
|
||||
edge_indices = _create_flat_edge_indices(flat_cube_indices, grid_size)
|
||||
|
||||
# Apply the LUT to figure out the triangles.
|
||||
flat_bitmasks = bitmasks.reshape(-1).long() # must cast to long for indexing to believe this not a mask
|
||||
local_tris = cases[flat_bitmasks]
|
||||
local_masks = masks[flat_bitmasks]
|
||||
# Compute the global edge indices for the triangles.
|
||||
global_tris = torch.gather(edge_indices, 1, local_tris.reshape(local_tris.shape[0], -1)).reshape(
|
||||
local_tris.shape
|
||||
)
|
||||
# Select the used triangles for each cube.
|
||||
selected_tris = global_tris.reshape(-1, 3)[local_masks.reshape(-1)]
|
||||
|
||||
# Now we have a bunch of indices into the full list of possible vertices,
|
||||
# but we want to reduce this list to only the used vertices.
|
||||
used_vertex_indices = torch.unique(selected_tris.view(-1))
|
||||
used_edge_midpoints = edge_midpoints[used_vertex_indices]
|
||||
old_index_to_new_index = torch.zeros(len(edge_midpoints), device=dev, dtype=torch.long)
|
||||
old_index_to_new_index[used_vertex_indices] = torch.arange(
|
||||
len(used_vertex_indices), device=dev, dtype=torch.long
|
||||
)
|
||||
|
||||
# Rewrite the triangles to use the new indices
|
||||
faces = torch.gather(old_index_to_new_index, 0, selected_tris.view(-1)).reshape(selected_tris.shape)
|
||||
|
||||
# Compute the actual interpolated coordinates corresponding to edge midpoints.
|
||||
v1 = torch.floor(used_edge_midpoints).to(torch.long)
|
||||
v2 = torch.ceil(used_edge_midpoints).to(torch.long)
|
||||
s1 = field[v1[:, 0], v1[:, 1], v1[:, 2]]
|
||||
s2 = field[v2[:, 0], v2[:, 1], v2[:, 2]]
|
||||
p1 = (v1.float() / (grid_size_tensor - 1)) * size + min_point
|
||||
p2 = (v2.float() / (grid_size_tensor - 1)) * size + min_point
|
||||
# The signs of s1 and s2 should be different. We want to find
|
||||
# t such that t*s2 + (1-t)*s1 = 0.
|
||||
t = (s1 / (s1 - s2))[:, None]
|
||||
verts = t * p2 + (1 - t) * p1
|
||||
|
||||
return MeshDecoderOutput(verts=verts, faces=faces, vertex_channels=None)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MLPNeRFModelOutput(BaseOutput):
|
||||
density: torch.Tensor
|
||||
@@ -429,7 +659,7 @@ class MLPNeRSTFModel(ModelMixin, ConfigMixin):
|
||||
|
||||
return mapped_output
|
||||
|
||||
def forward(self, *, position, direction, ts, nerf_level="coarse"):
|
||||
def forward(self, *, position, direction, ts, nerf_level="coarse", rendering_mode="nerf"):
|
||||
h = encode_position(position)
|
||||
|
||||
h_preact = h
|
||||
@@ -455,10 +685,17 @@ class MLPNeRSTFModel(ModelMixin, ConfigMixin):
|
||||
|
||||
if nerf_level == "coarse":
|
||||
h_density = activation["density_coarse"]
|
||||
h_channels = activation["nerf_coarse"]
|
||||
else:
|
||||
h_density = activation["density_fine"]
|
||||
h_channels = activation["nerf_fine"]
|
||||
|
||||
if rendering_mode == "nerf":
|
||||
if nerf_level == "coarse":
|
||||
h_channels = activation["nerf_coarse"]
|
||||
else:
|
||||
h_channels = activation["nerf_fine"]
|
||||
|
||||
elif rendering_mode == "stf":
|
||||
h_channels = activation["stf"]
|
||||
|
||||
density = self.density_activation(h_density)
|
||||
signed_distance = self.sdf_activation(activation["sdf"])
|
||||
@@ -583,6 +820,7 @@ class ShapERenderer(ModelMixin, ConfigMixin):
|
||||
self.mlp = MLPNeRSTFModel(d_hidden, n_output, n_hidden_layers, act_fn, insert_direction_at)
|
||||
self.void = VoidNeRFModel(background=background, channel_scale=255.0)
|
||||
self.volume = BoundingBoxVolume(bbox_max=[1.0, 1.0, 1.0], bbox_min=[-1.0, -1.0, -1.0])
|
||||
self.mesh_decoder = MeshDecoder()
|
||||
|
||||
@torch.no_grad()
|
||||
def render_rays(self, rays, sampler, n_samples, prev_model_out=None, render_with_direction=False):
|
||||
@@ -664,7 +902,7 @@ class ShapERenderer(ModelMixin, ConfigMixin):
|
||||
return channels, weighted_sampler, model_out
|
||||
|
||||
@torch.no_grad()
|
||||
def decode(
|
||||
def decode_to_image(
|
||||
self,
|
||||
latents,
|
||||
device,
|
||||
@@ -707,3 +945,106 @@ class ShapERenderer(ModelMixin, ConfigMixin):
|
||||
images = images.view(*camera.shape, camera.height, camera.width, -1).squeeze(0)
|
||||
|
||||
return images
|
||||
|
||||
@torch.no_grad()
|
||||
def decode_to_mesh(
|
||||
self,
|
||||
latents,
|
||||
device,
|
||||
grid_size: int = 128,
|
||||
query_batch_size: int = 4096,
|
||||
texture_channels: Tuple = ("R", "G", "B"),
|
||||
):
|
||||
# 1. project the the paramters from the generated latents
|
||||
projected_params = self.params_proj(latents)
|
||||
|
||||
# 2. update the mlp layers of the renderer
|
||||
for name, param in self.mlp.state_dict().items():
|
||||
if f"nerstf.{name}" in projected_params.keys():
|
||||
param.copy_(projected_params[f"nerstf.{name}"].squeeze(0))
|
||||
|
||||
# 3. decoding with STF rendering
|
||||
# 3.1 query the SDF values at vertices along a regular 128**3 grid
|
||||
|
||||
query_points = volume_query_points(self.volume, grid_size)
|
||||
query_positions = query_points[None].repeat(1, 1, 1).to(device=device, dtype=self.mlp.dtype)
|
||||
|
||||
fields = []
|
||||
|
||||
for idx in range(0, query_positions.shape[1], query_batch_size):
|
||||
query_batch = query_positions[:, idx : idx + query_batch_size]
|
||||
|
||||
model_out = self.mlp(
|
||||
position=query_batch, direction=None, ts=None, nerf_level="fine", rendering_mode="stf"
|
||||
)
|
||||
fields.append(model_out.signed_distance)
|
||||
|
||||
# predicted SDF values
|
||||
fields = torch.cat(fields, dim=1)
|
||||
fields = fields.float()
|
||||
|
||||
assert (
|
||||
len(fields.shape) == 3 and fields.shape[-1] == 1
|
||||
), f"expected [meta_batch x inner_batch] SDF results, but got {fields.shape}"
|
||||
|
||||
fields = fields.reshape(1, *([grid_size] * 3))
|
||||
|
||||
# create grid 128 x 128 x 128
|
||||
# - force a negative border around the SDFs to close off all the models.
|
||||
full_grid = torch.zeros(
|
||||
1,
|
||||
grid_size + 2,
|
||||
grid_size + 2,
|
||||
grid_size + 2,
|
||||
device=fields.device,
|
||||
dtype=fields.dtype,
|
||||
)
|
||||
full_grid.fill_(-1.0)
|
||||
full_grid[:, 1:-1, 1:-1, 1:-1] = fields
|
||||
fields = full_grid
|
||||
|
||||
# apply a differentiable implementation of Marching Cubes to construct meshs
|
||||
raw_meshes = []
|
||||
mesh_mask = []
|
||||
|
||||
for field in fields:
|
||||
raw_mesh = self.mesh_decoder(field, self.volume.bbox_min, self.volume.bbox_max - self.volume.bbox_min)
|
||||
mesh_mask.append(True)
|
||||
raw_meshes.append(raw_mesh)
|
||||
|
||||
mesh_mask = torch.tensor(mesh_mask, device=fields.device)
|
||||
max_vertices = max(len(m.verts) for m in raw_meshes)
|
||||
|
||||
# 3.2. query the texture color head at each vertex of the resulting mesh.
|
||||
texture_query_positions = torch.stack(
|
||||
[m.verts[torch.arange(0, max_vertices) % len(m.verts)] for m in raw_meshes],
|
||||
dim=0,
|
||||
)
|
||||
texture_query_positions = texture_query_positions.to(device=device, dtype=self.mlp.dtype)
|
||||
|
||||
textures = []
|
||||
|
||||
for idx in range(0, texture_query_positions.shape[1], query_batch_size):
|
||||
query_batch = texture_query_positions[:, idx : idx + query_batch_size]
|
||||
|
||||
texture_model_out = self.mlp(
|
||||
position=query_batch, direction=None, ts=None, nerf_level="fine", rendering_mode="stf"
|
||||
)
|
||||
textures.append(texture_model_out.channels)
|
||||
|
||||
# predict texture color
|
||||
textures = torch.cat(textures, dim=1)
|
||||
|
||||
textures = _convert_srgb_to_linear(textures)
|
||||
textures = textures.float()
|
||||
|
||||
# 3.3 augument the mesh with texture data
|
||||
assert len(textures.shape) == 3 and textures.shape[-1] == len(
|
||||
texture_channels
|
||||
), f"expected [meta_batch x inner_batch x texture_channels] field results, but got {textures.shape}"
|
||||
|
||||
for m, texture in zip(raw_meshes, textures):
|
||||
texture = texture[: len(m.verts)]
|
||||
m.vertex_channels = dict(zip(texture_channels, texture.unbind(-1)))
|
||||
|
||||
return raw_meshes[0]
|
||||
|
||||
@@ -621,8 +621,8 @@ def convert_ldm_unet_checkpoint(
|
||||
def convert_ldm_vae_checkpoint(checkpoint, config):
|
||||
# extract state dict for VAE
|
||||
vae_state_dict = {}
|
||||
vae_key = "first_stage_model."
|
||||
keys = list(checkpoint.keys())
|
||||
vae_key = "first_stage_model." if any(k.startswith("first_stage_model.") for k in keys) else ""
|
||||
for key in keys:
|
||||
if key.startswith(vae_key):
|
||||
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
|
||||
@@ -1064,7 +1064,7 @@ def convert_controlnet_checkpoint(
|
||||
if cross_attention_dim is not None:
|
||||
ctrlnet_config["cross_attention_dim"] = cross_attention_dim
|
||||
|
||||
controlnet_model = ControlNetModel(**ctrlnet_config)
|
||||
controlnet = ControlNetModel(**ctrlnet_config)
|
||||
|
||||
# Some controlnet ckpt files are distributed independently from the rest of the
|
||||
# model components i.e. https://huggingface.co/thibaud/controlnet-sd21/
|
||||
@@ -1082,9 +1082,9 @@ def convert_controlnet_checkpoint(
|
||||
skip_extract_state_dict=skip_extract_state_dict,
|
||||
)
|
||||
|
||||
controlnet_model.load_state_dict(converted_ctrl_checkpoint)
|
||||
controlnet.load_state_dict(converted_ctrl_checkpoint)
|
||||
|
||||
return controlnet_model
|
||||
return controlnet
|
||||
|
||||
|
||||
def download_from_original_stable_diffusion_ckpt(
|
||||
@@ -1176,13 +1176,12 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
StableDiffusionInpaintPipeline,
|
||||
StableDiffusionPipeline,
|
||||
StableDiffusionXLImg2ImgPipeline,
|
||||
StableDiffusionXLPipeline,
|
||||
StableUnCLIPImg2ImgPipeline,
|
||||
StableUnCLIPPipeline,
|
||||
)
|
||||
|
||||
if pipeline_class is None:
|
||||
pipeline_class = StableDiffusionPipeline
|
||||
pipeline_class = StableDiffusionPipeline if not controlnet else StableDiffusionControlNetPipeline
|
||||
|
||||
if prediction_type == "v-prediction":
|
||||
prediction_type = "v_prediction"
|
||||
@@ -1289,8 +1288,7 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
if controlnet is None:
|
||||
controlnet = "control_stage_config" in original_config.model.params
|
||||
|
||||
if controlnet:
|
||||
controlnet_model = convert_controlnet_checkpoint(
|
||||
controlnet = convert_controlnet_checkpoint(
|
||||
checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema
|
||||
)
|
||||
|
||||
@@ -1401,13 +1399,13 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
|
||||
if stable_unclip is None:
|
||||
if controlnet:
|
||||
pipe = StableDiffusionControlNetPipeline(
|
||||
pipe = pipeline_class(
|
||||
vae=vae,
|
||||
text_encoder=text_model,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
controlnet=controlnet_model,
|
||||
controlnet=controlnet,
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
requires_safety_checker=False,
|
||||
@@ -1504,12 +1502,12 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
feature_extractor = None
|
||||
|
||||
if controlnet:
|
||||
pipe = StableDiffusionControlNetPipeline(
|
||||
pipe = pipeline_class(
|
||||
vae=vae,
|
||||
text_encoder=text_model,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
controlnet=controlnet_model,
|
||||
controlnet=controlnet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
@@ -1536,7 +1534,7 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
checkpoint, config_name, prefix="conditioner.embedders.1.model.", has_projection=True, **config_kwargs
|
||||
)
|
||||
|
||||
pipe = StableDiffusionXLPipeline(
|
||||
pipe = pipeline_class(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
@@ -1624,7 +1622,7 @@ def download_controlnet_from_original_ckpt(
|
||||
if "control_stage_config" not in original_config.model.params:
|
||||
raise ValueError("`control_stage_config` not present in original config")
|
||||
|
||||
controlnet_model = convert_controlnet_checkpoint(
|
||||
controlnet = convert_controlnet_checkpoint(
|
||||
checkpoint,
|
||||
original_config,
|
||||
checkpoint_path,
|
||||
@@ -1635,4 +1633,4 @@ def download_controlnet_from_original_ckpt(
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
)
|
||||
|
||||
return controlnet_model
|
||||
return controlnet
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import warnings
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -25,7 +24,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...models import AsymmetricAutoencoderKL, AutoencoderKL, UNet2DConditionModel
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import deprecate, is_accelerate_available, is_accelerate_version, logging, randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
@@ -180,7 +179,7 @@ class StableDiffusionInpaintPipeline(
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
vae ([`AutoencoderKL`, `AsymmetricAutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`CLIPTextModel`]):
|
||||
Frozen text-encoder. Stable Diffusion uses the text portion of
|
||||
@@ -203,7 +202,7 @@ class StableDiffusionInpaintPipeline(
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae: AutoencoderKL,
|
||||
vae: Union[AutoencoderKL, AsymmetricAutoencoderKL],
|
||||
text_encoder: CLIPTextModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
@@ -513,20 +512,6 @@ class StableDiffusionInpaintPipeline(
|
||||
extra_step_kwargs["generator"] = generator
|
||||
return extra_step_kwargs
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
|
||||
def decode_latents(self, latents):
|
||||
warnings.warn(
|
||||
"The decode_latents method is deprecated and will be removed in a future version. Please"
|
||||
" use VaeImageProcessor instead",
|
||||
FutureWarning,
|
||||
)
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
image = self.vae.decode(latents, return_dict=False)[0]
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
return image
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
@@ -897,6 +882,7 @@ class StableDiffusionInpaintPipeline(
|
||||
mask, masked_image, init_image = prepare_mask_and_masked_image(
|
||||
image, mask_image, height, width, return_image=True
|
||||
)
|
||||
mask_condition = mask.clone()
|
||||
|
||||
# 6. Prepare latent variables
|
||||
num_channels_latents = self.vae.config.latent_channels
|
||||
@@ -1007,7 +993,14 @@ class StableDiffusionInpaintPipeline(
|
||||
callback(i, t, latents)
|
||||
|
||||
if not output_type == "latent":
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
||||
condition_kwargs = {}
|
||||
if isinstance(self.vae, AsymmetricAutoencoderKL):
|
||||
init_image = init_image.to(device=device, dtype=masked_image_latents.dtype)
|
||||
init_image_condition = init_image.clone()
|
||||
init_image = self._encode_vae_image(init_image, generator=generator)
|
||||
mask_condition = mask_condition.to(device=device, dtype=masked_image_latents.dtype)
|
||||
condition_kwargs = {"image": init_image_condition, "mask": mask_condition}
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, **condition_kwargs)[0]
|
||||
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
||||
else:
|
||||
image = latents
|
||||
|
||||
@@ -79,28 +79,28 @@ class UnCLIPPipeline(DiffusionPipeline):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
# prior: PriorTransformer,
|
||||
prior: PriorTransformer,
|
||||
decoder: UNet2DConditionModel,
|
||||
text_encoder: CLIPTextModelWithProjection,
|
||||
tokenizer: CLIPTokenizer,
|
||||
text_proj: UnCLIPTextProjModel,
|
||||
super_res_first: UNet2DModel,
|
||||
super_res_last: UNet2DModel,
|
||||
# prior_scheduler: UnCLIPScheduler,
|
||||
prior_scheduler: UnCLIPScheduler,
|
||||
decoder_scheduler: UnCLIPScheduler,
|
||||
super_res_scheduler: UnCLIPScheduler,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
# prior=prior,
|
||||
prior=prior,
|
||||
decoder=decoder,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
text_proj=text_proj,
|
||||
super_res_first=super_res_first,
|
||||
super_res_last=super_res_last,
|
||||
# prior_scheduler=prior_scheduler,
|
||||
prior_scheduler=prior_scheduler,
|
||||
decoder_scheduler=decoder_scheduler,
|
||||
super_res_scheduler=super_res_scheduler,
|
||||
)
|
||||
|
||||
@@ -103,7 +103,7 @@ if is_torch_available():
|
||||
)
|
||||
from .torch_utils import maybe_allow_in_graph
|
||||
|
||||
from .testing_utils import export_to_gif, export_to_video
|
||||
from .testing_utils import export_to_gif, export_to_obj, export_to_ply, export_to_video
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -2,6 +2,21 @@
|
||||
from ..utils import DummyObject, requires_backends
|
||||
|
||||
|
||||
class AsymmetricAutoencoderKL(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class AutoencoderKL(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
@@ -2,6 +2,21 @@
|
||||
from ..utils import DummyObject, requires_backends
|
||||
|
||||
|
||||
class StableDiffusionXLControlNetPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers", "invisible_watermark"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers", "invisible_watermark"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers", "invisible_watermark"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers", "invisible_watermark"])
|
||||
|
||||
|
||||
class StableDiffusionXLImg2ImgPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers", "invisible_watermark"]
|
||||
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
import inspect
|
||||
import io
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import struct
|
||||
import tempfile
|
||||
import unittest
|
||||
import urllib.parse
|
||||
from contextlib import contextmanager
|
||||
from distutils.util import strtobool
|
||||
from io import BytesIO, StringIO
|
||||
from pathlib import Path
|
||||
@@ -315,6 +318,85 @@ def export_to_gif(image: List[PIL.Image.Image], output_gif_path: str = None) ->
|
||||
return output_gif_path
|
||||
|
||||
|
||||
@contextmanager
|
||||
def buffered_writer(raw_f):
|
||||
f = io.BufferedWriter(raw_f)
|
||||
yield f
|
||||
f.flush()
|
||||
|
||||
|
||||
def export_to_ply(mesh, output_ply_path: str = None):
|
||||
"""
|
||||
Write a PLY file for a mesh.
|
||||
"""
|
||||
if output_ply_path is None:
|
||||
output_ply_path = tempfile.NamedTemporaryFile(suffix=".ply").name
|
||||
|
||||
coords = mesh.verts.detach().cpu().numpy()
|
||||
faces = mesh.faces.cpu().numpy()
|
||||
rgb = np.stack([mesh.vertex_channels[x].detach().cpu().numpy() for x in "RGB"], axis=1)
|
||||
|
||||
with buffered_writer(open(output_ply_path, "wb")) as f:
|
||||
f.write(b"ply\n")
|
||||
f.write(b"format binary_little_endian 1.0\n")
|
||||
f.write(bytes(f"element vertex {len(coords)}\n", "ascii"))
|
||||
f.write(b"property float x\n")
|
||||
f.write(b"property float y\n")
|
||||
f.write(b"property float z\n")
|
||||
if rgb is not None:
|
||||
f.write(b"property uchar red\n")
|
||||
f.write(b"property uchar green\n")
|
||||
f.write(b"property uchar blue\n")
|
||||
if faces is not None:
|
||||
f.write(bytes(f"element face {len(faces)}\n", "ascii"))
|
||||
f.write(b"property list uchar int vertex_index\n")
|
||||
f.write(b"end_header\n")
|
||||
|
||||
if rgb is not None:
|
||||
rgb = (rgb * 255.499).round().astype(int)
|
||||
vertices = [
|
||||
(*coord, *rgb)
|
||||
for coord, rgb in zip(
|
||||
coords.tolist(),
|
||||
rgb.tolist(),
|
||||
)
|
||||
]
|
||||
format = struct.Struct("<3f3B")
|
||||
for item in vertices:
|
||||
f.write(format.pack(*item))
|
||||
else:
|
||||
format = struct.Struct("<3f")
|
||||
for vertex in coords.tolist():
|
||||
f.write(format.pack(*vertex))
|
||||
|
||||
if faces is not None:
|
||||
format = struct.Struct("<B3I")
|
||||
for tri in faces.tolist():
|
||||
f.write(format.pack(len(tri), *tri))
|
||||
|
||||
return output_ply_path
|
||||
|
||||
|
||||
def export_to_obj(mesh, output_obj_path: str = None):
|
||||
if output_obj_path is None:
|
||||
output_obj_path = tempfile.NamedTemporaryFile(suffix=".obj").name
|
||||
|
||||
verts = mesh.verts.detach().cpu().numpy()
|
||||
faces = mesh.faces.cpu().numpy()
|
||||
|
||||
vertex_colors = np.stack([mesh.vertex_channels[x].detach().cpu().numpy() for x in "RGB"], axis=1)
|
||||
vertices = [
|
||||
"{} {} {} {} {} {}".format(*coord, *color) for coord, color in zip(verts.tolist(), vertex_colors.tolist())
|
||||
]
|
||||
|
||||
faces = ["f {} {} {}".format(str(tri[0] + 1), str(tri[1] + 1), str(tri[2] + 1)) for tri in faces.tolist()]
|
||||
|
||||
combined_data = ["v " + vertex for vertex in vertices] + faces
|
||||
|
||||
with open(output_obj_path, "w") as f:
|
||||
f.writelines("\n".join(combined_data))
|
||||
|
||||
|
||||
def export_to_video(video_frames: List[np.ndarray], output_video_path: str = None) -> str:
|
||||
if is_opencv_available():
|
||||
import cv2
|
||||
|
||||
@@ -52,27 +52,21 @@ class UNet1DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
def test_training(self):
|
||||
pass
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
|
||||
def test_determinism(self):
|
||||
super().test_determinism()
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
|
||||
def test_outputs_equivalence(self):
|
||||
super().test_outputs_equivalence()
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
|
||||
def test_from_save_pretrained(self):
|
||||
super().test_from_save_pretrained()
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
|
||||
def test_from_save_pretrained_variant(self):
|
||||
super().test_from_save_pretrained_variant()
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
|
||||
def test_model_from_pretrained(self):
|
||||
super().test_model_from_pretrained()
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
|
||||
def test_output(self):
|
||||
super().test_output()
|
||||
|
||||
@@ -89,12 +83,11 @@ class UNet1DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
"mid_block_type": "MidResTemporalBlock1D",
|
||||
"down_block_types": ("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D"),
|
||||
"up_block_types": ("UpResnetBlock1D", "UpResnetBlock1D", "UpResnetBlock1D"),
|
||||
"act_fn": "mish",
|
||||
"act_fn": "swish",
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
|
||||
def test_from_pretrained_hub(self):
|
||||
model, loading_info = UNet1DModel.from_pretrained(
|
||||
"bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="unet"
|
||||
@@ -107,7 +100,6 @@ class UNet1DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
|
||||
assert image is not None, "Make sure output is not None"
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
|
||||
def test_output_pretrained(self):
|
||||
model = UNet1DModel.from_pretrained("bglick13/hopper-medium-v2-value-function-hor32", subfolder="unet")
|
||||
torch.manual_seed(0)
|
||||
@@ -177,27 +169,21 @@ class UNetRLModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
def output_shape(self):
|
||||
return (4, 14, 1)
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
|
||||
def test_determinism(self):
|
||||
super().test_determinism()
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
|
||||
def test_outputs_equivalence(self):
|
||||
super().test_outputs_equivalence()
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
|
||||
def test_from_save_pretrained(self):
|
||||
super().test_from_save_pretrained()
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
|
||||
def test_from_save_pretrained_variant(self):
|
||||
super().test_from_save_pretrained_variant()
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
|
||||
def test_model_from_pretrained(self):
|
||||
super().test_model_from_pretrained()
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
|
||||
def test_output(self):
|
||||
# UNetRL is a value-function is different output shape
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
@@ -241,7 +227,6 @@ class UNetRLModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
|
||||
def test_from_pretrained_hub(self):
|
||||
value_function, vf_loading_info = UNet1DModel.from_pretrained(
|
||||
"bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="value_function"
|
||||
@@ -254,7 +239,6 @@ class UNetRLModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
|
||||
assert image is not None, "Make sure output is not None"
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
|
||||
def test_output_pretrained(self):
|
||||
value_function, vf_loading_info = UNet1DModel.from_pretrained(
|
||||
"bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="value_function"
|
||||
|
||||
@@ -19,7 +19,7 @@ import unittest
|
||||
import torch
|
||||
from parameterized import parameterized
|
||||
|
||||
from diffusers import AutoencoderKL
|
||||
from diffusers import AsymmetricAutoencoderKL, AutoencoderKL
|
||||
from diffusers.utils import floats_tensor, load_hf_numpy, require_torch_gpu, slow, torch_all_close, torch_device
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.testing_utils import enable_full_determinism
|
||||
@@ -173,6 +173,56 @@ class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2))
|
||||
|
||||
|
||||
class AsymmetricAutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
model_class = AsymmetricAutoencoderKL
|
||||
main_input_name = "sample"
|
||||
base_precision = 1e-2
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 4
|
||||
num_channels = 3
|
||||
sizes = (32, 32)
|
||||
|
||||
image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
|
||||
mask = torch.ones((batch_size, 1) + sizes).to(torch_device)
|
||||
|
||||
return {"sample": image, "mask": mask}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (3, 32, 32)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (3, 32, 32)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
"down_block_out_channels": [32, 64],
|
||||
"layers_per_down_block": 1,
|
||||
"up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||
"up_block_out_channels": [32, 64],
|
||||
"layers_per_up_block": 1,
|
||||
"act_fn": "silu",
|
||||
"latent_channels": 4,
|
||||
"norm_num_groups": 32,
|
||||
"sample_size": 32,
|
||||
"scaling_factor": 0.18215,
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_forward_signature(self):
|
||||
pass
|
||||
|
||||
def test_forward_with_norm_groups(self):
|
||||
pass
|
||||
|
||||
|
||||
@slow
|
||||
class AutoencoderKLIntegrationTests(unittest.TestCase):
|
||||
def get_file_format(self, seed, shape):
|
||||
@@ -199,7 +249,7 @@ class AutoencoderKLIntegrationTests(unittest.TestCase):
|
||||
torch_dtype=torch_dtype,
|
||||
revision=revision,
|
||||
)
|
||||
model.to(torch_device).eval()
|
||||
model.to(torch_device)
|
||||
|
||||
return model
|
||||
|
||||
@@ -383,3 +433,168 @@ class AutoencoderKLIntegrationTests(unittest.TestCase):
|
||||
|
||||
tolerance = 3e-3 if torch_device != "mps" else 1e-2
|
||||
assert torch_all_close(output_slice, expected_output_slice, atol=tolerance)
|
||||
|
||||
def test_stable_diffusion_model_local(self):
|
||||
model_id = "stabilityai/sd-vae-ft-mse"
|
||||
model_1 = AutoencoderKL.from_pretrained(model_id).to(torch_device)
|
||||
|
||||
url = "https://huggingface.co/stabilityai/sd-vae-ft-mse-original/blob/main/vae-ft-mse-840000-ema-pruned.safetensors"
|
||||
model_2 = AutoencoderKL.from_single_file(url).to(torch_device)
|
||||
image = self.get_sd_image(33)
|
||||
|
||||
with torch.no_grad():
|
||||
sample_1 = model_1(image).sample
|
||||
sample_2 = model_2(image).sample
|
||||
|
||||
assert sample_1.shape == sample_2.shape
|
||||
|
||||
output_slice_1 = sample_1[-1, -2:, -2:, :2].flatten().float().cpu()
|
||||
output_slice_2 = sample_2[-1, -2:, -2:, :2].flatten().float().cpu()
|
||||
|
||||
assert torch_all_close(output_slice_1, output_slice_2, atol=3e-3)
|
||||
|
||||
|
||||
@slow
|
||||
class AsymmetricAutoencoderKLIntegrationTests(unittest.TestCase):
|
||||
def get_file_format(self, seed, shape):
|
||||
return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy"
|
||||
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_sd_image(self, seed=0, shape=(4, 3, 512, 512), fp16=False):
|
||||
dtype = torch.float16 if fp16 else torch.float32
|
||||
image = torch.from_numpy(load_hf_numpy(self.get_file_format(seed, shape))).to(torch_device).to(dtype)
|
||||
return image
|
||||
|
||||
def get_sd_vae_model(self, model_id="cross-attention/asymmetric-autoencoder-kl-x-1-5", fp16=False):
|
||||
revision = "main"
|
||||
torch_dtype = torch.float32
|
||||
|
||||
model = AsymmetricAutoencoderKL.from_pretrained(
|
||||
model_id,
|
||||
torch_dtype=torch_dtype,
|
||||
revision=revision,
|
||||
)
|
||||
model.to(torch_device).eval()
|
||||
|
||||
return model
|
||||
|
||||
def get_generator(self, seed=0):
|
||||
if torch_device == "mps":
|
||||
return torch.manual_seed(seed)
|
||||
return torch.Generator(device=torch_device).manual_seed(seed)
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
# fmt: off
|
||||
[33, [-0.0344, 0.2912, 0.1687, -0.0137, -0.3462, 0.3552, -0.1337, 0.1078], [-0.1603, 0.9878, -0.0495, -0.0790, -0.2709, 0.8375, -0.2060, -0.0824]],
|
||||
[47, [0.4400, 0.0543, 0.2873, 0.2946, 0.0553, 0.0839, -0.1585, 0.2529], [-0.2376, 0.1168, 0.1332, -0.4840, -0.2508, -0.0791, -0.0493, -0.4089]],
|
||||
# fmt: on
|
||||
]
|
||||
)
|
||||
def test_stable_diffusion(self, seed, expected_slice, expected_slice_mps):
|
||||
model = self.get_sd_vae_model()
|
||||
image = self.get_sd_image(seed)
|
||||
generator = self.get_generator(seed)
|
||||
|
||||
with torch.no_grad():
|
||||
sample = model(image, generator=generator, sample_posterior=True).sample
|
||||
|
||||
assert sample.shape == image.shape
|
||||
|
||||
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
|
||||
expected_output_slice = torch.tensor(expected_slice_mps if torch_device == "mps" else expected_slice)
|
||||
|
||||
assert torch_all_close(output_slice, expected_output_slice, atol=5e-3)
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
# fmt: off
|
||||
[33, [-0.0340, 0.2870, 0.1698, -0.0105, -0.3448, 0.3529, -0.1321, 0.1097], [-0.0344, 0.2912, 0.1687, -0.0137, -0.3462, 0.3552, -0.1337, 0.1078]],
|
||||
[47, [0.4397, 0.0550, 0.2873, 0.2946, 0.0567, 0.0855, -0.1580, 0.2531], [0.4397, 0.0550, 0.2873, 0.2946, 0.0567, 0.0855, -0.1580, 0.2531]],
|
||||
# fmt: on
|
||||
]
|
||||
)
|
||||
def test_stable_diffusion_mode(self, seed, expected_slice, expected_slice_mps):
|
||||
model = self.get_sd_vae_model()
|
||||
image = self.get_sd_image(seed)
|
||||
|
||||
with torch.no_grad():
|
||||
sample = model(image).sample
|
||||
|
||||
assert sample.shape == image.shape
|
||||
|
||||
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
|
||||
expected_output_slice = torch.tensor(expected_slice_mps if torch_device == "mps" else expected_slice)
|
||||
|
||||
assert torch_all_close(output_slice, expected_output_slice, atol=3e-3)
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
# fmt: off
|
||||
[13, [-0.0521, -0.2939, 0.1540, -0.1855, -0.5936, -0.3138, -0.4579, -0.2275]],
|
||||
[37, [-0.1820, -0.4345, -0.0455, -0.2923, -0.8035, -0.5089, -0.4795, -0.3106]],
|
||||
# fmt: on
|
||||
]
|
||||
)
|
||||
@require_torch_gpu
|
||||
def test_stable_diffusion_decode(self, seed, expected_slice):
|
||||
model = self.get_sd_vae_model()
|
||||
encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64))
|
||||
|
||||
with torch.no_grad():
|
||||
sample = model.decode(encoding).sample
|
||||
|
||||
assert list(sample.shape) == [3, 3, 512, 512]
|
||||
|
||||
output_slice = sample[-1, -2:, :2, -2:].flatten().cpu()
|
||||
expected_output_slice = torch.tensor(expected_slice)
|
||||
|
||||
assert torch_all_close(output_slice, expected_output_slice, atol=2e-3)
|
||||
|
||||
@parameterized.expand([(13,), (16,), (37,)])
|
||||
@require_torch_gpu
|
||||
@unittest.skipIf(not is_xformers_available(), reason="xformers is not required when using PyTorch 2.0.")
|
||||
def test_stable_diffusion_decode_xformers_vs_2_0(self, seed):
|
||||
model = self.get_sd_vae_model()
|
||||
encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64))
|
||||
|
||||
with torch.no_grad():
|
||||
sample = model.decode(encoding).sample
|
||||
|
||||
model.enable_xformers_memory_efficient_attention()
|
||||
with torch.no_grad():
|
||||
sample_2 = model.decode(encoding).sample
|
||||
|
||||
assert list(sample.shape) == [3, 3, 512, 512]
|
||||
|
||||
assert torch_all_close(sample, sample_2, atol=5e-2)
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
# fmt: off
|
||||
[33, [-0.3001, 0.0918, -2.6984, -3.9720, -3.2099, -5.0353, 1.7338, -0.2065, 3.4267]],
|
||||
[47, [-1.5030, -4.3871, -6.0355, -9.1157, -1.6661, -2.7853, 2.1607, -5.0823, 2.5633]],
|
||||
# fmt: on
|
||||
]
|
||||
)
|
||||
def test_stable_diffusion_encode_sample(self, seed, expected_slice):
|
||||
model = self.get_sd_vae_model()
|
||||
image = self.get_sd_image(seed)
|
||||
generator = self.get_generator(seed)
|
||||
|
||||
with torch.no_grad():
|
||||
dist = model.encode(image).latent_dist
|
||||
sample = dist.sample(generator=generator)
|
||||
|
||||
assert list(sample.shape) == [image.shape[0], 4] + [i // 8 for i in image.shape[2:]]
|
||||
|
||||
output_slice = sample[0, -1, -3:, -3:].flatten().cpu()
|
||||
expected_output_slice = torch.tensor(expected_slice)
|
||||
|
||||
tolerance = 3e-3 if torch_device != "mps" else 1e-2
|
||||
assert torch_all_close(output_slice, expected_output_slice, atol=tolerance)
|
||||
|
||||
@@ -398,6 +398,179 @@ class StableDiffusionMultiControlNetPipelineFastTests(
|
||||
pass
|
||||
|
||||
|
||||
class StableDiffusionMultiControlNetOneModelPipelineFastTests(
|
||||
PipelineTesterMixin, PipelineKarrasSchedulerTesterMixin, unittest.TestCase
|
||||
):
|
||||
pipeline_class = StableDiffusionControlNetPipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||
image_params = frozenset([]) # TO_DO: add image_params once refactored VaeImageProcessor.preprocess
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
unet = UNet2DConditionModel(
|
||||
block_out_channels=(32, 64),
|
||||
layers_per_block=2,
|
||||
sample_size=32,
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
|
||||
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
|
||||
cross_attention_dim=32,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
|
||||
def init_weights(m):
|
||||
if isinstance(m, torch.nn.Conv2d):
|
||||
torch.nn.init.normal(m.weight)
|
||||
m.bias.data.fill_(1.0)
|
||||
|
||||
controlnet = ControlNetModel(
|
||||
block_out_channels=(32, 64),
|
||||
layers_per_block=2,
|
||||
in_channels=4,
|
||||
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
|
||||
cross_attention_dim=32,
|
||||
conditioning_embedding_out_channels=(16, 32),
|
||||
)
|
||||
controlnet.controlnet_down_blocks.apply(init_weights)
|
||||
|
||||
torch.manual_seed(0)
|
||||
scheduler = DDIMScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
clip_sample=False,
|
||||
set_alpha_to_one=False,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKL(
|
||||
block_out_channels=[32, 64],
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||
latent_channels=4,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
text_encoder_config = CLIPTextConfig(
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
hidden_size=32,
|
||||
intermediate_size=37,
|
||||
layer_norm_eps=1e-05,
|
||||
num_attention_heads=4,
|
||||
num_hidden_layers=5,
|
||||
pad_token_id=1,
|
||||
vocab_size=1000,
|
||||
)
|
||||
text_encoder = CLIPTextModel(text_encoder_config)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
controlnet = MultiControlNetModel([controlnet])
|
||||
|
||||
components = {
|
||||
"unet": unet,
|
||||
"controlnet": controlnet,
|
||||
"scheduler": scheduler,
|
||||
"vae": vae,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"safety_checker": None,
|
||||
"feature_extractor": None,
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
|
||||
controlnet_embedder_scale_factor = 2
|
||||
|
||||
images = [
|
||||
randn_tensor(
|
||||
(1, 3, 32 * controlnet_embedder_scale_factor, 32 * controlnet_embedder_scale_factor),
|
||||
generator=generator,
|
||||
device=torch.device(device),
|
||||
),
|
||||
]
|
||||
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 6.0,
|
||||
"output_type": "numpy",
|
||||
"image": images,
|
||||
}
|
||||
|
||||
return inputs
|
||||
|
||||
def test_control_guidance_switch(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(torch_device)
|
||||
|
||||
scale = 10.0
|
||||
steps = 4
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
inputs["num_inference_steps"] = steps
|
||||
inputs["controlnet_conditioning_scale"] = scale
|
||||
output_1 = pipe(**inputs)[0]
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
inputs["num_inference_steps"] = steps
|
||||
inputs["controlnet_conditioning_scale"] = scale
|
||||
output_2 = pipe(**inputs, control_guidance_start=0.1, control_guidance_end=0.2)[0]
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
inputs["num_inference_steps"] = steps
|
||||
inputs["controlnet_conditioning_scale"] = scale
|
||||
output_3 = pipe(
|
||||
**inputs,
|
||||
control_guidance_start=[0.1],
|
||||
control_guidance_end=[0.2],
|
||||
)[0]
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
inputs["num_inference_steps"] = steps
|
||||
inputs["controlnet_conditioning_scale"] = scale
|
||||
output_4 = pipe(**inputs, control_guidance_start=0.4, control_guidance_end=[0.5])[0]
|
||||
|
||||
# make sure that all outputs are different
|
||||
assert np.sum(np.abs(output_1 - output_2)) > 1e-3
|
||||
assert np.sum(np.abs(output_1 - output_3)) > 1e-3
|
||||
assert np.sum(np.abs(output_1 - output_4)) > 1e-3
|
||||
|
||||
def test_attention_slicing_forward_pass(self):
|
||||
return self._test_attention_slicing_forward_pass(expected_max_diff=2e-3)
|
||||
|
||||
@unittest.skipIf(
|
||||
torch_device != "cuda" or not is_xformers_available(),
|
||||
reason="XFormers attention is only available with CUDA and `xformers` installed",
|
||||
)
|
||||
def test_xformers_attention_forwardGenerator_pass(self):
|
||||
self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=2e-3)
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
self._test_inference_batch_single_identical(expected_max_diff=2e-3)
|
||||
|
||||
def test_save_pretrained_raise_not_implemented_exception(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
try:
|
||||
# save_pretrained is not implemented for Multi-ControlNet
|
||||
pipe.save_pretrained(tmpdir)
|
||||
except NotImplementedError:
|
||||
pass
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class ControlNetPipelineSlowTests(unittest.TestCase):
|
||||
@@ -752,6 +925,42 @@ class ControlNetPipelineSlowTests(unittest.TestCase):
|
||||
expected_slice = np.array([0.1338, 0.1597, 0.1202, 0.1687, 0.1377, 0.1017, 0.2070, 0.1574, 0.1348])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_load_local(self):
|
||||
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny")
|
||||
pipe_1 = StableDiffusionControlNetPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet
|
||||
)
|
||||
|
||||
controlnet = ControlNetModel.from_single_file(
|
||||
"https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth"
|
||||
)
|
||||
pipe_2 = StableDiffusionControlNetPipeline.from_single_file(
|
||||
"https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors",
|
||||
safety_checker=None,
|
||||
controlnet=controlnet,
|
||||
)
|
||||
pipes = [pipe_1, pipe_2]
|
||||
images = []
|
||||
|
||||
for pipe in pipes:
|
||||
pipe.enable_model_cpu_offload()
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
prompt = "bird"
|
||||
image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
|
||||
)
|
||||
|
||||
output = pipe(prompt, image, generator=generator, output_type="np", num_inference_steps=3)
|
||||
images.append(output.images[0])
|
||||
|
||||
del pipe
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
assert np.abs(images[0] - images[1]).sum() < 1e-3
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
|
||||
@@ -401,3 +401,49 @@ class ControlNetImg2ImgPipelineSlowTests(unittest.TestCase):
|
||||
)
|
||||
|
||||
assert np.abs(expected_image - image).max() < 9e-2
|
||||
|
||||
def test_load_local(self):
|
||||
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny")
|
||||
pipe_1 = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet
|
||||
)
|
||||
|
||||
controlnet = ControlNetModel.from_single_file(
|
||||
"https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth"
|
||||
)
|
||||
pipe_2 = StableDiffusionControlNetImg2ImgPipeline.from_single_file(
|
||||
"https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors",
|
||||
safety_checker=None,
|
||||
controlnet=controlnet,
|
||||
)
|
||||
control_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
|
||||
).resize((512, 512))
|
||||
image = load_image(
|
||||
"https://huggingface.co/lllyasviel/sd-controlnet-canny/resolve/main/images/bird.png"
|
||||
).resize((512, 512))
|
||||
|
||||
pipes = [pipe_1, pipe_2]
|
||||
images = []
|
||||
for pipe in pipes:
|
||||
pipe.enable_model_cpu_offload()
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
prompt = "bird"
|
||||
output = pipe(
|
||||
prompt,
|
||||
image=image,
|
||||
control_image=control_image,
|
||||
strength=0.9,
|
||||
generator=generator,
|
||||
output_type="np",
|
||||
num_inference_steps=3,
|
||||
)
|
||||
images.append(output.images[0])
|
||||
|
||||
del pipe
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
assert np.abs(images[0] - images[1]).sum() < 1e-3
|
||||
|
||||
@@ -543,3 +543,54 @@ class ControlNetInpaintPipelineSlowTests(unittest.TestCase):
|
||||
)
|
||||
|
||||
assert np.abs(expected_image - image).max() < 9e-2
|
||||
|
||||
def test_load_local(self):
|
||||
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny")
|
||||
pipe_1 = StableDiffusionControlNetInpaintPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet
|
||||
)
|
||||
|
||||
controlnet = ControlNetModel.from_single_file(
|
||||
"https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth"
|
||||
)
|
||||
pipe_2 = StableDiffusionControlNetInpaintPipeline.from_single_file(
|
||||
"https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors",
|
||||
safety_checker=None,
|
||||
controlnet=controlnet,
|
||||
)
|
||||
control_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
|
||||
).resize((512, 512))
|
||||
image = load_image(
|
||||
"https://huggingface.co/lllyasviel/sd-controlnet-canny/resolve/main/images/bird.png"
|
||||
).resize((512, 512))
|
||||
mask_image = load_image(
|
||||
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
|
||||
"/stable_diffusion_inpaint/input_bench_mask.png"
|
||||
).resize((512, 512))
|
||||
|
||||
pipes = [pipe_1, pipe_2]
|
||||
images = []
|
||||
for pipe in pipes:
|
||||
pipe.enable_model_cpu_offload()
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
prompt = "bird"
|
||||
output = pipe(
|
||||
prompt,
|
||||
image=image,
|
||||
control_image=control_image,
|
||||
mask_image=mask_image,
|
||||
strength=0.9,
|
||||
generator=generator,
|
||||
output_type="np",
|
||||
num_inference_steps=3,
|
||||
)
|
||||
images.append(output.images[0])
|
||||
|
||||
del pipe
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
assert np.abs(images[0] - images[1]).sum() < 1e-3
|
||||
|
||||
@@ -0,0 +1,179 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
ControlNetModel,
|
||||
EulerDiscreteScheduler,
|
||||
StableDiffusionXLControlNetPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.utils import randn_tensor, torch_device
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
)
|
||||
|
||||
from ..pipeline_params import (
|
||||
IMAGE_TO_IMAGE_IMAGE_PARAMS,
|
||||
TEXT_TO_IMAGE_BATCH_PARAMS,
|
||||
TEXT_TO_IMAGE_IMAGE_PARAMS,
|
||||
TEXT_TO_IMAGE_PARAMS,
|
||||
)
|
||||
from ..test_pipelines_common import (
|
||||
PipelineKarrasSchedulerTesterMixin,
|
||||
PipelineLatentTesterMixin,
|
||||
PipelineTesterMixin,
|
||||
)
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class ControlNetPipelineSDXLFastTests(
|
||||
PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase
|
||||
):
|
||||
pipeline_class = StableDiffusionXLControlNetPipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||
image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
|
||||
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
unet = UNet2DConditionModel(
|
||||
block_out_channels=(32, 64),
|
||||
layers_per_block=2,
|
||||
sample_size=32,
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
|
||||
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
|
||||
# SD2-specific config below
|
||||
attention_head_dim=(2, 4),
|
||||
use_linear_projection=True,
|
||||
addition_embed_type="text_time",
|
||||
addition_time_embed_dim=8,
|
||||
transformer_layers_per_block=(1, 2),
|
||||
projection_class_embeddings_input_dim=80, # 6 * 8 + 32
|
||||
cross_attention_dim=64,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
controlnet = ControlNetModel(
|
||||
block_out_channels=(32, 64),
|
||||
layers_per_block=2,
|
||||
in_channels=4,
|
||||
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
|
||||
conditioning_embedding_out_channels=(16, 32),
|
||||
# SD2-specific config below
|
||||
attention_head_dim=(2, 4),
|
||||
use_linear_projection=True,
|
||||
addition_embed_type="text_time",
|
||||
addition_time_embed_dim=8,
|
||||
transformer_layers_per_block=(1, 2),
|
||||
projection_class_embeddings_input_dim=80, # 6 * 8 + 32
|
||||
cross_attention_dim=64,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
scheduler = EulerDiscreteScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
steps_offset=1,
|
||||
beta_schedule="scaled_linear",
|
||||
timestep_spacing="leading",
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKL(
|
||||
block_out_channels=[32, 64],
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||
latent_channels=4,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
text_encoder_config = CLIPTextConfig(
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
hidden_size=32,
|
||||
intermediate_size=37,
|
||||
layer_norm_eps=1e-05,
|
||||
num_attention_heads=4,
|
||||
num_hidden_layers=5,
|
||||
pad_token_id=1,
|
||||
vocab_size=1000,
|
||||
# SD2-specific config below
|
||||
hidden_act="gelu",
|
||||
projection_dim=32,
|
||||
)
|
||||
text_encoder = CLIPTextModel(text_encoder_config)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip", local_files_only=True)
|
||||
|
||||
text_encoder_2 = CLIPTextModelWithProjection(text_encoder_config)
|
||||
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip", local_files_only=True)
|
||||
|
||||
components = {
|
||||
"unet": unet,
|
||||
"controlnet": controlnet,
|
||||
"scheduler": scheduler,
|
||||
"vae": vae,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"text_encoder_2": text_encoder_2,
|
||||
"tokenizer_2": tokenizer_2,
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
|
||||
controlnet_embedder_scale_factor = 2
|
||||
image = randn_tensor(
|
||||
(1, 3, 32 * controlnet_embedder_scale_factor, 32 * controlnet_embedder_scale_factor),
|
||||
generator=generator,
|
||||
device=torch.device(device),
|
||||
)
|
||||
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 6.0,
|
||||
"output_type": "numpy",
|
||||
"image": image,
|
||||
}
|
||||
|
||||
return inputs
|
||||
|
||||
def test_attention_slicing_forward_pass(self):
|
||||
return self._test_attention_slicing_forward_pass(expected_max_diff=2e-3)
|
||||
|
||||
@unittest.skipIf(
|
||||
torch_device != "cuda" or not is_xformers_available(),
|
||||
reason="XFormers attention is only available with CUDA and `xformers` installed",
|
||||
)
|
||||
def test_xformers_attention_forwardGenerator_pass(self):
|
||||
self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=2e-3)
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
self._test_inference_batch_single_identical(expected_max_diff=2e-3)
|
||||
@@ -131,7 +131,7 @@ class ShapEPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
prior = self.dummy_prior
|
||||
text_encoder = self.dummy_text_encoder
|
||||
tokenizer = self.dummy_tokenizer
|
||||
renderer = self.dummy_renderer
|
||||
shap_e_renderer = self.dummy_renderer
|
||||
|
||||
scheduler = HeunDiscreteScheduler(
|
||||
beta_schedule="exp",
|
||||
@@ -145,7 +145,7 @@ class ShapEPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
"prior": prior,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"renderer": renderer,
|
||||
"shap_e_renderer": shap_e_renderer,
|
||||
"scheduler": scheduler,
|
||||
}
|
||||
|
||||
|
||||
@@ -143,7 +143,7 @@ class ShapEImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
prior = self.dummy_prior
|
||||
image_encoder = self.dummy_image_encoder
|
||||
image_processor = self.dummy_image_processor
|
||||
renderer = self.dummy_renderer
|
||||
shap_e_renderer = self.dummy_renderer
|
||||
|
||||
scheduler = HeunDiscreteScheduler(
|
||||
beta_schedule="exp",
|
||||
@@ -157,7 +157,7 @@ class ShapEImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
"prior": prior,
|
||||
"image_encoder": image_encoder,
|
||||
"image_processor": image_processor,
|
||||
"renderer": renderer,
|
||||
"shap_e_renderer": shap_e_renderer,
|
||||
"scheduler": scheduler,
|
||||
}
|
||||
|
||||
|
||||
@@ -25,6 +25,7 @@ from PIL import Image
|
||||
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from diffusers import (
|
||||
AsymmetricAutoencoderKL,
|
||||
AutoencoderKL,
|
||||
DDIMScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
@@ -552,6 +553,230 @@ class StableDiffusionInpaintPipelineSlowTests(unittest.TestCase):
|
||||
assert np.max(np.abs(image - image_ckpt)) < 1e-4
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class StableDiffusionInpaintPipelineAsymmetricAutoencoderKLSlowTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
|
||||
generator = torch.Generator(device=generator_device).manual_seed(seed)
|
||||
init_image = load_image(
|
||||
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
|
||||
"/stable_diffusion_inpaint/input_bench_image.png"
|
||||
)
|
||||
mask_image = load_image(
|
||||
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
|
||||
"/stable_diffusion_inpaint/input_bench_mask.png"
|
||||
)
|
||||
inputs = {
|
||||
"prompt": "Face of a yellow cat, high resolution, sitting on a park bench",
|
||||
"image": init_image,
|
||||
"mask_image": mask_image,
|
||||
"generator": generator,
|
||||
"num_inference_steps": 3,
|
||||
"guidance_scale": 7.5,
|
||||
"output_type": "numpy",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_stable_diffusion_inpaint_ddim(self):
|
||||
vae = AsymmetricAutoencoderKL.from_pretrained("cross-attention/asymmetric-autoencoder-kl-x-1-5")
|
||||
pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-inpainting", safety_checker=None
|
||||
)
|
||||
pipe.vae = vae
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.enable_attention_slicing()
|
||||
|
||||
inputs = self.get_inputs(torch_device)
|
||||
image = pipe(**inputs).images
|
||||
image_slice = image[0, 253:256, 253:256, -1].flatten()
|
||||
|
||||
assert image.shape == (1, 512, 512, 3)
|
||||
expected_slice = np.array([0.0521, 0.0606, 0.0602, 0.0446, 0.0495, 0.0434, 0.1175, 0.1290, 0.1431])
|
||||
|
||||
assert np.abs(expected_slice - image_slice).max() < 6e-4
|
||||
|
||||
def test_stable_diffusion_inpaint_fp16(self):
|
||||
vae = AsymmetricAutoencoderKL.from_pretrained(
|
||||
"cross-attention/asymmetric-autoencoder-kl-x-1-5", torch_dtype=torch.float16
|
||||
)
|
||||
pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16, safety_checker=None
|
||||
)
|
||||
pipe.vae = vae
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.enable_attention_slicing()
|
||||
|
||||
inputs = self.get_inputs(torch_device, dtype=torch.float16)
|
||||
image = pipe(**inputs).images
|
||||
image_slice = image[0, 253:256, 253:256, -1].flatten()
|
||||
|
||||
assert image.shape == (1, 512, 512, 3)
|
||||
expected_slice = np.array([0.1343, 0.1406, 0.1440, 0.1504, 0.1729, 0.0989, 0.1807, 0.2822, 0.1179])
|
||||
|
||||
assert np.abs(expected_slice - image_slice).max() < 5e-2
|
||||
|
||||
def test_stable_diffusion_inpaint_pndm(self):
|
||||
vae = AsymmetricAutoencoderKL.from_pretrained("cross-attention/asymmetric-autoencoder-kl-x-1-5")
|
||||
pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-inpainting", safety_checker=None
|
||||
)
|
||||
pipe.vae = vae
|
||||
pipe.scheduler = PNDMScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.enable_attention_slicing()
|
||||
|
||||
inputs = self.get_inputs(torch_device)
|
||||
image = pipe(**inputs).images
|
||||
image_slice = image[0, 253:256, 253:256, -1].flatten()
|
||||
|
||||
assert image.shape == (1, 512, 512, 3)
|
||||
expected_slice = np.array([0.0976, 0.1071, 0.1119, 0.1363, 0.1260, 0.1150, 0.3745, 0.3586, 0.3340])
|
||||
|
||||
assert np.abs(expected_slice - image_slice).max() < 5e-3
|
||||
|
||||
def test_stable_diffusion_inpaint_k_lms(self):
|
||||
vae = AsymmetricAutoencoderKL.from_pretrained("cross-attention/asymmetric-autoencoder-kl-x-1-5")
|
||||
pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-inpainting", safety_checker=None
|
||||
)
|
||||
pipe.vae = vae
|
||||
pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.enable_attention_slicing()
|
||||
|
||||
inputs = self.get_inputs(torch_device)
|
||||
image = pipe(**inputs).images
|
||||
image_slice = image[0, 253:256, 253:256, -1].flatten()
|
||||
|
||||
assert image.shape == (1, 512, 512, 3)
|
||||
expected_slice = np.array([0.8909, 0.8620, 0.9024, 0.8501, 0.8558, 0.9074, 0.8790, 0.7540, 0.9003])
|
||||
|
||||
assert np.abs(expected_slice - image_slice).max() < 6e-3
|
||||
|
||||
def test_stable_diffusion_inpaint_with_sequential_cpu_offloading(self):
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_max_memory_allocated()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
vae = AsymmetricAutoencoderKL.from_pretrained(
|
||||
"cross-attention/asymmetric-autoencoder-kl-x-1-5", torch_dtype=torch.float16
|
||||
)
|
||||
pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-inpainting", safety_checker=None, torch_dtype=torch.float16
|
||||
)
|
||||
pipe.vae = vae
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.enable_attention_slicing(1)
|
||||
pipe.enable_sequential_cpu_offload()
|
||||
|
||||
inputs = self.get_inputs(torch_device, dtype=torch.float16)
|
||||
_ = pipe(**inputs)
|
||||
|
||||
mem_bytes = torch.cuda.max_memory_allocated()
|
||||
# make sure that less than 2.45 GB is allocated
|
||||
assert mem_bytes < 2.45 * 10**9
|
||||
|
||||
@require_torch_2
|
||||
def test_inpaint_compile(self):
|
||||
pass
|
||||
|
||||
def test_stable_diffusion_inpaint_pil_input_resolution_test(self):
|
||||
vae = AsymmetricAutoencoderKL.from_pretrained(
|
||||
"cross-attention/asymmetric-autoencoder-kl-x-1-5",
|
||||
)
|
||||
pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-inpainting", safety_checker=None
|
||||
)
|
||||
pipe.vae = vae
|
||||
pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.enable_attention_slicing()
|
||||
|
||||
inputs = self.get_inputs(torch_device)
|
||||
# change input image to a random size (one that would cause a tensor mismatch error)
|
||||
inputs["image"] = inputs["image"].resize((127, 127))
|
||||
inputs["mask_image"] = inputs["mask_image"].resize((127, 127))
|
||||
inputs["height"] = 128
|
||||
inputs["width"] = 128
|
||||
image = pipe(**inputs).images
|
||||
# verify that the returned image has the same height and width as the input height and width
|
||||
assert image.shape == (1, inputs["height"], inputs["width"], 3)
|
||||
|
||||
def test_stable_diffusion_inpaint_strength_test(self):
|
||||
vae = AsymmetricAutoencoderKL.from_pretrained("cross-attention/asymmetric-autoencoder-kl-x-1-5")
|
||||
pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-inpainting", safety_checker=None
|
||||
)
|
||||
pipe.vae = vae
|
||||
pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.enable_attention_slicing()
|
||||
|
||||
inputs = self.get_inputs(torch_device)
|
||||
# change input strength
|
||||
inputs["strength"] = 0.75
|
||||
image = pipe(**inputs).images
|
||||
# verify that the returned image has the same height and width as the input height and width
|
||||
assert image.shape == (1, 512, 512, 3)
|
||||
|
||||
image_slice = image[0, 253:256, 253:256, -1].flatten()
|
||||
expected_slice = np.array([0.2458, 0.2576, 0.3124, 0.2679, 0.2669, 0.2796, 0.2872, 0.2975, 0.2661])
|
||||
assert np.abs(expected_slice - image_slice).max() < 3e-3
|
||||
|
||||
def test_stable_diffusion_simple_inpaint_ddim(self):
|
||||
vae = AsymmetricAutoencoderKL.from_pretrained("cross-attention/asymmetric-autoencoder-kl-x-1-5")
|
||||
pipe = StableDiffusionInpaintPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", safety_checker=None)
|
||||
pipe.vae = vae
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.enable_attention_slicing()
|
||||
|
||||
inputs = self.get_inputs(torch_device)
|
||||
image = pipe(**inputs).images
|
||||
|
||||
image_slice = image[0, 253:256, 253:256, -1].flatten()
|
||||
|
||||
assert image.shape == (1, 512, 512, 3)
|
||||
expected_slice = np.array([0.3312, 0.4052, 0.4103, 0.4153, 0.4347, 0.4154, 0.4932, 0.4920, 0.4431])
|
||||
|
||||
assert np.abs(expected_slice - image_slice).max() < 6e-4
|
||||
|
||||
def test_download_local(self):
|
||||
vae = AsymmetricAutoencoderKL.from_pretrained(
|
||||
"cross-attention/asymmetric-autoencoder-kl-x-1-5", torch_dtype=torch.float16
|
||||
)
|
||||
filename = hf_hub_download("runwayml/stable-diffusion-inpainting", filename="sd-v1-5-inpainting.ckpt")
|
||||
|
||||
pipe = StableDiffusionInpaintPipeline.from_single_file(filename, torch_dtype=torch.float16)
|
||||
pipe.vae = vae
|
||||
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.to("cuda")
|
||||
|
||||
inputs = self.get_inputs(torch_device)
|
||||
inputs["num_inference_steps"] = 1
|
||||
image_out = pipe(**inputs).images[0]
|
||||
|
||||
assert image_out.shape == (512, 512, 3)
|
||||
|
||||
def test_download_ckpt_diff_format_is_same(self):
|
||||
pass
|
||||
|
||||
|
||||
@nightly
|
||||
@require_torch_gpu
|
||||
class StableDiffusionInpaintPipelineNightlyTests(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user