Compare commits

..

8 Commits

Author SHA1 Message Date
Dhruv Nair 16fe810b3e test 2023-11-07 05:35:25 +00:00
Dhruv Nair 79e7fb3a16 test 2023-11-07 05:33:25 +00:00
Dhruv Nair c15d9936c1 test 2023-11-07 05:27:34 +00:00
Beinsezii 9bafef34bd Add Pixart to AUTO_TEXT2IMAGE_PIPELINES_MAPPING (#5664) 2023-11-07 07:45:56 +05:30
Sayak Paul 64603389da post release (v0.22.0) (#5658)
post release
2023-11-06 16:23:38 +01:00
Patrick von Platen f05d75c076 [Custom Pipelines] Make sure that community pipelines can use repo revision (#5659)
fix custom pipelines
2023-11-06 15:11:48 +01:00
Sayak Paul aec3de8bdb correct pipeline class name (#5652) 2023-11-06 14:08:27 +05:30
Sayak Paul d61889fc17 [Feat] PixArt-Alpha (#5642)
* init pixart alpha pipeline

* fix: import

* script

* script

* script

* add: vae to the pipeline

* add: vae_scale_factor

* add: checkpoint_path

* clean conversion script a bit.

* size embeddings.

* fix: size embedding

* update scrip

* support for interpolation of position embedding.

* support for conditioning.

* ..

* ..

* ..

* final layer

* final layer

* align if encode_prompt

* support for caption embedding

* refactor

* refactor

* refactor

* start cross attention

* start cross attention

* cross_attention_dim

* cross

* cross

* support for resolution and aspect_ratio

* support for caption projection

* refactor patch embeddings

* batch_size

* up

* commit

* commit

* commit.

* squeeze

* squeeze

* squeeze

* squeeze

* squeeze

* squeeze

* squeeze

* squeeze

* squeeze

* squeeze

* squeeze

* squeeze.

* squeeze.

* fix final block./

* fix final block./

* fix final block./

* clean

* fix: interpolation scale.

* debugging'

* debugging'

* debugging'

* debugging'

* debugging'

* debugging'

* debugging'

* debugging'

* debugging'

* debugging'

* debugging'

* debugging'

* debugging'

* debugging'

* debugging'

* debugging'

* debugging'

* debugging'

* debugging'

* debugging'

* debugging'

* debugging'

* debugging'

* debugging'

* debugging'

* debugging'

* debugging'

* debugging'

* debugging'

* debugging'

* debugging'

* debugging'

* debugging'

* debugging'

* debugging'

* debugging'

* debugging'

* debugging'

* debugging'

* debugging'

* debugging'

* debugging'

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* make --checkpoint_path non-required.

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* remove num_tokens

* timesteps -> timestep

* timesteps -> timestep

* timesteps -> timestep

* timesteps -> timestep

* timesteps -> timestep

* timesteps -> timestep

* debug

* debug

* update conversion script.

* update conversion script.

* update conversion script.

* debug

* debug

* debug

* clean

* debug

* debug

* debug

* debug

* debug

* debug

* debug

* debug

* deug

* debug

* debug

* debug

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* clean

* fix

* fix

* boom

* boom

* some changes

* boom

* save

* up

* remove i

* fix more tests

* DPMSolverMultistepScheduler

* fix

* offloading

* fix conversion script

* fix conversion script

* remove print

* remove support for negative prompt embeds.

* typo.

* remove extra kwargs

* bring conversion script to where it was

* fix

* trying mu luck

* trying my luck again

* again

* again

* again

* clean up

* up

* up

* update example

* support for 512

* remove spacing

* finalize docs.

* test debug

* fix: assertion values.

* debug

* debug

* debug

* fix: repeat

* remove prints.

* Apply suggestions from code review

* Apply suggestions from code review

* Correct more

* Apply suggestions from code review

* Change all

* Clean more

* fix more

* Fix more

* Fix more

* Correct more

* address patrick's comments.

* remove unneeded args

* clean up pipeline.

* sty;e

* make the use of additional conditions better conditioned.

* None better

* dtype

* height and width validation

* add a note about size brackets.

* fix

* spit out slow test outputs.

* fix?

* fix optional test

* fix more

* remove unneeded comment

* debug

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
2023-11-06 08:40:04 +01:00
52 changed files with 1540 additions and 894 deletions
+2 -1
View File
@@ -28,5 +28,6 @@ jobs:
pip install pytest
- name: Check for soft dependencies
run: |
python utils/print_env.py
python -c "from diffusers import UNet2DConditionModel; print(type(UNet2DConditionModel))"
pytest tests/others/test_dependencies.py
+2 -4
View File
@@ -19,8 +19,6 @@
title: Train a diffusion model
- local: tutorials/using_peft_for_inference
title: Inference with PEFT
- local: tutorials/workflows
title: Working with workflows
title: Tutorials
- sections:
- sections:
@@ -180,8 +178,6 @@
title: Logging
- local: api/outputs
title: Outputs
- local: api/workflows
title: Shareable workflows
title: Main Classes
- sections:
- local: api/models/overview
@@ -272,6 +268,8 @@
title: Parallel Sampling of Diffusion Models
- local: api/pipelines/pix2pix_zero
title: Pix2Pix Zero
- local: api/pipelines/pixart
title: PixArt
- local: api/pipelines/pndm
title: PNDM
- local: api/pipelines/repaint
+36
View File
@@ -0,0 +1,36 @@
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# PixArt
![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/pixart/header_collage.png)
[PixArt-α: Fast Training of Diffusion Transformer for Photorealistic Text-to-Image Synthesis](https://huggingface.co/papers/2310.00426) is Junsong Chen, Jincheng Yu, Chongjian Ge, Lewei Yao, Enze Xie, Yue Wu, Zhongdao Wang, James Kwok, Ping Luo, Huchuan Lu, and Zhenguo Li.
The abstract from the paper is:
*The most advanced text-to-image (T2I) models require significant training costs (e.g., millions of GPU hours), seriously hindering the fundamental innovation for the AIGC community while increasing CO2 emissions. This paper introduces PIXART-α, a Transformer-based T2I diffusion model whose image generation quality is competitive with state-of-the-art image generators (e.g., Imagen, SDXL, and even Midjourney), reaching near-commercial application standards. Additionally, it supports high-resolution image synthesis up to 1024px resolution with low training cost, as shown in Figure 1 and 2. To achieve this goal, three core designs are proposed: (1) Training strategy decomposition: We devise three distinct training steps that separately optimize pixel dependency, text-image alignment, and image aesthetic quality; (2) Efficient T2I Transformer: We incorporate cross-attention modules into Diffusion Transformer (DiT) to inject text conditions and streamline the computation-intensive class-condition branch; (3) High-informative data: We emphasize the significance of concept density in text-image pairs and leverage a large Vision-Language model to auto-label dense pseudo-captions to assist text-image alignment learning. As a result, PIXART-α's training speed markedly surpasses existing large-scale T2I models, e.g., PIXART-α only takes 10.8% of Stable Diffusion v1.5's training time (675 vs. 6,250 A100 GPU days), saving nearly $300,000 ($26,000 vs. $320,000) and reducing 90% CO2 emissions. Moreover, compared with a larger SOTA model, RAPHAEL, our training cost is merely 1%. Extensive experiments demonstrate that PIXART-α excels in image quality, artistry, and semantic control. We hope PIXART-α will provide new insights to the AIGC community and startups to accelerate building their own high-quality yet low-cost generative models from scratch.*
You can find the original codebase at [PixArt-alpha/PixArt-alpha](https://github.com/PixArt-alpha/PixArt-alpha) and all the available checkpoints at [PixArt-alpha](https://huggingface.co/PixArt-alpha).
Some notes about this pipeline:
* It uses a Transformer backbone (instead of a UNet) for denoising. As such it has a similar architecture as [DiT](./dit.md).
* It was trained using text conditions computed from T5. This aspect makes the pipeline better at following complex text prompts with intricate details.
* It is good at producing high-resolution images at different aspect ratios. To get the best results, the authors recommend some size brackets which can be found [here](https://github.com/PixArt-alpha/PixArt-alpha/blob/08fbbd281ec96866109bdd2cdb75f2f58fb17610/diffusion/data/datasets/utils.py).
* It rivals the quality of state-of-the-art text-to-image generation systems (as of this writing) such as Stable Diffusion XL, Imagen, and DALL-E 2, while being more efficient than them.
## PixArtAlphaPipeline
[[autodoc]] PixArtAlphaPipeline
- all
- __call__
-7
View File
@@ -1,7 +0,0 @@
# Shareable workflows
Workflows provide a simple mechanism to share your 🤗 Diffusers pipeline call arguments and scheduler configuration, making it easier to reproduce results.
## Workflow
[[autodoc]] workflow_utils.Workflow
-333
View File
@@ -1,333 +0,0 @@
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Working with workflows
<Tip warning={true}>
🧪 Workflow is experimental and its APIs can change in the future.
</Tip>
Workflows provide a simple mechanism to share your pipeline call arguments and scheduler configuration, making it easier to reproduce results.
## Serializing a workflow
A [`Workflow`] object provides all the argument values in the `__call__()` of a pipeline. Add `return_workflow=True` to return a `Workflow` object.
```python
from diffusers import DiffusionPipeline
import torch
pipeline = DiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, safety_checker=None
).to("cuda")
outputs = pipeline(
"A painting of a horse",
num_inference_steps=15,
generator=torch.manual_seed(0),
return_workflow=True
)
workflow = outputs.workflow
```
<Tip warning={true}>
It's mandatory to specify the `generator` when `return_workflow` is set to True.
</Tip>
If you look at this specific workflow, you'll see values like the number of inference steps, guidance scale, and height and width as well as the scheduler details:
```bash
{'prompt': 'A painting of a horse',
'height': None,
'width': None,
'num_inference_steps': 15,
'guidance_scale': 7.5,
'negative_prompt': None,
'eta': 0.0,
'latents': None,
'prompt_embeds': None,
'negative_prompt_embeds': None,
'output_type': 'pil',
'return_dict': True,
'callback': None,
'callback_steps': 1,
'cross_attention_kwargs': None,
'guidance_rescale': 0.0,
'clip_skip': None,
'generator_seed': 0,
'generator_device': device(type='cpu'),
'_name_or_path': 'runwayml/stable-diffusion-v1-5',
'scheduler_config': FrozenDict([('num_train_timesteps', 1000),
('beta_start', 0.00085),
('beta_end', 0.012),
('beta_schedule', 'scaled_linear'),
('trained_betas', None),
('skip_prk_steps', True),
('set_alpha_to_one', False),
('prediction_type', 'epsilon'),
('timestep_spacing', 'leading'),
('steps_offset', 1),
('_use_default_values', ['prediction_type', 'timestep_spacing']),
('_class_name', 'PNDMScheduler'),
('_diffusers_version', '0.6.0'),
('clip_sample', False)])}
```
Once you have generated a workflow object, you can serialize it with [`~Workflow.save_workflow`]:
```python
outputs.workflow.save_workflow("my-simple-workflow-sd")
```
By default, your workflows are saved as `diffusion_workflow.json`, but you can give them a specific name with the `filename` argument:
```python
outputs.workflow.save_workflow("my-simple-workflow-sd", filename="my_workflow.json")
```
You can also set `push_to_hub=True` in [`~Workflow.save_workflow`] to directly push the workflow object to the Hub.
## Loading a workflow
You can load a workflow in a pipeline with [`~DiffusionPipeline.load_workflow`]:
```python
from diffusers import DiffusionPipeline
import torch
pipeline = DiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16
).to("cuda")
pipeline.load_workflow("sayakpaul/my-simple-workflow-sd")
```
Once the pipeline is loaded with the desired workflow, it's ready to be called:
```python
image = pipeline().images[0]
```
By default, while loading a workflow, the scheduler of the underlying pipeline from the workflow isn't modified but you can change it by adding `load_scheduler=True`:
```
pipeline.load_workflow("sayakpaul/my-simple-workflow-sd", load_scheduler=True)
```
This is particularly useful if you have changed the scheduler after loading a pipeline.
You can also override the pipeline call arguments. For example, to add a `negative_prompt`:
```python
image = pipeline(negative_prompt="bad quality").images[0]
```
Loading from a workflow is possible by specifying the `filename` argument inside the [`DiffusionPipeline.load_workflow`] method.
A workflow doesn't necessarily have to be used with the same pipeline that generated it. You can use it with a different pipeline too:
```python
from diffusers import DiffusionPipeline
import torch
pipeline = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
).to("cuda")
pipeline.load_workflow("sayakpaul/my-simple-workflow-sd")
image = pipeline().images[0]
```
However, make sure to thoroughly inspect the values you are calling the pipeline with, in this case.
Loading from a local workflow is also possible:
```python
from diffusers import DiffusionPipeline
import torch
pipeline = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
).to("cuda")
pipeline.load_workflow("path_to_local_dir")
image = pipeline().images[0]
```
Alternatively, if you want to load a workflow file and populate the pipeline arguments manually:
```python
from diffusers import DiffusionPipeline
import json
import torch
pipeline = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
).to("cuda")
with open("path_to_workflow_file.json") as f:
workflow = json.load(f)
pipeline.load_workflow(workflow)
images = pipeline().images[0]
```
## Unsupported serialization types
Image-to-image pipelines like [`StableDiffusionControlNetPipeline`] accept one or more images in their `call` method. Currently, workflows don't support serializing `call` arguments that are of type `PIL.Image.Image` or `List[PIL.Image.Image]`. To make those pipelines work with workflows, you need to pass the images manually.
Let's say you generated the workflow below:
```python
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
from diffusers.utils import load_image
import numpy as np
import torch
import cv2
from PIL import Image
# download an image
image = load_image(
"https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png"
)
image = np.array(image)
# get canny image
image = cv2.Canny(image, 100, 200)
image = image[:, :, None]
image = np.concatenate([image, image, image], axis=2)
canny_image = Image.fromarray(image)
# load control net and stable diffusion v1-5
controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
pipe = StableDiffusionControlNetPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
)
pipe.enable_model_cpu_offload()
# generate image
generator = torch.manual_seed(0)
outputs = pipe(
prompt="futuristic-looking office",
image=canny_image,
num_inference_steps=20,
generator=generator,
return_workflow=True
)
workflow = outputs.workflow
```
If you look at the workflow, you'll see the image that was passed to the pipeline isn't included:
```bash
{'prompt': 'futuristic-looking office',
'height': None,
'width': None,
'num_inference_steps': 20,
'guidance_scale': 7.5,
'negative_prompt': None,
'eta': 0.0,
'latents': None,
'prompt_embeds': None,
'negative_prompt_embeds': None,
'output_type': 'pil',
'return_dict': True,
'callback': None,
'callback_steps': 1,
'cross_attention_kwargs': None,
'controlnet_conditioning_scale': 1.0,
'guess_mode': False,
'control_guidance_start': 0.0,
'control_guidance_end': 1.0,
'clip_skip': None,
'generator_seed': 0,
'generator_device': 'cpu',
'_name_or_path': 'runwayml/stable-diffusion-v1-5',
'scheduler_config': FrozenDict([('num_train_timesteps', 1000),
('beta_start', 0.00085),
('beta_end', 0.012),
('beta_schedule', 'scaled_linear'),
('trained_betas', None),
('solver_order', 2),
('prediction_type', 'epsilon'),
('thresholding', False),
('dynamic_thresholding_ratio', 0.995),
('sample_max_value', 1.0),
('predict_x0', True),
('solver_type', 'bh2'),
('lower_order_final', True),
('disable_corrector', []),
('solver_p', None),
('use_karras_sigmas', False),
('timestep_spacing', 'linspace'),
('steps_offset', 1),
('_use_default_values',
['lower_order_final',
'sample_max_value',
'solver_p',
'dynamic_thresholding_ratio',
'thresholding',
'solver_type',
'prediction_type',
'predict_x0',
'use_karras_sigmas',
'disable_corrector',
'timestep_spacing',
'solver_order']),
('skip_prk_steps', True),
('set_alpha_to_one', False),
('_class_name', 'PNDMScheduler'),
('_diffusers_version', '0.6.0'),
('clip_sample', False)])}
```
Let's serialize the workflow and reload the pipeline to see what happens when you try to use it.
```python
workflow.save_workflow("my-simple-workflow-sd", filename="controlnet_simple.json", push_to_hub=True)
```
Then load the workflow into [`StableDiffusionControlNetPipeline`]:
```python
# load control net and stable diffusion v1-5
controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
pipe = StableDiffusionControlNetPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
)
pipe.enable_model_cpu_offload()
pipe.load_workflow("sayakpaul/my-simple-workflow-sd", filename="controlnet_simple.json")
```
If you try to generate an image now, it'll return the following error:
```bash
TypeError: image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is <class 'NoneType'>
```
To resolve the error, manually pass the conditioning image `canny_image`:
```python
image = pipe(image=canny_image).images[0]
```
Other unsupported serialization types include:
* LoRA checkpoints: any information from LoRA checkpoints that might be loaded into a pipeline isn't serialized. Workflows generated from pipelines loaded with a LoRA checkpoint should be handled cautiously! You should ensure the LoRA checkpoint is loaded into the pipeline first before loading the corresponding workflow.
* Call arguments including the following types: `torch.Tensor`, `np.ndarray`, `Callable`, `PIL.Image.Image`, and `List[PIL.Image.Image]`.
+1 -1
View File
@@ -56,7 +56,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.22.0.dev0")
check_min_version("0.23.0.dev0")
logger = get_logger(__name__)
+1 -1
View File
@@ -59,7 +59,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.22.0.dev0")
check_min_version("0.23.0.dev0")
logger = logging.getLogger(__name__)
+1 -1
View File
@@ -58,7 +58,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.22.0.dev0")
check_min_version("0.23.0.dev0")
logger = get_logger(__name__)
@@ -62,7 +62,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.22.0.dev0")
check_min_version("0.23.0.dev0")
logger = get_logger(__name__)
+1 -1
View File
@@ -61,7 +61,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.22.0.dev0")
check_min_version("0.23.0.dev0")
logger = get_logger(__name__)
+1 -1
View File
@@ -35,7 +35,7 @@ from diffusers.utils import check_min_version
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.22.0.dev0")
check_min_version("0.23.0.dev0")
# Cache compiled models across invocations of this script.
cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache"))
+1 -1
View File
@@ -68,7 +68,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.22.0.dev0")
check_min_version("0.23.0.dev0")
logger = get_logger(__name__)
@@ -58,7 +58,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.22.0.dev0")
check_min_version("0.23.0.dev0")
logger = get_logger(__name__)
@@ -52,7 +52,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.22.0.dev0")
check_min_version("0.23.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -55,7 +55,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.22.0.dev0")
check_min_version("0.23.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -52,7 +52,7 @@ if is_wandb_available():
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.21.0.dev0")
check_min_version("0.23.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -46,7 +46,7 @@ from diffusers.utils import check_min_version, is_wandb_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.21.0.dev0")
check_min_version("0.23.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -46,7 +46,7 @@ from diffusers.utils import check_min_version, is_wandb_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.21.0.dev0")
check_min_version("0.23.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -51,7 +51,7 @@ if is_wandb_available():
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.21.0.dev0")
check_min_version("0.23.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -58,7 +58,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.22.0.dev0")
check_min_version("0.23.0.dev0")
logger = get_logger(__name__)
@@ -53,7 +53,7 @@ if is_wandb_available():
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.22.0.dev0")
check_min_version("0.23.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -33,7 +33,7 @@ from diffusers.utils import check_min_version
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.22.0.dev0")
check_min_version("0.23.0.dev0")
logger = logging.getLogger(__name__)
@@ -49,7 +49,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.22.0.dev0")
check_min_version("0.23.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -58,7 +58,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.22.0.dev0")
check_min_version("0.23.0.dev0")
logger = get_logger(__name__)
@@ -57,7 +57,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.22.0.dev0")
check_min_version("0.23.0.dev0")
logger = get_logger(__name__)
@@ -79,7 +79,7 @@ else:
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.22.0.dev0")
check_min_version("0.23.0.dev0")
logger = get_logger(__name__)
@@ -56,7 +56,7 @@ else:
# ------------------------------------------------------------------------------
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.22.0.dev0")
check_min_version("0.23.0.dev0")
logger = logging.getLogger(__name__)
@@ -29,7 +29,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.22.0.dev0")
check_min_version("0.23.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -50,7 +50,7 @@ if is_wandb_available():
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.22.0")
check_min_version("0.23.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -51,7 +51,7 @@ if is_wandb_available():
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.22.0")
check_min_version("0.23.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -0,0 +1,198 @@
import argparse
import os
import torch
from transformers import T5EncoderModel, T5Tokenizer
from diffusers import AutoencoderKL, DPMSolverMultistepScheduler, PixArtAlphaPipeline, Transformer2DModel
ckpt_id = "PixArt-alpha/PixArt-alpha"
# https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/scripts/inference.py#L125
interpolation_scale = {512: 1, 1024: 2}
def main(args):
all_state_dict = torch.load(args.orig_ckpt_path)
state_dict = all_state_dict.pop("state_dict")
converted_state_dict = {}
# Patch embeddings.
converted_state_dict["pos_embed.proj.weight"] = state_dict.pop("x_embedder.proj.weight")
converted_state_dict["pos_embed.proj.bias"] = state_dict.pop("x_embedder.proj.bias")
# Caption projection.
converted_state_dict["caption_projection.y_embedding"] = state_dict.pop("y_embedder.y_embedding")
converted_state_dict["caption_projection.linear_1.weight"] = state_dict.pop("y_embedder.y_proj.fc1.weight")
converted_state_dict["caption_projection.linear_1.bias"] = state_dict.pop("y_embedder.y_proj.fc1.bias")
converted_state_dict["caption_projection.linear_2.weight"] = state_dict.pop("y_embedder.y_proj.fc2.weight")
converted_state_dict["caption_projection.linear_2.bias"] = state_dict.pop("y_embedder.y_proj.fc2.bias")
# AdaLN-single LN
converted_state_dict["adaln_single.emb.timestep_embedder.linear_1.weight"] = state_dict.pop(
"t_embedder.mlp.0.weight"
)
converted_state_dict["adaln_single.emb.timestep_embedder.linear_1.bias"] = state_dict.pop("t_embedder.mlp.0.bias")
converted_state_dict["adaln_single.emb.timestep_embedder.linear_2.weight"] = state_dict.pop(
"t_embedder.mlp.2.weight"
)
converted_state_dict["adaln_single.emb.timestep_embedder.linear_2.bias"] = state_dict.pop("t_embedder.mlp.2.bias")
if args.image_size == 1024:
# Resolution.
converted_state_dict["adaln_single.emb.resolution_embedder.linear_1.weight"] = state_dict.pop(
"csize_embedder.mlp.0.weight"
)
converted_state_dict["adaln_single.emb.resolution_embedder.linear_1.bias"] = state_dict.pop(
"csize_embedder.mlp.0.bias"
)
converted_state_dict["adaln_single.emb.resolution_embedder.linear_2.weight"] = state_dict.pop(
"csize_embedder.mlp.2.weight"
)
converted_state_dict["adaln_single.emb.resolution_embedder.linear_2.bias"] = state_dict.pop(
"csize_embedder.mlp.2.bias"
)
# Aspect ratio.
converted_state_dict["adaln_single.emb.aspect_ratio_embedder.linear_1.weight"] = state_dict.pop(
"ar_embedder.mlp.0.weight"
)
converted_state_dict["adaln_single.emb.aspect_ratio_embedder.linear_1.bias"] = state_dict.pop(
"ar_embedder.mlp.0.bias"
)
converted_state_dict["adaln_single.emb.aspect_ratio_embedder.linear_2.weight"] = state_dict.pop(
"ar_embedder.mlp.2.weight"
)
converted_state_dict["adaln_single.emb.aspect_ratio_embedder.linear_2.bias"] = state_dict.pop(
"ar_embedder.mlp.2.bias"
)
# Shared norm.
converted_state_dict["adaln_single.linear.weight"] = state_dict.pop("t_block.1.weight")
converted_state_dict["adaln_single.linear.bias"] = state_dict.pop("t_block.1.bias")
for depth in range(28):
# Transformer blocks.
converted_state_dict[f"transformer_blocks.{depth}.scale_shift_table"] = state_dict.pop(
f"blocks.{depth}.scale_shift_table"
)
# Attention is all you need 🤘
# Self attention.
q, k, v = torch.chunk(state_dict.pop(f"blocks.{depth}.attn.qkv.weight"), 3, dim=0)
q_bias, k_bias, v_bias = torch.chunk(state_dict.pop(f"blocks.{depth}.attn.qkv.bias"), 3, dim=0)
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_q.weight"] = q
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_q.bias"] = q_bias
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_k.weight"] = k
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_k.bias"] = k_bias
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_v.weight"] = v
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_v.bias"] = v_bias
# Projection.
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.weight"] = state_dict.pop(
f"blocks.{depth}.attn.proj.weight"
)
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.bias"] = state_dict.pop(
f"blocks.{depth}.attn.proj.bias"
)
# Feed-forward.
converted_state_dict[f"transformer_blocks.{depth}.ff.net.0.proj.weight"] = state_dict.pop(
f"blocks.{depth}.mlp.fc1.weight"
)
converted_state_dict[f"transformer_blocks.{depth}.ff.net.0.proj.bias"] = state_dict.pop(
f"blocks.{depth}.mlp.fc1.bias"
)
converted_state_dict[f"transformer_blocks.{depth}.ff.net.2.weight"] = state_dict.pop(
f"blocks.{depth}.mlp.fc2.weight"
)
converted_state_dict[f"transformer_blocks.{depth}.ff.net.2.bias"] = state_dict.pop(
f"blocks.{depth}.mlp.fc2.bias"
)
# Cross-attention.
q = state_dict.pop(f"blocks.{depth}.cross_attn.q_linear.weight")
q_bias = state_dict.pop(f"blocks.{depth}.cross_attn.q_linear.bias")
k, v = torch.chunk(state_dict.pop(f"blocks.{depth}.cross_attn.kv_linear.weight"), 2, dim=0)
k_bias, v_bias = torch.chunk(state_dict.pop(f"blocks.{depth}.cross_attn.kv_linear.bias"), 2, dim=0)
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_q.weight"] = q
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_q.bias"] = q_bias
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_k.weight"] = k
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_k.bias"] = k_bias
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.weight"] = v
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.bias"] = v_bias
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_out.0.weight"] = state_dict.pop(
f"blocks.{depth}.cross_attn.proj.weight"
)
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_out.0.bias"] = state_dict.pop(
f"blocks.{depth}.cross_attn.proj.bias"
)
# Final block.
converted_state_dict["proj_out.weight"] = state_dict.pop("final_layer.linear.weight")
converted_state_dict["proj_out.bias"] = state_dict.pop("final_layer.linear.bias")
converted_state_dict["scale_shift_table"] = state_dict.pop("final_layer.scale_shift_table")
# DiT XL/2
transformer = Transformer2DModel(
sample_size=args.image_size // 8,
num_layers=28,
attention_head_dim=72,
in_channels=4,
out_channels=8,
patch_size=2,
attention_bias=True,
num_attention_heads=16,
cross_attention_dim=1152,
activation_fn="gelu-approximate",
num_embeds_ada_norm=1000,
norm_type="ada_norm_single",
norm_elementwise_affine=False,
norm_eps=1e-6,
caption_channels=4096,
)
transformer.load_state_dict(converted_state_dict, strict=True)
assert transformer.pos_embed.pos_embed is not None
state_dict.pop("pos_embed")
assert len(state_dict) == 0, f"State dict is not empty, {state_dict.keys()}"
num_model_params = sum(p.numel() for p in transformer.parameters())
print(f"Total number of transformer parameters: {num_model_params}")
if args.only_transformer:
transformer.save_pretrained(os.path.join(args.dump_path, "transformer"))
else:
scheduler = DPMSolverMultistepScheduler()
vae = AutoencoderKL.from_pretrained(ckpt_id, subfolder="sd-vae-ft-ema")
tokenizer = T5Tokenizer.from_pretrained(ckpt_id, subfolder="t5-v1_1-xxl")
text_encoder = T5EncoderModel.from_pretrained(ckpt_id, subfolder="t5-v1_1-xxl")
pipeline = PixArtAlphaPipeline(
tokenizer=tokenizer, text_encoder=text_encoder, transformer=transformer, vae=vae, scheduler=scheduler
)
pipeline.save_pretrained(args.dump_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--orig_ckpt_path", default=None, type=str, required=False, help="Path to the checkpoint to convert."
)
parser.add_argument(
"--image_size",
default=1024,
type=int,
choices=[512, 1024],
required=False,
help="Image size of pretrained model, either 512 or 1024.",
)
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.")
parser.add_argument("--only_transformer", default=True, type=bool, required=True)
args = parser.parse_args()
main(args)
+1 -1
View File
@@ -244,7 +244,7 @@ install_requires = [
setup(
name="diffusers",
version="0.22.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
version="0.23.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
description="State-of-the-art diffusion in PyTorch and JAX.",
long_description=open("README.md", "r", encoding="utf-8").read(),
long_description_content_type="text/markdown",
+3 -1
View File
@@ -1,4 +1,4 @@
__version__ = "0.22.0.dev0"
__version__ = "0.23.0.dev0"
from typing import TYPE_CHECKING
@@ -235,6 +235,7 @@ else:
"LDMTextToImagePipeline",
"MusicLDMPipeline",
"PaintByExamplePipeline",
"PixArtAlphaPipeline",
"SemanticStableDiffusionPipeline",
"ShapEImg2ImgPipeline",
"ShapEPipeline",
@@ -579,6 +580,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
LDMTextToImagePipeline,
MusicLDMPipeline,
PaintByExamplePipeline,
PixArtAlphaPipeline,
SemanticStableDiffusionPipeline,
ShapEImg2ImgPipeline,
ShapEPipeline,
+53 -11
View File
@@ -117,7 +117,8 @@ class BasicTransformerBlock(nn.Module):
double_self_attention: bool = False,
upcast_attention: bool = False,
norm_elementwise_affine: bool = True,
norm_type: str = "layer_norm",
norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single'
norm_eps: float = 1e-5,
final_dropout: bool = False,
attention_type: str = "default",
positional_embeddings: Optional[str] = None,
@@ -128,6 +129,8 @@ class BasicTransformerBlock(nn.Module):
self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
self.use_layer_norm = norm_type == "layer_norm"
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
raise ValueError(
@@ -152,7 +155,8 @@ class BasicTransformerBlock(nn.Module):
elif self.use_ada_layer_norm_zero:
self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
else:
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
@@ -171,7 +175,7 @@ class BasicTransformerBlock(nn.Module):
self.norm2 = (
AdaLayerNorm(dim, num_embeds_ada_norm)
if self.use_ada_layer_norm
else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
)
self.attn2 = Attention(
query_dim=dim,
@@ -187,13 +191,19 @@ class BasicTransformerBlock(nn.Module):
self.attn2 = None
# 3. Feed-forward
self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
if not self.use_ada_layer_norm_single:
self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
# 4. Fuser
if attention_type == "gated" or attention_type == "gated-text-image":
self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
# 5. Scale-shift for PixArt-Alpha.
if self.use_ada_layer_norm_single:
self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
# let chunk size default to None
self._chunk_size = None
self._chunk_dim = 0
@@ -215,14 +225,25 @@ class BasicTransformerBlock(nn.Module):
) -> torch.FloatTensor:
# Notice that normalization is always applied before the real computation in the following blocks.
# 0. Self-Attention
batch_size = hidden_states.shape[0]
if self.use_ada_layer_norm:
norm_hidden_states = self.norm1(hidden_states, timestep)
elif self.use_ada_layer_norm_zero:
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
)
else:
elif self.use_layer_norm:
norm_hidden_states = self.norm1(hidden_states)
elif self.use_ada_layer_norm_single:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
).chunk(6, dim=1)
norm_hidden_states = self.norm1(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
norm_hidden_states = norm_hidden_states.squeeze(1)
else:
raise ValueError("Incorrect norm used")
if self.pos_embed is not None:
norm_hidden_states = self.pos_embed(norm_hidden_states)
@@ -242,19 +263,31 @@ class BasicTransformerBlock(nn.Module):
)
if self.use_ada_layer_norm_zero:
attn_output = gate_msa.unsqueeze(1) * attn_output
elif self.use_ada_layer_norm_single:
attn_output = gate_msa * attn_output
hidden_states = attn_output + hidden_states
if hidden_states.ndim == 4:
hidden_states = hidden_states.squeeze(1)
# 2.5 GLIGEN Control
if gligen_kwargs is not None:
hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
# 2.5 ends
# 3. Cross-Attention
if self.attn2 is not None:
norm_hidden_states = (
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
)
if self.pos_embed is not None:
if self.use_ada_layer_norm:
norm_hidden_states = self.norm2(hidden_states, timestep)
elif self.use_ada_layer_norm_zero or self.use_layer_norm:
norm_hidden_states = self.norm2(hidden_states)
elif self.use_ada_layer_norm_single:
# For PixArt norm2 isn't applied here:
# https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
norm_hidden_states = hidden_states
else:
raise ValueError("Incorrect norm")
if self.pos_embed is not None and self.use_ada_layer_norm_single is None:
norm_hidden_states = self.pos_embed(norm_hidden_states)
attn_output = self.attn2(
@@ -266,11 +299,16 @@ class BasicTransformerBlock(nn.Module):
hidden_states = attn_output + hidden_states
# 4. Feed-forward
norm_hidden_states = self.norm3(hidden_states)
if not self.use_ada_layer_norm_single:
norm_hidden_states = self.norm3(hidden_states)
if self.use_ada_layer_norm_zero:
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
if self.use_ada_layer_norm_single:
norm_hidden_states = self.norm2(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
if self._chunk_size is not None:
# "feed_forward_chunk_size" can be used to save memory
if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
@@ -291,8 +329,12 @@ class BasicTransformerBlock(nn.Module):
if self.use_ada_layer_norm_zero:
ff_output = gate_mlp.unsqueeze(1) * ff_output
elif self.use_ada_layer_norm_single:
ff_output = gate_mlp * ff_output
hidden_states = ff_output + hidden_states
if hidden_states.ndim == 4:
hidden_states = hidden_states.squeeze(1)
return hidden_states
+113 -6
View File
@@ -66,17 +66,22 @@ def get_timestep_embedding(
return emb
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
def get_2d_sincos_pos_embed(
embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16
):
"""
grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
[1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
grid_h = np.arange(grid_size, dtype=np.float32)
grid_w = np.arange(grid_size, dtype=np.float32)
if isinstance(grid_size, int):
grid_size = (grid_size, grid_size)
grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size) / interpolation_scale
grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size) / interpolation_scale
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_size, grid_size])
grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token and extra_tokens > 0:
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
@@ -129,6 +134,7 @@ class PatchEmbed(nn.Module):
layer_norm=False,
flatten=True,
bias=True,
interpolation_scale=1,
):
super().__init__()
@@ -144,16 +150,41 @@ class PatchEmbed(nn.Module):
else:
self.norm = None
pos_embed = get_2d_sincos_pos_embed(embed_dim, int(num_patches**0.5))
self.patch_size = patch_size
# See:
# https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L161
self.height, self.width = height // patch_size, width // patch_size
self.base_size = height // patch_size
self.interpolation_scale = interpolation_scale
pos_embed = get_2d_sincos_pos_embed(
embed_dim, int(num_patches**0.5), base_size=self.base_size, interpolation_scale=self.interpolation_scale
)
self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False)
def forward(self, latent):
height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size
latent = self.proj(latent)
if self.flatten:
latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
if self.layer_norm:
latent = self.norm(latent)
return latent + self.pos_embed
# Interpolate positional embeddings if needed.
# (For PixArt-Alpha: https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L162C151-L162C160)
if self.height != height or self.width != width:
pos_embed = get_2d_sincos_pos_embed(
embed_dim=self.pos_embed.shape[-1],
grid_size=(height, width),
base_size=self.base_size,
interpolation_scale=self.interpolation_scale,
)
pos_embed = torch.from_numpy(pos_embed)
pos_embed = pos_embed.float().unsqueeze(0).to(latent.device)
else:
pos_embed = self.pos_embed
return (latent + pos_embed).to(latent.dtype)
class TimestepEmbedding(nn.Module):
@@ -683,3 +714,79 @@ class PositionNet(nn.Module):
objs = torch.cat([objs_text, objs_image], dim=1)
return objs
class CombinedTimestepSizeEmbeddings(nn.Module):
"""
For PixArt-Alpha.
Reference:
https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29
"""
def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False):
super().__init__()
self.outdim = size_emb_dim
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
self.use_additional_conditions = use_additional_conditions
if use_additional_conditions:
self.use_additional_conditions = True
self.additional_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.resolution_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)
self.aspect_ratio_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)
def apply_condition(self, size: torch.Tensor, batch_size: int, embedder: nn.Module):
if size.ndim == 1:
size = size[:, None]
if size.shape[0] != batch_size:
size = size.repeat(batch_size // size.shape[0], 1)
if size.shape[0] != batch_size:
raise ValueError(f"`batch_size` should be {size.shape[0]} but found {batch_size}.")
current_batch_size, dims = size.shape[0], size.shape[1]
size = size.reshape(-1)
size_freq = self.additional_condition_proj(size).to(size.dtype)
size_emb = embedder(size_freq)
size_emb = size_emb.reshape(current_batch_size, dims * self.outdim)
return size_emb
def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype):
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
if self.use_additional_conditions:
resolution = self.apply_condition(resolution, batch_size=batch_size, embedder=self.resolution_embedder)
aspect_ratio = self.apply_condition(
aspect_ratio, batch_size=batch_size, embedder=self.aspect_ratio_embedder
)
conditioning = timesteps_emb + torch.cat([resolution, aspect_ratio], dim=1)
else:
conditioning = timesteps_emb
return conditioning
class CaptionProjection(nn.Module):
"""
Projects caption embeddings. Also handles dropout for classifier-free guidance.
Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
"""
def __init__(self, in_features, hidden_size, num_tokens=120):
super().__init__()
self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True)
self.act_1 = nn.GELU(approximate="tanh")
self.linear_2 = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True)
self.register_buffer("y_embedding", nn.Parameter(torch.randn(num_tokens, in_features) / in_features**0.5))
def forward(self, caption, force_drop_ids=None):
hidden_states = self.linear_1(caption)
hidden_states = self.act_1(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
+35 -2
View File
@@ -13,14 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple
from typing import Dict, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from .activations import get_activation
from .embeddings import CombinedTimestepLabelEmbeddings
from .embeddings import CombinedTimestepLabelEmbeddings, CombinedTimestepSizeEmbeddings
class AdaLayerNorm(nn.Module):
@@ -77,6 +77,39 @@ class AdaLayerNormZero(nn.Module):
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
class AdaLayerNormSingle(nn.Module):
r"""
Norm layer adaptive layer norm single (adaLN-single).
As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3).
Parameters:
embedding_dim (`int`): The size of each embedding vector.
use_additional_conditions (`bool`): To use additional conditions for normalization or not.
"""
def __init__(self, embedding_dim: int, use_additional_conditions: bool = False):
super().__init__()
self.emb = CombinedTimestepSizeEmbeddings(
embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions
)
self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
def forward(
self,
timestep: torch.Tensor,
added_cond_kwargs: Dict[str, torch.Tensor] = None,
batch_size: int = None,
hidden_dtype: Optional[torch.dtype] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
# No modulation happening here.
embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype)
return self.linear(self.silu(embedded_timestep)), embedded_timestep
class AdaGroupNorm(nn.Module):
r"""
GroupNorm layer modified to incorporate timestep embeddings.
+58 -10
View File
@@ -22,9 +22,10 @@ from ..configuration_utils import ConfigMixin, register_to_config
from ..models.embeddings import ImagePositionalEmbeddings
from ..utils import USE_PEFT_BACKEND, BaseOutput, deprecate
from .attention import BasicTransformerBlock
from .embeddings import PatchEmbed
from .embeddings import CaptionProjection, PatchEmbed
from .lora import LoRACompatibleConv, LoRACompatibleLinear
from .modeling_utils import ModelMixin
from .normalization import AdaLayerNormSingle
@dataclass
@@ -92,7 +93,9 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
upcast_attention: bool = False,
norm_type: str = "layer_norm",
norm_elementwise_affine: bool = True,
norm_eps: float = 1e-5,
attention_type: str = "default",
caption_channels: int = None,
):
super().__init__()
self.use_linear_projection = use_linear_projection
@@ -164,12 +167,15 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
self.width = sample_size
self.patch_size = patch_size
interpolation_scale = self.config.sample_size // 64 # => 64 (= 512 pixart) has interpolation scale 1
interpolation_scale = max(interpolation_scale, 1)
self.pos_embed = PatchEmbed(
height=sample_size,
width=sample_size,
patch_size=patch_size,
in_channels=in_channels,
embed_dim=inner_dim,
interpolation_scale=interpolation_scale,
)
# 3. Define transformers blocks
@@ -189,6 +195,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
upcast_attention=upcast_attention,
norm_type=norm_type,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
attention_type=attention_type,
)
for d in range(num_layers)
@@ -206,10 +213,27 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
elif self.is_input_vectorized:
self.norm_out = nn.LayerNorm(inner_dim)
self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
elif self.is_input_patches:
elif self.is_input_patches and norm_type != "ada_norm_single":
self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
elif self.is_input_patches and norm_type == "ada_norm_single":
self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
# 5. PixArt-Alpha blocks.
self.adaln_single = None
self.use_additional_conditions = False
if norm_type == "ada_norm_single":
self.use_additional_conditions = self.config.sample_size == 128
# TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use
# additional conditions until we find better name
self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=self.use_additional_conditions)
self.caption_projection = None
if caption_channels is not None:
self.caption_projection = CaptionProjection(in_features=caption_channels, hidden_size=inner_dim)
self.gradient_checkpointing = False
@@ -218,6 +242,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
timestep: Optional[torch.LongTensor] = None,
added_cond_kwargs: Dict[str, torch.Tensor] = None,
class_labels: Optional[torch.LongTensor] = None,
cross_attention_kwargs: Dict[str, Any] = None,
attention_mask: Optional[torch.Tensor] = None,
@@ -316,7 +341,22 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
elif self.is_input_patches:
hidden_states = self.pos_embed(hidden_states)
if self.adaln_single is not None:
if self.use_additional_conditions and added_cond_kwargs is None:
raise ValueError(
"`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`."
)
batch_size = hidden_states.shape[0]
timestep, embedded_timestep = self.adaln_single(
timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
)
# 2. Blocks
if self.caption_projection is not None:
batch_size = hidden_states.shape[0]
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
for block in self.transformer_blocks:
if self.training and self.gradient_checkpointing:
hidden_states = torch.utils.checkpoint.checkpoint(
@@ -367,14 +407,22 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
# log(p(x_0))
output = F.log_softmax(logits.double(), dim=1).float()
elif self.is_input_patches:
# TODO: cleanup!
conditioning = self.transformer_blocks[0].norm1.emb(
timestep, class_labels, hidden_dtype=hidden_states.dtype
)
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
hidden_states = self.proj_out_2(hidden_states)
if self.is_input_patches:
if self.config.norm_type != "ada_norm_single":
conditioning = self.transformer_blocks[0].norm1.emb(
timestep, class_labels, hidden_dtype=hidden_states.dtype
)
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
hidden_states = self.proj_out_2(hidden_states)
elif self.config.norm_type == "ada_norm_single":
shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
hidden_states = self.norm_out(hidden_states)
# Modulation
hidden_states = hidden_states * (1 + scale) + shift
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.squeeze(1)
# unpatchify
height = width = int(hidden_states.shape[1] ** 0.5)
+2
View File
@@ -117,6 +117,7 @@ else:
_import_structure["latent_diffusion"].extend(["LDMTextToImagePipeline"])
_import_structure["musicldm"] = ["MusicLDMPipeline"]
_import_structure["paint_by_example"] = ["PaintByExamplePipeline"]
_import_structure["pixart_alpha"] = ["PixArtAlphaPipeline"]
_import_structure["semantic_stable_diffusion"] = ["SemanticStableDiffusionPipeline"]
_import_structure["shap_e"] = ["ShapEImg2ImgPipeline", "ShapEPipeline"]
_import_structure["stable_diffusion"].extend(
@@ -341,6 +342,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .latent_diffusion import LDMTextToImagePipeline
from .musicldm import MusicLDMPipeline
from .paint_by_example import PaintByExamplePipeline
from .pixart_alpha import PixArtAlphaPipeline
from .semantic_stable_diffusion import SemanticStableDiffusionPipeline
from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline
from .stable_diffusion import (
+2
View File
@@ -43,6 +43,7 @@ from .kandinsky2_2 import (
KandinskyV22Pipeline,
)
from .latent_consistency_models import LatentConsistencyModelImg2ImgPipeline, LatentConsistencyModelPipeline
from .pixart_alpha import PixArtAlphaPipeline
from .stable_diffusion import (
StableDiffusionImg2ImgPipeline,
StableDiffusionInpaintPipeline,
@@ -67,6 +68,7 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetPipeline),
("wuerstchen", WuerstchenCombinedPipeline),
("lcm", LatentConsistencyModelPipeline),
("pixart", PixArtAlphaPipeline),
]
)
@@ -752,7 +752,6 @@ class StableDiffusionControlNetPipeline(
guess_mode: bool = False,
control_guidance_start: Union[float, List[float]] = 0.0,
control_guidance_end: Union[float, List[float]] = 1.0,
return_workflow: bool = False,
clip_skip: Optional[int] = None,
):
r"""
@@ -825,8 +824,6 @@ class StableDiffusionControlNetPipeline(
The percentage of total steps at which the ControlNet starts applying.
control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
The percentage of total steps at which the ControlNet stops applying.
return_workflow (`bool`, *optional*, defaults to `False`):
Whether to return used pipeline call arguments.
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
@@ -840,14 +837,6 @@ class StableDiffusionControlNetPipeline(
second element is a list of `bool`s indicating whether the corresponding generated image contains
"not-safe-for-work" (nsfw) content.
"""
# We do this first to capture the "True" call values. If we do this at a later point in time,
# we cannot ensure that the call values weren't changed during the process.
workflow = None
if return_workflow:
if generator is None:
raise ValueError(f"`generator` cannot be None when `return_workflow` is {return_workflow}.")
workflow = self.populate_workflow_from_pipeline()
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
# align format for control guidance
@@ -1086,11 +1075,6 @@ class StableDiffusionControlNetPipeline(
self.maybe_free_model_hooks()
if not return_dict:
outputs = (image, has_nsfw_concept)
return (image, has_nsfw_concept)
if return_workflow:
outputs += (workflow,)
return outputs
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept, workflow=workflow)
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
@@ -166,7 +166,6 @@ class DiTPipeline(DiffusionPipeline):
# set step values
self.scheduler.set_timesteps(num_inference_steps)
for t in self.progress_bar(self.scheduler.timesteps):
if guidance_scale > 1:
half = latent_model_input[: len(latent_model_input) // 2]
+6 -119
View File
@@ -22,7 +22,6 @@ import re
import sys
import warnings
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Union
@@ -55,9 +54,7 @@ from ..utils import (
logging,
numpy_to_pil,
)
from ..utils.constants import WORKFLOW_NAME
from ..utils.torch_utils import is_compiled_module
from ..workflow_utils import _NON_CALL_ARGUMENTS, Workflow
if is_transformers_available():
@@ -67,7 +64,6 @@ if is_transformers_available():
from transformers.utils import SAFE_WEIGHTS_NAME as TRANSFORMERS_SAFE_WEIGHTS_NAME
from transformers.utils import WEIGHTS_NAME as TRANSFORMERS_WEIGHTS_NAME
from ..utils import FLAX_WEIGHTS_NAME, ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME, PushToHubMixin
@@ -357,13 +353,18 @@ def _get_pipeline_class(
else:
file_name = CUSTOM_PIPELINE_FILE_NAME
if repo_id is not None and hub_revision is not None:
# if we load the pipeline code from the Hub
# make sure to overwrite the `revison`
revision = hub_revision
return get_class_from_dynamic_module(
custom_pipeline,
module_file=file_name,
class_name=class_name,
repo_id=repo_id,
cache_dir=cache_dir,
revision=revision if hub_revision is None else hub_revision,
revision=revision,
)
if class_obj != DiffusionPipeline:
@@ -2079,117 +2080,3 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
for module in modules:
module.set_attention_slice(slice_size)
def populate_workflow_from_pipeline(self) -> Dict:
r"""Populates the call arguments in a dictionary.
Returns:
[`Workflow`]: A dictionary containing the details of the pipeline call arguments and (optionally) LoRA
checkpoint details.
"""
# A `Workflow` object is an extended Python dictionary. So, all regular dictionary methods
# apply to it.
workflow = Workflow()
signature = inspect.signature(self.__call__)
argument_names = [param.name for param in signature.parameters.values()]
call_arg_values = inspect.getargvalues(inspect.currentframe().f_back).locals
# Populate call arguments.
call_arguments = {
arg: call_arg_values[arg]
for arg in argument_names
if arg != "return_workflow"
and "image" not in arg
and not isinstance(call_arg_values[arg], (torch.Tensor, np.ndarray, Callable))
}
workflow.update(call_arguments)
# Handle generator device and seed.
generator = workflow["generator"]
if isinstance(generator, list):
for g in generator:
if "generator_seed" not in workflow:
workflow.update({"generator_seed": [g.initial_seed()]})
workflow.update({"generator_device": [str(g.device)]})
workflow.update({"generator_state": g.get_state().numpy().tolist()})
else:
workflow["generator_seed"].append(g.initial_seed())
workflow["generator_device"].append(g.device)
workflow["generator_state"].append(g.get_state().numpy().tolist())
else:
workflow.update({"generator_seed": generator.initial_seed()})
workflow.update({"generator_device": str(generator.device)})
workflow.update({"generator_state": generator.get_state().numpy().tolist()})
workflow.pop("generator")
# Handle pipeline-level things.
if hasattr(self, "config") and hasattr(self.config, "_name_or_path"):
pipeline_config_name_or_path = self.config._name_or_path
else:
pipeline_config_name_or_path = None
workflow["_name_or_path"] = pipeline_config_name_or_path
workflow["scheduler_config"] = self.scheduler.config
return workflow
def load_workflow(
self,
workflow_id_or_path: Union[str, dict],
filename: Optional[str] = None,
):
r"""Loads a workflow from the Hub or from a local path. Also patches the pipeline call arguments with values from the
workflow.
Args:
workflow_id_or_path (`str` or `dict`):
Can be either:
- A string, the workflow id (for example `sayakpaul/sdxl-workflow`) of a workflow hosted on the
Hub.
- A path to a directory (for example `./my_workflow_directory`) containing the workflow file with
[`Workflow.save_workflow`] or [`Workflow.push_to_hub`].
- A Python dictionary.
filename (`str`, *optional*):
Optional name of the workflow file to load. Especially useful when working with multiple workflow
files.
"""
filename = filename or WORKFLOW_NAME
# Load workflow.
if not isinstance(workflow_id_or_path, dict):
if os.path.isdir(workflow_id_or_path):
workflow_filepath = os.path.join(workflow_id_or_path, filename)
elif os.path.isfile(workflow_id_or_path):
workflow_filepath = workflow_id_or_path
else:
workflow_filepath = hf_hub_download(repo_id=workflow_id_or_path, filename=filename)
workflow = self._dict_from_json_file(workflow_filepath)
else:
workflow = workflow_id_or_path
# We make a copy of the original workflow and operate on it.
workflow_copy = dict(workflow.items())
# Handle generator.
seed = workflow_copy.pop("generator_seed")
device = workflow_copy.pop("generator_device", "cpu")
last_known_state = workflow_copy.pop("generator_state")
if isinstance(seed, list):
generator = [
torch.Generator(device=d).manual_seed(s).set_state(torch.from_numpy(np.array(lst)).byte())
for s, d, lst in zip(seed, device, last_known_state)
]
else:
last_known_state = torch.from_numpy(np.array(last_known_state)).byte()
generator = torch.Generator(device=device).manual_seed(seed).set_state(last_known_state)
workflow_copy.update({"generator": generator})
# Handle non-call arguments.
final_call_args = {k: v for k, v in workflow_copy.items() if k not in _NON_CALL_ARGUMENTS}
# Handle the call here.
partial_call = partial(self.__call__, **final_call_args)
setattr(self.__class__, "__call__", partial_call)
@@ -0,0 +1 @@
from .pipeline_pixart_alpha import PixArtAlphaPipeline
@@ -0,0 +1,722 @@
# Copyright 2023 PixArt-Alpha Authors and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import html
import inspect
import re
import urllib.parse as ul
from typing import Callable, List, Optional, Tuple, Union
import torch
from transformers import T5EncoderModel, T5Tokenizer
from ...image_processor import VaeImageProcessor
from ...models import AutoencoderKL, Transformer2DModel
from ...schedulers import DPMSolverMultistepScheduler
from ...utils import (
BACKENDS_MAPPING,
is_bs4_available,
is_ftfy_available,
logging,
replace_example_docstring,
)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
if is_bs4_available():
from bs4 import BeautifulSoup
if is_ftfy_available():
import ftfy
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from diffusers import PixArtAlphaPipeline
>>> # You can replace the checkpoint id with "PixArt-alpha/PixArt-XL-2-512x512" too.
>>> pipe = PixArtAlphaPipeline.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", torch_dtype=torch.float16)
>>> # Enable memory optimizations.
>>> pipe.enable_model_cpu_offload()
>>> prompt = "A small cactus with a happy face in the Sahara desert."
>>> image = pipe(prompt).images[0]
```
"""
class PixArtAlphaPipeline(DiffusionPipeline):
r"""
Pipeline for text-to-image generation using PixArt-Alpha.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`T5EncoderModel`]):
Frozen text-encoder. PixArt-Alpha uses
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
[t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
tokenizer (`T5Tokenizer`):
Tokenizer of class
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
transformer ([`Transformer2DModel`]):
A text conditioned `Transformer2DModel` to denoise the encoded image latents.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
"""
bad_punct_regex = re.compile(
r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
) # noqa
_optional_components = ["tokenizer", "text_encoder"]
model_cpu_offload_seq = "text_encoder->transformer->vae"
def __init__(
self,
tokenizer: T5Tokenizer,
text_encoder: T5EncoderModel,
vae: AutoencoderKL,
transformer: Transformer2DModel,
scheduler: DPMSolverMultistepScheduler,
):
super().__init__()
self.register_modules(
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
# Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/utils.py
def mask_text_embeddings(self, emb, mask):
if emb.shape[0] == 1:
keep_index = mask.sum().item()
return emb[:, :, :keep_index, :], keep_index
else:
masked_feature = emb * mask[:, None, :, None]
return masked_feature, emb.shape[2]
# Adapted from diffusers.pipelines.deepfloyd_if.pipeline_if.encode_prompt
def encode_prompt(
self,
prompt: Union[str, List[str]],
do_classifier_free_guidance: bool = True,
negative_prompt: str = "",
num_images_per_prompt: int = 1,
device: Optional[torch.device] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
clean_caption: bool = False,
mask_feature: bool = True,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
negative_prompt (`str` or `List[str]`, *optional*):
The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`
instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For
PixArt-Alpha, this should be "".
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
whether to use classifier free guidance or not
num_images_per_prompt (`int`, *optional*, defaults to 1):
number of images that should be generated per prompt
device: (`torch.device`, *optional*):
torch device to place the resulting embeddings on
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. For PixArt-Alpha, it's should be the embeddings of the ""
string.
clean_caption (bool, defaults to `False`):
If `True`, the function will preprocess and clean the provided caption before encoding.
mask_feature: (bool, defaults to `True`):
If `True`, the function will mask the text embeddings.
"""
if device is None:
device = self._execution_device
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
# See Section 3.1. of the paper.
max_length = 120
if prompt_embeds is None:
prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=max_length,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {max_length} tokens: {removed_text}"
)
attention_mask = text_inputs.attention_mask.to(device)
prompt_embeds_attention_mask = attention_mask
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
prompt_embeds = prompt_embeds[0]
else:
prompt_embeds_attention_mask = torch.ones_like(prompt_embeds)
if self.text_encoder is not None:
dtype = self.text_encoder.dtype
elif self.transformer is not None:
dtype = self.transformer.dtype
else:
dtype = None
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
prompt_embeds_attention_mask = prompt_embeds_attention_mask.view(bs_embed, -1)
prompt_embeds_attention_mask = prompt_embeds_attention_mask.repeat(num_images_per_prompt, 1)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance and negative_prompt_embeds is None:
uncond_tokens = [negative_prompt] * batch_size
uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
max_length=max_length,
truncation=True,
return_attention_mask=True,
add_special_tokens=True,
return_tensors="pt",
)
attention_mask = uncond_input.attention_mask.to(device)
negative_prompt_embeds = self.text_encoder(
uncond_input.input_ids.to(device),
attention_mask=attention_mask,
)
negative_prompt_embeds = negative_prompt_embeds[0]
if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
else:
negative_prompt_embeds = None
# Perform additional masking.
if mask_feature:
prompt_embeds = prompt_embeds.unsqueeze(1)
masked_prompt_embeds, keep_indices = self.mask_text_embeddings(prompt_embeds, prompt_embeds_attention_mask)
masked_prompt_embeds = masked_prompt_embeds.squeeze(1)
masked_negative_prompt_embeds = (
negative_prompt_embeds[:, :keep_indices, :] if negative_prompt_embeds is not None else None
)
return masked_prompt_embeds, masked_negative_prompt_embeds
return prompt_embeds, negative_prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
def check_inputs(
self,
prompt,
height,
width,
negative_prompt,
callback_steps,
prompt_embeds=None,
negative_prompt_embeds=None,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
def _text_preprocessing(self, text, clean_caption=False):
if clean_caption and not is_bs4_available():
logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
logger.warn("Setting `clean_caption` to False...")
clean_caption = False
if clean_caption and not is_ftfy_available():
logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
logger.warn("Setting `clean_caption` to False...")
clean_caption = False
if not isinstance(text, (tuple, list)):
text = [text]
def process(text: str):
if clean_caption:
text = self._clean_caption(text)
text = self._clean_caption(text)
else:
text = text.lower().strip()
return text
return [process(t) for t in text]
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption
def _clean_caption(self, caption):
caption = str(caption)
caption = ul.unquote_plus(caption)
caption = caption.strip().lower()
caption = re.sub("<person>", "person", caption)
# urls:
caption = re.sub(
r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
"",
caption,
) # regex for urls
caption = re.sub(
r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
"",
caption,
) # regex for urls
# html:
caption = BeautifulSoup(caption, features="html.parser").text
# @<nickname>
caption = re.sub(r"@[\w\d]+\b", "", caption)
# 31C0—31EF CJK Strokes
# 31F0—31FF Katakana Phonetic Extensions
# 3200—32FF Enclosed CJK Letters and Months
# 3300—33FF CJK Compatibility
# 3400—4DBF CJK Unified Ideographs Extension A
# 4DC0—4DFF Yijing Hexagram Symbols
# 4E00—9FFF CJK Unified Ideographs
caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
#######################################################
# все виды тире / all types of dash --> "-"
caption = re.sub(
r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
"-",
caption,
)
# кавычки к одному стандарту
caption = re.sub(r"[`´«»“”¨]", '"', caption)
caption = re.sub(r"[‘’]", "'", caption)
# &quot;
caption = re.sub(r"&quot;?", "", caption)
# &amp
caption = re.sub(r"&amp", "", caption)
# ip adresses:
caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
# article ids:
caption = re.sub(r"\d:\d\d\s+$", "", caption)
# \n
caption = re.sub(r"\\n", " ", caption)
# "#123"
caption = re.sub(r"#\d{1,3}\b", "", caption)
# "#12345.."
caption = re.sub(r"#\d{5,}\b", "", caption)
# "123456.."
caption = re.sub(r"\b\d{6,}\b", "", caption)
# filenames:
caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
#
caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
# this-is-my-cute-cat / this_is_my_cute_cat
regex2 = re.compile(r"(?:\-|\_)")
if len(re.findall(regex2, caption)) > 3:
caption = re.sub(regex2, " ", caption)
caption = ftfy.fix_text(caption)
caption = html.unescape(html.unescape(caption))
caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
caption = re.sub(r"\bpage\s+\d+\b", "", caption)
caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
caption = re.sub(r"\b\s+\:\s+", r": ", caption)
caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
caption = re.sub(r"\s+", " ", caption)
caption.strip()
caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
caption = re.sub(r"^\.\S+$", "", caption)
return caption.strip()
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]] = None,
negative_prompt: str = "",
num_inference_steps: int = 20,
timesteps: List[int] = None,
guidance_scale: float = 4.5,
num_images_per_prompt: Optional[int] = 1,
height: Optional[int] = None,
width: Optional[int] = None,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
clean_caption: bool = True,
mask_feature: bool = True,
) -> Union[ImagePipelineOutput, Tuple]:
"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
num_inference_steps (`int`, *optional*, defaults to 100):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
timesteps are used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 7.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
height (`int`, *optional*, defaults to self.unet.config.sample_size):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.unet.config.sample_size):
The width in pixels of the generated image.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. For PixArt-Alpha this negative prompt should be "". If not
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
clean_caption (`bool`, *optional*, defaults to `True`):
Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
be installed. If the dependencies are not installed, the embeddings will be created from the raw
prompt.
mask_feature (`bool` defaults to `True`): If set to `True`, the text embeddings will be masked.
Examples:
Returns:
[`~pipelines.ImagePipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
returned where the first element is a list with the generated images
"""
# 1. Check inputs. Raise error if not correct
height = height or self.transformer.config.sample_size * self.vae_scale_factor
width = width or self.transformer.config.sample_size * self.vae_scale_factor
self.check_inputs(
prompt, height, width, negative_prompt, callback_steps, prompt_embeds, negative_prompt_embeds
)
# 2. Default height and width to transformer
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt,
do_classifier_free_guidance,
negative_prompt=negative_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
clean_caption=clean_caption,
mask_feature=mask_feature,
)
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 5. Prepare latents.
latent_channels = self.transformer.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
latent_channels,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 6.1 Prepare micro-conditions.
added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
if self.transformer.config.sample_size == 128:
resolution = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1)
aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size * num_images_per_prompt, 1)
resolution = resolution.to(dtype=prompt_embeds.dtype, device=device)
aspect_ratio = aspect_ratio.to(dtype=prompt_embeds.dtype, device=device)
added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio}
# 7. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
current_timestep = t
if not torch.is_tensor(current_timestep):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = latent_model_input.device.type == "mps"
if isinstance(current_timestep, float):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device)
elif len(current_timestep.shape) == 0:
current_timestep = current_timestep[None].to(latent_model_input.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
current_timestep = current_timestep.expand(latent_model_input.shape[0])
# predict noise model_output
noise_pred = self.transformer(
latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=current_timestep,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)[0]
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# learned sigma
if self.transformer.config.out_channels // 2 == latent_channels:
noise_pred = noise_pred.chunk(2, dim=1)[0]
else:
noise_pred = noise_pred
# compute previous image: x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
else:
image = latents
if not output_type == "latent":
image = self.image_processor.postprocess(image, output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image,)
return ImagePipelineOutput(images=image)
@@ -19,13 +19,10 @@ class StableDiffusionPipelineOutput(BaseOutput):
nsfw_content_detected (`List[bool]`)
List indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content or
`None` if safety checking could not be performed.
workflow (`dict`):
Dictionary containing pipeline component configurations and call arguments
"""
images: Union[List[PIL.Image.Image], np.ndarray]
nsfw_content_detected: Optional[List[bool]]
workflow: Optional[dict] = None
if is_flax_available():
@@ -623,7 +623,6 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
return_dict: bool = True,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guidance_rescale: float = 0.0,
return_workflow: bool = False,
clip_skip: Optional[int] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
@@ -678,8 +677,6 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
using zero terminal SNR.
return_workflow (`bool`, *optional*, defaults to `False`):
Whether to return used pipeline call arguments.
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
@@ -702,13 +699,6 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
second element is a list of `bool`s indicating whether the corresponding generated image contains
"not-safe-for-work" (nsfw) content.
"""
# We do this first to capture the "True" call values. If we do this at a later point in time,
# we cannot ensure that the call values weren't changed during the process.
workflow = None
if return_workflow:
if generator is None:
raise ValueError(f"`generator` cannot be None when `return_workflow` is {return_workflow}.")
workflow = self.populate_workflow_from_pipeline()
callback = kwargs.pop("callback", None)
callback_steps = kwargs.pop("callback_steps", None)
@@ -865,11 +855,6 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
self.maybe_free_model_hooks()
if not return_dict:
outputs = (image, has_nsfw_concept)
return (image, has_nsfw_concept)
if return_workflow:
outputs += (workflow,)
return outputs
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept, workflow=workflow)
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
-3
View File
@@ -14,7 +14,6 @@
import importlib
import os
import numpy as np
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE, hf_cache_home
from packaging import version
@@ -33,13 +32,11 @@ FLAX_WEIGHTS_NAME = "diffusion_flax_model.msgpack"
ONNX_WEIGHTS_NAME = "model.onnx"
SAFETENSORS_WEIGHTS_NAME = "diffusion_pytorch_model.safetensors"
ONNX_EXTERNAL_WEIGHTS_NAME = "weights.pb"
WORKFLOW_NAME = "diffusion_workflow.json"
HUGGINGFACE_CO_RESOLVE_ENDPOINT = os.environ.get("HF_ENDPOINT", "https://huggingface.co")
DIFFUSERS_CACHE = default_cache_path
DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules"))
DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"]
MAX_SEED = np.iinfo(np.int32).max
# Below should be `True` if the current version of `peft` and `transformers` are compatible with
# PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are
@@ -572,6 +572,21 @@ class PaintByExamplePipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class PixArtAlphaPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class SemanticStableDiffusionPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
-161
View File
@@ -1,161 +0,0 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for managing workflows."""
import json
import os
from pathlib import PosixPath
from typing import Union
import numpy as np
from huggingface_hub import create_repo
from . import __version__
from .utils import PushToHubMixin, logging
from .utils.constants import WORKFLOW_NAME
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
_NON_CALL_ARGUMENTS = {"_name_or_path", "scheduler_config", "_class_name", "_diffusers_version"}
class Workflow(dict, PushToHubMixin):
"""Class sub-classing from native Python dict to have support for interacting with the Hub."""
config_name = None
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.config_name = WORKFLOW_NAME
self._internal_dict = {}
def __setitem__(self, __key, __value):
self._internal_dict[__key] = __value
return super().__setitem__(__key, __value)
def update(self, __m, **kwargs):
self._internal_dict.update(__m, **kwargs)
super().update(__m, **kwargs)
def pop(self, key, *args):
self._internal_dict.pop(key, *args)
return super().pop(key, *args)
# Copied from diffusers.configuration_utils.ConfigMixin.to_json_string
def to_json_string(self) -> str:
"""
Serializes the configuration instance to a JSON string.
Returns:
`str`:
String containing all the attributes that make up the configuration instance in JSON format.
"""
config_dict = self._internal_dict if hasattr(self, "_internal_dict") else {}
config_dict["_class_name"] = self.__class__.__name__
config_dict["_diffusers_version"] = __version__
def to_json_saveable(value):
if isinstance(value, np.ndarray):
value = value.tolist()
elif isinstance(value, PosixPath):
value = str(value)
return value
config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()}
# Don't save "_ignore_files" or "_use_default_values"
config_dict.pop("_ignore_files", None)
config_dict.pop("_use_default_values", None)
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
def save_workflow(
self,
save_directory: Union[str, os.PathLike],
push_to_hub: bool = False,
filename: str = WORKFLOW_NAME,
**kwargs,
):
"""
Saves a workflow to a directory.
Args:
save_directory (`str` or `os.PathLike`):
Directory where the workflow JSON file will be saved (will be created if it does not exist).
push_to_hub (`bool`, *optional*, defaults to `False`):
Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
namespace).
filename (`str`, *optional*, defaults to `workflow.json`):
Optional filename to use to serialize the workflow JSON.
kwargs (`Dict[str, Any]`, *optional*):
Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
"""
self.config_name = filename
if os.path.isfile(save_directory):
raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
os.makedirs(save_directory, exist_ok=True)
output_config_file = os.path.join(save_directory, self.config_name)
with open(output_config_file, "w", encoding="utf-8") as writer:
writer.write(self.to_json_string())
logger.info(f"Configuration saved in {output_config_file}")
if push_to_hub:
commit_message = kwargs.pop("commit_message", None)
private = kwargs.pop("private", False)
create_pr = kwargs.pop("create_pr", False)
token = kwargs.pop("token", None)
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id
self._upload_folder(
save_directory,
repo_id,
token=token,
commit_message=commit_message,
create_pr=create_pr,
)
def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
push_to_hub: bool = False,
filename: str = WORKFLOW_NAME,
**kwargs,
):
"""
Saves a workflow to a directory. This internally calls [`Workflow.save_workflow`], This method exists to have
feature parity with [`PushToHubMixin.push_to_hub`].
Args:
save_directory (`str` or `os.PathLike`):
Directory where the workflow JSON file will be saved (will be created if it does not exist).
push_to_hub (`bool`, *optional*, defaults to `False`):
Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
namespace).
filename (`str`, *optional*, defaults to `workflow.json`):
Optional filename to use to serialize the workflow JSON.
kwargs (`Dict[str, Any]`, *optional*):
Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
"""
self.save_workflow(
save_directory=save_directory,
push_to_hub=push_to_hub,
filename=filename,
**kwargs,
)
-171
View File
@@ -1,171 +0,0 @@
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import tempfile
import unittest
import uuid
import numpy as np
import torch
from huggingface_hub import delete_repo, hf_hub_download
from test_utils import TOKEN, USER, is_staging_test
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
from diffusers import (
AutoencoderKL,
DDIMScheduler,
StableDiffusionPipeline,
UNet2DConditionModel,
)
from diffusers.utils.constants import WORKFLOW_NAME
from diffusers.utils.testing_utils import torch_device
from diffusers.workflow_utils import Workflow
class WorkflowFastTests(unittest.TestCase):
def get_dummy_components(self):
torch.manual_seed(0)
unet = UNet2DConditionModel(
block_out_channels=(4, 8),
layers_per_block=1,
sample_size=32,
in_channels=4,
out_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
cross_attention_dim=32,
norm_num_groups=2,
)
scheduler = DDIMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
clip_sample=False,
set_alpha_to_one=False,
)
torch.manual_seed(0)
vae = AutoencoderKL(
block_out_channels=[4, 8],
in_channels=3,
out_channels=3,
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
latent_channels=4,
norm_num_groups=2,
)
torch.manual_seed(0)
text_encoder_config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=32,
intermediate_size=64,
layer_norm_eps=1e-05,
num_attention_heads=8,
num_hidden_layers=3,
pad_token_id=1,
vocab_size=1000,
)
text_encoder = CLIPTextModel(text_encoder_config)
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
components = {
"unet": unet,
"scheduler": scheduler,
"vae": vae,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"safety_checker": None,
"feature_extractor": None,
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
inputs = {
"prompt": "A painting of a squirrel eating a burger",
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 6.0,
"output_type": "np",
}
return inputs
def test_workflow_with_stable_diffusion(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
sd_pipe = StableDiffusionPipeline(**components)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
output = sd_pipe(**inputs, return_workflow=True)
image = output.images
image_slice = image[0, -3:, -3:, -1]
with tempfile.TemporaryDirectory() as tmpdirname:
output.workflow.save_pretrained(tmpdirname)
components = self.get_dummy_components()
sd_pipe = StableDiffusionPipeline(**components)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
sd_pipe.load_workflow(tmpdirname)
inputs = self.get_dummy_inputs(device)
output = sd_pipe(**inputs)
image = output.images
workflow_image_slice = image[0, -3:, -3:, -1]
self.assertTrue(np.allclose(image_slice, workflow_image_slice))
@is_staging_test
class WorkflowPushToHubTester(unittest.TestCase):
identifier = uuid.uuid4()
repo_id = f"test-workflow-{identifier}"
org_repo_id = f"valid_org/{repo_id}-org"
def compare_workflow_values(self, repo_id: str, actual_workflow: dict):
local_path = hf_hub_download(repo_id=repo_id, filename=WORKFLOW_NAME, token=TOKEN)
with open(local_path) as f:
locally_loaded_workflow = json.load(f)
for k in actual_workflow:
assert actual_workflow[k] == locally_loaded_workflow[k]
def test_push_to_hub(self):
workflow = Workflow()
workflow.update({"prompt": "hey", "num_inference_steps": 25})
workflow.push_to_hub(self.repo_id, token=TOKEN)
self.compare_workflow_values(repo_id=f"{USER}/{self.repo_id}", actual_workflow=workflow)
# Reset repo
delete_repo(token=TOKEN, repo_id=self.repo_id)
def test_push_to_hub_in_organization(self):
workflow = Workflow()
workflow.update({"prompt": "hey", "num_inference_steps": 25})
workflow.push_to_hub(self.org_repo_id, token=TOKEN)
self.compare_workflow_values(repo_id=self.org_repo_id, actual_workflow=workflow)
# Reset repo
delete_repo(token=TOKEN, repo_id=self.org_repo_id)
View File
+262
View File
@@ -0,0 +1,262 @@
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import tempfile
import unittest
import numpy as np
import torch
from transformers import AutoTokenizer, T5EncoderModel
from diffusers import (
AutoencoderKL,
DDIMScheduler,
PixArtAlphaPipeline,
Transformer2DModel,
)
from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu, slow, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin, to_np
enable_full_determinism()
class PixArtAlphaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = PixArtAlphaPipeline
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
required_optional_params = PipelineTesterMixin.required_optional_params
def get_dummy_components(self):
torch.manual_seed(0)
transformer = Transformer2DModel(
sample_size=8,
num_layers=2,
patch_size=2,
attention_head_dim=8,
num_attention_heads=3,
caption_channels=32,
in_channels=4,
cross_attention_dim=24,
out_channels=8,
attention_bias=True,
activation_fn="gelu-approximate",
num_embeds_ada_norm=1000,
norm_type="ada_norm_single",
norm_elementwise_affine=False,
norm_eps=1e-6,
)
vae = AutoencoderKL()
scheduler = DDIMScheduler()
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
components = {
"transformer": transformer.eval(),
"vae": vae.eval(),
"scheduler": scheduler,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
inputs = {
"prompt": "A painting of a squirrel eating a burger",
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 5.0,
"output_type": "numpy",
}
return inputs
def test_sequential_cpu_offload_forward_pass(self):
# TODO(PVP, Sayak) need to fix later
return
def test_save_load_optional_components(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
prompt = inputs["prompt"]
generator = inputs["generator"]
num_inference_steps = inputs["num_inference_steps"]
output_type = inputs["output_type"]
prompt_embeds, negative_prompt_embeds = pipe.encode_prompt(prompt, mask_feature=False)
# inputs with prompt converted to embeddings
inputs = {
"prompt_embeds": prompt_embeds,
"negative_prompt": None,
"negative_prompt_embeds": negative_prompt_embeds,
"generator": generator,
"num_inference_steps": num_inference_steps,
"output_type": output_type,
"mask_feature": False,
}
# set all optional components to None
for optional_component in pipe._optional_components:
setattr(pipe, optional_component, None)
output = pipe(**inputs)[0]
with tempfile.TemporaryDirectory() as tmpdir:
pipe.save_pretrained(tmpdir)
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir)
pipe_loaded.to(torch_device)
pipe_loaded.set_progress_bar_config(disable=None)
for optional_component in pipe._optional_components:
self.assertTrue(
getattr(pipe_loaded, optional_component) is None,
f"`{optional_component}` did not stay set to None after loading.",
)
inputs = self.get_dummy_inputs(torch_device)
generator = inputs["generator"]
num_inference_steps = inputs["num_inference_steps"]
output_type = inputs["output_type"]
# inputs with prompt converted to embeddings
inputs = {
"prompt_embeds": prompt_embeds,
"negative_prompt": None,
"negative_prompt_embeds": negative_prompt_embeds,
"generator": generator,
"num_inference_steps": num_inference_steps,
"output_type": output_type,
"mask_feature": False,
}
output_loaded = pipe_loaded(**inputs)[0]
max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
self.assertLess(max_diff, 1e-4)
def test_inference(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images
image_slice = image[0, -3:, -3:, -1]
print(torch.from_numpy(image_slice.flatten()))
self.assertEqual(image.shape, (1, 8, 8, 3))
expected_slice = np.array([0.5303, 0.2658, 0.7979, 0.1182, 0.3304, 0.4608, 0.5195, 0.4261, 0.4675])
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
self.assertLessEqual(max_diff, 1e-3)
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(expected_max_diff=1e-3)
# TODO: needs to be updated.
@slow
@require_torch_gpu
class PixArtAlphaPipelineIntegrationTests(unittest.TestCase):
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def test_pixart_1024_fast(self):
generator = torch.manual_seed(0)
pipe = PixArtAlphaPipeline.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", torch_dtype=torch.float16)
pipe.enable_model_cpu_offload()
prompt = "A small cactus with a happy face in the Sahara desert."
image = pipe(prompt, generator=generator, num_inference_steps=2, output_type="np").images
image_slice = image[0, -3:, -3:, -1]
expected_slice = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1323])
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
self.assertLessEqual(max_diff, 1e-3)
def test_pixart_512_fast(self):
generator = torch.manual_seed(0)
pipe = PixArtAlphaPipeline.from_pretrained("PixArt-alpha/PixArt-XL-2-512x512", torch_dtype=torch.float16)
pipe.enable_model_cpu_offload()
prompt = "A small cactus with a happy face in the Sahara desert."
image = pipe(prompt, generator=generator, num_inference_steps=2, output_type="np").images
image_slice = image[0, -3:, -3:, -1]
expected_slice = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0266])
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
self.assertLessEqual(max_diff, 1e-3)
def test_pixart_1024(self):
generator = torch.manual_seed(0)
pipe = PixArtAlphaPipeline.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", torch_dtype=torch.float16)
pipe.enable_model_cpu_offload()
prompt = "A small cactus with a happy face in the Sahara desert."
image = pipe(prompt, generator=generator, output_type="np").images
image_slice = image[0, -3:, -3:, -1]
expected_slice = np.array([0.1501, 0.1755, 0.1877, 0.1445, 0.1665, 0.1763, 0.1389, 0.176, 0.2031])
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
self.assertLessEqual(max_diff, 1e-3)
def test_pixart_512(self):
generator = torch.manual_seed(0)
pipe = PixArtAlphaPipeline.from_pretrained("PixArt-alpha/PixArt-XL-2-512x512", torch_dtype=torch.float16)
pipe.enable_model_cpu_offload()
prompt = "A small cactus with a happy face in the Sahara desert."
image = pipe(prompt, generator=generator, output_type="np").images
image_slice = image[0, -3:, -3:, -1]
expected_slice = np.array([0.2515, 0.2593, 0.2593, 0.2544, 0.2759, 0.2788, 0.2812, 0.3169, 0.332])
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
self.assertLessEqual(max_diff, 1e-3)