Compare commits
40 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 747039b5c8 | |||
| ab6672fecd | |||
| f90a5139a2 | |||
| a2bc2e14b9 | |||
| f427345ab1 | |||
| 6e221334cd | |||
| 53bc30dd45 | |||
| eacf5e34eb | |||
| 4c05f7856a | |||
| bbd3572044 | |||
| f948778322 | |||
| 4684ea2fe8 | |||
| b64f835ea7 | |||
| 880c0fdd36 | |||
| c36f1c3160 | |||
| 0a08d41961 | |||
| e185084a5d | |||
| b21729225a | |||
| 8a812e4e14 | |||
| bf92e746c0 | |||
| b785a155d6 | |||
| d486f0e846 | |||
| 3351270627 | |||
| 4520e1221a | |||
| 618260409f | |||
| dadd55fb36 | |||
| 1b6c7ea74e | |||
| b41f809a4e | |||
| 0f55c17e17 | |||
| 5058d27f12 | |||
| 748c1b3ec7 | |||
| 523507034f | |||
| 46c751e970 | |||
| bc1d28c888 | |||
| af378c1dd1 | |||
| 6ba4c5395f | |||
| c1e4529541 | |||
| d29d97b616 | |||
| 7d4a257c7f | |||
| 141cd52d56 |
@@ -35,14 +35,15 @@ jobs:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
fetch-depth: 0
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
apt-get update && apt-get install libsndfile1-dev libgl1 -y
|
||||
python -m pip install -e .
|
||||
python -m pip install -e .[quality,test]
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
echo $(git --version)
|
||||
- name: Fetch Tests
|
||||
run: |
|
||||
python utils/tests_fetcher.py | tee test_preparation.txt
|
||||
@@ -110,7 +111,7 @@ jobs:
|
||||
continue-on-error: true
|
||||
run: |
|
||||
cat reports/${{ matrix.modules }}_tests_cpu_stats.txt
|
||||
cat reports/${{ matrix.modules }}_tests_cpu/failures_short.txt
|
||||
cat reports/${{ matrix.modules }}_tests_cpu_failures_short.txt
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
|
||||
+1
-1
@@ -355,7 +355,7 @@ You will need basic `git` proficiency to be able to contribute to
|
||||
manual. Type `git --help` in a shell and enjoy. If you prefer books, [Pro
|
||||
Git](https://git-scm.com/book/en/v2) is a very good reference.
|
||||
|
||||
Follow these steps to start contributing ([supported Python versions](https://github.com/huggingface/diffusers/blob/main/setup.py#L244)):
|
||||
Follow these steps to start contributing ([supported Python versions](https://github.com/huggingface/diffusers/blob/main/setup.py#L265)):
|
||||
|
||||
1. Fork the [repository](https://github.com/huggingface/diffusers) by
|
||||
clicking on the 'Fork' button on the repository's page. This creates a copy of the code
|
||||
|
||||
@@ -41,7 +41,7 @@ repo-consistency:
|
||||
|
||||
quality:
|
||||
ruff check $(check_dirs) setup.py
|
||||
ruff format --check $(check_dirs) setup.py
|
||||
ruff format --check $(check_dirs) setup.py
|
||||
python utils/check_doc_toc.py
|
||||
|
||||
# Format source code automatically and check is there are any problems left that need manual fixing
|
||||
|
||||
@@ -20,6 +20,9 @@ An attention processor is a class for applying different types of attention mech
|
||||
## AttnProcessor2_0
|
||||
[[autodoc]] models.attention_processor.AttnProcessor2_0
|
||||
|
||||
## FusedAttnProcessor2_0
|
||||
[[autodoc]] models.attention_processor.FusedAttnProcessor2_0
|
||||
|
||||
## LoRAAttnProcessor
|
||||
[[autodoc]] models.attention_processor.LoRAAttnProcessor
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ The abstract from the paper is:
|
||||
|
||||
*Model-based reinforcement learning methods often use learning only for the purpose of estimating an approximate dynamics model, offloading the rest of the decision-making work to classical trajectory optimizers. While conceptually simple, this combination has a number of empirical shortcomings, suggesting that learned models may not be well-suited to standard trajectory optimization. In this paper, we consider what it would look like to fold as much of the trajectory optimization pipeline as possible into the modeling problem, such that sampling from the model and planning with it become nearly identical. The core of our technical approach lies in a diffusion probabilistic model that plans by iteratively denoising trajectories. We show how classifier-guided sampling and image inpainting can be reinterpreted as coherent planning strategies, explore the unusual and useful properties of diffusion-based planning methods, and demonstrate the effectiveness of our framework in control settings that emphasize long-horizon decision-making and test-time flexibility.*
|
||||
|
||||
You can find additional information about the model on the [project page](https://diffusion-planning.github.io/), the [original codebase](https://github.com/jannerm/diffuser), or try it out in a demo [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/reinforcement_learning_with_diffusers.ipynb).
|
||||
You can find additional information about the model on the [project page](https://diffusion-planning.github.io/), the [original codebase](https://github.com/jannerm/diffuser), or try it out in a demo [notebook](https://colab.research.google.com/drive/1rXm8CX4ZdN5qivjJ2lhwhkOmt_m0CvU0#scrollTo=6HXJvhyqcITc&uniqifier=1).
|
||||
|
||||
The script to run the model is available [here](https://github.com/huggingface/diffusers/tree/main/examples/reinforcement_learning).
|
||||
|
||||
|
||||
@@ -297,17 +297,37 @@ if you don't know yet what specific component you would like to add:
|
||||
- [Model or pipeline](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22New+pipeline%2Fmodel%22)
|
||||
- [Scheduler](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22New+scheduler%22)
|
||||
|
||||
Before adding any of the three components, it is strongly recommended that you give the [Philosophy guide](philosophy) a read to better understand the design of any of the three components. Please be aware that
|
||||
we cannot merge model, scheduler, or pipeline additions that strongly diverge from our design philosophy
|
||||
as it will lead to API inconsistencies. If you fundamentally disagree with a design choice, please
|
||||
open a [Feedback issue](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=&template=feedback.md&title=) instead so that it can be discussed whether a certain design
|
||||
pattern/design choice shall be changed everywhere in the library and whether we shall update our design philosophy. Consistency across the library is very important for us.
|
||||
Before adding any of the three components, it is strongly recommended that you give the [Philosophy guide](philosophy) a read to better understand the design of any of the three components. Please be aware that we cannot merge model, scheduler, or pipeline additions that strongly diverge from our design philosophy
|
||||
as it will lead to API inconsistencies. If you fundamentally disagree with a design choice, please open a [Feedback issue](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=&template=feedback.md&title=) instead so that it can be discussed whether a certain design pattern/design choice shall be changed everywhere in the library and whether we shall update our design philosophy. Consistency across the library is very important for us.
|
||||
|
||||
Please make sure to add links to the original codebase/paper to the PR and ideally also ping the
|
||||
original author directly on the PR so that they can follow the progress and potentially help with questions.
|
||||
Please make sure to add links to the original codebase/paper to the PR and ideally also ping the original author directly on the PR so that they can follow the progress and potentially help with questions.
|
||||
|
||||
If you are unsure or stuck in the PR, don't hesitate to leave a message to ask for a first review or help.
|
||||
|
||||
#### Copied from mechanism
|
||||
|
||||
A unique and important feature to understand when adding any pipeline, model or scheduler code is the `# Copied from` mechanism. You'll see this all over the Diffusers codebase, and the reason we use it is to keep the codebase easy to understand and maintain. Marking code with the `# Copied from` mechanism forces the marked code to be identical to the code it was copied from. This makes it easy to update and propagate changes across many files whenever you run `make fix-copies`.
|
||||
|
||||
For example, in the code example below, [`~diffusers.pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is the original code and `AltDiffusionPipelineOutput` uses the `# Copied from` mechanism to copy it. The only difference is changing the class prefix from `Stable` to `Alt`.
|
||||
|
||||
```py
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_output.StableDiffusionPipelineOutput with Stable->Alt
|
||||
class AltDiffusionPipelineOutput(BaseOutput):
|
||||
"""
|
||||
Output class for Alt Diffusion pipelines.
|
||||
|
||||
Args:
|
||||
images (`List[PIL.Image.Image]` or `np.ndarray`)
|
||||
List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width,
|
||||
num_channels)`.
|
||||
nsfw_content_detected (`List[bool]`)
|
||||
List indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content or
|
||||
`None` if safety checking could not be performed.
|
||||
"""
|
||||
```
|
||||
|
||||
To learn more, read this section of the [~Don't~ Repeat Yourself*](https://huggingface.co/blog/transformers-design-philosophy#4-machine-learning-models-are-static) blog post.
|
||||
|
||||
## How to write a good issue
|
||||
|
||||
**The better your issue is written, the higher the chances that it will be quickly resolved.**
|
||||
|
||||
@@ -20,6 +20,8 @@ The Kandinsky models are a series of multilingual text-to-image generation model
|
||||
|
||||
[Kandinsky 2.2](../api/pipelines/kandinsky_v22) improves on the previous model by replacing the image encoder of the image prior model with a larger CLIP-ViT-G model to improve quality. The image prior model was also retrained on images with different resolutions and aspect ratios to generate higher-resolution images and different image sizes.
|
||||
|
||||
[Kandinsky 3](../api/pipelines/kandinsky3) simplifies the architecture and shifts away from the two-stage generation process involving the prior model and diffusion model. Instead, Kandinsky 3 uses [Flan-UL2](https://huggingface.co/google/flan-ul2) to encode text, a UNet with [BigGan-deep](https://hf.co/papers/1809.11096) blocks, and [Sber-MoVQGAN](https://github.com/ai-forever/MoVQGAN) to decode the latents into images. Text understanding and generated image quality are primarily achieved by using a larger text encoder and UNet.
|
||||
|
||||
This guide will show you how to use the Kandinsky models for text-to-image, image-to-image, inpainting, interpolation, and more.
|
||||
|
||||
Before you begin, make sure you have the following libraries installed:
|
||||
@@ -33,6 +35,10 @@ Before you begin, make sure you have the following libraries installed:
|
||||
|
||||
Kandinsky 2.1 and 2.2 usage is very similar! The only difference is Kandinsky 2.2 doesn't accept `prompt` as an input when decoding the latents. Instead, Kandinsky 2.2 only accepts `image_embeds` during decoding.
|
||||
|
||||
<br>
|
||||
|
||||
Kandinsky 3 has a more concise architecture and it doesn't require a prior model. This means it's usage is identical to other diffusion models like [Stable Diffusion XL](sdxl).
|
||||
|
||||
</Tip>
|
||||
|
||||
## Text-to-image
|
||||
@@ -91,6 +97,23 @@ image
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/kandinsky-text-to-image.png"/>
|
||||
</div>
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="Kandinsky 3">
|
||||
|
||||
Kandinsky 3 doesn't require a prior model so you can directly load the [`Kandinsky3Pipeline`] and pass a prompt to generate an image:
|
||||
|
||||
```py
|
||||
from diffusers import Kandinsky3Pipeline
|
||||
import torch
|
||||
|
||||
pipeline = Kandinsky3Pipeline.from_pretrained("kandinsky-community/kandinsky-3", variant="fp16", torch_dtype=torch.float16)
|
||||
pipeline.enable_model_cpu_offload()
|
||||
|
||||
prompt = "A alien cheeseburger creature eating itself, claymation, cinematic, moody lighting"
|
||||
image = pipeline(prompt).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
@@ -161,6 +184,20 @@ prior_pipeline = KandinskyPriorPipeline.from_pretrained("kandinsky-community/kan
|
||||
pipeline = KandinskyV22Img2ImgPipeline.from_pretrained("kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16, use_safetensors=True).to("cuda")
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="Kandinsky 3">
|
||||
|
||||
Kandinsky 3 doesn't require a prior model so you can directly load the image-to-image pipeline:
|
||||
|
||||
```py
|
||||
from diffusers import Kandinsky3Img2ImgPipeline
|
||||
from diffusers.utils import load_image
|
||||
import torch
|
||||
|
||||
pipeline = Kandinsky3Img2ImgPipeline.from_pretrained("kandinsky-community/kandinsky-3", variant="fp16", torch_dtype=torch.float16)
|
||||
pipeline.enable_model_cpu_offload()
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
@@ -218,6 +255,14 @@ make_image_grid([original_image.resize((512, 512)), image.resize((512, 512))], r
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/kandinsky-image-to-image.png"/>
|
||||
</div>
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="Kandinsky 3">
|
||||
|
||||
```py
|
||||
image = pipeline(prompt, negative_prompt=negative_prompt, image=image, strength=0.75, num_inference_steps=25).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
|
||||
@@ -53,8 +53,9 @@ frames = pipe(image, decode_chunk_size=8, generator=generator).frames[0]
|
||||
export_to_video(frames, "generated.mp4", fps=7)
|
||||
```
|
||||
|
||||
<video width="1024" height="576" controls>
|
||||
<source src="https://i.imgur.com/jJzVDKw.mp4" type="video/mp4">
|
||||
<video controls width="1024" height="576">
|
||||
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd/rocket_generated.webm" type="video/webm" />
|
||||
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd/rocket_generated.mp4" type="video/mp4" />
|
||||
</video>
|
||||
|
||||
<Tip>
|
||||
|
||||
@@ -54,7 +54,7 @@ from diffusers import (
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.loaders import LoraLoaderMixin
|
||||
from diffusers.models.lora import LoRALinearLayer, text_encoder_lora_state_dict
|
||||
from diffusers.models.lora import LoRALinearLayer
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import compute_snr, unet_lora_state_dict
|
||||
from diffusers.utils import check_min_version, is_wandb_available
|
||||
@@ -62,16 +62,51 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.24.0.dev0")
|
||||
check_min_version("0.25.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
# TODO: This function should be removed once training scripts are rewritten in PEFT
|
||||
def text_encoder_lora_state_dict(text_encoder):
|
||||
state_dict = {}
|
||||
|
||||
def text_encoder_attn_modules(text_encoder):
|
||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection
|
||||
|
||||
attn_modules = []
|
||||
|
||||
if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
|
||||
for i, layer in enumerate(text_encoder.text_model.encoder.layers):
|
||||
name = f"text_model.encoder.layers.{i}.self_attn"
|
||||
mod = layer.self_attn
|
||||
attn_modules.append((name, mod))
|
||||
|
||||
return attn_modules
|
||||
|
||||
for name, module in text_encoder_attn_modules(text_encoder):
|
||||
for k, v in module.q_proj.lora_linear_layer.state_dict().items():
|
||||
state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v
|
||||
|
||||
for k, v in module.k_proj.lora_linear_layer.state_dict().items():
|
||||
state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v
|
||||
|
||||
for k, v in module.v_proj.lora_linear_layer.state_dict().items():
|
||||
state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v
|
||||
|
||||
for k, v in module.out_proj.lora_linear_layer.state_dict().items():
|
||||
state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def save_model_card(
|
||||
repo_id: str,
|
||||
images=None,
|
||||
base_model=str,
|
||||
train_text_encoder=False,
|
||||
train_text_encoder_ti=False,
|
||||
token_abstraction_dict=None,
|
||||
instance_prompt=str,
|
||||
validation_prompt=str,
|
||||
repo_folder=None,
|
||||
@@ -83,10 +118,33 @@ def save_model_card(
|
||||
img_str += f"""
|
||||
- text: '{validation_prompt if validation_prompt else ' ' }'
|
||||
output:
|
||||
url: >-
|
||||
url:
|
||||
"image_{i}.png"
|
||||
"""
|
||||
|
||||
trigger_str = f"You should use {instance_prompt} to trigger the image generation."
|
||||
diffusers_imports_pivotal = ""
|
||||
diffusers_example_pivotal = ""
|
||||
if train_text_encoder_ti:
|
||||
trigger_str = (
|
||||
"To trigger image generation of trained concept(or concepts) replace each concept identifier "
|
||||
"in you prompt with the new inserted tokens:\n"
|
||||
)
|
||||
diffusers_imports_pivotal = """from huggingface_hub import hf_hub_download
|
||||
from safetensors.torch import load_file
|
||||
"""
|
||||
diffusers_example_pivotal = f"""embedding_path = hf_hub_download(repo_id="{repo_id}", filename="embeddings.safetensors", repo_type="model")
|
||||
state_dict = load_file(embedding_path)
|
||||
pipeline.load_textual_inversion(state_dict["clip_l"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer)
|
||||
pipeline.load_textual_inversion(state_dict["clip_g"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder_2, tokenizer=pipe.tokenizer_2)
|
||||
"""
|
||||
if token_abstraction_dict:
|
||||
for key, value in token_abstraction_dict.items():
|
||||
tokens = "".join(value)
|
||||
trigger_str += f"""
|
||||
to trigger concept `{key}` → use `{tokens}` in your prompt \n
|
||||
"""
|
||||
|
||||
yaml = f"""
|
||||
---
|
||||
tags:
|
||||
@@ -96,9 +154,7 @@ tags:
|
||||
- diffusers
|
||||
- lora
|
||||
- template:sd-lora
|
||||
widget:
|
||||
{img_str}
|
||||
---
|
||||
base_model: {base_model}
|
||||
instance_prompt: {instance_prompt}
|
||||
license: openrail++
|
||||
@@ -112,16 +168,35 @@ license: openrail++
|
||||
|
||||
## Model description
|
||||
|
||||
These are {repo_id} LoRA adaption weights for {base_model}.
|
||||
### These are {repo_id} LoRA adaption weights for {base_model}.
|
||||
|
||||
The weights were trained using [DreamBooth](https://dreambooth.github.io/).
|
||||
|
||||
LoRA for the text encoder was enabled: {train_text_encoder}.
|
||||
|
||||
Pivotal tuning was enabled: {train_text_encoder_ti}.
|
||||
|
||||
Special VAE used for training: {vae_path}.
|
||||
|
||||
## Trigger words
|
||||
|
||||
You should use {instance_prompt} to trigger the image generation.
|
||||
{trigger_str}
|
||||
|
||||
## Download model
|
||||
## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers)
|
||||
|
||||
```py
|
||||
from diffusers import AutoPipelineForText2Image
|
||||
import torch
|
||||
{diffusers_imports_pivotal}
|
||||
pipeline = AutoPipelineForText2Image.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', torch_dtype=torch.float16).to('cuda')
|
||||
pipeline.load_lora_weights('{repo_id}', weight_name='pytorch_lora_weights.safetensors')
|
||||
{diffusers_example_pivotal}
|
||||
image = pipeline('{validation_prompt if validation_prompt else instance_prompt}').images[0]
|
||||
```
|
||||
|
||||
For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters)
|
||||
|
||||
## Download model (use it with UIs such as AUTO1111, Comfy, SD.Next, Invoke)
|
||||
|
||||
Weights for this model are available in Safetensors format.
|
||||
|
||||
@@ -174,6 +249,12 @@ def parse_args(input_args=None):
|
||||
required=False,
|
||||
help="Revision of pretrained model identifier from huggingface.co/models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--variant",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset_name",
|
||||
type=str,
|
||||
@@ -181,20 +262,26 @@ def parse_args(input_args=None):
|
||||
help=(
|
||||
"The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private,"
|
||||
" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
|
||||
" or to a folder containing files that 🤗 Datasets can understand."
|
||||
" or to a folder containing files that 🤗 Datasets can understand.To load the custom captions, the training set directory needs to follow the structure of a "
|
||||
"datasets ImageFolder, containing both the images and the corresponding caption for each image. see: "
|
||||
"https://huggingface.co/docs/datasets/image_dataset for more information"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset_config_name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The config of the Dataset, leave as None if there's only one config.",
|
||||
help="The config of the Dataset. In some cases, a dataset may have more than one configuration (for example "
|
||||
"if it contains different subsets of data within, and you only wish to load a specific subset - in that case specify the desired configuration using --dataset_config_name. Leave as "
|
||||
"None if there's only one config.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--instance_data_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help=("A folder containing the training data. "),
|
||||
help="A path to local folder containing the training data of instance images. Specify this arg instead of "
|
||||
"--dataset_name if you wish to train using a local folder without custom captions. If you wish to train with custom captions please specify "
|
||||
"--dataset_name instead.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@@ -237,15 +324,18 @@ def parse_args(input_args=None):
|
||||
)
|
||||
parser.add_argument(
|
||||
"--token_abstraction",
|
||||
type=str,
|
||||
default="TOK",
|
||||
help="identifier specifying the instance(or instances) as used in instance_prompt, validation prompt, "
|
||||
"captions - e.g. TOK",
|
||||
"captions - e.g. TOK. To use multiple identifiers, please specify them in a comma seperated string - e.g. "
|
||||
"'TOK,TOK2,TOK3' etc.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num_new_tokens_per_abstraction",
|
||||
type=int,
|
||||
default=2,
|
||||
help="number of new tokens inserted to the tokenizers per token_abstraction value when "
|
||||
help="number of new tokens inserted to the tokenizers per token_abstraction identifier when "
|
||||
"--train_text_encoder_ti = True. By default, each --token_abstraction (e.g. TOK) is mapped to 2 new "
|
||||
"tokens - <si><si+1> ",
|
||||
)
|
||||
@@ -455,7 +545,7 @@ def parse_args(input_args=None):
|
||||
parser.add_argument(
|
||||
"--train_text_encoder_frac",
|
||||
type=float,
|
||||
default=0.5,
|
||||
default=1.0,
|
||||
help=("The percentage of epochs to perform text encoder tuning"),
|
||||
)
|
||||
|
||||
@@ -488,7 +578,7 @@ def parse_args(input_args=None):
|
||||
parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay")
|
||||
parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params")
|
||||
parser.add_argument(
|
||||
"--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder"
|
||||
"--adam_weight_decay_text_encoder", type=float, default=None, help="Weight decay to use for text_encoder"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@@ -596,17 +686,6 @@ def parse_args(input_args=None):
|
||||
"inversion training check `--train_text_encoder_ti`"
|
||||
)
|
||||
|
||||
if args.train_text_encoder_ti:
|
||||
if isinstance(args.token_abstraction, str):
|
||||
args.token_abstraction = [args.token_abstraction]
|
||||
elif isinstance(args.token_abstraction, List):
|
||||
args.token_abstraction = args.token_abstraction
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported type for --args.token_abstraction: {type(args.token_abstraction)}. "
|
||||
f"Supported types are: str (for a single instance identifier) or List[str] (for multiple concepts)"
|
||||
)
|
||||
|
||||
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
||||
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
||||
args.local_rank = env_local_rank
|
||||
@@ -679,12 +758,19 @@ class TokenEmbeddingsHandler:
|
||||
def save_embeddings(self, file_path: str):
|
||||
assert self.train_ids is not None, "Initialize new tokens before saving embeddings."
|
||||
tensors = {}
|
||||
# text_encoder_0 - CLIP ViT-L/14, text_encoder_1 - CLIP ViT-G/14
|
||||
idx_to_text_encoder_name = {0: "clip_l", 1: "clip_g"}
|
||||
for idx, text_encoder in enumerate(self.text_encoders):
|
||||
assert text_encoder.text_model.embeddings.token_embedding.weight.data.shape[0] == len(
|
||||
self.tokenizers[0]
|
||||
), "Tokenizers should be the same."
|
||||
new_token_embeddings = text_encoder.text_model.embeddings.token_embedding.weight.data[self.train_ids]
|
||||
tensors[f"text_encoders_{idx}"] = new_token_embeddings
|
||||
|
||||
# New tokens for each text encoder are saved under "clip_l" (for text_encoder 0), "clip_g" (for
|
||||
# text_encoder 1) to keep compatible with the ecosystem.
|
||||
# Note: When loading with diffusers, any name can work - simply specify in inference
|
||||
tensors[idx_to_text_encoder_name[idx]] = new_token_embeddings
|
||||
# tensors[f"text_encoders_{idx}"] = new_token_embeddings
|
||||
|
||||
save_file(tensors, file_path)
|
||||
|
||||
@@ -696,19 +782,6 @@ class TokenEmbeddingsHandler:
|
||||
def device(self):
|
||||
return self.text_encoders[0].device
|
||||
|
||||
# def _load_embeddings(self, loaded_embeddings, tokenizer, text_encoder):
|
||||
# # Assuming new tokens are of the format <s_i>
|
||||
# self.inserting_toks = [f"<s{i}>" for i in range(loaded_embeddings.shape[0])]
|
||||
# special_tokens_dict = {"additional_special_tokens": self.inserting_toks}
|
||||
# tokenizer.add_special_tokens(special_tokens_dict)
|
||||
# text_encoder.resize_token_embeddings(len(tokenizer))
|
||||
#
|
||||
# self.train_ids = tokenizer.convert_tokens_to_ids(self.inserting_toks)
|
||||
# assert self.train_ids is not None, "New tokens could not be converted to IDs."
|
||||
# text_encoder.text_model.embeddings.token_embedding.weight.data[
|
||||
# self.train_ids
|
||||
# ] = loaded_embeddings.to(device=self.device).to(dtype=self.dtype)
|
||||
|
||||
@torch.no_grad()
|
||||
def retract_embeddings(self):
|
||||
for idx, text_encoder in enumerate(self.text_encoders):
|
||||
@@ -730,15 +803,6 @@ class TokenEmbeddingsHandler:
|
||||
new_embeddings = new_embeddings * (off_ratio**0.1)
|
||||
text_encoder.text_model.embeddings.token_embedding.weight.data[index_updates] = new_embeddings
|
||||
|
||||
# def load_embeddings(self, file_path: str):
|
||||
# with safe_open(file_path, framework="pt", device=self.device.type) as f:
|
||||
# for idx in range(len(self.text_encoders)):
|
||||
# text_encoder = self.text_encoders[idx]
|
||||
# tokenizer = self.tokenizers[idx]
|
||||
#
|
||||
# loaded_embeddings = f.get_tensor(f"text_encoders_{idx}")
|
||||
# self._load_embeddings(loaded_embeddings, tokenizer, text_encoder)
|
||||
|
||||
|
||||
class DreamBoothDataset(Dataset):
|
||||
"""
|
||||
@@ -751,6 +815,12 @@ class DreamBoothDataset(Dataset):
|
||||
instance_data_root,
|
||||
instance_prompt,
|
||||
class_prompt,
|
||||
dataset_name,
|
||||
dataset_config_name,
|
||||
cache_dir,
|
||||
image_column,
|
||||
caption_column,
|
||||
train_text_encoder_ti,
|
||||
class_data_root=None,
|
||||
class_num=None,
|
||||
token_abstraction_dict=None, # token mapping for textual inversion
|
||||
@@ -765,10 +835,10 @@ class DreamBoothDataset(Dataset):
|
||||
self.custom_instance_prompts = None
|
||||
self.class_prompt = class_prompt
|
||||
self.token_abstraction_dict = token_abstraction_dict
|
||||
|
||||
self.train_text_encoder_ti = train_text_encoder_ti
|
||||
# if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory,
|
||||
# we load the training data using load_dataset
|
||||
if args.dataset_name is not None:
|
||||
if dataset_name is not None:
|
||||
try:
|
||||
from datasets import load_dataset
|
||||
except ImportError:
|
||||
@@ -781,26 +851,25 @@ class DreamBoothDataset(Dataset):
|
||||
# See more about loading custom images at
|
||||
# https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
|
||||
dataset = load_dataset(
|
||||
args.dataset_name,
|
||||
args.dataset_config_name,
|
||||
cache_dir=args.cache_dir,
|
||||
dataset_name,
|
||||
dataset_config_name,
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
# Preprocessing the datasets.
|
||||
column_names = dataset["train"].column_names
|
||||
|
||||
# 6. Get the column names for input/target.
|
||||
if args.image_column is None:
|
||||
if image_column is None:
|
||||
image_column = column_names[0]
|
||||
logger.info(f"image column defaulting to {image_column}")
|
||||
else:
|
||||
image_column = args.image_column
|
||||
if image_column not in column_names:
|
||||
raise ValueError(
|
||||
f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
|
||||
f"`--image_column` value '{image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
|
||||
)
|
||||
instance_images = dataset["train"][image_column]
|
||||
|
||||
if args.caption_column is None:
|
||||
if caption_column is None:
|
||||
logger.info(
|
||||
"No caption column provided, defaulting to instance_prompt for all images. If your dataset "
|
||||
"contains captions/prompts for the images, make sure to specify the "
|
||||
@@ -808,11 +877,11 @@ class DreamBoothDataset(Dataset):
|
||||
)
|
||||
self.custom_instance_prompts = None
|
||||
else:
|
||||
if args.caption_column not in column_names:
|
||||
if caption_column not in column_names:
|
||||
raise ValueError(
|
||||
f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
|
||||
f"`--caption_column` value '{caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
|
||||
)
|
||||
custom_instance_prompts = dataset["train"][args.caption_column]
|
||||
custom_instance_prompts = dataset["train"][caption_column]
|
||||
# create final list of captions according to --repeats
|
||||
self.custom_instance_prompts = []
|
||||
for caption in custom_instance_prompts:
|
||||
@@ -867,7 +936,7 @@ class DreamBoothDataset(Dataset):
|
||||
if self.custom_instance_prompts:
|
||||
caption = self.custom_instance_prompts[index % self.num_instance_images]
|
||||
if caption:
|
||||
if args.train_text_encoder_ti:
|
||||
if self.train_text_encoder_ti:
|
||||
# replace instances of --token_abstraction in caption with the new tokens: "<si><si+1>" etc.
|
||||
for token_abs, token_replacement in self.token_abstraction_dict.items():
|
||||
caption = caption.replace(token_abs, "".join(token_replacement))
|
||||
@@ -1021,6 +1090,7 @@ def main(args):
|
||||
args.pretrained_model_name_or_path,
|
||||
torch_dtype=torch_dtype,
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
)
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
@@ -1052,17 +1122,25 @@ def main(args):
|
||||
if args.output_dir is not None:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
model_id = args.hub_model_id or Path(args.output_dir).name
|
||||
repo_id = None
|
||||
if args.push_to_hub:
|
||||
repo_id = create_repo(
|
||||
repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
|
||||
).repo_id
|
||||
repo_id = create_repo(repo_id=model_id, exist_ok=True, token=args.hub_token).repo_id
|
||||
|
||||
# Load the tokenizers
|
||||
tokenizer_one = AutoTokenizer.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, use_fast=False
|
||||
args.pretrained_model_name_or_path,
|
||||
subfolder="tokenizer",
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
use_fast=False,
|
||||
)
|
||||
tokenizer_two = AutoTokenizer.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.revision, use_fast=False
|
||||
args.pretrained_model_name_or_path,
|
||||
subfolder="tokenizer_2",
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
use_fast=False,
|
||||
)
|
||||
|
||||
# import correct text encoder classes
|
||||
@@ -1076,10 +1154,10 @@ def main(args):
|
||||
# Load scheduler and models
|
||||
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
||||
text_encoder_one = text_encoder_cls_one.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
|
||||
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
|
||||
)
|
||||
text_encoder_two = text_encoder_cls_two.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision
|
||||
args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
|
||||
)
|
||||
vae_path = (
|
||||
args.pretrained_model_name_or_path
|
||||
@@ -1087,16 +1165,24 @@ def main(args):
|
||||
else args.pretrained_vae_model_name_or_path
|
||||
)
|
||||
vae = AutoencoderKL.from_pretrained(
|
||||
vae_path, subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, revision=args.revision
|
||||
vae_path,
|
||||
subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
)
|
||||
unet = UNet2DConditionModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
|
||||
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
|
||||
)
|
||||
|
||||
if args.train_text_encoder_ti:
|
||||
# we parse the provided token identifier (or identifiers) into a list. s.t. - "TOK" -> ["TOK"], "TOK,
|
||||
# TOK2" -> ["TOK", "TOK2"] etc.
|
||||
token_abstraction_list = "".join(args.token_abstraction.split()).split(",")
|
||||
logger.info(f"list of token identifiers: {token_abstraction_list}")
|
||||
|
||||
token_abstraction_dict = {}
|
||||
token_idx = 0
|
||||
for i, token in enumerate(args.token_abstraction):
|
||||
for i, token in enumerate(token_abstraction_list):
|
||||
token_abstraction_dict[token] = [
|
||||
f"<s{token_idx + i + j}>" for j in range(args.num_new_tokens_per_abstraction)
|
||||
]
|
||||
@@ -1216,6 +1302,8 @@ def main(args):
|
||||
text_lora_parameters_one = []
|
||||
for name, param in text_encoder_one.named_parameters():
|
||||
if "token_embedding" in name:
|
||||
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
|
||||
param = param.to(dtype=torch.float32)
|
||||
param.requires_grad = True
|
||||
text_lora_parameters_one.append(param)
|
||||
else:
|
||||
@@ -1223,6 +1311,8 @@ def main(args):
|
||||
text_lora_parameters_two = []
|
||||
for name, param in text_encoder_two.named_parameters():
|
||||
if "token_embedding" in name:
|
||||
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
|
||||
param = param.to(dtype=torch.float32)
|
||||
param.requires_grad = True
|
||||
text_lora_parameters_two.append(param)
|
||||
else:
|
||||
@@ -1309,12 +1399,16 @@ def main(args):
|
||||
# different learning rate for text encoder and unet
|
||||
text_lora_parameters_one_with_lr = {
|
||||
"params": text_lora_parameters_one,
|
||||
"weight_decay": args.adam_weight_decay_text_encoder,
|
||||
"weight_decay": args.adam_weight_decay_text_encoder
|
||||
if args.adam_weight_decay_text_encoder
|
||||
else args.adam_weight_decay,
|
||||
"lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,
|
||||
}
|
||||
text_lora_parameters_two_with_lr = {
|
||||
"params": text_lora_parameters_two,
|
||||
"weight_decay": args.adam_weight_decay_text_encoder,
|
||||
"weight_decay": args.adam_weight_decay_text_encoder
|
||||
if args.adam_weight_decay_text_encoder
|
||||
else args.adam_weight_decay,
|
||||
"lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,
|
||||
}
|
||||
params_to_optimize = [
|
||||
@@ -1399,6 +1493,12 @@ def main(args):
|
||||
instance_data_root=args.instance_data_dir,
|
||||
instance_prompt=args.instance_prompt,
|
||||
class_prompt=args.class_prompt,
|
||||
dataset_name=args.dataset_name,
|
||||
dataset_config_name=args.dataset_config_name,
|
||||
cache_dir=args.cache_dir,
|
||||
image_column=args.image_column,
|
||||
train_text_encoder_ti=args.train_text_encoder_ti,
|
||||
caption_column=args.caption_column,
|
||||
class_data_root=args.class_data_dir if args.with_prior_preservation else None,
|
||||
token_abstraction_dict=token_abstraction_dict if args.train_text_encoder_ti else None,
|
||||
class_num=args.num_class_images,
|
||||
@@ -1494,6 +1594,12 @@ def main(args):
|
||||
tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0)
|
||||
tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)
|
||||
|
||||
if args.train_text_encoder_ti and args.validation_prompt:
|
||||
# replace instances of --token_abstraction in validation prompt with the new tokens: "<si><si+1>" etc.
|
||||
for token_abs, token_replacement in train_dataset.token_abstraction_dict.items():
|
||||
args.validation_prompt = args.validation_prompt.replace(token_abs, "".join(token_replacement))
|
||||
print("validation prompt:", args.validation_prompt)
|
||||
|
||||
# Scheduler and math around the number of training steps.
|
||||
overrode_max_train_steps = False
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
@@ -1593,27 +1699,10 @@ def main(args):
|
||||
if epoch == num_train_epochs_text_encoder:
|
||||
print("PIVOT HALFWAY", epoch)
|
||||
# stopping optimization of text_encoder params
|
||||
params_to_optimize = params_to_optimize[:1]
|
||||
# reinitializing the optimizer to optimize only on unet params
|
||||
if args.optimizer.lower() == "prodigy":
|
||||
optimizer = optimizer_class(
|
||||
params_to_optimize,
|
||||
lr=args.learning_rate,
|
||||
betas=(args.adam_beta1, args.adam_beta2),
|
||||
beta3=args.prodigy_beta3,
|
||||
weight_decay=args.adam_weight_decay,
|
||||
eps=args.adam_epsilon,
|
||||
decouple=args.prodigy_decouple,
|
||||
use_bias_correction=args.prodigy_use_bias_correction,
|
||||
safeguard_warmup=args.prodigy_safeguard_warmup,
|
||||
)
|
||||
else: # AdamW or 8-bit-AdamW
|
||||
optimizer = optimizer_class(
|
||||
params_to_optimize,
|
||||
betas=(args.adam_beta1, args.adam_beta2),
|
||||
weight_decay=args.adam_weight_decay,
|
||||
eps=args.adam_epsilon,
|
||||
)
|
||||
# re setting the optimizer to optimize only on unet params
|
||||
optimizer.param_groups[1]["lr"] = 0.0
|
||||
optimizer.param_groups[2]["lr"] = 0.0
|
||||
|
||||
else:
|
||||
# still optimizng the text encoder
|
||||
text_encoder_one.train()
|
||||
@@ -1628,7 +1717,7 @@ def main(args):
|
||||
with accelerator.accumulate(unet):
|
||||
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
|
||||
prompts = batch["prompts"]
|
||||
print(prompts)
|
||||
# print(prompts)
|
||||
# encode batch prompts when custom prompts are provided for each image -
|
||||
if train_dataset.custom_instance_prompts:
|
||||
if freeze_text_encoder:
|
||||
@@ -1801,12 +1890,18 @@ def main(args):
|
||||
f" {args.validation_prompt}."
|
||||
)
|
||||
# create pipeline
|
||||
if not args.train_text_encoder:
|
||||
if freeze_text_encoder:
|
||||
text_encoder_one = text_encoder_cls_one.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
|
||||
args.pretrained_model_name_or_path,
|
||||
subfolder="text_encoder",
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
)
|
||||
text_encoder_two = text_encoder_cls_two.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision
|
||||
args.pretrained_model_name_or_path,
|
||||
subfolder="text_encoder_2",
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
)
|
||||
pipeline = StableDiffusionXLPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
@@ -1815,6 +1910,7 @@ def main(args):
|
||||
text_encoder_2=accelerator.unwrap_model(text_encoder_two),
|
||||
unet=accelerator.unwrap_model(unet),
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
|
||||
@@ -1892,10 +1988,15 @@ def main(args):
|
||||
vae_path,
|
||||
subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
pipeline = StableDiffusionXLPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path, vae=vae, revision=args.revision, torch_dtype=weight_dtype
|
||||
args.pretrained_model_name_or_path,
|
||||
vae=vae,
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
|
||||
# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
|
||||
@@ -1938,21 +2039,23 @@ def main(args):
|
||||
}
|
||||
)
|
||||
|
||||
if args.push_to_hub:
|
||||
if args.train_text_encoder_ti:
|
||||
embedding_handler.save_embeddings(
|
||||
f"{args.output_dir}/embeddings.safetensors",
|
||||
)
|
||||
save_model_card(
|
||||
repo_id,
|
||||
images=images,
|
||||
base_model=args.pretrained_model_name_or_path,
|
||||
train_text_encoder=args.train_text_encoder,
|
||||
instance_prompt=args.instance_prompt,
|
||||
validation_prompt=args.validation_prompt,
|
||||
repo_folder=args.output_dir,
|
||||
vae_path=args.pretrained_vae_model_name_or_path,
|
||||
if args.train_text_encoder_ti:
|
||||
embedding_handler.save_embeddings(
|
||||
f"{args.output_dir}/embeddings.safetensors",
|
||||
)
|
||||
save_model_card(
|
||||
model_id if not args.push_to_hub else repo_id,
|
||||
images=images,
|
||||
base_model=args.pretrained_model_name_or_path,
|
||||
train_text_encoder=args.train_text_encoder,
|
||||
train_text_encoder_ti=args.train_text_encoder_ti,
|
||||
token_abstraction_dict=train_dataset.token_abstraction_dict,
|
||||
instance_prompt=args.instance_prompt,
|
||||
validation_prompt=args.validation_prompt,
|
||||
repo_folder=args.output_dir,
|
||||
vae_path=args.pretrained_vae_model_name_or_path,
|
||||
)
|
||||
if args.push_to_hub:
|
||||
upload_folder(
|
||||
repo_id=repo_id,
|
||||
folder_path=args.output_dir,
|
||||
|
||||
@@ -48,8 +48,9 @@ prompt-to-prompt | change parts of a prompt and retain image structure (see [pap
|
||||
| Latent Consistency Pipeline | Implementation of [Latent Consistency Models: Synthesizing High-Resolution Images with Few-Step Inference](https://arxiv.org/abs/2310.04378) | [Latent Consistency Pipeline](#latent-consistency-pipeline) | - | [Simian Luo](https://github.com/luosiallen) |
|
||||
| Latent Consistency Img2img Pipeline | Img2img pipeline for Latent Consistency Models | [Latent Consistency Img2Img Pipeline](#latent-consistency-img2img-pipeline) | - | [Logan Zoellner](https://github.com/nagolinc) |
|
||||
| Latent Consistency Interpolation Pipeline | Interpolate the latent space of Latent Consistency Models with multiple prompts | [Latent Consistency Interpolation Pipeline](#latent-consistency-interpolation-pipeline) | [](https://colab.research.google.com/drive/1pK3NrLWJSiJsBynLns1K1-IDTW9zbPvl?usp=sharing) | [Aryan V S](https://github.com/a-r-r-o-w) |
|
||||
| Regional Prompting Pipeline | Assign multiple prompts for different regions | [Regional Prompting Pipeline](#regional-prompting-pipeline) | - | [hako-mikan](https://github.com/hako-mikan) |
|
||||
| LDM3D-sr (LDM3D upscaler) | Upscale low resolution RGB and depth inputs to high resolution | [StableDiffusionUpscaleLDM3D Pipeline](https://github.com/estelleafl/diffusers/tree/ldm3d_upscaler_community/examples/community#stablediffusionupscaleldm3d-pipeline) | - | [Estelle Aflalo](https://github.com/estelleafl) |
|
||||
|
|
||||
| DemoFusion Pipeline | Implementation of [DemoFusion: Democratising High-Resolution Image Generation With No $$$](https://arxiv.org/abs/2311.16973) | [DemoFusion Pipeline](#DemoFusion) | - | [Ruoyi Du](https://github.com/RuoyiDu) |
|
||||
|
||||
To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly.
|
||||
```py
|
||||
@@ -77,6 +78,7 @@ from diffusers import DiffusionPipeline
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
"longlian/lmd_plus",
|
||||
custom_pipeline="llm_grounded_diffusion",
|
||||
custom_revision="main",
|
||||
variant="fp16", torch_dtype=torch.float16
|
||||
)
|
||||
pipe.enable_model_cpu_offload()
|
||||
@@ -2524,6 +2526,181 @@ images[0].save("controlnet_and_adapter_inpaint.png")
|
||||
|
||||
```
|
||||
|
||||
### Regional Prompting Pipeline
|
||||
This pipeline is a port of the [Regional Prompter extension](https://github.com/hako-mikan/sd-webui-regional-prompter) for [Stable Diffusion web UI](https://github.com/AUTOMATIC1111/stable-diffusion-webui) to diffusers.
|
||||
This code implements a pipeline for the Stable Diffusion model, enabling the division of the canvas into multiple regions, with different prompts applicable to each region. Users can specify regions in two ways: using `Cols` and `Rows` modes for grid-like divisions, or the `Prompt` mode for regions calculated based on prompts.
|
||||
|
||||

|
||||
|
||||
### Usage
|
||||
### Sample Code
|
||||
```
|
||||
from from examples.community.regional_prompting_stable_diffusion import RegionalPromptingStableDiffusionPipeline
|
||||
pipe = RegionalPromptingStableDiffusionPipeline.from_single_file(model_path, vae=vae)
|
||||
|
||||
rp_args = {
|
||||
"mode":"rows",
|
||||
"div": "1;1;1"
|
||||
}
|
||||
|
||||
prompt ="""
|
||||
green hair twintail BREAK
|
||||
red blouse BREAK
|
||||
blue skirt
|
||||
"""
|
||||
|
||||
images = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
guidance_scale=7.5,
|
||||
height = 768,
|
||||
width = 512,
|
||||
num_inference_steps =20,
|
||||
num_images_per_prompt = 1,
|
||||
rp_args = rp_args
|
||||
).images
|
||||
|
||||
time = time.strftime(r"%Y%m%d%H%M%S")
|
||||
i = 1
|
||||
for image in images:
|
||||
i += 1
|
||||
fileName = f'img-{time}-{i+1}.png'
|
||||
image.save(fileName)
|
||||
```
|
||||
### Cols, Rows mode
|
||||
In the Cols, Rows mode, you can split the screen vertically and horizontally and assign prompts to each region. The split ratio can be specified by 'div', and you can set the division ratio like '3;3;2' or '0.1;0.5'. Furthermore, as will be described later, you can also subdivide the split Cols, Rows to specify more complex regions.
|
||||
|
||||
In this image, the image is divided into three parts, and a separate prompt is applied to each. The prompts are divided by 'BREAK', and each is applied to the respective region.
|
||||

|
||||
```
|
||||
green hair twintail BREAK
|
||||
red blouse BREAK
|
||||
blue skirt
|
||||
```
|
||||
|
||||
### 2-Dimentional division
|
||||
The prompt consists of instructions separated by the term `BREAK` and is assigned to different regions of a two-dimensional space. The image is initially split in the main splitting direction, which in this case is rows, due to the presence of a single semicolon`;`, dividing the space into an upper and a lower section. Additional sub-splitting is then applied, indicated by commas. The upper row is split into ratios of `2:1:1`, while the lower row is split into a ratio of `4:6`. Rows themselves are split in a `1:2` ratio. According to the reference image, the blue sky is designated as the first region, green hair as the second, the bookshelf as the third, and so on, in a sequence based on their position from the top left. The terrarium is placed on the desk in the fourth region, and the orange dress and sofa are in the fifth region, conforming to their respective splits.
|
||||
```
|
||||
rp_args = {
|
||||
"mode":"rows",
|
||||
"div": "1,2,1,1;2,4,6"
|
||||
}
|
||||
|
||||
prompt ="""
|
||||
blue sky BREAK
|
||||
green hair BREAK
|
||||
book shelf BREAK
|
||||
terrarium on desk BREAK
|
||||
orange dress and sofa
|
||||
"""
|
||||
```
|
||||

|
||||
|
||||
### Prompt Mode
|
||||
There are limitations to methods of specifying regions in advance. This is because specifying regions can be a hindrance when designating complex shapes or dynamic compositions. In the region specified by the prompt, the regions is determined after the image generation has begun. This allows us to accommodate compositions and complex regions.
|
||||
For further infomagen, see [here](https://github.com/hako-mikan/sd-webui-regional-prompter/blob/main/prompt_en.md).
|
||||
### syntax
|
||||
```
|
||||
baseprompt target1 target2 BREAK
|
||||
effect1, target1 BREAK
|
||||
effect2 ,target2
|
||||
```
|
||||
|
||||
First, write the base prompt. In the base prompt, write the words (target1, target2) for which you want to create a mask. Next, separate them with BREAK. Next, write the prompt corresponding to target1. Then enter a comma and write target1. The order of the targets in the base prompt and the order of the BREAK-separated targets can be back to back.
|
||||
|
||||
```
|
||||
target2 baseprompt target1 BREAK
|
||||
effect1, target1 BREAK
|
||||
effect2 ,target2
|
||||
```
|
||||
is also effective.
|
||||
|
||||
### Sample
|
||||
In this example, masks are calculated for shirt, tie, skirt, and color prompts are specified only for those regions.
|
||||
```
|
||||
rp_args = {
|
||||
"mode":"prompt-ex",
|
||||
"save_mask":True,
|
||||
"th": "0.4,0.6,0.6",
|
||||
}
|
||||
|
||||
prompt ="""
|
||||
a girl in street with shirt, tie, skirt BREAK
|
||||
red, shirt BREAK
|
||||
green, tie BREAK
|
||||
blue , skirt
|
||||
"""
|
||||
```
|
||||

|
||||
### threshold
|
||||
The threshold used to determine the mask created by the prompt. This can be set as many times as there are masks, as the range varies widely depending on the target prompt. If multiple regions are used, enter them separated by commas. For example, hair tends to be ambiguous and requires a small value, while face tends to be large and requires a small value. These should be ordered by BREAK.
|
||||
|
||||
```
|
||||
a lady ,hair, face BREAK
|
||||
red, hair BREAK
|
||||
tanned ,face
|
||||
```
|
||||
`threshold : 0.4,0.6`
|
||||
If only one input is given for multiple regions, they are all assumed to be the same value.
|
||||
|
||||
### Prompt and Prompt-EX
|
||||
The difference is that in Prompt, duplicate regions are added, whereas in Prompt-EX, duplicate regions are overwritten sequentially. Since they are processed in order, setting a TARGET with a large regions first makes it easier for the effect of small regions to remain unmuffled.
|
||||
|
||||
### Accuracy
|
||||
In the case of a 512 x 512 image, Attention mode reduces the size of the region to about 8 x 8 pixels deep in the U-Net, so that small regions get mixed up; Latent mode calculates 64*64, so that the region is exact.
|
||||
```
|
||||
girl hair twintail frills,ribbons, dress, face BREAK
|
||||
girl, ,face
|
||||
```
|
||||
|
||||
### Mask
|
||||
When an image is generated, the generated mask is displayed. It is generated at the same size as the image, but is actually used at a much smaller size.
|
||||
|
||||
|
||||
### Use common prompt
|
||||
You can attach the prompt up to ADDCOMM to all prompts by separating it first with ADDCOMM. This is useful when you want to include elements common to all regions. For example, when generating pictures of three people with different appearances, it's necessary to include the instruction of 'three people' in all regions. It's also useful when inserting quality tags and other things."For example, if you write as follows:
|
||||
```
|
||||
best quality, 3persons in garden, ADDCOMM
|
||||
a girl white dress BREAK
|
||||
a boy blue shirt BREAK
|
||||
an old man red suit
|
||||
```
|
||||
If common is enabled, this prompt is converted to the following:
|
||||
```
|
||||
best quality, 3persons in garden, a girl white dress BREAK
|
||||
best quality, 3persons in garden, a boy blue shirt BREAK
|
||||
best quality, 3persons in garden, an old man red suit
|
||||
```
|
||||
### Negative prompt
|
||||
Negative prompts are equally effective across all regions, but it is possible to set region-specific prompts for negative prompts as well. The number of BREAKs must be the same as the number of prompts. If the number of prompts does not match, the negative prompts will be used without being divided into regions.
|
||||
|
||||
### Parameters
|
||||
To activate Regional Prompter, it is necessary to enter settings in rp_args. The items that can be set are as follows. rp_args is a dictionary type.
|
||||
|
||||
### Input Parameters
|
||||
Parameters are specified through the `rp_arg`(dictionary type).
|
||||
|
||||
```
|
||||
rp_args = {
|
||||
"mode":"rows",
|
||||
"div": "1;1;1"
|
||||
}
|
||||
|
||||
pipe(prompt =prompt, rp_args = rp_args)
|
||||
```
|
||||
|
||||
|
||||
|
||||
### Required Parameters
|
||||
- `mode`: Specifies the method for defining regions. Choose from `Cols`, `Rows`, `Prompt` or `Prompt-Ex`. This parameter is case-insensitive.
|
||||
- `divide`: Used in `Cols` and `Rows` modes. Details on how to specify this are provided under the respective `Cols` and `Rows` sections.
|
||||
- `th`: Used in `Prompt` mode. The method of specification is detailed under the `Prompt` section.
|
||||
|
||||
### Optional Parameters
|
||||
- `save_mask`: In `Prompt` mode, choose whether to output the generated mask along with the image. The default is `False`.
|
||||
|
||||
The Pipeline supports `compel` syntax. Input prompts using the `compel` structure will be automatically applied and processed.
|
||||
|
||||
## Diffusion Posterior Sampling Pipeline
|
||||
* Reference paper
|
||||
```
|
||||
@@ -2665,3 +2842,86 @@ images[0].save("controlnet_and_adapter_inpaint.png")
|
||||
* 
|
||||
* Reconstructed image:
|
||||
* 
|
||||
|
||||
### DemoFusion
|
||||
This pipeline is the official implementation of [DemoFusion: Democratising High-Resolution Image Generation With No $$$](https://arxiv.org/abs/2311.16973).
|
||||
The original repo can be found at [repo](https://github.com/PRIS-CV/DemoFusion).
|
||||
- `view_batch_size` (`int`, defaults to 16):
|
||||
The batch size for multiple denoising paths. Typically, a larger batch size can result in higher efficiency but comes with increased GPU memory requirements.
|
||||
|
||||
- `stride` (`int`, defaults to 64):
|
||||
The stride of moving local patches. A smaller stride is better for alleviating seam issues, but it also introduces additional computational overhead and inference time.
|
||||
|
||||
- `cosine_scale_1` (`float`, defaults to 3):
|
||||
Control the strength of skip-residual. For specific impacts, please refer to Appendix C in the DemoFusion paper.
|
||||
|
||||
- `cosine_scale_2` (`float`, defaults to 1):
|
||||
Control the strength of dilated sampling. For specific impacts, please refer to Appendix C in the DemoFusion paper.
|
||||
|
||||
- `cosine_scale_3` (`float`, defaults to 1):
|
||||
Control the strength of the Gaussian filter. For specific impacts, please refer to Appendix C in the DemoFusion paper.
|
||||
|
||||
- `sigma` (`float`, defaults to 1):
|
||||
The standard value of the Gaussian filter. Larger sigma promotes the global guidance of dilated sampling, but has the potential of over-smoothing.
|
||||
|
||||
- `multi_decoder` (`bool`, defaults to True):
|
||||
Determine whether to use a tiled decoder. Generally, when the resolution exceeds 3072x3072, a tiled decoder becomes necessary.
|
||||
|
||||
- `show_image` (`bool`, defaults to False):
|
||||
Determine whether to show intermediate results during generation.
|
||||
```
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
custom_pipeline="pipeline_demofusion_sdxl",
|
||||
custom_revision="main",
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
pipe = pipe.to("cuda")
|
||||
|
||||
prompt = "Envision a portrait of an elderly woman, her face a canvas of time, framed by a headscarf with muted tones of rust and cream. Her eyes, blue like faded denim. Her attire, simple yet dignified."
|
||||
negative_prompt = "blurry, ugly, duplicate, poorly drawn, deformed, mosaic"
|
||||
|
||||
images = pipe(
|
||||
prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
height=3072,
|
||||
width=3072,
|
||||
view_batch_size=16,
|
||||
stride=64,
|
||||
num_inference_steps=50,
|
||||
guidance_scale=7.5,
|
||||
cosine_scale_1=3,
|
||||
cosine_scale_2=1,
|
||||
cosine_scale_3=1,
|
||||
sigma=0.8,
|
||||
multi_decoder=True,
|
||||
show_image=True
|
||||
)
|
||||
```
|
||||
You can display and save the generated images as:
|
||||
```
|
||||
def image_grid(imgs, save_path=None):
|
||||
|
||||
w = 0
|
||||
for i, img in enumerate(imgs):
|
||||
h_, w_ = imgs[i].size
|
||||
w += w_
|
||||
h = h_
|
||||
grid = Image.new('RGB', size=(w, h))
|
||||
grid_w, grid_h = grid.size
|
||||
|
||||
w = 0
|
||||
for i, img in enumerate(imgs):
|
||||
h_, w_ = imgs[i].size
|
||||
grid.paste(img, box=(w, h - h_))
|
||||
if save_path != None:
|
||||
img.save(save_path + "/img_{}.jpg".format((i + 1) * 1024))
|
||||
w += w_
|
||||
|
||||
return grid
|
||||
|
||||
image_grid(images, save_path="./outputs/")
|
||||
```
|
||||

|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
|
||||
import ast
|
||||
import gc
|
||||
import inspect
|
||||
import math
|
||||
import warnings
|
||||
from collections.abc import Iterable
|
||||
@@ -23,16 +24,29 @@ from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||
from packaging import version
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
||||
|
||||
from diffusers.configuration_utils import FrozenDict
|
||||
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from diffusers.loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers.models.attention import Attention, GatedSelfAttentionDense
|
||||
from diffusers.models.attention_processor import AttnProcessor2_0
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline
|
||||
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
||||
from diffusers.pipelines import DiffusionPipeline
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from diffusers.schedulers import KarrasDiffusionSchedulers
|
||||
from diffusers.utils import logging, replace_example_docstring
|
||||
from diffusers.utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
deprecate,
|
||||
logging,
|
||||
replace_example_docstring,
|
||||
scale_lora_layers,
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
@@ -44,6 +58,7 @@ EXAMPLE_DOC_STRING = """
|
||||
>>> pipe = DiffusionPipeline.from_pretrained(
|
||||
... "longlian/lmd_plus",
|
||||
... custom_pipeline="llm_grounded_diffusion",
|
||||
... custom_revision="main",
|
||||
... variant="fp16", torch_dtype=torch.float16
|
||||
... )
|
||||
>>> pipe.enable_model_cpu_offload()
|
||||
@@ -96,7 +111,12 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
# All keys in Stable Diffusion models: [('down', 0, 0, 0), ('down', 0, 1, 0), ('down', 1, 0, 0), ('down', 1, 1, 0), ('down', 2, 0, 0), ('down', 2, 1, 0), ('mid', 0, 0, 0), ('up', 1, 0, 0), ('up', 1, 1, 0), ('up', 1, 2, 0), ('up', 2, 0, 0), ('up', 2, 1, 0), ('up', 2, 2, 0), ('up', 3, 0, 0), ('up', 3, 1, 0), ('up', 3, 2, 0)]
|
||||
# Note that the first up block is `UpBlock2D` rather than `CrossAttnUpBlock2D` and does not have attention. The last index is always 0 in our case since we have one `BasicTransformerBlock` in each `Transformer2DModel`.
|
||||
DEFAULT_GUIDANCE_ATTN_KEYS = [("mid", 0, 0, 0), ("up", 1, 0, 0), ("up", 1, 1, 0), ("up", 1, 2, 0)]
|
||||
DEFAULT_GUIDANCE_ATTN_KEYS = [
|
||||
("mid", 0, 0, 0),
|
||||
("up", 1, 0, 0),
|
||||
("up", 1, 1, 0),
|
||||
("up", 1, 2, 0),
|
||||
]
|
||||
|
||||
|
||||
def convert_attn_keys(key):
|
||||
@@ -126,7 +146,15 @@ def scale_proportion(obj_box, H, W):
|
||||
|
||||
# Adapted from the parent class `AttnProcessor2_0`
|
||||
class AttnProcessorWithHook(AttnProcessor2_0):
|
||||
def __init__(self, attn_processor_key, hidden_size, cross_attention_dim, hook=None, fast_attn=True, enabled=True):
|
||||
def __init__(
|
||||
self,
|
||||
attn_processor_key,
|
||||
hidden_size,
|
||||
cross_attention_dim,
|
||||
hook=None,
|
||||
fast_attn=True,
|
||||
enabled=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.attn_processor_key = attn_processor_key
|
||||
self.hidden_size = hidden_size
|
||||
@@ -165,15 +193,16 @@ class AttnProcessorWithHook(AttnProcessor2_0):
|
||||
if attn.group_norm is not None:
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
query = attn.to_q(hidden_states, scale=scale)
|
||||
args = () if USE_PEFT_BACKEND else (scale,)
|
||||
query = attn.to_q(hidden_states, *args)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
key = attn.to_k(encoder_hidden_states, scale=scale)
|
||||
value = attn.to_v(encoder_hidden_states, scale=scale)
|
||||
key = attn.to_k(encoder_hidden_states, *args)
|
||||
value = attn.to_v(encoder_hidden_states, *args)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
@@ -186,7 +215,13 @@ class AttnProcessorWithHook(AttnProcessor2_0):
|
||||
|
||||
if self.hook is not None and self.enabled:
|
||||
# Call the hook with query, key, value, and attention maps
|
||||
self.hook(self.attn_processor_key, query_batch_dim, key_batch_dim, value_batch_dim, attention_probs)
|
||||
self.hook(
|
||||
self.attn_processor_key,
|
||||
query_batch_dim,
|
||||
key_batch_dim,
|
||||
value_batch_dim,
|
||||
attention_probs,
|
||||
)
|
||||
|
||||
if self.fast_attn:
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
@@ -202,7 +237,12 @@ class AttnProcessorWithHook(AttnProcessor2_0):
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=attention_mask,
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
@@ -211,7 +251,7 @@ class AttnProcessorWithHook(AttnProcessor2_0):
|
||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states, scale=scale)
|
||||
hidden_states = attn.to_out[0](hidden_states, *args)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
@@ -226,7 +266,9 @@ class AttnProcessorWithHook(AttnProcessor2_0):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LLMGroundedDiffusionPipeline(StableDiffusionPipeline):
|
||||
class LLMGroundedDiffusionPipeline(
|
||||
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin
|
||||
):
|
||||
r"""
|
||||
Pipeline for layout-grounded text-to-image generation using LLM-grounded Diffusion (LMD+): https://arxiv.org/pdf/2305.13655.pdf.
|
||||
|
||||
@@ -257,6 +299,11 @@ class LLMGroundedDiffusionPipeline(StableDiffusionPipeline):
|
||||
Whether a safety checker is needed for this pipeline.
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->unet->vae"
|
||||
_optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
|
||||
_exclude_from_cpu_offload = ["safety_checker"]
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
||||
|
||||
objects_text = "Objects: "
|
||||
bg_prompt_text = "Background prompt: "
|
||||
bg_prompt_text_no_trailing_space = bg_prompt_text.rstrip()
|
||||
@@ -272,12 +319,91 @@ class LLMGroundedDiffusionPipeline(StableDiffusionPipeline):
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPImageProcessor,
|
||||
image_encoder: CLIPVisionModelWithProjection = None,
|
||||
requires_safety_checker: bool = True,
|
||||
):
|
||||
super().__init__(
|
||||
vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker
|
||||
)
|
||||
# This is copied from StableDiffusionPipeline, with hook initizations for LMD+.
|
||||
super().__init__()
|
||||
|
||||
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
|
||||
deprecation_message = (
|
||||
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
|
||||
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
|
||||
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
|
||||
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
|
||||
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
|
||||
" file"
|
||||
)
|
||||
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
|
||||
new_config = dict(scheduler.config)
|
||||
new_config["steps_offset"] = 1
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
|
||||
deprecation_message = (
|
||||
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
|
||||
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
|
||||
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
|
||||
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
|
||||
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
|
||||
)
|
||||
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
|
||||
new_config = dict(scheduler.config)
|
||||
new_config["clip_sample"] = False
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if safety_checker is None and requires_safety_checker:
|
||||
logger.warning(
|
||||
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
||||
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
||||
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
||||
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
|
||||
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
||||
)
|
||||
|
||||
if safety_checker is not None and feature_extractor is None:
|
||||
raise ValueError(
|
||||
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
|
||||
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
||||
)
|
||||
|
||||
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
|
||||
version.parse(unet.config._diffusers_version).base_version
|
||||
) < version.parse("0.9.0.dev0")
|
||||
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
||||
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
|
||||
deprecation_message = (
|
||||
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
||||
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
|
||||
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
|
||||
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
|
||||
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
|
||||
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
|
||||
" in the config might lead to incorrect results in future versions. If you have downloaded this"
|
||||
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
|
||||
" the `unet/config.json` file"
|
||||
)
|
||||
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
|
||||
new_config = dict(unet.config)
|
||||
new_config["sample_size"] = 64
|
||||
unet._internal_dict = FrozenDict(new_config)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
image_encoder=image_encoder,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
# Initialize the attention hooks for LLM-grounded Diffusion
|
||||
self.register_attn_hooks(unet)
|
||||
self._saved_attn = None
|
||||
|
||||
@@ -464,7 +590,14 @@ class LLMGroundedDiffusionPipeline(StableDiffusionPipeline):
|
||||
|
||||
return token_map
|
||||
|
||||
def get_phrase_indices(self, prompt, phrases, token_map=None, add_suffix_if_not_found=False, verbose=False):
|
||||
def get_phrase_indices(
|
||||
self,
|
||||
prompt,
|
||||
phrases,
|
||||
token_map=None,
|
||||
add_suffix_if_not_found=False,
|
||||
verbose=False,
|
||||
):
|
||||
for obj in phrases:
|
||||
# Suffix the prompt with object name for attention guidance if object is not in the prompt, using "|" to separate the prompt and the suffix
|
||||
if obj not in prompt:
|
||||
@@ -485,7 +618,14 @@ class LLMGroundedDiffusionPipeline(StableDiffusionPipeline):
|
||||
phrase_token_map_str = " ".join(phrase_token_map)
|
||||
|
||||
if verbose:
|
||||
logger.info("Full str:", token_map_str, "Substr:", phrase_token_map_str, "Phrase:", phrases)
|
||||
logger.info(
|
||||
"Full str:",
|
||||
token_map_str,
|
||||
"Substr:",
|
||||
phrase_token_map_str,
|
||||
"Phrase:",
|
||||
phrases,
|
||||
)
|
||||
|
||||
# Count the number of token before substr
|
||||
# The substring comes with a trailing space that needs to be removed by minus one in the index.
|
||||
@@ -552,7 +692,15 @@ class LLMGroundedDiffusionPipeline(StableDiffusionPipeline):
|
||||
|
||||
return loss
|
||||
|
||||
def compute_ca_loss(self, saved_attn, bboxes, phrase_indices, guidance_attn_keys, verbose=False, **kwargs):
|
||||
def compute_ca_loss(
|
||||
self,
|
||||
saved_attn,
|
||||
bboxes,
|
||||
phrase_indices,
|
||||
guidance_attn_keys,
|
||||
verbose=False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
The `saved_attn` is supposed to be passed to `save_attn_to_dict` in `cross_attention_kwargs` prior to computing ths loss.
|
||||
`AttnProcessor` will put attention maps into the `save_attn_to_dict`.
|
||||
@@ -605,6 +753,7 @@ class LLMGroundedDiffusionPipeline(StableDiffusionPipeline):
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
ip_adapter_image: Optional[PipelineImageInput] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
@@ -662,6 +811,7 @@ class LLMGroundedDiffusionPipeline(StableDiffusionPipeline):
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
|
||||
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
|
||||
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
@@ -724,9 +874,10 @@ class LLMGroundedDiffusionPipeline(StableDiffusionPipeline):
|
||||
phrase_indices = []
|
||||
prompt_parsed = []
|
||||
for prompt_item in prompt:
|
||||
phrase_indices_parsed_item, prompt_parsed_item = self.get_phrase_indices(
|
||||
prompt_item, add_suffix_if_not_found=True
|
||||
)
|
||||
(
|
||||
phrase_indices_parsed_item,
|
||||
prompt_parsed_item,
|
||||
) = self.get_phrase_indices(prompt_item, add_suffix_if_not_found=True)
|
||||
phrase_indices.append(phrase_indices_parsed_item)
|
||||
prompt_parsed.append(prompt_parsed_item)
|
||||
prompt = prompt_parsed
|
||||
@@ -759,6 +910,11 @@ class LLMGroundedDiffusionPipeline(StableDiffusionPipeline):
|
||||
if do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
if ip_adapter_image is not None:
|
||||
image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt)
|
||||
if self.do_classifier_free_guidance:
|
||||
image_embeds = torch.cat([negative_image_embeds, image_embeds])
|
||||
|
||||
# 4. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
@@ -801,7 +957,10 @@ class LLMGroundedDiffusionPipeline(StableDiffusionPipeline):
|
||||
if n_objs:
|
||||
cond_boxes[:n_objs] = torch.tensor(boxes)
|
||||
text_embeddings = torch.zeros(
|
||||
max_objs, self.unet.config.cross_attention_dim, device=device, dtype=self.text_encoder.dtype
|
||||
max_objs,
|
||||
self.unet.config.cross_attention_dim,
|
||||
device=device,
|
||||
dtype=self.text_encoder.dtype,
|
||||
)
|
||||
if n_objs:
|
||||
text_embeddings[:n_objs] = _text_embeddings
|
||||
@@ -833,6 +992,9 @@ class LLMGroundedDiffusionPipeline(StableDiffusionPipeline):
|
||||
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 6.1 Add image embeds for IP-Adapter
|
||||
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
|
||||
|
||||
loss_attn = torch.tensor(10000.0)
|
||||
|
||||
# 7. Denoising loop
|
||||
@@ -869,6 +1031,7 @@ class LLMGroundedDiffusionPipeline(StableDiffusionPipeline):
|
||||
t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
).sample
|
||||
|
||||
# perform guidance
|
||||
@@ -1013,3 +1176,438 @@ class LLMGroundedDiffusionPipeline(StableDiffusionPipeline):
|
||||
self.enable_attn_hook(enabled=False)
|
||||
|
||||
return latents, loss
|
||||
|
||||
# Below are methods copied from StableDiffusionPipeline
|
||||
# The design choice of not inheriting from StableDiffusionPipeline is discussed here: https://github.com/huggingface/diffusers/pull/5993#issuecomment-1834258517
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
|
||||
def enable_vae_slicing(self):
|
||||
r"""
|
||||
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
||||
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
||||
"""
|
||||
self.vae.enable_slicing()
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
|
||||
def disable_vae_slicing(self):
|
||||
r"""
|
||||
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
self.vae.disable_slicing()
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
|
||||
def enable_vae_tiling(self):
|
||||
r"""
|
||||
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
||||
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
||||
processing larger images.
|
||||
"""
|
||||
self.vae.enable_tiling()
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
|
||||
def disable_vae_tiling(self):
|
||||
r"""
|
||||
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
self.vae.disable_tiling()
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
|
||||
def _encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt=None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
lora_scale: Optional[float] = None,
|
||||
**kwargs,
|
||||
):
|
||||
deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
|
||||
deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
|
||||
|
||||
prompt_embeds_tuple = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
lora_scale=lora_scale,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# concatenate for backwards comp
|
||||
prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt=None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
lora_scale: Optional[float] = None,
|
||||
clip_skip: Optional[int] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
device: (`torch.device`):
|
||||
torch device
|
||||
num_images_per_prompt (`int`):
|
||||
number of images that should be generated per prompt
|
||||
do_classifier_free_guidance (`bool`):
|
||||
whether to use classifier free guidance or not
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
lora_scale (`float`, *optional*):
|
||||
A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
||||
clip_skip (`int`, *optional*):
|
||||
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
||||
the output of the pre-final layer will be used for computing the prompt embeddings.
|
||||
"""
|
||||
# set lora scale so that monkey patched LoRA
|
||||
# function of text encoder can correctly access it
|
||||
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
if not USE_PEFT_BACKEND:
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
else:
|
||||
scale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
# textual inversion: procecss multi-vector tokens if necessary
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
||||
text_input_ids, untruncated_ids
|
||||
):
|
||||
removed_text = self.tokenizer.batch_decode(
|
||||
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
||||
)
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
||||
attention_mask = text_inputs.attention_mask.to(device)
|
||||
else:
|
||||
attention_mask = None
|
||||
|
||||
if clip_skip is None:
|
||||
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
|
||||
prompt_embeds = prompt_embeds[0]
|
||||
else:
|
||||
prompt_embeds = self.text_encoder(
|
||||
text_input_ids.to(device),
|
||||
attention_mask=attention_mask,
|
||||
output_hidden_states=True,
|
||||
)
|
||||
# Access the `hidden_states` first, that contains a tuple of
|
||||
# all the hidden states from the encoder layers. Then index into
|
||||
# the tuple to access the hidden states from the desired layer.
|
||||
prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
|
||||
# We also need to apply the final LayerNorm here to not mess with the
|
||||
# representations. The `last_hidden_states` that we typically use for
|
||||
# obtaining the final prompt representations passes through the LayerNorm
|
||||
# layer.
|
||||
prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
|
||||
|
||||
if self.text_encoder is not None:
|
||||
prompt_embeds_dtype = self.text_encoder.dtype
|
||||
elif self.unet is not None:
|
||||
prompt_embeds_dtype = self.unet.dtype
|
||||
else:
|
||||
prompt_embeds_dtype = prompt_embeds.dtype
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
||||
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
uncond_tokens: List[str]
|
||||
if negative_prompt is None:
|
||||
uncond_tokens = [""] * batch_size
|
||||
elif prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif isinstance(negative_prompt, str):
|
||||
uncond_tokens = [negative_prompt]
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
else:
|
||||
uncond_tokens = negative_prompt
|
||||
|
||||
# textual inversion: procecss multi-vector tokens if necessary
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
|
||||
|
||||
max_length = prompt_embeds.shape[1]
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_tokens,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
||||
attention_mask = uncond_input.attention_mask.to(device)
|
||||
else:
|
||||
attention_mask = None
|
||||
|
||||
negative_prompt_embeds = self.text_encoder(
|
||||
uncond_input.input_ids.to(device),
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
negative_prompt_embeds = negative_prompt_embeds[0]
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = negative_prompt_embeds.shape[1]
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
|
||||
def encode_image(self, image, device, num_images_per_prompt):
|
||||
dtype = next(self.image_encoder.parameters()).dtype
|
||||
|
||||
if not isinstance(image, torch.Tensor):
|
||||
image = self.feature_extractor(image, return_tensors="pt").pixel_values
|
||||
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
image_embeds = self.image_encoder(image).image_embeds
|
||||
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
|
||||
uncond_image_embeds = torch.zeros_like(image_embeds)
|
||||
return image_embeds, uncond_image_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
|
||||
def run_safety_checker(self, image, device, dtype):
|
||||
if self.safety_checker is None:
|
||||
has_nsfw_concept = None
|
||||
else:
|
||||
if torch.is_tensor(image):
|
||||
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
|
||||
else:
|
||||
feature_extractor_input = self.image_processor.numpy_to_pil(image)
|
||||
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
|
||||
image, has_nsfw_concept = self.safety_checker(
|
||||
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
|
||||
)
|
||||
return image, has_nsfw_concept
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
|
||||
def decode_latents(self, latents):
|
||||
deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
|
||||
deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
|
||||
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
image = self.vae.decode(latents, return_dict=False)[0]
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
return image
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
|
||||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
# check if the scheduler accepts generator
|
||||
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
if accepts_generator:
|
||||
extra_step_kwargs["generator"] = generator
|
||||
return extra_step_kwargs
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
device,
|
||||
generator,
|
||||
latents=None,
|
||||
):
|
||||
shape = (
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
height // self.vae_scale_factor,
|
||||
width // self.vae_scale_factor,
|
||||
)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
latents = latents.to(device)
|
||||
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_freeu
|
||||
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
|
||||
r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497.
|
||||
|
||||
The suffixes after the scaling factors represent the stages where they are being applied.
|
||||
|
||||
Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values
|
||||
that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
|
||||
|
||||
Args:
|
||||
s1 (`float`):
|
||||
Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
|
||||
mitigate "oversmoothing effect" in the enhanced denoising process.
|
||||
s2 (`float`):
|
||||
Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
|
||||
mitigate "oversmoothing effect" in the enhanced denoising process.
|
||||
b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
|
||||
b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
|
||||
"""
|
||||
if not hasattr(self, "unet"):
|
||||
raise ValueError("The pipeline must have `unet` for using FreeU.")
|
||||
self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
|
||||
def disable_freeu(self):
|
||||
"""Disables the FreeU mechanism if enabled."""
|
||||
self.unet.disable_freeu()
|
||||
|
||||
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
|
||||
def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
|
||||
"""
|
||||
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
|
||||
|
||||
Args:
|
||||
timesteps (`torch.Tensor`):
|
||||
generate embedding vectors at these timesteps
|
||||
embedding_dim (`int`, *optional*, defaults to 512):
|
||||
dimension of the embeddings to generate
|
||||
dtype:
|
||||
data type of the generated embeddings
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
|
||||
"""
|
||||
assert len(w.shape) == 1
|
||||
w = w * 1000.0
|
||||
|
||||
half_dim = embedding_dim // 2
|
||||
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
|
||||
emb = w.to(dtype)[:, None] * emb[None, :]
|
||||
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
||||
if embedding_dim % 2 == 1: # zero pad
|
||||
emb = torch.nn.functional.pad(emb, (0, 1))
|
||||
assert emb.shape == (w.shape[0], embedding_dim)
|
||||
return emb
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.guidance_scale
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.guidance_rescale
|
||||
@property
|
||||
def guidance_rescale(self):
|
||||
return self._guidance_rescale
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.clip_skip
|
||||
@property
|
||||
def clip_skip(self):
|
||||
return self._clip_skip
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.do_classifier_free_guidance
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.cross_attention_kwargs
|
||||
@property
|
||||
def cross_attention_kwargs(self):
|
||||
return self._cross_attention_kwargs
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.num_timesteps
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,589 @@
|
||||
import math
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
import torchvision.transforms.functional as FF
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from diffusers import StableDiffusionPipeline
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from diffusers.schedulers import KarrasDiffusionSchedulers
|
||||
from diffusers.utils import USE_PEFT_BACKEND
|
||||
|
||||
|
||||
try:
|
||||
from compel import Compel
|
||||
except ImportError:
|
||||
Compel = None
|
||||
|
||||
KCOMM = "ADDCOMM"
|
||||
KBRK = "BREAK"
|
||||
|
||||
|
||||
class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
|
||||
r"""
|
||||
Args for Regional Prompting Pipeline:
|
||||
rp_args:dict
|
||||
Required
|
||||
rp_args["mode"]: cols, rows, prompt, prompt-ex
|
||||
for cols, rows mode
|
||||
rp_args["div"]: ex) 1;1;1(Divide into 3 regions)
|
||||
for prompt, prompt-ex mode
|
||||
rp_args["th"]: ex) 0.5,0.5,0.6 (threshold for prompt mode)
|
||||
|
||||
Optional
|
||||
rp_args["save_mask"]: True/False (save masks in prompt mode)
|
||||
|
||||
Pipeline for text-to-image generation using Stable Diffusion.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`CLIPTextModel`]):
|
||||
Frozen text-encoder. Stable Diffusion uses the text portion of
|
||||
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
||||
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
||||
tokenizer (`CLIPTokenizer`):
|
||||
Tokenizer of class
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
safety_checker ([`StableDiffusionSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||||
Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
|
||||
feature_extractor ([`CLIPImageProcessor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: CLIPTextModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
requires_safety_checker: bool = True,
|
||||
):
|
||||
super().__init__(
|
||||
vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker
|
||||
)
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt: str,
|
||||
height: int = 512,
|
||||
width: int = 512,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 7.5,
|
||||
negative_prompt: str = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
rp_args: Dict[str, str] = None,
|
||||
):
|
||||
active = KBRK in prompt[0] if type(prompt) == list else KBRK in prompt # noqa: E721
|
||||
if negative_prompt is None:
|
||||
negative_prompt = "" if type(prompt) == str else [""] * len(prompt) # noqa: E721
|
||||
|
||||
device = self._execution_device
|
||||
regions = 0
|
||||
|
||||
self.power = int(rp_args["power"]) if "power" in rp_args else 1
|
||||
|
||||
prompts = prompt if type(prompt) == list else [prompt] # noqa: E721
|
||||
n_prompts = negative_prompt if type(negative_prompt) == list else [negative_prompt] # noqa: E721
|
||||
self.batch = batch = num_images_per_prompt * len(prompts)
|
||||
all_prompts_cn, all_prompts_p = promptsmaker(prompts, num_images_per_prompt)
|
||||
all_n_prompts_cn, _ = promptsmaker(n_prompts, num_images_per_prompt)
|
||||
|
||||
cn = len(all_prompts_cn) == len(all_n_prompts_cn)
|
||||
|
||||
if Compel:
|
||||
compel = Compel(tokenizer=self.tokenizer, text_encoder=self.text_encoder)
|
||||
|
||||
def getcompelembs(prps):
|
||||
embl = []
|
||||
for prp in prps:
|
||||
embl.append(compel.build_conditioning_tensor(prp))
|
||||
return torch.cat(embl)
|
||||
|
||||
conds = getcompelembs(all_prompts_cn)
|
||||
unconds = getcompelembs(all_n_prompts_cn) if cn else getcompelembs(n_prompts)
|
||||
embs = getcompelembs(prompts)
|
||||
n_embs = getcompelembs(n_prompts)
|
||||
prompt = negative_prompt = None
|
||||
else:
|
||||
conds = self.encode_prompt(prompts, device, 1, True)[0]
|
||||
unconds = (
|
||||
self.encode_prompt(n_prompts, device, 1, True)[0]
|
||||
if cn
|
||||
else self.encode_prompt(all_n_prompts_cn, device, 1, True)[0]
|
||||
)
|
||||
embs = n_embs = None
|
||||
|
||||
if not active:
|
||||
pcallback = None
|
||||
mode = None
|
||||
else:
|
||||
if any(x in rp_args["mode"].upper() for x in ["COL", "ROW"]):
|
||||
mode = "COL" if "COL" in rp_args["mode"].upper() else "ROW"
|
||||
ocells, icells, regions = make_cells(rp_args["div"])
|
||||
|
||||
elif "PRO" in rp_args["mode"].upper():
|
||||
regions = len(all_prompts_p[0])
|
||||
mode = "PROMPT"
|
||||
reset_attnmaps(self)
|
||||
self.ex = "EX" in rp_args["mode"].upper()
|
||||
self.target_tokens = target_tokens = tokendealer(self, all_prompts_p)
|
||||
thresholds = [float(x) for x in rp_args["th"].split(",")]
|
||||
|
||||
orig_hw = (height, width)
|
||||
revers = True
|
||||
|
||||
def pcallback(s_self, step: int, timestep: int, latents: torch.FloatTensor, selfs=None):
|
||||
if "PRO" in mode: # in Prompt mode, make masks from sum of attension maps
|
||||
self.step = step
|
||||
|
||||
if len(self.attnmaps_sizes) > 3:
|
||||
self.history[step] = self.attnmaps.copy()
|
||||
for hw in self.attnmaps_sizes:
|
||||
allmasks = []
|
||||
basemasks = [None] * batch
|
||||
for tt, th in zip(target_tokens, thresholds):
|
||||
for b in range(batch):
|
||||
key = f"{tt}-{b}"
|
||||
_, mask, _ = makepmask(self, self.attnmaps[key], hw[0], hw[1], th, step)
|
||||
mask = mask.unsqueeze(0).unsqueeze(-1)
|
||||
if self.ex:
|
||||
allmasks[b::batch] = [x - mask for x in allmasks[b::batch]]
|
||||
allmasks[b::batch] = [torch.where(x > 0, 1, 0) for x in allmasks[b::batch]]
|
||||
allmasks.append(mask)
|
||||
basemasks[b] = mask if basemasks[b] is None else basemasks[b] + mask
|
||||
basemasks = [1 - mask for mask in basemasks]
|
||||
basemasks = [torch.where(x > 0, 1, 0) for x in basemasks]
|
||||
allmasks = basemasks + allmasks
|
||||
|
||||
self.attnmasks[hw] = torch.cat(allmasks)
|
||||
self.maskready = True
|
||||
return latents
|
||||
|
||||
def hook_forward(module):
|
||||
# diffusers==0.23.2
|
||||
def forward(
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
scale: float = 1.0,
|
||||
) -> torch.Tensor:
|
||||
attn = module
|
||||
xshape = hidden_states.shape
|
||||
self.hw = (h, w) = split_dims(xshape[1], *orig_hw)
|
||||
|
||||
if revers:
|
||||
nx, px = hidden_states.chunk(2)
|
||||
else:
|
||||
px, nx = hidden_states.chunk(2)
|
||||
|
||||
if cn:
|
||||
hidden_states = torch.cat([px for i in range(regions)] + [nx for i in range(regions)], 0)
|
||||
encoder_hidden_states = torch.cat([conds] + [unconds])
|
||||
else:
|
||||
hidden_states = torch.cat([px for i in range(regions)] + [nx], 0)
|
||||
encoder_hidden_states = torch.cat([conds] + [unconds])
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
args = () if USE_PEFT_BACKEND else (scale,)
|
||||
|
||||
if attn.spatial_norm is not None:
|
||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||
|
||||
if attn.group_norm is not None:
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
args = () if USE_PEFT_BACKEND else (scale,)
|
||||
query = attn.to_q(hidden_states, *args)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
key = attn.to_k(encoder_hidden_states, *args)
|
||||
value = attn.to_v(encoder_hidden_states, *args)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
hidden_states = scaled_dot_product_attention(
|
||||
self,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=attention_mask,
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
getattn="PRO" in mode,
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states, *args)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
if attn.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states = hidden_states / attn.rescale_output_factor
|
||||
|
||||
#### Regional Prompting Col/Row mode
|
||||
if any(x in mode for x in ["COL", "ROW"]):
|
||||
reshaped = hidden_states.reshape(hidden_states.size()[0], h, w, hidden_states.size()[2])
|
||||
center = reshaped.shape[0] // 2
|
||||
px = reshaped[0:center] if cn else reshaped[0:-batch]
|
||||
nx = reshaped[center:] if cn else reshaped[-batch:]
|
||||
outs = [px, nx] if cn else [px]
|
||||
for out in outs:
|
||||
c = 0
|
||||
for i, ocell in enumerate(ocells):
|
||||
for icell in icells[i]:
|
||||
if "ROW" in mode:
|
||||
out[
|
||||
0:batch,
|
||||
int(h * ocell[0]) : int(h * ocell[1]),
|
||||
int(w * icell[0]) : int(w * icell[1]),
|
||||
:,
|
||||
] = out[
|
||||
c * batch : (c + 1) * batch,
|
||||
int(h * ocell[0]) : int(h * ocell[1]),
|
||||
int(w * icell[0]) : int(w * icell[1]),
|
||||
:,
|
||||
]
|
||||
else:
|
||||
out[
|
||||
0:batch,
|
||||
int(h * icell[0]) : int(h * icell[1]),
|
||||
int(w * ocell[0]) : int(w * ocell[1]),
|
||||
:,
|
||||
] = out[
|
||||
c * batch : (c + 1) * batch,
|
||||
int(h * icell[0]) : int(h * icell[1]),
|
||||
int(w * ocell[0]) : int(w * ocell[1]),
|
||||
:,
|
||||
]
|
||||
c += 1
|
||||
px, nx = (px[0:batch], nx[0:batch]) if cn else (px[0:batch], nx)
|
||||
hidden_states = torch.cat([nx, px], 0) if revers else torch.cat([px, nx], 0)
|
||||
hidden_states = hidden_states.reshape(xshape)
|
||||
|
||||
#### Regional Prompting Prompt mode
|
||||
elif "PRO" in mode:
|
||||
center = reshaped.shape[0] // 2
|
||||
px = reshaped[0:center] if cn else reshaped[0:-batch]
|
||||
nx = reshaped[center:] if cn else reshaped[-batch:]
|
||||
|
||||
if (h, w) in self.attnmasks and self.maskready:
|
||||
|
||||
def mask(input):
|
||||
out = torch.multiply(input, self.attnmasks[(h, w)])
|
||||
for b in range(batch):
|
||||
for r in range(1, regions):
|
||||
out[b] = out[b] + out[r * batch + b]
|
||||
return out
|
||||
|
||||
px, nx = (mask(px), mask(nx)) if cn else (mask(px), nx)
|
||||
px, nx = (px[0:batch], nx[0:batch]) if cn else (px[0:batch], nx)
|
||||
hidden_states = torch.cat([nx, px], 0) if revers else torch.cat([px, nx], 0)
|
||||
return hidden_states
|
||||
|
||||
return forward
|
||||
|
||||
def hook_forwards(root_module: torch.nn.Module):
|
||||
for name, module in root_module.named_modules():
|
||||
if "attn2" in name and module.__class__.__name__ == "Attention":
|
||||
module.forward = hook_forward(module)
|
||||
|
||||
hook_forwards(self.unet)
|
||||
|
||||
output = StableDiffusionPipeline(**self.components)(
|
||||
prompt=prompt,
|
||||
prompt_embeds=embs,
|
||||
negative_prompt=negative_prompt,
|
||||
negative_prompt_embeds=n_embs,
|
||||
height=height,
|
||||
width=width,
|
||||
num_inference_steps=num_inference_steps,
|
||||
guidance_scale=guidance_scale,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
eta=eta,
|
||||
generator=generator,
|
||||
latents=latents,
|
||||
output_type=output_type,
|
||||
return_dict=return_dict,
|
||||
callback_on_step_end=pcallback,
|
||||
)
|
||||
|
||||
if "save_mask" in rp_args:
|
||||
save_mask = rp_args["save_mask"]
|
||||
else:
|
||||
save_mask = False
|
||||
|
||||
if mode == "PROMPT" and save_mask:
|
||||
saveattnmaps(self, output, height, width, thresholds, num_inference_steps // 2, regions)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
### Make prompt list for each regions
|
||||
def promptsmaker(prompts, batch):
|
||||
out_p = []
|
||||
plen = len(prompts)
|
||||
for prompt in prompts:
|
||||
add = ""
|
||||
if KCOMM in prompt:
|
||||
add, prompt = prompt.split(KCOMM)
|
||||
add = add + " "
|
||||
prompts = prompt.split(KBRK)
|
||||
out_p.append([add + p for p in prompts])
|
||||
out = [None] * batch * len(out_p[0]) * len(out_p)
|
||||
for p, prs in enumerate(out_p): # inputs prompts
|
||||
for r, pr in enumerate(prs): # prompts for regions
|
||||
start = (p + r * plen) * batch
|
||||
out[start : start + batch] = [pr] * batch # P1R1B1,P1R1B2...,P1R2B1,P1R2B2...,P2R1B1...
|
||||
return out, out_p
|
||||
|
||||
|
||||
### make regions from ratios
|
||||
### ";" makes outercells, "," makes inner cells
|
||||
def make_cells(ratios):
|
||||
if ";" not in ratios and "," in ratios:
|
||||
ratios = ratios.replace(",", ";")
|
||||
ratios = ratios.split(";")
|
||||
ratios = [inratios.split(",") for inratios in ratios]
|
||||
|
||||
icells = []
|
||||
ocells = []
|
||||
|
||||
def startend(cells, array):
|
||||
current_start = 0
|
||||
array = [float(x) for x in array]
|
||||
for value in array:
|
||||
end = current_start + (value / sum(array))
|
||||
cells.append([current_start, end])
|
||||
current_start = end
|
||||
|
||||
startend(ocells, [r[0] for r in ratios])
|
||||
|
||||
for inratios in ratios:
|
||||
if 2 > len(inratios):
|
||||
icells.append([[0, 1]])
|
||||
else:
|
||||
add = []
|
||||
startend(add, inratios[1:])
|
||||
icells.append(add)
|
||||
|
||||
return ocells, icells, sum(len(cell) for cell in icells)
|
||||
|
||||
|
||||
def make_emblist(self, prompts):
|
||||
with torch.no_grad():
|
||||
tokens = self.tokenizer(
|
||||
prompts, max_length=self.tokenizer.model_max_length, padding=True, truncation=True, return_tensors="pt"
|
||||
).input_ids.to(self.device)
|
||||
embs = self.text_encoder(tokens, output_hidden_states=True).last_hidden_state.to(self.device, dtype=self.dtype)
|
||||
return embs
|
||||
|
||||
|
||||
def split_dims(xs, height, width):
|
||||
xs = xs
|
||||
|
||||
def repeat_div(x, y):
|
||||
while y > 0:
|
||||
x = math.ceil(x / 2)
|
||||
y = y - 1
|
||||
return x
|
||||
|
||||
scale = math.ceil(math.log2(math.sqrt(height * width / xs)))
|
||||
dsh = repeat_div(height, scale)
|
||||
dsw = repeat_div(width, scale)
|
||||
return dsh, dsw
|
||||
|
||||
|
||||
##### for prompt mode
|
||||
def get_attn_maps(self, attn):
|
||||
height, width = self.hw
|
||||
target_tokens = self.target_tokens
|
||||
if (height, width) not in self.attnmaps_sizes:
|
||||
self.attnmaps_sizes.append((height, width))
|
||||
|
||||
for b in range(self.batch):
|
||||
for t in target_tokens:
|
||||
power = self.power
|
||||
add = attn[b, :, :, t[0] : t[0] + len(t)] ** (power) * (self.attnmaps_sizes.index((height, width)) + 1)
|
||||
add = torch.sum(add, dim=2)
|
||||
key = f"{t}-{b}"
|
||||
if key not in self.attnmaps:
|
||||
self.attnmaps[key] = add
|
||||
else:
|
||||
if self.attnmaps[key].shape[1] != add.shape[1]:
|
||||
add = add.view(8, height, width)
|
||||
add = FF.resize(add, self.attnmaps_sizes[0], antialias=None)
|
||||
add = add.reshape_as(self.attnmaps[key])
|
||||
|
||||
self.attnmaps[key] = self.attnmaps[key] + add
|
||||
|
||||
|
||||
def reset_attnmaps(self): # init parameters in every batch
|
||||
self.step = 0
|
||||
self.attnmaps = {} # maked from attention maps
|
||||
self.attnmaps_sizes = [] # height,width set of u-net blocks
|
||||
self.attnmasks = {} # maked from attnmaps for regions
|
||||
self.maskready = False
|
||||
self.history = {}
|
||||
|
||||
|
||||
def saveattnmaps(self, output, h, w, th, step, regions):
|
||||
masks = []
|
||||
for i, mask in enumerate(self.history[step].values()):
|
||||
img, _, mask = makepmask(self, mask, h, w, th[i % len(th)], step)
|
||||
if self.ex:
|
||||
masks = [x - mask for x in masks]
|
||||
masks.append(mask)
|
||||
if len(masks) == regions - 1:
|
||||
output.images.extend([FF.to_pil_image(mask) for mask in masks])
|
||||
masks = []
|
||||
else:
|
||||
output.images.append(img)
|
||||
|
||||
|
||||
def makepmask(
|
||||
self, mask, h, w, th, step
|
||||
): # make masks from attention cache return [for preview, for attention, for Latent]
|
||||
th = th - step * 0.005
|
||||
if 0.05 >= th:
|
||||
th = 0.05
|
||||
mask = torch.mean(mask, dim=0)
|
||||
mask = mask / mask.max().item()
|
||||
mask = torch.where(mask > th, 1, 0)
|
||||
mask = mask.float()
|
||||
mask = mask.view(1, *self.attnmaps_sizes[0])
|
||||
img = FF.to_pil_image(mask)
|
||||
img = img.resize((w, h))
|
||||
mask = FF.resize(mask, (h, w), interpolation=FF.InterpolationMode.NEAREST, antialias=None)
|
||||
lmask = mask
|
||||
mask = mask.reshape(h * w)
|
||||
mask = torch.where(mask > 0.1, 1, 0)
|
||||
return img, mask, lmask
|
||||
|
||||
|
||||
def tokendealer(self, all_prompts):
|
||||
for prompts in all_prompts:
|
||||
targets = [p.split(",")[-1] for p in prompts[1:]]
|
||||
tt = []
|
||||
|
||||
for target in targets:
|
||||
ptokens = (
|
||||
self.tokenizer(
|
||||
prompts,
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
).input_ids
|
||||
)[0]
|
||||
ttokens = (
|
||||
self.tokenizer(
|
||||
target,
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
).input_ids
|
||||
)[0]
|
||||
|
||||
tlist = []
|
||||
|
||||
for t in range(ttokens.shape[0] - 2):
|
||||
for p in range(ptokens.shape[0]):
|
||||
if ttokens[t + 1] == ptokens[p]:
|
||||
tlist.append(p)
|
||||
if tlist != []:
|
||||
tt.append(tlist)
|
||||
|
||||
return tt
|
||||
|
||||
|
||||
def scaled_dot_product_attention(
|
||||
self, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None, getattn=False
|
||||
) -> torch.Tensor:
|
||||
# Efficient implementation equivalent to the following:
|
||||
L, S = query.size(-2), key.size(-2)
|
||||
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
|
||||
attn_bias = torch.zeros(L, S, dtype=query.dtype, device=self.device)
|
||||
if is_causal:
|
||||
assert attn_mask is None
|
||||
temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
|
||||
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
|
||||
attn_bias.to(query.dtype)
|
||||
|
||||
if attn_mask is not None:
|
||||
if attn_mask.dtype == torch.bool:
|
||||
attn_mask.masked_fill_(attn_mask.logical_not(), float("-inf"))
|
||||
else:
|
||||
attn_bias += attn_mask
|
||||
attn_weight = query @ key.transpose(-2, -1) * scale_factor
|
||||
attn_weight += attn_bias
|
||||
attn_weight = torch.softmax(attn_weight, dim=-1)
|
||||
if getattn:
|
||||
get_attn_maps(self, attn_weight)
|
||||
attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
|
||||
return attn_weight @ value
|
||||
@@ -41,7 +41,7 @@ from polygraphy.backend.trt import (
|
||||
save_engine,
|
||||
)
|
||||
from polygraphy.backend.trt import util as trt_util
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
||||
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers.pipelines.stable_diffusion import (
|
||||
@@ -709,6 +709,7 @@ class TensorRTStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
|
||||
scheduler: DDIMScheduler,
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
image_encoder: CLIPVisionModelWithProjection = None,
|
||||
requires_safety_checker: bool = True,
|
||||
stages=["clip", "unet", "vae", "vae_encoder"],
|
||||
image_height: int = 512,
|
||||
@@ -724,7 +725,15 @@ class TensorRTStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
|
||||
timing_cache: str = "timing_cache",
|
||||
):
|
||||
super().__init__(
|
||||
vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker
|
||||
vae,
|
||||
text_encoder,
|
||||
tokenizer,
|
||||
unet,
|
||||
scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
image_encoder=image_encoder,
|
||||
requires_safety_checker=requires_safety_checker,
|
||||
)
|
||||
|
||||
self.vae.forward = self.vae.decode
|
||||
|
||||
@@ -41,7 +41,7 @@ from polygraphy.backend.trt import (
|
||||
save_engine,
|
||||
)
|
||||
from polygraphy.backend.trt import util as trt_util
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
||||
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers.pipelines.stable_diffusion import (
|
||||
@@ -710,6 +710,7 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline):
|
||||
scheduler: DDIMScheduler,
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
image_encoder: CLIPVisionModelWithProjection = None,
|
||||
requires_safety_checker: bool = True,
|
||||
stages=["clip", "unet", "vae", "vae_encoder"],
|
||||
image_height: int = 512,
|
||||
@@ -725,7 +726,15 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline):
|
||||
timing_cache: str = "timing_cache",
|
||||
):
|
||||
super().__init__(
|
||||
vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker
|
||||
vae,
|
||||
text_encoder,
|
||||
tokenizer,
|
||||
unet,
|
||||
scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
image_encoder=image_encoder,
|
||||
requires_safety_checker=requires_safety_checker,
|
||||
)
|
||||
|
||||
self.vae.forward = self.vae.decode
|
||||
|
||||
@@ -40,7 +40,7 @@ from polygraphy.backend.trt import (
|
||||
save_engine,
|
||||
)
|
||||
from polygraphy.backend.trt import util as trt_util
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
||||
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers.pipelines.stable_diffusion import (
|
||||
@@ -624,6 +624,7 @@ class TensorRTStableDiffusionPipeline(StableDiffusionPipeline):
|
||||
scheduler: DDIMScheduler,
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
image_encoder: CLIPVisionModelWithProjection = None,
|
||||
requires_safety_checker: bool = True,
|
||||
stages=["clip", "unet", "vae"],
|
||||
image_height: int = 768,
|
||||
@@ -639,7 +640,15 @@ class TensorRTStableDiffusionPipeline(StableDiffusionPipeline):
|
||||
timing_cache: str = "timing_cache",
|
||||
):
|
||||
super().__init__(
|
||||
vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker
|
||||
vae,
|
||||
text_encoder,
|
||||
tokenizer,
|
||||
unet,
|
||||
scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
image_encoder=image_encoder,
|
||||
requires_safety_checker=requires_safety_checker,
|
||||
)
|
||||
|
||||
self.vae.forward = self.vae.decode
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# Latent Consistency Distillation Example:
|
||||
|
||||
[Latent Consistency Models (LCMs)](https://arxiv.org/abs/2310.04378) is method to distill latent diffusion model to enable swift inference with minimal steps. This example demonstrates how to use the latent consistency distillation to distill stable-diffusion-v1.5 for less timestep inference.
|
||||
[Latent Consistency Models (LCMs)](https://arxiv.org/abs/2310.04378) is a method to distill a latent diffusion model to enable swift inference with minimal steps. This example demonstrates how to use latent consistency distillation to distill stable-diffusion-v1.5 for inference with few timesteps.
|
||||
|
||||
## Full model distillation
|
||||
|
||||
@@ -24,7 +24,7 @@ Then cd in the example folder and run
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
|
||||
And initialize an [🤗 Accelerate](https://github.com/huggingface/accelerate/) environment with:
|
||||
|
||||
```bash
|
||||
accelerate config
|
||||
@@ -46,12 +46,16 @@ write_basic_config()
|
||||
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
|
||||
|
||||
|
||||
#### Example with LAION-A6+ dataset
|
||||
#### Example
|
||||
|
||||
The following uses the [Conceptual Captions 12M (CC12M) dataset](https://github.com/google-research-datasets/conceptual-12m) as an example, and for illustrative purposes only. For best results you may consider large and high-quality text-image datasets such as [LAION](https://laion.ai/blog/laion-400-open-dataset/). You may also need to search the hyperparameter space according to the dataset you use.
|
||||
|
||||
```bash
|
||||
runwayml/stable-diffusion-v1-5
|
||||
PROGRAM="train_lcm_distill_sd_wds.py \
|
||||
--pretrained_teacher_model=$MODEL_DIR \
|
||||
export MODEL_NAME="runwayml/stable-diffusion-v1-5"
|
||||
export OUTPUT_DIR="path/to/saved/model"
|
||||
|
||||
accelerate launch train_lcm_distill_sd_wds.py \
|
||||
--pretrained_teacher_model=$MODEL_NAME \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--mixed_precision=fp16 \
|
||||
--resolution=512 \
|
||||
@@ -59,7 +63,7 @@ PROGRAM="train_lcm_distill_sd_wds.py \
|
||||
--max_train_steps=1000 \
|
||||
--max_train_samples=4000000 \
|
||||
--dataloader_num_workers=8 \
|
||||
--train_shards_path_or_url='pipe:aws s3 cp s3://muse-datasets/laion-aesthetic6plus-min512-data/{00000..01210}.tar -' \
|
||||
--train_shards_path_or_url="pipe:curl -L -s https://huggingface.co/datasets/laion/conceptual-captions-12m-webdataset/resolve/main/data/{00000..01099}.tar?download=true" \
|
||||
--validation_steps=200 \
|
||||
--checkpointing_steps=200 --checkpoints_total_limit=10 \
|
||||
--train_batch_size=12 \
|
||||
@@ -69,19 +73,23 @@ PROGRAM="train_lcm_distill_sd_wds.py \
|
||||
--resume_from_checkpoint=latest \
|
||||
--report_to=wandb \
|
||||
--seed=453645634 \
|
||||
--push_to_hub \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
## LCM-LoRA
|
||||
|
||||
Instead of fine-tuning the full model, we can also just train a LoRA that can be injected into any SDXL model.
|
||||
|
||||
### Example with LAION-A6+ dataset
|
||||
|
||||
### Example
|
||||
|
||||
The following uses the [Conceptual Captions 12M (CC12M) dataset](https://github.com/google-research-datasets/conceptual-12m) as an example. For best results you may consider large and high-quality text-image datasets such as [LAION](https://laion.ai/blog/laion-400-open-dataset/).
|
||||
|
||||
```bash
|
||||
runwayml/stable-diffusion-v1-5
|
||||
PROGRAM="train_lcm_distill_lora_sd_wds.py \
|
||||
--pretrained_teacher_model=$MODEL_DIR \
|
||||
export MODEL_NAME="runwayml/stable-diffusion-v1-5"
|
||||
export OUTPUT_DIR="path/to/saved/model"
|
||||
|
||||
accelerate launch train_lcm_distill_lora_sd_wds.py \
|
||||
--pretrained_teacher_model=$MODEL_NAME \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--mixed_precision=fp16 \
|
||||
--resolution=512 \
|
||||
@@ -90,7 +98,7 @@ PROGRAM="train_lcm_distill_lora_sd_wds.py \
|
||||
--max_train_steps=1000 \
|
||||
--max_train_samples=4000000 \
|
||||
--dataloader_num_workers=8 \
|
||||
--train_shards_path_or_url='pipe:aws s3 cp s3://muse-datasets/laion-aesthetic6plus-min512-data/{00000..01210}.tar -' \
|
||||
--train_shards_path_or_url="pipe:curl -L -s https://huggingface.co/datasets/laion/conceptual-captions-12m-webdataset/resolve/main/data/{00000..01099}.tar?download=true" \
|
||||
--validation_steps=200 \
|
||||
--checkpointing_steps=200 --checkpoints_total_limit=10 \
|
||||
--train_batch_size=12 \
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# Latent Consistency Distillation Example:
|
||||
|
||||
[Latent Consistency Models (LCMs)](https://arxiv.org/abs/2310.04378) is method to distill latent diffusion model to enable swift inference with minimal steps. This example demonstrates how to use the latent consistency distillation to distill SDXL for less timestep inference.
|
||||
[Latent Consistency Models (LCMs)](https://arxiv.org/abs/2310.04378) is a method to distill a latent diffusion model to enable swift inference with minimal steps. This example demonstrates how to use latent consistency distillation to distill SDXL for inference with few timesteps.
|
||||
|
||||
## Full model distillation
|
||||
|
||||
@@ -24,7 +24,7 @@ Then cd in the example folder and run
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
|
||||
And initialize an [🤗 Accelerate](https://github.com/huggingface/accelerate/) environment with:
|
||||
|
||||
```bash
|
||||
accelerate config
|
||||
@@ -46,12 +46,16 @@ write_basic_config()
|
||||
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
|
||||
|
||||
|
||||
#### Example with LAION-A6+ dataset
|
||||
#### Example
|
||||
|
||||
The following uses the [Conceptual Captions 12M (CC12M) dataset](https://github.com/google-research-datasets/conceptual-12m) as an example, and for illustrative purposes only. For best results you may consider large and high-quality text-image datasets such as [LAION](https://laion.ai/blog/laion-400-open-dataset/). You may also need to search the hyperparameter space according to the dataset you use.
|
||||
|
||||
```bash
|
||||
export MODEL_DIR="stabilityai/stable-diffusion-xl-base-1.0"
|
||||
PROGRAM="train_lcm_distill_sdxl_wds.py \
|
||||
--pretrained_teacher_model=$MODEL_DIR \
|
||||
export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
|
||||
export OUTPUT_DIR="path/to/saved/model"
|
||||
|
||||
accelerate launch train_lcm_distill_sdxl_wds.py \
|
||||
--pretrained_teacher_model=$MODEL_NAME \
|
||||
--pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--mixed_precision=fp16 \
|
||||
@@ -60,7 +64,7 @@ PROGRAM="train_lcm_distill_sdxl_wds.py \
|
||||
--max_train_steps=1000 \
|
||||
--max_train_samples=4000000 \
|
||||
--dataloader_num_workers=8 \
|
||||
--train_shards_path_or_url='pipe:aws s3 cp s3://muse-datasets/laion-aesthetic6plus-min512-data/{00000..01210}.tar -' \
|
||||
--train_shards_path_or_url="pipe:curl -L -s https://huggingface.co/datasets/laion/conceptual-captions-12m-webdataset/resolve/main/data/{00000..01099}.tar?download=true" \
|
||||
--validation_steps=200 \
|
||||
--checkpointing_steps=200 --checkpoints_total_limit=10 \
|
||||
--train_batch_size=12 \
|
||||
@@ -77,11 +81,15 @@ PROGRAM="train_lcm_distill_sdxl_wds.py \
|
||||
|
||||
Instead of fine-tuning the full model, we can also just train a LoRA that can be injected into any SDXL model.
|
||||
|
||||
### Example with LAION-A6+ dataset
|
||||
|
||||
### Example
|
||||
|
||||
The following uses the [Conceptual Captions 12M (CC12M) dataset](https://github.com/google-research-datasets/conceptual-12m) as an example. For best results you may consider large and high-quality text-image datasets such as [LAION](https://laion.ai/blog/laion-400-open-dataset/).
|
||||
|
||||
```bash
|
||||
export MODEL_DIR="stabilityai/stable-diffusion-xl-base-1.0"
|
||||
PROGRAM="train_lcm_distill_lora_sdxl_wds.py \
|
||||
export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
|
||||
export OUTPUT_DIR="path/to/saved/model"
|
||||
|
||||
accelerate launch train_lcm_distill_lora_sdxl_wds.py \
|
||||
--pretrained_teacher_model=$MODEL_DIR \
|
||||
--pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
@@ -92,7 +100,7 @@ PROGRAM="train_lcm_distill_lora_sdxl_wds.py \
|
||||
--max_train_steps=1000 \
|
||||
--max_train_samples=4000000 \
|
||||
--dataloader_num_workers=8 \
|
||||
--train_shards_path_or_url='pipe:aws s3 cp s3://muse-datasets/laion-aesthetic6plus-min512-data/{00000..01210}.tar -' \
|
||||
--train_shards_path_or_url="pipe:curl -L -s https://huggingface.co/datasets/laion/conceptual-captions-12m-webdataset/resolve/main/data/{00000..01099}.tar?download=true" \
|
||||
--validation_steps=200 \
|
||||
--checkpointing_steps=200 --checkpoints_total_limit=10 \
|
||||
--train_batch_size=12 \
|
||||
|
||||
@@ -71,7 +71,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.18.0.dev0")
|
||||
check_min_version("0.25.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -1123,7 +1123,7 @@ def main(args):
|
||||
for epoch in range(first_epoch, args.num_train_epochs):
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
with accelerator.accumulate(unet):
|
||||
image, text, _, _ = batch
|
||||
image, text = batch
|
||||
|
||||
image = image.to(accelerator.device, non_blocking=True)
|
||||
encoded_text = compute_embeddings_fn(text)
|
||||
|
||||
@@ -68,11 +68,16 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
MAX_SEQ_LENGTH = 77
|
||||
|
||||
# Adjust for your dataset
|
||||
WDS_JSON_WIDTH = "width" # original_width for LAION
|
||||
WDS_JSON_HEIGHT = "height" # original_height for LAION
|
||||
MIN_SIZE = 700 # ~960 for LAION, ideal: 1024 if the dataset contains large images
|
||||
|
||||
if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.18.0.dev0")
|
||||
check_min_version("0.25.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -146,10 +151,10 @@ class WebdatasetFilter:
|
||||
try:
|
||||
if "json" in x:
|
||||
x_json = json.loads(x["json"])
|
||||
filter_size = (x_json.get("original_width", 0.0) or 0.0) >= self.min_size and x_json.get(
|
||||
"original_height", 0
|
||||
filter_size = (x_json.get(WDS_JSON_WIDTH, 0.0) or 0.0) >= self.min_size and x_json.get(
|
||||
WDS_JSON_HEIGHT, 0
|
||||
) >= self.min_size
|
||||
filter_watermark = (x_json.get("pwatermark", 1.0) or 1.0) <= self.max_pwatermark
|
||||
filter_watermark = (x_json.get("pwatermark", 0.0) or 0.0) <= self.max_pwatermark
|
||||
return filter_size and filter_watermark
|
||||
else:
|
||||
return False
|
||||
@@ -180,7 +185,7 @@ class Text2ImageDataset:
|
||||
if use_fix_crop_and_size:
|
||||
return (resolution, resolution)
|
||||
else:
|
||||
return (int(json.get("original_width", 0.0)), int(json.get("original_height", 0.0)))
|
||||
return (int(json.get(WDS_JSON_WIDTH, 0.0)), int(json.get(WDS_JSON_HEIGHT, 0.0)))
|
||||
|
||||
def transform(example):
|
||||
# resize image
|
||||
@@ -212,7 +217,7 @@ class Text2ImageDataset:
|
||||
pipeline = [
|
||||
wds.ResampledShards(train_shards_path_or_url),
|
||||
tarfile_to_samples_nothrow,
|
||||
wds.select(WebdatasetFilter(min_size=960)),
|
||||
wds.select(WebdatasetFilter(min_size=MIN_SIZE)),
|
||||
wds.shuffle(shuffle_buffer_size),
|
||||
*processing_pipeline,
|
||||
wds.batched(per_gpu_batch_size, partial=False, collation_fn=default_collate),
|
||||
|
||||
@@ -70,7 +70,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.18.0.dev0")
|
||||
check_min_version("0.25.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -1106,7 +1106,7 @@ def main(args):
|
||||
for epoch in range(first_epoch, args.num_train_epochs):
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
with accelerator.accumulate(unet):
|
||||
image, text, _, _ = batch
|
||||
image, text = batch
|
||||
|
||||
image = image.to(accelerator.device, non_blocking=True)
|
||||
encoded_text = compute_embeddings_fn(text)
|
||||
|
||||
@@ -67,11 +67,16 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
MAX_SEQ_LENGTH = 77
|
||||
|
||||
# Adjust for your dataset
|
||||
WDS_JSON_WIDTH = "width" # original_width for LAION
|
||||
WDS_JSON_HEIGHT = "height" # original_height for LAION
|
||||
MIN_SIZE = 700 # ~960 for LAION, ideal: 1024 if the dataset contains large images
|
||||
|
||||
if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.18.0.dev0")
|
||||
check_min_version("0.25.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -128,10 +133,10 @@ class WebdatasetFilter:
|
||||
try:
|
||||
if "json" in x:
|
||||
x_json = json.loads(x["json"])
|
||||
filter_size = (x_json.get("original_width", 0.0) or 0.0) >= self.min_size and x_json.get(
|
||||
"original_height", 0
|
||||
filter_size = (x_json.get(WDS_JSON_WIDTH, 0.0) or 0.0) >= self.min_size and x_json.get(
|
||||
WDS_JSON_HEIGHT, 0
|
||||
) >= self.min_size
|
||||
filter_watermark = (x_json.get("pwatermark", 1.0) or 1.0) <= self.max_pwatermark
|
||||
filter_watermark = (x_json.get("pwatermark", 0.0) or 0.0) <= self.max_pwatermark
|
||||
return filter_size and filter_watermark
|
||||
else:
|
||||
return False
|
||||
@@ -162,7 +167,7 @@ class Text2ImageDataset:
|
||||
if use_fix_crop_and_size:
|
||||
return (resolution, resolution)
|
||||
else:
|
||||
return (int(json.get("original_width", 0.0)), int(json.get("original_height", 0.0)))
|
||||
return (int(json.get(WDS_JSON_WIDTH, 0.0)), int(json.get(WDS_JSON_HEIGHT, 0.0)))
|
||||
|
||||
def transform(example):
|
||||
# resize image
|
||||
@@ -194,7 +199,7 @@ class Text2ImageDataset:
|
||||
pipeline = [
|
||||
wds.ResampledShards(train_shards_path_or_url),
|
||||
tarfile_to_samples_nothrow,
|
||||
wds.select(WebdatasetFilter(min_size=960)),
|
||||
wds.select(WebdatasetFilter(min_size=MIN_SIZE)),
|
||||
wds.shuffle(shuffle_buffer_size),
|
||||
*processing_pipeline,
|
||||
wds.batched(per_gpu_batch_size, partial=False, collation_fn=default_collate),
|
||||
|
||||
@@ -56,7 +56,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.24.0.dev0")
|
||||
check_min_version("0.25.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -59,7 +59,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.24.0.dev0")
|
||||
check_min_version("0.25.0.dev0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -58,7 +58,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.24.0.dev0")
|
||||
check_min_version("0.25.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -62,7 +62,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.24.0.dev0")
|
||||
check_min_version("0.25.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -61,7 +61,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.24.0.dev0")
|
||||
check_min_version("0.25.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -35,7 +35,7 @@ from diffusers.utils import check_min_version
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.24.0.dev0")
|
||||
check_min_version("0.25.0.dev0")
|
||||
|
||||
# Cache compiled models across invocations of this script.
|
||||
cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache"))
|
||||
|
||||
@@ -65,7 +65,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.24.0.dev0")
|
||||
check_min_version("0.25.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -58,7 +58,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.24.0.dev0")
|
||||
check_min_version("0.25.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -52,7 +52,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.24.0.dev0")
|
||||
check_min_version("0.25.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -55,7 +55,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.24.0.dev0")
|
||||
check_min_version("0.25.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -52,7 +52,7 @@ if is_wandb_available():
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.24.0.dev0")
|
||||
check_min_version("0.25.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -46,7 +46,7 @@ from diffusers.utils import check_min_version, is_wandb_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.24.0.dev0")
|
||||
check_min_version("0.25.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -46,7 +46,7 @@ from diffusers.utils import check_min_version, is_wandb_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.24.0.dev0")
|
||||
check_min_version("0.25.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -51,7 +51,7 @@ if is_wandb_available():
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.24.0.dev0")
|
||||
check_min_version("0.25.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -58,7 +58,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.24.0.dev0")
|
||||
check_min_version("0.25.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -53,7 +53,7 @@ if is_wandb_available():
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.24.0.dev0")
|
||||
check_min_version("0.25.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -33,7 +33,7 @@ from diffusers.utils import check_min_version
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.24.0.dev0")
|
||||
check_min_version("0.25.0.dev0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -48,7 +48,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.24.0.dev0")
|
||||
check_min_version("0.25.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -58,7 +58,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.24.0.dev0")
|
||||
check_min_version("0.25.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -57,7 +57,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.24.0.dev0")
|
||||
check_min_version("0.25.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -79,7 +79,7 @@ else:
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.24.0.dev0")
|
||||
check_min_version("0.25.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -56,7 +56,7 @@ else:
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.24.0.dev0")
|
||||
check_min_version("0.25.0.dev0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.24.0.dev0")
|
||||
check_min_version("0.25.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -50,7 +50,7 @@ if is_wandb_available():
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.24.0.dev0")
|
||||
check_min_version("0.25.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -51,7 +51,7 @@ if is_wandb_available():
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.24.0.dev0")
|
||||
check_min_version("0.25.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -118,9 +118,10 @@ _deps = [
|
||||
"pytest-timeout",
|
||||
"pytest-xdist",
|
||||
"python>=3.8.0",
|
||||
"ruff>=0.1.5,<=0.2",
|
||||
"ruff==0.1.5",
|
||||
"safetensors>=0.3.1",
|
||||
"sentencepiece>=0.1.91,!=0.1.92",
|
||||
"GitPython<3.1.19",
|
||||
"scipy",
|
||||
"onnx",
|
||||
"regex!=2019.12.17",
|
||||
@@ -206,6 +207,7 @@ extras["docs"] = deps_list("hf-doc-builder")
|
||||
extras["training"] = deps_list("accelerate", "datasets", "protobuf", "tensorboard", "Jinja2")
|
||||
extras["test"] = deps_list(
|
||||
"compel",
|
||||
"GitPython",
|
||||
"datasets",
|
||||
"Jinja2",
|
||||
"invisible-watermark",
|
||||
@@ -249,13 +251,13 @@ version_range_max = max(sys.version_info[1], 10) + 1
|
||||
|
||||
setup(
|
||||
name="diffusers",
|
||||
version="0.24.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
version="0.25.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
description="State-of-the-art diffusion in PyTorch and JAX.",
|
||||
long_description=open("README.md", "r", encoding="utf-8").read(),
|
||||
long_description_content_type="text/markdown",
|
||||
keywords="deep learning diffusion jax pytorch stable diffusion audioldm",
|
||||
license="Apache",
|
||||
author="The HuggingFace team",
|
||||
license="Apache 2.0 License",
|
||||
author="The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/diffusers/graphs/contributors)",
|
||||
author_email="patrick@huggingface.co",
|
||||
url="https://github.com/huggingface/diffusers",
|
||||
package_dir={"": "src"},
|
||||
@@ -279,24 +281,3 @@ setup(
|
||||
+ [f"Programming Language :: Python :: 3.{i}" for i in range(8, version_range_max)],
|
||||
cmdclass={"deps_table_update": DepsTableUpdateCommand},
|
||||
)
|
||||
|
||||
|
||||
# Release checklist
|
||||
# 1. Change the version in __init__.py and setup.py.
|
||||
# 2. Commit these changes with the message: "Release: Release"
|
||||
# 3. Add a tag in git to mark the release: "git tag RELEASE -m 'Adds tag RELEASE for PyPI'"
|
||||
# Push the tag to git: git push --tags origin main
|
||||
# 4. Run the following commands in the top-level directory:
|
||||
# python setup.py bdist_wheel
|
||||
# python setup.py sdist
|
||||
# 5. Upload the package to the PyPI test server first:
|
||||
# twine upload dist/* -r pypitest
|
||||
# twine upload dist/* -r pypitest --repository-url=https://test.pypi.org/legacy/
|
||||
# 6. Check that you can install it in a virtualenv by running:
|
||||
# pip install -i https://testpypi.python.org/pypi diffusers
|
||||
# diffusers env
|
||||
# diffusers test
|
||||
# 7. Upload the final version to the actual PyPI:
|
||||
# twine upload dist/* -r pypi
|
||||
# 8. Add release notes to the tag in GitHub once everything is looking hunky-dory.
|
||||
# 9. Update the version in __init__.py, setup.py to the new version "-dev" and push to main.
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
__version__ = "0.24.0.dev0"
|
||||
__version__ = "0.25.0.dev0"
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
|
||||
@@ -30,9 +30,10 @@ deps = {
|
||||
"pytest-timeout": "pytest-timeout",
|
||||
"pytest-xdist": "pytest-xdist",
|
||||
"python": "python>=3.8.0",
|
||||
"ruff": "ruff>=0.1.5,<=0.2",
|
||||
"ruff": "ruff==0.1.5",
|
||||
"safetensors": "safetensors>=0.3.1",
|
||||
"sentencepiece": "sentencepiece>=0.1.91,!=0.1.92",
|
||||
"GitPython": "GitPython<3.1.19",
|
||||
"scipy": "scipy",
|
||||
"onnx": "onnx",
|
||||
"regex": "regex!=2019.12.17",
|
||||
|
||||
@@ -113,7 +113,7 @@ class ValueGuidedRLPipeline(DiffusionPipeline):
|
||||
prev_x = self.unet(x.permute(0, 2, 1), timesteps).sample.permute(0, 2, 1)
|
||||
|
||||
# TODO: verify deprecation of this kwarg
|
||||
x = self.scheduler.step(prev_x, i, x, predict_epsilon=False)["prev_sample"]
|
||||
x = self.scheduler.step(prev_x, i, x)["prev_sample"]
|
||||
|
||||
# apply conditions to the trajectory (set the initial state)
|
||||
x = self.reset_x0(x, conditions, self.action_dim)
|
||||
|
||||
@@ -391,6 +391,10 @@ class LoraLoaderMixin:
|
||||
# their prefixes.
|
||||
keys = list(state_dict.keys())
|
||||
|
||||
if all(key.startswith("unet.unet") for key in keys):
|
||||
deprecation_message = "Keys starting with 'unet.unet' are deprecated."
|
||||
deprecate("unet.unet keys", "0.27", deprecation_message)
|
||||
|
||||
if all(key.startswith(cls.unet_name) or key.startswith(cls.text_encoder_name) for key in keys):
|
||||
# Load the layers corresponding to UNet.
|
||||
logger.info(f"Loading {cls.unet_name}.")
|
||||
@@ -407,8 +411,9 @@ class LoraLoaderMixin:
|
||||
else:
|
||||
# Otherwise, we're dealing with the old format. This means the `state_dict` should only
|
||||
# contain the module names of the `unet` as its keys WITHOUT any prefix.
|
||||
warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet.{module_name}': params for module_name, params in old_state_dict.items()}`."
|
||||
logger.warn(warn_message)
|
||||
if not USE_PEFT_BACKEND:
|
||||
warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet.{module_name}': params for module_name, params in old_state_dict.items()}`."
|
||||
logger.warn(warn_message)
|
||||
|
||||
if USE_PEFT_BACKEND and len(state_dict.keys()) > 0:
|
||||
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
|
||||
@@ -675,8 +680,7 @@ class LoraLoaderMixin:
|
||||
|
||||
@classmethod
|
||||
def _remove_text_encoder_monkey_patch_classmethod(cls, text_encoder):
|
||||
if version.parse(__version__) > version.parse("0.23"):
|
||||
deprecate("_remove_text_encoder_monkey_patch_classmethod", "0.25", LORA_DEPRECATION_MESSAGE)
|
||||
deprecate("_remove_text_encoder_monkey_patch_classmethod", "0.27", LORA_DEPRECATION_MESSAGE)
|
||||
|
||||
for _, attn_module in text_encoder_attn_modules(text_encoder):
|
||||
if isinstance(attn_module.q_proj, PatchedLoraProjection):
|
||||
@@ -704,8 +708,7 @@ class LoraLoaderMixin:
|
||||
r"""
|
||||
Monkey-patches the forward passes of attention modules of the text encoder.
|
||||
"""
|
||||
if version.parse(__version__) > version.parse("0.23"):
|
||||
deprecate("_modify_text_encoder", "0.25", LORA_DEPRECATION_MESSAGE)
|
||||
deprecate("_modify_text_encoder", "0.27", LORA_DEPRECATION_MESSAGE)
|
||||
|
||||
def create_patched_linear_lora(model, network_alpha, rank, dtype, lora_parameters):
|
||||
linear_layer = model.regular_linear_layer if isinstance(model, PatchedLoraProjection) else model
|
||||
@@ -802,29 +805,21 @@ class LoraLoaderMixin:
|
||||
safe_serialization (`bool`, *optional*, defaults to `True`):
|
||||
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
||||
"""
|
||||
# Create a flat dictionary.
|
||||
state_dict = {}
|
||||
|
||||
# Populate the dictionary.
|
||||
if unet_lora_layers is not None:
|
||||
weights = (
|
||||
unet_lora_layers.state_dict() if isinstance(unet_lora_layers, torch.nn.Module) else unet_lora_layers
|
||||
)
|
||||
def pack_weights(layers, prefix):
|
||||
layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
|
||||
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
|
||||
return layers_state_dict
|
||||
|
||||
unet_lora_state_dict = {f"{cls.unet_name}.{module_name}": param for module_name, param in weights.items()}
|
||||
state_dict.update(unet_lora_state_dict)
|
||||
if not (unet_lora_layers or text_encoder_lora_layers):
|
||||
raise ValueError("You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers`.")
|
||||
|
||||
if text_encoder_lora_layers is not None:
|
||||
weights = (
|
||||
text_encoder_lora_layers.state_dict()
|
||||
if isinstance(text_encoder_lora_layers, torch.nn.Module)
|
||||
else text_encoder_lora_layers
|
||||
)
|
||||
if unet_lora_layers:
|
||||
state_dict.update(pack_weights(unet_lora_layers, "unet"))
|
||||
|
||||
text_encoder_lora_state_dict = {
|
||||
f"{cls.text_encoder_name}.{module_name}": param for module_name, param in weights.items()
|
||||
}
|
||||
state_dict.update(text_encoder_lora_state_dict)
|
||||
if text_encoder_lora_layers:
|
||||
state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
|
||||
|
||||
# Save the model
|
||||
cls.write_lora_layers(
|
||||
@@ -948,8 +943,7 @@ class LoraLoaderMixin:
|
||||
module.merge()
|
||||
|
||||
else:
|
||||
if version.parse(__version__) > version.parse("0.23"):
|
||||
deprecate("fuse_text_encoder_lora", "0.25", LORA_DEPRECATION_MESSAGE)
|
||||
deprecate("fuse_text_encoder_lora", "0.27", LORA_DEPRECATION_MESSAGE)
|
||||
|
||||
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False):
|
||||
for _, attn_module in text_encoder_attn_modules(text_encoder):
|
||||
@@ -1006,8 +1000,7 @@ class LoraLoaderMixin:
|
||||
module.unmerge()
|
||||
|
||||
else:
|
||||
if version.parse(__version__) > version.parse("0.23"):
|
||||
deprecate("unfuse_text_encoder_lora", "0.25", LORA_DEPRECATION_MESSAGE)
|
||||
deprecate("unfuse_text_encoder_lora", "0.27", LORA_DEPRECATION_MESSAGE)
|
||||
|
||||
def unfuse_text_encoder_lora(text_encoder):
|
||||
for _, attn_module in text_encoder_attn_modules(text_encoder):
|
||||
|
||||
@@ -282,7 +282,7 @@ class FromSingleFileMixin:
|
||||
)
|
||||
|
||||
if torch_dtype is not None:
|
||||
pipe.to(torch_dtype=torch_dtype)
|
||||
pipe.to(dtype=torch_dtype)
|
||||
|
||||
return pipe
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from collections import OrderedDict, defaultdict
|
||||
from contextlib import nullcontext
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
|
||||
@@ -21,7 +21,7 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ..models.embeddings import ImageProjection
|
||||
from ..models.embeddings import ImageProjection, Resampler
|
||||
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
|
||||
from ..utils import (
|
||||
DIFFUSERS_CACHE,
|
||||
@@ -672,6 +672,17 @@ class UNet2DConditionLoadersMixin:
|
||||
IPAdapterAttnProcessor2_0,
|
||||
)
|
||||
|
||||
if "proj.weight" in state_dict["image_proj"]:
|
||||
# IP-Adapter
|
||||
num_image_text_embeds = 4
|
||||
else:
|
||||
# IP-Adapter Plus
|
||||
num_image_text_embeds = state_dict["image_proj"]["latents"].shape[1]
|
||||
|
||||
# Set encoder_hid_proj after loading ip_adapter weights,
|
||||
# because `Resampler` also has `attn_processors`.
|
||||
self.encoder_hid_proj = None
|
||||
|
||||
# set ip-adapter cross-attention processors & load state_dict
|
||||
attn_procs = {}
|
||||
key_id = 1
|
||||
@@ -695,7 +706,10 @@ class UNet2DConditionLoadersMixin:
|
||||
IPAdapterAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else IPAdapterAttnProcessor
|
||||
)
|
||||
attn_procs[name] = attn_processor_class(
|
||||
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0
|
||||
hidden_size=hidden_size,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
scale=1.0,
|
||||
num_tokens=num_image_text_embeds,
|
||||
).to(dtype=self.dtype, device=self.device)
|
||||
|
||||
value_dict = {}
|
||||
@@ -708,26 +722,76 @@ class UNet2DConditionLoadersMixin:
|
||||
self.set_attn_processor(attn_procs)
|
||||
|
||||
# create image projection layers.
|
||||
clip_embeddings_dim = state_dict["image_proj"]["proj.weight"].shape[-1]
|
||||
cross_attention_dim = state_dict["image_proj"]["proj.weight"].shape[0] // 4
|
||||
if "proj.weight" in state_dict["image_proj"]:
|
||||
# IP-Adapter
|
||||
clip_embeddings_dim = state_dict["image_proj"]["proj.weight"].shape[-1]
|
||||
cross_attention_dim = state_dict["image_proj"]["proj.weight"].shape[0] // 4
|
||||
|
||||
image_projection = ImageProjection(
|
||||
cross_attention_dim=cross_attention_dim, image_embed_dim=clip_embeddings_dim, num_image_text_embeds=4
|
||||
)
|
||||
image_projection.to(dtype=self.dtype, device=self.device)
|
||||
image_projection = ImageProjection(
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
image_embed_dim=clip_embeddings_dim,
|
||||
num_image_text_embeds=num_image_text_embeds,
|
||||
)
|
||||
image_projection.to(dtype=self.dtype, device=self.device)
|
||||
|
||||
# load image projection layer weights
|
||||
image_proj_state_dict = {}
|
||||
image_proj_state_dict.update(
|
||||
{
|
||||
"image_embeds.weight": state_dict["image_proj"]["proj.weight"],
|
||||
"image_embeds.bias": state_dict["image_proj"]["proj.bias"],
|
||||
"norm.weight": state_dict["image_proj"]["norm.weight"],
|
||||
"norm.bias": state_dict["image_proj"]["norm.bias"],
|
||||
}
|
||||
)
|
||||
# load image projection layer weights
|
||||
image_proj_state_dict = {}
|
||||
image_proj_state_dict.update(
|
||||
{
|
||||
"image_embeds.weight": state_dict["image_proj"]["proj.weight"],
|
||||
"image_embeds.bias": state_dict["image_proj"]["proj.bias"],
|
||||
"norm.weight": state_dict["image_proj"]["norm.weight"],
|
||||
"norm.bias": state_dict["image_proj"]["norm.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
image_projection.load_state_dict(image_proj_state_dict)
|
||||
image_projection.load_state_dict(image_proj_state_dict)
|
||||
|
||||
else:
|
||||
# IP-Adapter Plus
|
||||
embed_dims = state_dict["image_proj"]["proj_in.weight"].shape[1]
|
||||
output_dims = state_dict["image_proj"]["proj_out.weight"].shape[0]
|
||||
hidden_dims = state_dict["image_proj"]["latents"].shape[2]
|
||||
heads = state_dict["image_proj"]["layers.0.0.to_q.weight"].shape[0] // 64
|
||||
|
||||
image_projection = Resampler(
|
||||
embed_dims=embed_dims,
|
||||
output_dims=output_dims,
|
||||
hidden_dims=hidden_dims,
|
||||
heads=heads,
|
||||
num_queries=num_image_text_embeds,
|
||||
)
|
||||
|
||||
image_proj_state_dict = state_dict["image_proj"]
|
||||
|
||||
new_sd = OrderedDict()
|
||||
for k, v in image_proj_state_dict.items():
|
||||
if "0.to" in k:
|
||||
k = k.replace("0.to", "2.to")
|
||||
elif "1.0.weight" in k:
|
||||
k = k.replace("1.0.weight", "3.0.weight")
|
||||
elif "1.0.bias" in k:
|
||||
k = k.replace("1.0.bias", "3.0.bias")
|
||||
elif "1.1.weight" in k:
|
||||
k = k.replace("1.1.weight", "3.1.net.0.proj.weight")
|
||||
elif "1.3.weight" in k:
|
||||
k = k.replace("1.3.weight", "3.1.net.2.weight")
|
||||
|
||||
if "norm1" in k:
|
||||
new_sd[k.replace("0.norm1", "0")] = v
|
||||
elif "norm2" in k:
|
||||
new_sd[k.replace("0.norm2", "1")] = v
|
||||
elif "to_kv" in k:
|
||||
v_chunk = v.chunk(2, dim=0)
|
||||
new_sd[k.replace("to_kv", "to_k")] = v_chunk[0]
|
||||
new_sd[k.replace("to_kv", "to_v")] = v_chunk[1]
|
||||
elif "to_out" in k:
|
||||
new_sd[k.replace("to_out", "to_out.0")] = v
|
||||
else:
|
||||
new_sd[k] = v
|
||||
|
||||
image_projection.load_state_dict(new_sd)
|
||||
del image_proj_state_dict
|
||||
|
||||
self.encoder_hid_proj = image_projection.to(device=self.device, dtype=self.dtype)
|
||||
self.config.encoder_hid_dim_type = "ip_image_proj"
|
||||
|
||||
@@ -34,6 +34,7 @@ if is_torch_available():
|
||||
_import_structure["controlnet"] = ["ControlNetModel"]
|
||||
_import_structure["dual_transformer_2d"] = ["DualTransformer2DModel"]
|
||||
_import_structure["modeling_utils"] = ["ModelMixin"]
|
||||
_import_structure["embeddings"] = ["ImageProjection"]
|
||||
_import_structure["prior_transformer"] = ["PriorTransformer"]
|
||||
_import_structure["t5_film_transformer"] = ["T5FilmDecoder"]
|
||||
_import_structure["transformer_2d"] = ["Transformer2DModel"]
|
||||
@@ -42,7 +43,7 @@ if is_torch_available():
|
||||
_import_structure["unet_2d"] = ["UNet2DModel"]
|
||||
_import_structure["unet_2d_condition"] = ["UNet2DConditionModel"]
|
||||
_import_structure["unet_3d_condition"] = ["UNet3DConditionModel"]
|
||||
_import_structure["unet_kandi3"] = ["Kandinsky3UNet"]
|
||||
_import_structure["unet_kandinsky3"] = ["Kandinsky3UNet"]
|
||||
_import_structure["unet_motion_model"] = ["MotionAdapter", "UNetMotionModel"]
|
||||
_import_structure["unet_spatio_temporal_condition"] = ["UNetSpatioTemporalConditionModel"]
|
||||
_import_structure["vq_model"] = ["VQModel"]
|
||||
@@ -63,6 +64,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .consistency_decoder_vae import ConsistencyDecoderVAE
|
||||
from .controlnet import ControlNetModel
|
||||
from .dual_transformer_2d import DualTransformer2DModel
|
||||
from .embeddings import ImageProjection
|
||||
from .modeling_utils import ModelMixin
|
||||
from .prior_transformer import PriorTransformer
|
||||
from .t5_film_transformer import T5FilmDecoder
|
||||
@@ -72,7 +74,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .unet_2d import UNet2DModel
|
||||
from .unet_2d_condition import UNet2DConditionModel
|
||||
from .unet_3d_condition import UNet3DConditionModel
|
||||
from .unet_kandi3 import Kandinsky3UNet
|
||||
from .unet_kandinsky3 import Kandinsky3UNet
|
||||
from .unet_motion_model import MotionAdapter, UNetMotionModel
|
||||
from .unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel
|
||||
from .vq_model import VQModel
|
||||
|
||||
@@ -55,11 +55,12 @@ class GELU(nn.Module):
|
||||
dim_in (`int`): The number of channels in the input.
|
||||
dim_out (`int`): The number of channels in the output.
|
||||
approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation.
|
||||
bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
|
||||
"""
|
||||
|
||||
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"):
|
||||
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True):
|
||||
super().__init__()
|
||||
self.proj = nn.Linear(dim_in, dim_out)
|
||||
self.proj = nn.Linear(dim_in, dim_out, bias=bias)
|
||||
self.approximate = approximate
|
||||
|
||||
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
|
||||
@@ -81,13 +82,14 @@ class GEGLU(nn.Module):
|
||||
Parameters:
|
||||
dim_in (`int`): The number of channels in the input.
|
||||
dim_out (`int`): The number of channels in the output.
|
||||
bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
|
||||
"""
|
||||
|
||||
def __init__(self, dim_in: int, dim_out: int):
|
||||
def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
|
||||
super().__init__()
|
||||
linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear
|
||||
|
||||
self.proj = linear_cls(dim_in, dim_out * 2)
|
||||
self.proj = linear_cls(dim_in, dim_out * 2, bias=bias)
|
||||
|
||||
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
|
||||
if gate.device.type != "mps":
|
||||
@@ -109,11 +111,12 @@ class ApproximateGELU(nn.Module):
|
||||
Parameters:
|
||||
dim_in (`int`): The number of channels in the input.
|
||||
dim_out (`int`): The number of channels in the output.
|
||||
bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
|
||||
"""
|
||||
|
||||
def __init__(self, dim_in: int, dim_out: int):
|
||||
def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
|
||||
super().__init__()
|
||||
self.proj = nn.Linear(dim_in, dim_out)
|
||||
self.proj = nn.Linear(dim_in, dim_out, bias=bias)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.proj(x)
|
||||
|
||||
@@ -501,6 +501,7 @@ class FeedForward(nn.Module):
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
||||
final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
|
||||
bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -511,6 +512,7 @@ class FeedForward(nn.Module):
|
||||
dropout: float = 0.0,
|
||||
activation_fn: str = "geglu",
|
||||
final_dropout: bool = False,
|
||||
bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
@@ -518,13 +520,13 @@ class FeedForward(nn.Module):
|
||||
linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear
|
||||
|
||||
if activation_fn == "gelu":
|
||||
act_fn = GELU(dim, inner_dim)
|
||||
act_fn = GELU(dim, inner_dim, bias=bias)
|
||||
if activation_fn == "gelu-approximate":
|
||||
act_fn = GELU(dim, inner_dim, approximate="tanh")
|
||||
act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias)
|
||||
elif activation_fn == "geglu":
|
||||
act_fn = GEGLU(dim, inner_dim)
|
||||
act_fn = GEGLU(dim, inner_dim, bias=bias)
|
||||
elif activation_fn == "geglu-approximate":
|
||||
act_fn = ApproximateGELU(dim, inner_dim)
|
||||
act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
|
||||
|
||||
self.net = nn.ModuleList([])
|
||||
# project in
|
||||
@@ -532,7 +534,7 @@ class FeedForward(nn.Module):
|
||||
# project dropout
|
||||
self.net.append(nn.Dropout(dropout))
|
||||
# project out
|
||||
self.net.append(linear_cls(inner_dim, dim_out))
|
||||
self.net.append(linear_cls(inner_dim, dim_out, bias=bias))
|
||||
# FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
|
||||
if final_dropout:
|
||||
self.net.append(nn.Dropout(dropout))
|
||||
|
||||
@@ -16,7 +16,7 @@ from typing import Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import einsum, nn
|
||||
from torch import nn
|
||||
|
||||
from ..utils import USE_PEFT_BACKEND, deprecate, logging
|
||||
from ..utils.import_utils import is_xformers_available
|
||||
@@ -109,15 +109,19 @@ class Attention(nn.Module):
|
||||
residual_connection: bool = False,
|
||||
_from_deprecated_attn_block: bool = False,
|
||||
processor: Optional["AttnProcessor"] = None,
|
||||
out_dim: int = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.inner_dim = dim_head * heads
|
||||
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
|
||||
self.query_dim = query_dim
|
||||
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
|
||||
self.upcast_attention = upcast_attention
|
||||
self.upcast_softmax = upcast_softmax
|
||||
self.rescale_output_factor = rescale_output_factor
|
||||
self.residual_connection = residual_connection
|
||||
self.dropout = dropout
|
||||
self.fused_projections = False
|
||||
self.out_dim = out_dim if out_dim is not None else query_dim
|
||||
|
||||
# we make use of this private variable to know whether this class is loaded
|
||||
# with an deprecated state dict so that we can convert it on the fly
|
||||
@@ -126,7 +130,7 @@ class Attention(nn.Module):
|
||||
self.scale_qk = scale_qk
|
||||
self.scale = dim_head**-0.5 if self.scale_qk else 1.0
|
||||
|
||||
self.heads = heads
|
||||
self.heads = out_dim // dim_head if out_dim is not None else heads
|
||||
# for slice_size > 0 the attention score computation
|
||||
# is split across the batch axis to save memory
|
||||
# You can set slice_size with `set_attention_slice`
|
||||
@@ -178,6 +182,7 @@ class Attention(nn.Module):
|
||||
else:
|
||||
linear_cls = LoRACompatibleLinear
|
||||
|
||||
self.linear_cls = linear_cls
|
||||
self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias)
|
||||
|
||||
if not self.only_cross_attention:
|
||||
@@ -193,7 +198,7 @@ class Attention(nn.Module):
|
||||
self.add_v_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
|
||||
|
||||
self.to_out = nn.ModuleList([])
|
||||
self.to_out.append(linear_cls(self.inner_dim, query_dim, bias=out_bias))
|
||||
self.to_out.append(linear_cls(self.inner_dim, self.out_dim, bias=out_bias))
|
||||
self.to_out.append(nn.Dropout(dropout))
|
||||
|
||||
# set attention processor
|
||||
@@ -690,6 +695,32 @@ class Attention(nn.Module):
|
||||
|
||||
return encoder_hidden_states
|
||||
|
||||
@torch.no_grad()
|
||||
def fuse_projections(self, fuse=True):
|
||||
is_cross_attention = self.cross_attention_dim != self.query_dim
|
||||
device = self.to_q.weight.data.device
|
||||
dtype = self.to_q.weight.data.dtype
|
||||
|
||||
if not is_cross_attention:
|
||||
# fetch weight matrices.
|
||||
concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
|
||||
in_features = concatenated_weights.shape[1]
|
||||
out_features = concatenated_weights.shape[0]
|
||||
|
||||
# create a new single projection layer and copy over the weights.
|
||||
self.to_qkv = self.linear_cls(in_features, out_features, bias=False, device=device, dtype=dtype)
|
||||
self.to_qkv.weight.copy_(concatenated_weights)
|
||||
|
||||
else:
|
||||
concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
|
||||
in_features = concatenated_weights.shape[1]
|
||||
out_features = concatenated_weights.shape[0]
|
||||
|
||||
self.to_kv = self.linear_cls(in_features, out_features, bias=False, device=device, dtype=dtype)
|
||||
self.to_kv.weight.copy_(concatenated_weights)
|
||||
|
||||
self.fused_projections = fuse
|
||||
|
||||
|
||||
class AttnProcessor:
|
||||
r"""
|
||||
@@ -1182,9 +1213,6 @@ class AttnProcessor2_0:
|
||||
scale: float = 1.0,
|
||||
) -> torch.FloatTensor:
|
||||
residual = hidden_states
|
||||
|
||||
args = () if USE_PEFT_BACKEND else (scale,)
|
||||
|
||||
if attn.spatial_norm is not None:
|
||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||
|
||||
@@ -1251,6 +1279,103 @@ class AttnProcessor2_0:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FusedAttnProcessor2_0:
|
||||
r"""
|
||||
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
||||
It uses fused projection layers. For self-attention modules, all projection matrices (i.e., query,
|
||||
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is currently 🧪 experimental in nature and can change in future.
|
||||
|
||||
</Tip>
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(
|
||||
"FusedAttnProcessor2_0 requires at least PyTorch 2.0, to use it. Please upgrade PyTorch to > 2.0."
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
scale: float = 1.0,
|
||||
) -> torch.FloatTensor:
|
||||
residual = hidden_states
|
||||
if attn.spatial_norm is not None:
|
||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
# scaled_dot_product_attention expects attention_mask shape to be
|
||||
# (batch, heads, source_length, target_length)
|
||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||
|
||||
if attn.group_norm is not None:
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
args = () if USE_PEFT_BACKEND else (scale,)
|
||||
if encoder_hidden_states is None:
|
||||
qkv = attn.to_qkv(hidden_states, *args)
|
||||
split_size = qkv.shape[-1] // 3
|
||||
query, key, value = torch.split(qkv, split_size, dim=-1)
|
||||
else:
|
||||
if attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
query = attn.to_q(hidden_states, *args)
|
||||
|
||||
kv = attn.to_kv(encoder_hidden_states, *args)
|
||||
split_size = kv.shape[-1] // 2
|
||||
key, value = torch.split(kv, split_size, dim=-1)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states, *args)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
if attn.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states = hidden_states / attn.rescale_output_factor
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CustomDiffusionXFormersAttnProcessor(nn.Module):
|
||||
r"""
|
||||
Processor for implementing memory efficient attention using xFormers for the Custom Diffusion method.
|
||||
@@ -2219,44 +2344,6 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
# TODO(Yiyi): This class should not exist, we can replace it with a normal attention processor I believe
|
||||
# this way torch.compile and co. will work as well
|
||||
class Kandi3AttnProcessor:
|
||||
r"""
|
||||
Default kandinsky3 proccesor for performing attention-related computations.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _reshape(hid_states, h):
|
||||
b, n, f = hid_states.shape
|
||||
d = f // h
|
||||
return hid_states.unsqueeze(-1).reshape(b, n, h, d).permute(0, 2, 1, 3)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn,
|
||||
x,
|
||||
context,
|
||||
context_mask=None,
|
||||
):
|
||||
query = self._reshape(attn.to_q(x), h=attn.num_heads)
|
||||
key = self._reshape(attn.to_k(context), h=attn.num_heads)
|
||||
value = self._reshape(attn.to_v(context), h=attn.num_heads)
|
||||
|
||||
attention_matrix = einsum("b h i d, b h j d -> b h i j", query, key)
|
||||
|
||||
if context_mask is not None:
|
||||
max_neg_value = -torch.finfo(attention_matrix.dtype).max
|
||||
context_mask = context_mask.unsqueeze(1).unsqueeze(1)
|
||||
attention_matrix = attention_matrix.masked_fill(~(context_mask != 0), max_neg_value)
|
||||
attention_matrix = (attention_matrix * attn.scale).softmax(dim=-1)
|
||||
|
||||
out = einsum("b h i j, b h j d -> b h i d", attention_matrix, value)
|
||||
out = out.permute(0, 2, 1, 3).reshape(out.shape[0], out.shape[2], -1)
|
||||
out = attn.to_out[0](out)
|
||||
return out
|
||||
|
||||
|
||||
LORA_ATTENTION_PROCESSORS = (
|
||||
LoRAAttnProcessor,
|
||||
LoRAAttnProcessor2_0,
|
||||
@@ -2282,12 +2369,12 @@ CROSS_ATTENTION_PROCESSORS = (
|
||||
LoRAXFormersAttnProcessor,
|
||||
IPAdapterAttnProcessor,
|
||||
IPAdapterAttnProcessor2_0,
|
||||
Kandi3AttnProcessor,
|
||||
)
|
||||
|
||||
AttentionProcessor = Union[
|
||||
AttnProcessor,
|
||||
AttnProcessor2_0,
|
||||
FusedAttnProcessor2_0,
|
||||
XFormersAttnProcessor,
|
||||
SlicedAttnProcessor,
|
||||
AttnAddedKVProcessor,
|
||||
|
||||
@@ -22,6 +22,7 @@ from ..utils.accelerate_utils import apply_forward_hook
|
||||
from .attention_processor import (
|
||||
ADDED_KV_ATTENTION_PROCESSORS,
|
||||
CROSS_ATTENTION_PROCESSORS,
|
||||
Attention,
|
||||
AttentionProcessor,
|
||||
AttnAddedKVProcessor,
|
||||
AttnProcessor,
|
||||
@@ -448,3 +449,41 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
|
||||
return (dec,)
|
||||
|
||||
return DecoderOutput(sample=dec)
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
|
||||
def fuse_qkv_projections(self):
|
||||
"""
|
||||
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
|
||||
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
"""
|
||||
self.original_attn_processors = None
|
||||
|
||||
for _, attn_processor in self.attn_processors.items():
|
||||
if "Added" in str(attn_processor.__class__.__name__):
|
||||
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
||||
|
||||
self.original_attn_processors = self.attn_processors
|
||||
|
||||
for module in self.modules():
|
||||
if isinstance(module, Attention):
|
||||
module.fuse_projections(fuse=True)
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
||||
def unfuse_qkv_projections(self):
|
||||
"""Disables the fused QKV projection if enabled.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
|
||||
"""
|
||||
if self.original_attn_processors is not None:
|
||||
self.set_attn_processor(self.original_attn_processors)
|
||||
|
||||
@@ -20,6 +20,7 @@ from torch import nn
|
||||
|
||||
from ..utils import USE_PEFT_BACKEND
|
||||
from .activations import get_activation
|
||||
from .attention_processor import Attention
|
||||
from .lora import LoRACompatibleLinear
|
||||
|
||||
|
||||
@@ -790,3 +791,91 @@ class CaptionProjection(nn.Module):
|
||||
hidden_states = self.act_1(hidden_states)
|
||||
hidden_states = self.linear_2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Resampler(nn.Module):
|
||||
"""Resampler of IP-Adapter Plus.
|
||||
|
||||
Args:
|
||||
----
|
||||
embed_dims (int): The feature dimension. Defaults to 768.
|
||||
output_dims (int): The number of output channels, that is the same
|
||||
number of the channels in the
|
||||
`unet.config.cross_attention_dim`. Defaults to 1024.
|
||||
hidden_dims (int): The number of hidden channels. Defaults to 1280.
|
||||
depth (int): The number of blocks. Defaults to 8.
|
||||
dim_head (int): The number of head channels. Defaults to 64.
|
||||
heads (int): Parallel attention heads. Defaults to 16.
|
||||
num_queries (int): The number of queries. Defaults to 8.
|
||||
ffn_ratio (float): The expansion ratio of feedforward network hidden
|
||||
layer channels. Defaults to 4.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dims: int = 768,
|
||||
output_dims: int = 1024,
|
||||
hidden_dims: int = 1280,
|
||||
depth: int = 4,
|
||||
dim_head: int = 64,
|
||||
heads: int = 16,
|
||||
num_queries: int = 8,
|
||||
ffn_ratio: float = 4,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
from .attention import FeedForward # Lazy import to avoid circular import
|
||||
|
||||
self.latents = nn.Parameter(torch.randn(1, num_queries, hidden_dims) / hidden_dims**0.5)
|
||||
|
||||
self.proj_in = nn.Linear(embed_dims, hidden_dims)
|
||||
|
||||
self.proj_out = nn.Linear(hidden_dims, output_dims)
|
||||
self.norm_out = nn.LayerNorm(output_dims)
|
||||
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(
|
||||
nn.ModuleList(
|
||||
[
|
||||
nn.LayerNorm(hidden_dims),
|
||||
nn.LayerNorm(hidden_dims),
|
||||
Attention(
|
||||
query_dim=hidden_dims,
|
||||
dim_head=dim_head,
|
||||
heads=heads,
|
||||
out_bias=False,
|
||||
),
|
||||
nn.Sequential(
|
||||
nn.LayerNorm(hidden_dims),
|
||||
FeedForward(hidden_dims, hidden_dims, activation_fn="gelu", mult=ffn_ratio, bias=False),
|
||||
),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Forward pass.
|
||||
|
||||
Args:
|
||||
----
|
||||
x (torch.Tensor): Input Tensor.
|
||||
|
||||
Returns:
|
||||
-------
|
||||
torch.Tensor: Output Tensor.
|
||||
"""
|
||||
latents = self.latents.repeat(x.size(0), 1, 1)
|
||||
|
||||
x = self.proj_in(x)
|
||||
|
||||
for ln0, ln1, attn, ff in self.layers:
|
||||
residual = latents
|
||||
|
||||
encoder_hidden_states = ln0(x)
|
||||
latents = ln1(latents)
|
||||
encoder_hidden_states = torch.cat([encoder_hidden_states, latents], dim=-2)
|
||||
latents = attn(latents, encoder_hidden_states) + residual
|
||||
latents = ff(latents) + latents
|
||||
|
||||
latents = self.proj_out(latents)
|
||||
return self.norm_out(latents)
|
||||
|
||||
@@ -25,6 +25,7 @@ from .activations import get_activation
|
||||
from .attention_processor import (
|
||||
ADDED_KV_ATTENTION_PROCESSORS,
|
||||
CROSS_ATTENTION_PROCESSORS,
|
||||
Attention,
|
||||
AttentionProcessor,
|
||||
AttnAddedKVProcessor,
|
||||
AttnProcessor,
|
||||
@@ -794,6 +795,42 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
|
||||
setattr(upsample_block, k, None)
|
||||
|
||||
def fuse_qkv_projections(self):
|
||||
"""
|
||||
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
|
||||
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
"""
|
||||
self.original_attn_processors = None
|
||||
|
||||
for _, attn_processor in self.attn_processors.items():
|
||||
if "Added" in str(attn_processor.__class__.__name__):
|
||||
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
||||
|
||||
self.original_attn_processors = self.attn_processors
|
||||
|
||||
for module in self.modules():
|
||||
if isinstance(module, Attention):
|
||||
module.fuse_projections(fuse=True)
|
||||
|
||||
def unfuse_qkv_projections(self):
|
||||
"""Disables the fused QKV projection if enabled.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
|
||||
"""
|
||||
if self.original_attn_processors is not None:
|
||||
self.set_attn_processor(self.original_attn_processors)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
|
||||
@@ -1,16 +1,28 @@
|
||||
import math
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import BaseOutput, logging
|
||||
from .attention_processor import AttentionProcessor, Kandi3AttnProcessor
|
||||
from .embeddings import TimestepEmbedding
|
||||
from .attention_processor import Attention, AttentionProcessor, AttnProcessor
|
||||
from .embeddings import TimestepEmbedding, Timesteps
|
||||
from .modeling_utils import ModelMixin
|
||||
|
||||
|
||||
@@ -22,36 +34,6 @@ class Kandinsky3UNetOutput(BaseOutput):
|
||||
sample: torch.FloatTensor = None
|
||||
|
||||
|
||||
# TODO(Yiyi): This class needs to be removed
|
||||
def set_default_item(condition, item_1, item_2=None):
|
||||
if condition:
|
||||
return item_1
|
||||
else:
|
||||
return item_2
|
||||
|
||||
|
||||
# TODO(Yiyi): This class needs to be removed
|
||||
def set_default_layer(condition, layer_1, args_1=[], kwargs_1={}, layer_2=torch.nn.Identity, args_2=[], kwargs_2={}):
|
||||
if condition:
|
||||
return layer_1(*args_1, **kwargs_1)
|
||||
else:
|
||||
return layer_2(*args_2, **kwargs_2)
|
||||
|
||||
|
||||
# TODO(Yiyi): This class should be removed and be replaced by Timesteps
|
||||
class SinusoidalPosEmb(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, x, type_tensor=None):
|
||||
half_dim = self.dim // 2
|
||||
emb = math.log(10000) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, device=x.device) * -emb)
|
||||
emb = x[:, None] * emb[None, :]
|
||||
return torch.cat((emb.sin(), emb.cos()), dim=-1)
|
||||
|
||||
|
||||
class Kandinsky3EncoderProj(nn.Module):
|
||||
def __init__(self, encoder_hid_dim, cross_attention_dim):
|
||||
super().__init__()
|
||||
@@ -87,9 +69,7 @@ class Kandinsky3UNet(ModelMixin, ConfigMixin):
|
||||
|
||||
out_channels = in_channels
|
||||
init_channels = block_out_channels[0] // 2
|
||||
# TODO(Yiyi): Should be replaced with Timesteps class -> make sure that results are the same
|
||||
# self.time_proj = Timesteps(init_channels, flip_sin_to_cos=False, downscale_freq_shift=1)
|
||||
self.time_proj = SinusoidalPosEmb(init_channels)
|
||||
self.time_proj = Timesteps(init_channels, flip_sin_to_cos=False, downscale_freq_shift=1)
|
||||
|
||||
self.time_embedding = TimestepEmbedding(
|
||||
init_channels,
|
||||
@@ -106,7 +86,7 @@ class Kandinsky3UNet(ModelMixin, ConfigMixin):
|
||||
|
||||
hidden_dims = [init_channels] + list(block_out_channels)
|
||||
in_out_dims = list(zip(hidden_dims[:-1], hidden_dims[1:]))
|
||||
text_dims = [set_default_item(is_exist, cross_attention_dim) for is_exist in add_cross_attention]
|
||||
text_dims = [cross_attention_dim if is_exist else None for is_exist in add_cross_attention]
|
||||
num_blocks = len(block_out_channels) * [layers_per_block]
|
||||
layer_params = [num_blocks, text_dims, add_self_attention]
|
||||
rev_layer_params = map(reversed, layer_params)
|
||||
@@ -118,7 +98,7 @@ class Kandinsky3UNet(ModelMixin, ConfigMixin):
|
||||
zip(in_out_dims, *layer_params)
|
||||
):
|
||||
down_sample = level != (self.num_levels - 1)
|
||||
cat_dims.append(set_default_item(level != (self.num_levels - 1), out_dim, 0))
|
||||
cat_dims.append(out_dim if level != (self.num_levels - 1) else 0)
|
||||
self.down_blocks.append(
|
||||
Kandinsky3DownSampleBlock(
|
||||
in_dim,
|
||||
@@ -223,18 +203,16 @@ class Kandinsky3UNet(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
Disables custom attention processors and sets the default attention implementation.
|
||||
"""
|
||||
self.set_attn_processor(Kandi3AttnProcessor())
|
||||
self.set_attn_processor(AttnProcessor())
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def forward(self, sample, timestep, encoder_hidden_states=None, encoder_attention_mask=None, return_dict=True):
|
||||
# TODO(Yiyi): Clean up the following variables - these names should not be used
|
||||
# but instead only the ones that we pass to forward
|
||||
x = sample
|
||||
context_mask = encoder_attention_mask
|
||||
context = encoder_hidden_states
|
||||
if encoder_attention_mask is not None:
|
||||
encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
|
||||
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
||||
|
||||
if not torch.is_tensor(timestep):
|
||||
dtype = torch.float32 if isinstance(timestep, float) else torch.int32
|
||||
@@ -244,33 +222,33 @@ class Kandinsky3UNet(ModelMixin, ConfigMixin):
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = timestep.expand(sample.shape[0])
|
||||
time_embed_input = self.time_proj(timestep).to(x.dtype)
|
||||
time_embed_input = self.time_proj(timestep).to(sample.dtype)
|
||||
time_embed = self.time_embedding(time_embed_input)
|
||||
|
||||
context = self.encoder_hid_proj(context)
|
||||
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
|
||||
|
||||
if context is not None:
|
||||
time_embed = self.add_time_condition(time_embed, context, context_mask)
|
||||
if encoder_hidden_states is not None:
|
||||
time_embed = self.add_time_condition(time_embed, encoder_hidden_states, encoder_attention_mask)
|
||||
|
||||
hidden_states = []
|
||||
x = self.conv_in(x)
|
||||
sample = self.conv_in(sample)
|
||||
for level, down_sample in enumerate(self.down_blocks):
|
||||
x = down_sample(x, time_embed, context, context_mask)
|
||||
sample = down_sample(sample, time_embed, encoder_hidden_states, encoder_attention_mask)
|
||||
if level != self.num_levels - 1:
|
||||
hidden_states.append(x)
|
||||
hidden_states.append(sample)
|
||||
|
||||
for level, up_sample in enumerate(self.up_blocks):
|
||||
if level != 0:
|
||||
x = torch.cat([x, hidden_states.pop()], dim=1)
|
||||
x = up_sample(x, time_embed, context, context_mask)
|
||||
sample = torch.cat([sample, hidden_states.pop()], dim=1)
|
||||
sample = up_sample(sample, time_embed, encoder_hidden_states, encoder_attention_mask)
|
||||
|
||||
x = self.conv_norm_out(x)
|
||||
x = self.conv_act_out(x)
|
||||
x = self.conv_out(x)
|
||||
sample = self.conv_norm_out(sample)
|
||||
sample = self.conv_act_out(sample)
|
||||
sample = self.conv_out(sample)
|
||||
|
||||
if not return_dict:
|
||||
return (x,)
|
||||
return Kandinsky3UNetOutput(sample=x)
|
||||
return (sample,)
|
||||
return Kandinsky3UNetOutput(sample=sample)
|
||||
|
||||
|
||||
class Kandinsky3UpSampleBlock(nn.Module):
|
||||
@@ -290,7 +268,7 @@ class Kandinsky3UpSampleBlock(nn.Module):
|
||||
self_attention=True,
|
||||
):
|
||||
super().__init__()
|
||||
up_resolutions = [[None, set_default_item(up_sample, True), None, None]] + [[None] * 4] * (num_blocks - 1)
|
||||
up_resolutions = [[None, True if up_sample else None, None, None]] + [[None] * 4] * (num_blocks - 1)
|
||||
hidden_channels = (
|
||||
[(in_channels + cat_dim, in_channels)]
|
||||
+ [(in_channels, in_channels)] * (num_blocks - 2)
|
||||
@@ -303,27 +281,27 @@ class Kandinsky3UpSampleBlock(nn.Module):
|
||||
self.self_attention = self_attention
|
||||
self.context_dim = context_dim
|
||||
|
||||
attentions.append(
|
||||
set_default_layer(
|
||||
self_attention,
|
||||
Kandinsky3AttentionBlock,
|
||||
(out_channels, time_embed_dim, None, groups, head_dim, expansion_ratio),
|
||||
layer_2=nn.Identity,
|
||||
if self_attention:
|
||||
attentions.append(
|
||||
Kandinsky3AttentionBlock(out_channels, time_embed_dim, None, groups, head_dim, expansion_ratio)
|
||||
)
|
||||
)
|
||||
else:
|
||||
attentions.append(nn.Identity())
|
||||
|
||||
for (in_channel, out_channel), up_resolution in zip(hidden_channels, up_resolutions):
|
||||
resnets_in.append(
|
||||
Kandinsky3ResNetBlock(in_channel, in_channel, time_embed_dim, groups, compression_ratio, up_resolution)
|
||||
)
|
||||
attentions.append(
|
||||
set_default_layer(
|
||||
context_dim is not None,
|
||||
Kandinsky3AttentionBlock,
|
||||
(in_channel, time_embed_dim, context_dim, groups, head_dim, expansion_ratio),
|
||||
layer_2=nn.Identity,
|
||||
|
||||
if context_dim is not None:
|
||||
attentions.append(
|
||||
Kandinsky3AttentionBlock(
|
||||
in_channel, time_embed_dim, context_dim, groups, head_dim, expansion_ratio
|
||||
)
|
||||
)
|
||||
)
|
||||
else:
|
||||
attentions.append(nn.Identity())
|
||||
|
||||
resnets_out.append(
|
||||
Kandinsky3ResNetBlock(in_channel, out_channel, time_embed_dim, groups, compression_ratio)
|
||||
)
|
||||
@@ -367,29 +345,29 @@ class Kandinsky3DownSampleBlock(nn.Module):
|
||||
self.self_attention = self_attention
|
||||
self.context_dim = context_dim
|
||||
|
||||
attentions.append(
|
||||
set_default_layer(
|
||||
self_attention,
|
||||
Kandinsky3AttentionBlock,
|
||||
(in_channels, time_embed_dim, None, groups, head_dim, expansion_ratio),
|
||||
layer_2=nn.Identity,
|
||||
if self_attention:
|
||||
attentions.append(
|
||||
Kandinsky3AttentionBlock(in_channels, time_embed_dim, None, groups, head_dim, expansion_ratio)
|
||||
)
|
||||
)
|
||||
else:
|
||||
attentions.append(nn.Identity())
|
||||
|
||||
up_resolutions = [[None] * 4] * (num_blocks - 1) + [[None, None, set_default_item(down_sample, False), None]]
|
||||
up_resolutions = [[None] * 4] * (num_blocks - 1) + [[None, None, False if down_sample else None, None]]
|
||||
hidden_channels = [(in_channels, out_channels)] + [(out_channels, out_channels)] * (num_blocks - 1)
|
||||
for (in_channel, out_channel), up_resolution in zip(hidden_channels, up_resolutions):
|
||||
resnets_in.append(
|
||||
Kandinsky3ResNetBlock(in_channel, out_channel, time_embed_dim, groups, compression_ratio)
|
||||
)
|
||||
attentions.append(
|
||||
set_default_layer(
|
||||
context_dim is not None,
|
||||
Kandinsky3AttentionBlock,
|
||||
(out_channel, time_embed_dim, context_dim, groups, head_dim, expansion_ratio),
|
||||
layer_2=nn.Identity,
|
||||
|
||||
if context_dim is not None:
|
||||
attentions.append(
|
||||
Kandinsky3AttentionBlock(
|
||||
out_channel, time_embed_dim, context_dim, groups, head_dim, expansion_ratio
|
||||
)
|
||||
)
|
||||
)
|
||||
else:
|
||||
attentions.append(nn.Identity())
|
||||
|
||||
resnets_out.append(
|
||||
Kandinsky3ResNetBlock(
|
||||
out_channel, out_channel, time_embed_dim, groups, compression_ratio, up_resolution
|
||||
@@ -431,68 +409,23 @@ class Kandinsky3ConditionalGroupNorm(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
# TODO(Yiyi): This class should ideally not even exist, it slows everything needlessly down. I'm pretty
|
||||
# sure we can delete it and instead just pass an attention_mask
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, context_dim, head_dim=64):
|
||||
super().__init__()
|
||||
assert out_channels % head_dim == 0
|
||||
self.num_heads = out_channels // head_dim
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
# to_q
|
||||
self.to_q = nn.Linear(in_channels, out_channels, bias=False)
|
||||
# to_k
|
||||
self.to_k = nn.Linear(context_dim, out_channels, bias=False)
|
||||
# to_v
|
||||
self.to_v = nn.Linear(context_dim, out_channels, bias=False)
|
||||
processor = Kandi3AttnProcessor()
|
||||
self.set_processor(processor)
|
||||
# to_out
|
||||
self.to_out = nn.ModuleList([])
|
||||
self.to_out.append(nn.Linear(out_channels, out_channels, bias=False))
|
||||
|
||||
def set_processor(self, processor: "AttnProcessor"): # noqa: F821
|
||||
# if current processor is in `self._modules` and if passed `processor` is not, we need to
|
||||
# pop `processor` from `self._modules`
|
||||
if (
|
||||
hasattr(self, "processor")
|
||||
and isinstance(self.processor, torch.nn.Module)
|
||||
and not isinstance(processor, torch.nn.Module)
|
||||
):
|
||||
logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
|
||||
self._modules.pop("processor")
|
||||
|
||||
self.processor = processor
|
||||
|
||||
def forward(self, x, context, context_mask=None, image_mask=None):
|
||||
return self.processor(
|
||||
self,
|
||||
x,
|
||||
context=context,
|
||||
context_mask=context_mask,
|
||||
)
|
||||
|
||||
|
||||
class Kandinsky3Block(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, time_embed_dim, kernel_size=3, norm_groups=32, up_resolution=None):
|
||||
super().__init__()
|
||||
self.group_norm = Kandinsky3ConditionalGroupNorm(norm_groups, in_channels, time_embed_dim)
|
||||
self.activation = nn.SiLU()
|
||||
self.up_sample = set_default_layer(
|
||||
up_resolution is not None and up_resolution,
|
||||
nn.ConvTranspose2d,
|
||||
(in_channels, in_channels),
|
||||
{"kernel_size": 2, "stride": 2},
|
||||
)
|
||||
if up_resolution is not None and up_resolution:
|
||||
self.up_sample = nn.ConvTranspose2d(in_channels, in_channels, kernel_size=2, stride=2)
|
||||
else:
|
||||
self.up_sample = nn.Identity()
|
||||
|
||||
padding = int(kernel_size > 1)
|
||||
self.projection = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding)
|
||||
self.down_sample = set_default_layer(
|
||||
up_resolution is not None and not up_resolution,
|
||||
nn.Conv2d,
|
||||
(out_channels, out_channels),
|
||||
{"kernel_size": 2, "stride": 2},
|
||||
)
|
||||
|
||||
if up_resolution is not None and not up_resolution:
|
||||
self.down_sample = nn.Conv2d(out_channels, out_channels, kernel_size=2, stride=2)
|
||||
else:
|
||||
self.down_sample = nn.Identity()
|
||||
|
||||
def forward(self, x, time_embed):
|
||||
x = self.group_norm(x, time_embed)
|
||||
@@ -521,14 +454,18 @@ class Kandinsky3ResNetBlock(nn.Module):
|
||||
)
|
||||
]
|
||||
)
|
||||
self.shortcut_up_sample = set_default_layer(
|
||||
True in up_resolutions, nn.ConvTranspose2d, (in_channels, in_channels), {"kernel_size": 2, "stride": 2}
|
||||
self.shortcut_up_sample = (
|
||||
nn.ConvTranspose2d(in_channels, in_channels, kernel_size=2, stride=2)
|
||||
if True in up_resolutions
|
||||
else nn.Identity()
|
||||
)
|
||||
self.shortcut_projection = set_default_layer(
|
||||
in_channels != out_channels, nn.Conv2d, (in_channels, out_channels), {"kernel_size": 1}
|
||||
self.shortcut_projection = (
|
||||
nn.Conv2d(in_channels, out_channels, kernel_size=1) if in_channels != out_channels else nn.Identity()
|
||||
)
|
||||
self.shortcut_down_sample = set_default_layer(
|
||||
False in up_resolutions, nn.Conv2d, (out_channels, out_channels), {"kernel_size": 2, "stride": 2}
|
||||
self.shortcut_down_sample = (
|
||||
nn.Conv2d(out_channels, out_channels, kernel_size=2, stride=2)
|
||||
if False in up_resolutions
|
||||
else nn.Identity()
|
||||
)
|
||||
|
||||
def forward(self, x, time_embed):
|
||||
@@ -546,9 +483,16 @@ class Kandinsky3ResNetBlock(nn.Module):
|
||||
class Kandinsky3AttentionPooling(nn.Module):
|
||||
def __init__(self, num_channels, context_dim, head_dim=64):
|
||||
super().__init__()
|
||||
self.attention = Attention(context_dim, num_channels, context_dim, head_dim)
|
||||
self.attention = Attention(
|
||||
context_dim,
|
||||
context_dim,
|
||||
dim_head=head_dim,
|
||||
out_dim=num_channels,
|
||||
out_bias=False,
|
||||
)
|
||||
|
||||
def forward(self, x, context, context_mask=None):
|
||||
context_mask = context_mask.to(dtype=context.dtype)
|
||||
context = self.attention(context.mean(dim=1, keepdim=True), context, context_mask)
|
||||
return x + context.squeeze(1)
|
||||
|
||||
@@ -557,7 +501,13 @@ class Kandinsky3AttentionBlock(nn.Module):
|
||||
def __init__(self, num_channels, time_embed_dim, context_dim=None, norm_groups=32, head_dim=64, expansion_ratio=4):
|
||||
super().__init__()
|
||||
self.in_norm = Kandinsky3ConditionalGroupNorm(norm_groups, num_channels, time_embed_dim)
|
||||
self.attention = Attention(num_channels, num_channels, context_dim or num_channels, head_dim)
|
||||
self.attention = Attention(
|
||||
num_channels,
|
||||
context_dim or num_channels,
|
||||
dim_head=head_dim,
|
||||
out_dim=num_channels,
|
||||
out_bias=False,
|
||||
)
|
||||
|
||||
hidden_channels = expansion_ratio * num_channels
|
||||
self.out_norm = Kandinsky3ConditionalGroupNorm(norm_groups, num_channels, time_embed_dim)
|
||||
@@ -572,14 +522,10 @@ class Kandinsky3AttentionBlock(nn.Module):
|
||||
out = self.in_norm(x, time_embed)
|
||||
out = out.reshape(x.shape[0], -1, height * width).permute(0, 2, 1)
|
||||
context = context if context is not None else out
|
||||
if context_mask is not None:
|
||||
context_mask = context_mask.to(dtype=context.dtype)
|
||||
|
||||
if image_mask is not None:
|
||||
mask_height, mask_width = image_mask.shape[-2:]
|
||||
kernel_size = (mask_height // height, mask_width // width)
|
||||
image_mask = F.max_pool2d(image_mask, kernel_size, kernel_size)
|
||||
image_mask = image_mask.reshape(image_mask.shape[0], -1)
|
||||
|
||||
out = self.attention(out, context, context_mask, image_mask)
|
||||
out = self.attention(out, context, context_mask)
|
||||
out = out.permute(0, 2, 1).unsqueeze(-1).reshape(out.shape[0], -1, height, width)
|
||||
x = x + out
|
||||
|
||||
@@ -22,7 +22,7 @@ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, XLMR
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
|
||||
from ...models.lora import adjust_lora_scale_text_encoder
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import (
|
||||
@@ -494,18 +494,29 @@ class AltDiffusionPipeline(
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
def encode_image(self, image, device, num_images_per_prompt):
|
||||
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
|
||||
dtype = next(self.image_encoder.parameters()).dtype
|
||||
|
||||
if not isinstance(image, torch.Tensor):
|
||||
image = self.feature_extractor(image, return_tensors="pt").pixel_values
|
||||
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
image_embeds = self.image_encoder(image).image_embeds
|
||||
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
if output_hidden_states:
|
||||
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
|
||||
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
uncond_image_enc_hidden_states = self.image_encoder(
|
||||
torch.zeros_like(image), output_hidden_states=True
|
||||
).hidden_states[-2]
|
||||
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
|
||||
num_images_per_prompt, dim=0
|
||||
)
|
||||
return image_enc_hidden_states, uncond_image_enc_hidden_states
|
||||
else:
|
||||
image_embeds = self.image_encoder(image).image_embeds
|
||||
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
uncond_image_embeds = torch.zeros_like(image_embeds)
|
||||
|
||||
uncond_image_embeds = torch.zeros_like(image_embeds)
|
||||
return image_embeds, uncond_image_embeds
|
||||
return image_embeds, uncond_image_embeds
|
||||
|
||||
def run_safety_checker(self, image, device, dtype):
|
||||
if self.safety_checker is None:
|
||||
@@ -875,7 +886,10 @@ class AltDiffusionPipeline(
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
if ip_adapter_image is not None:
|
||||
image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt)
|
||||
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
|
||||
image_embeds, negative_image_embeds = self.encode_image(
|
||||
ip_adapter_image, device, num_images_per_prompt, output_hidden_state
|
||||
)
|
||||
if self.do_classifier_free_guidance:
|
||||
image_embeds = torch.cat([negative_image_embeds, image_embeds])
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, XLMR
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
|
||||
from ...models.lora import adjust_lora_scale_text_encoder
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import (
|
||||
@@ -505,18 +505,29 @@ class AltDiffusionImg2ImgPipeline(
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
def encode_image(self, image, device, num_images_per_prompt):
|
||||
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
|
||||
dtype = next(self.image_encoder.parameters()).dtype
|
||||
|
||||
if not isinstance(image, torch.Tensor):
|
||||
image = self.feature_extractor(image, return_tensors="pt").pixel_values
|
||||
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
image_embeds = self.image_encoder(image).image_embeds
|
||||
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
if output_hidden_states:
|
||||
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
|
||||
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
uncond_image_enc_hidden_states = self.image_encoder(
|
||||
torch.zeros_like(image), output_hidden_states=True
|
||||
).hidden_states[-2]
|
||||
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
|
||||
num_images_per_prompt, dim=0
|
||||
)
|
||||
return image_enc_hidden_states, uncond_image_enc_hidden_states
|
||||
else:
|
||||
image_embeds = self.image_encoder(image).image_embeds
|
||||
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
uncond_image_embeds = torch.zeros_like(image_embeds)
|
||||
|
||||
uncond_image_embeds = torch.zeros_like(image_embeds)
|
||||
return image_embeds, uncond_image_embeds
|
||||
return image_embeds, uncond_image_embeds
|
||||
|
||||
def run_safety_checker(self, image, device, dtype):
|
||||
if self.safety_checker is None:
|
||||
@@ -919,7 +930,10 @@ class AltDiffusionImg2ImgPipeline(
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
if ip_adapter_image is not None:
|
||||
image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt)
|
||||
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
|
||||
image_embeds, negative_image_embeds = self.encode_image(
|
||||
ip_adapter_image, device, num_images_per_prompt, output_hidden_state
|
||||
)
|
||||
if self.do_classifier_free_guidance:
|
||||
image_embeds = torch.cat([negative_image_embeds, image_embeds])
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPV
|
||||
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel, UNetMotionModel
|
||||
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel
|
||||
from ...models.lora import adjust_lora_scale_text_encoder
|
||||
from ...models.unet_motion_model import MotionAdapter
|
||||
from ...schedulers import (
|
||||
@@ -320,18 +320,29 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
|
||||
def encode_image(self, image, device, num_images_per_prompt):
|
||||
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
|
||||
dtype = next(self.image_encoder.parameters()).dtype
|
||||
|
||||
if not isinstance(image, torch.Tensor):
|
||||
image = self.feature_extractor(image, return_tensors="pt").pixel_values
|
||||
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
image_embeds = self.image_encoder(image).image_embeds
|
||||
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
if output_hidden_states:
|
||||
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
|
||||
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
uncond_image_enc_hidden_states = self.image_encoder(
|
||||
torch.zeros_like(image), output_hidden_states=True
|
||||
).hidden_states[-2]
|
||||
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
|
||||
num_images_per_prompt, dim=0
|
||||
)
|
||||
return image_enc_hidden_states, uncond_image_enc_hidden_states
|
||||
else:
|
||||
image_embeds = self.image_encoder(image).image_embeds
|
||||
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
uncond_image_embeds = torch.zeros_like(image_embeds)
|
||||
|
||||
uncond_image_embeds = torch.zeros_like(image_embeds)
|
||||
return image_embeds, uncond_image_embeds
|
||||
return image_embeds, uncond_image_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents
|
||||
def decode_latents(self, latents):
|
||||
@@ -651,7 +662,10 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
if ip_adapter_image is not None:
|
||||
image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_videos_per_prompt)
|
||||
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
|
||||
image_embeds, negative_image_embeds = self.encode_image(
|
||||
ip_adapter_image, device, num_videos_per_prompt, output_hidden_state
|
||||
)
|
||||
if do_classifier_free_guidance:
|
||||
image_embeds = torch.cat([negative_image_embeds, image_embeds])
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPV
|
||||
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
|
||||
from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel
|
||||
from ...models.lora import adjust_lora_scale_text_encoder
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import (
|
||||
@@ -479,18 +479,29 @@ class StableDiffusionControlNetPipeline(
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
|
||||
def encode_image(self, image, device, num_images_per_prompt):
|
||||
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
|
||||
dtype = next(self.image_encoder.parameters()).dtype
|
||||
|
||||
if not isinstance(image, torch.Tensor):
|
||||
image = self.feature_extractor(image, return_tensors="pt").pixel_values
|
||||
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
image_embeds = self.image_encoder(image).image_embeds
|
||||
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
if output_hidden_states:
|
||||
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
|
||||
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
uncond_image_enc_hidden_states = self.image_encoder(
|
||||
torch.zeros_like(image), output_hidden_states=True
|
||||
).hidden_states[-2]
|
||||
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
|
||||
num_images_per_prompt, dim=0
|
||||
)
|
||||
return image_enc_hidden_states, uncond_image_enc_hidden_states
|
||||
else:
|
||||
image_embeds = self.image_encoder(image).image_embeds
|
||||
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
uncond_image_embeds = torch.zeros_like(image_embeds)
|
||||
|
||||
uncond_image_embeds = torch.zeros_like(image_embeds)
|
||||
return image_embeds, uncond_image_embeds
|
||||
return image_embeds, uncond_image_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
|
||||
def run_safety_checker(self, image, device, dtype):
|
||||
@@ -1067,7 +1078,10 @@ class StableDiffusionControlNetPipeline(
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
if ip_adapter_image is not None:
|
||||
image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt)
|
||||
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
|
||||
image_embeds, negative_image_embeds = self.encode_image(
|
||||
ip_adapter_image, device, num_images_per_prompt, output_hidden_state
|
||||
)
|
||||
if self.do_classifier_free_guidance:
|
||||
image_embeds = torch.cat([negative_image_embeds, image_embeds])
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPV
|
||||
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
|
||||
from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel
|
||||
from ...models.lora import adjust_lora_scale_text_encoder
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import (
|
||||
@@ -597,18 +597,29 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
|
||||
def encode_image(self, image, device, num_images_per_prompt):
|
||||
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
|
||||
dtype = next(self.image_encoder.parameters()).dtype
|
||||
|
||||
if not isinstance(image, torch.Tensor):
|
||||
image = self.feature_extractor(image, return_tensors="pt").pixel_values
|
||||
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
image_embeds = self.image_encoder(image).image_embeds
|
||||
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
if output_hidden_states:
|
||||
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
|
||||
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
uncond_image_enc_hidden_states = self.image_encoder(
|
||||
torch.zeros_like(image), output_hidden_states=True
|
||||
).hidden_states[-2]
|
||||
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
|
||||
num_images_per_prompt, dim=0
|
||||
)
|
||||
return image_enc_hidden_states, uncond_image_enc_hidden_states
|
||||
else:
|
||||
image_embeds = self.image_encoder(image).image_embeds
|
||||
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
uncond_image_embeds = torch.zeros_like(image_embeds)
|
||||
|
||||
uncond_image_embeds = torch.zeros_like(image_embeds)
|
||||
return image_embeds, uncond_image_embeds
|
||||
return image_embeds, uncond_image_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
|
||||
def run_safety_checker(self, image, device, dtype):
|
||||
@@ -1284,7 +1295,10 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
if ip_adapter_image is not None:
|
||||
image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt)
|
||||
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
|
||||
image_embeds, negative_image_embeds = self.encode_image(
|
||||
ip_adapter_image, device, num_images_per_prompt, output_hidden_state
|
||||
)
|
||||
if self.do_classifier_free_guidance:
|
||||
image_embeds = torch.cat([negative_image_embeds, image_embeds])
|
||||
|
||||
|
||||
@@ -37,7 +37,7 @@ from ...loaders import (
|
||||
StableDiffusionXLLoraLoaderMixin,
|
||||
TextualInversionLoaderMixin,
|
||||
)
|
||||
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
|
||||
from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel
|
||||
from ...models.attention_processor import (
|
||||
AttnProcessor2_0,
|
||||
LoRAAttnProcessor2_0,
|
||||
@@ -489,18 +489,29 @@ class StableDiffusionXLControlNetPipeline(
|
||||
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
|
||||
def encode_image(self, image, device, num_images_per_prompt):
|
||||
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
|
||||
dtype = next(self.image_encoder.parameters()).dtype
|
||||
|
||||
if not isinstance(image, torch.Tensor):
|
||||
image = self.feature_extractor(image, return_tensors="pt").pixel_values
|
||||
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
image_embeds = self.image_encoder(image).image_embeds
|
||||
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
if output_hidden_states:
|
||||
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
|
||||
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
uncond_image_enc_hidden_states = self.image_encoder(
|
||||
torch.zeros_like(image), output_hidden_states=True
|
||||
).hidden_states[-2]
|
||||
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
|
||||
num_images_per_prompt, dim=0
|
||||
)
|
||||
return image_enc_hidden_states, uncond_image_enc_hidden_states
|
||||
else:
|
||||
image_embeds = self.image_encoder(image).image_embeds
|
||||
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
uncond_image_embeds = torch.zeros_like(image_embeds)
|
||||
|
||||
uncond_image_embeds = torch.zeros_like(image_embeds)
|
||||
return image_embeds, uncond_image_embeds
|
||||
return image_embeds, uncond_image_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
@@ -1169,7 +1180,10 @@ class StableDiffusionXLControlNetPipeline(
|
||||
|
||||
# 3.2 Encode ip_adapter_image
|
||||
if ip_adapter_image is not None:
|
||||
image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt)
|
||||
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
|
||||
image_embeds, negative_image_embeds = self.encode_image(
|
||||
ip_adapter_image, device, num_images_per_prompt, output_hidden_state
|
||||
)
|
||||
if self.do_classifier_free_guidance:
|
||||
image_embeds = torch.cat([negative_image_embeds, image_embeds])
|
||||
|
||||
|
||||
@@ -21,8 +21,8 @@ except OptionalDependencyNotAvailable:
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["kandinsky3_pipeline"] = ["Kandinsky3Pipeline"]
|
||||
_import_structure["kandinsky3img2img_pipeline"] = ["Kandinsky3Img2ImgPipeline"]
|
||||
_import_structure["pipeline_kandinsky3"] = ["Kandinsky3Pipeline"]
|
||||
_import_structure["pipeline_kandinsky3_img2img"] = ["Kandinsky3Img2ImgPipeline"]
|
||||
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
@@ -33,8 +33,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .kandinsky3_pipeline import Kandinsky3Pipeline
|
||||
from .kandinsky3img2img_pipeline import Kandinsky3Img2ImgPipeline
|
||||
from .pipeline_kandinsky3 import Kandinsky3Pipeline
|
||||
from .pipeline_kandinsky3_img2img import Kandinsky3Img2ImgPipeline
|
||||
else:
|
||||
import sys
|
||||
|
||||
|
||||
+172
-35
@@ -1,4 +1,4 @@
|
||||
from typing import Callable, List, Optional, Union
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from transformers import T5EncoderModel, T5Tokenizer
|
||||
@@ -7,8 +7,10 @@ from ...loaders import LoraLoaderMixin
|
||||
from ...models import Kandinsky3UNet, VQModel
|
||||
from ...schedulers import DDPMScheduler
|
||||
from ...utils import (
|
||||
deprecate,
|
||||
is_accelerate_available,
|
||||
logging,
|
||||
replace_example_docstring,
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
@@ -16,6 +18,23 @@ from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> from diffusers import AutoPipelineForText2Image
|
||||
>>> import torch
|
||||
|
||||
>>> pipe = AutoPipelineForText2Image.from_pretrained("kandinsky-community/kandinsky-3", variant="fp16", torch_dtype=torch.float16)
|
||||
>>> pipe.enable_model_cpu_offload()
|
||||
|
||||
>>> prompt = "A photograph of the inside of a subway train. There are raccoons sitting on the seats. One of them is reading a newspaper. The window shows the city in the background."
|
||||
|
||||
>>> generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
>>> image = pipe(prompt, num_inference_steps=25, generator=generator).images[0]
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
|
||||
def downscale_height_and_width(height, width, scale_factor=8):
|
||||
new_height = height // scale_factor**2
|
||||
@@ -29,6 +48,13 @@ def downscale_height_and_width(height, width, scale_factor=8):
|
||||
|
||||
class Kandinsky3Pipeline(DiffusionPipeline, LoraLoaderMixin):
|
||||
model_cpu_offload_seq = "text_encoder->unet->movq"
|
||||
_callback_tensor_inputs = [
|
||||
"latents",
|
||||
"prompt_embeds",
|
||||
"negative_prompt_embeds",
|
||||
"negative_attention_mask",
|
||||
"attention_mask",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -50,7 +76,7 @@ class Kandinsky3Pipeline(DiffusionPipeline, LoraLoaderMixin):
|
||||
else:
|
||||
raise ImportError("Please install accelerate via `pip install accelerate`")
|
||||
|
||||
for model in [self.text_encoder, self.unet]:
|
||||
for model in [self.text_encoder, self.unet, self.movq]:
|
||||
if model is not None:
|
||||
remove_hook_from_module(model, recurse=True)
|
||||
|
||||
@@ -77,12 +103,14 @@ class Kandinsky3Pipeline(DiffusionPipeline, LoraLoaderMixin):
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
_cut_context=False,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
negative_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
device: (`torch.device`, *optional*):
|
||||
torch device to place the resulting embeddings on
|
||||
@@ -101,6 +129,10 @@ class Kandinsky3Pipeline(DiffusionPipeline, LoraLoaderMixin):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
attention_mask (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated attention mask. Must provide if passing `prompt_embeds` directly.
|
||||
negative_attention_mask (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative attention mask. Must provide if passing `negative_prompt_embeds` directly.
|
||||
"""
|
||||
if prompt is not None and negative_prompt is not None:
|
||||
if type(prompt) is not type(negative_prompt):
|
||||
@@ -228,14 +260,21 @@ class Kandinsky3Pipeline(DiffusionPipeline, LoraLoaderMixin):
|
||||
negative_prompt=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
attention_mask=None,
|
||||
negative_attention_mask=None,
|
||||
):
|
||||
if (callback_steps is None) or (
|
||||
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
||||
):
|
||||
if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
|
||||
raise ValueError(
|
||||
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
||||
f" {type(callback_steps)}."
|
||||
)
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
@@ -262,8 +301,42 @@ class Kandinsky3Pipeline(DiffusionPipeline, LoraLoaderMixin):
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
if negative_prompt_embeds is not None and negative_attention_mask is None:
|
||||
raise ValueError("Please provide `negative_attention_mask` along with `negative_prompt_embeds`")
|
||||
|
||||
if negative_prompt_embeds is not None and negative_attention_mask is not None:
|
||||
if negative_prompt_embeds.shape[:2] != negative_attention_mask.shape:
|
||||
raise ValueError(
|
||||
"`negative_prompt_embeds` and `negative_attention_mask` must have the same batch_size and token length when passed directly, but"
|
||||
f" got: `negative_prompt_embeds` {negative_prompt_embeds.shape[:2]} != `negative_attention_mask`"
|
||||
f" {negative_attention_mask.shape}."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and attention_mask is None:
|
||||
raise ValueError("Please provide `attention_mask` along with `prompt_embeds`")
|
||||
|
||||
if prompt_embeds is not None and attention_mask is not None:
|
||||
if prompt_embeds.shape[:2] != attention_mask.shape:
|
||||
raise ValueError(
|
||||
"`prompt_embeds` and `attention_mask` must have the same batch_size and token length when passed directly, but"
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape[:2]} != `attention_mask`"
|
||||
f" {attention_mask.shape}."
|
||||
)
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
@@ -276,11 +349,14 @@ class Kandinsky3Pipeline(DiffusionPipeline, LoraLoaderMixin):
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
negative_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: int = 1,
|
||||
latents=None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
@@ -289,7 +365,7 @@ class Kandinsky3Pipeline(DiffusionPipeline, LoraLoaderMixin):
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
num_inference_steps (`int`, *optional*, defaults to 25):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
@@ -324,6 +400,10 @@ class Kandinsky3Pipeline(DiffusionPipeline, LoraLoaderMixin):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
attention_mask (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated attention mask. Must provide if passing `prompt_embeds` directly.
|
||||
negative_attention_mask (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative attention mask. Must provide if passing `negative_prompt_embeds` directly.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
@@ -343,12 +423,53 @@ class Kandinsky3Pipeline(DiffusionPipeline, LoraLoaderMixin):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.ImagePipelineOutput`] or `tuple`
|
||||
|
||||
"""
|
||||
|
||||
callback = kwargs.pop("callback", None)
|
||||
callback_steps = kwargs.pop("callback_steps", None)
|
||||
|
||||
if callback is not None:
|
||||
deprecate(
|
||||
"callback",
|
||||
"1.0.0",
|
||||
"Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
|
||||
)
|
||||
if callback_steps is not None:
|
||||
deprecate(
|
||||
"callback_steps",
|
||||
"1.0.0",
|
||||
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
|
||||
)
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
cut_context = True
|
||||
device = self._execution_device
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds)
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
callback_steps,
|
||||
negative_prompt,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
attention_mask,
|
||||
negative_attention_mask,
|
||||
)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
@@ -357,24 +478,21 @@ class Kandinsky3Pipeline(DiffusionPipeline, LoraLoaderMixin):
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
# 3. Encode input prompt
|
||||
prompt_embeds, negative_prompt_embeds, attention_mask, negative_attention_mask = self.encode_prompt(
|
||||
prompt,
|
||||
do_classifier_free_guidance,
|
||||
self.do_classifier_free_guidance,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
_cut_context=cut_context,
|
||||
attention_mask=attention_mask,
|
||||
negative_attention_mask=negative_attention_mask,
|
||||
)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
if self.do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
attention_mask = torch.cat([negative_attention_mask, attention_mask]).bool()
|
||||
# 4. Prepare timesteps
|
||||
@@ -397,11 +515,11 @@ class Kandinsky3Pipeline(DiffusionPipeline, LoraLoaderMixin):
|
||||
self.text_encoder_offload_hook.offload()
|
||||
|
||||
# 7. Denoising loop
|
||||
# TODO(Yiyi): Correct the following line and use correctly
|
||||
# num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
self._num_timesteps = len(timesteps)
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(
|
||||
@@ -412,7 +530,7 @@ class Kandinsky3Pipeline(DiffusionPipeline, LoraLoaderMixin):
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
|
||||
noise_pred = (guidance_scale + 1.0) * noise_pred_text - guidance_scale * noise_pred_uncond
|
||||
@@ -425,26 +543,45 @@ class Kandinsky3Pipeline(DiffusionPipeline, LoraLoaderMixin):
|
||||
latents,
|
||||
generator=generator,
|
||||
).prev_sample
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
step_idx = i // getattr(self.scheduler, "order", 1)
|
||||
callback(step_idx, t, latents)
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
attention_mask = callback_outputs.pop("attention_mask", attention_mask)
|
||||
negative_attention_mask = callback_outputs.pop("negative_attention_mask", negative_attention_mask)
|
||||
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
step_idx = i // getattr(self.scheduler, "order", 1)
|
||||
callback(step_idx, t, latents)
|
||||
|
||||
# post-processing
|
||||
image = self.movq.decode(latents, force_not_quantize=True)["sample"]
|
||||
|
||||
if output_type not in ["pt", "np", "pil"]:
|
||||
if output_type not in ["pt", "np", "pil", "latent"]:
|
||||
raise ValueError(
|
||||
f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}"
|
||||
f"Only the output types `pt`, `pil`, `np` and `latent` are supported not output_type={output_type}"
|
||||
)
|
||||
|
||||
if output_type in ["np", "pil"]:
|
||||
image = image * 0.5 + 0.5
|
||||
image = image.clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
if not output_type == "latent":
|
||||
image = self.movq.decode(latents, force_not_quantize=True)["sample"]
|
||||
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
if output_type in ["np", "pil"]:
|
||||
image = image * 0.5 + 0.5
|
||||
image = image.clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
else:
|
||||
image = latents
|
||||
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
+228
-34
@@ -1,5 +1,5 @@
|
||||
import inspect
|
||||
from typing import Callable, List, Optional, Union
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
@@ -11,8 +11,10 @@ from ...loaders import LoraLoaderMixin
|
||||
from ...models import Kandinsky3UNet, VQModel
|
||||
from ...schedulers import DDPMScheduler
|
||||
from ...utils import (
|
||||
deprecate,
|
||||
is_accelerate_available,
|
||||
logging,
|
||||
replace_example_docstring,
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
@@ -20,6 +22,24 @@ from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> from diffusers import AutoPipelineForImage2Image
|
||||
>>> from diffusers.utils import load_image
|
||||
>>> import torch
|
||||
|
||||
>>> pipe = AutoPipelineForImage2Image.from_pretrained("kandinsky-community/kandinsky-3", variant="fp16", torch_dtype=torch.float16)
|
||||
>>> pipe.enable_model_cpu_offload()
|
||||
|
||||
>>> prompt = "A painting of the inside of a subway train with tiny raccoons."
|
||||
>>> image = load_image("https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky3/t2i.png")
|
||||
|
||||
>>> generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
>>> image = pipe(prompt, image=image, strength=0.75, num_inference_steps=25, generator=generator).images[0]
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
def downscale_height_and_width(height, width, scale_factor=8):
|
||||
new_height = height // scale_factor**2
|
||||
@@ -40,7 +60,14 @@ def prepare_image(pil_image):
|
||||
|
||||
|
||||
class Kandinsky3Img2ImgPipeline(DiffusionPipeline, LoraLoaderMixin):
|
||||
model_cpu_offload_seq = "text_encoder->unet->movq"
|
||||
model_cpu_offload_seq = "text_encoder->movq->unet->movq"
|
||||
_callback_tensor_inputs = [
|
||||
"latents",
|
||||
"prompt_embeds",
|
||||
"negative_prompt_embeds",
|
||||
"negative_attention_mask",
|
||||
"attention_mask",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -99,6 +126,8 @@ class Kandinsky3Img2ImgPipeline(DiffusionPipeline, LoraLoaderMixin):
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
_cut_context=False,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
negative_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
@@ -123,6 +152,10 @@ class Kandinsky3Img2ImgPipeline(DiffusionPipeline, LoraLoaderMixin):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
attention_mask (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated attention mask. Must provide if passing `prompt_embeds` directly.
|
||||
negative_attention_mask (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative attention mask. Must provide if passing `negative_prompt_embeds` directly.
|
||||
"""
|
||||
if prompt is not None and negative_prompt is not None:
|
||||
if type(prompt) is not type(negative_prompt):
|
||||
@@ -299,15 +332,23 @@ class Kandinsky3Img2ImgPipeline(DiffusionPipeline, LoraLoaderMixin):
|
||||
negative_prompt=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
attention_mask=None,
|
||||
negative_attention_mask=None,
|
||||
):
|
||||
if (callback_steps is None) or (
|
||||
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
||||
):
|
||||
if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
|
||||
raise ValueError(
|
||||
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
||||
f" {type(callback_steps)}."
|
||||
)
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
@@ -334,7 +375,42 @@ class Kandinsky3Img2ImgPipeline(DiffusionPipeline, LoraLoaderMixin):
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
|
||||
if negative_prompt_embeds is not None and negative_attention_mask is None:
|
||||
raise ValueError("Please provide `negative_attention_mask` along with `negative_prompt_embeds`")
|
||||
|
||||
if negative_prompt_embeds is not None and negative_attention_mask is not None:
|
||||
if negative_prompt_embeds.shape[:2] != negative_attention_mask.shape:
|
||||
raise ValueError(
|
||||
"`negative_prompt_embeds` and `negative_attention_mask` must have the same batch_size and token length when passed directly, but"
|
||||
f" got: `negative_prompt_embeds` {negative_prompt_embeds.shape[:2]} != `negative_attention_mask`"
|
||||
f" {negative_attention_mask.shape}."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and attention_mask is None:
|
||||
raise ValueError("Please provide `attention_mask` along with `prompt_embeds`")
|
||||
|
||||
if prompt_embeds is not None and attention_mask is not None:
|
||||
if prompt_embeds.shape[:2] != attention_mask.shape:
|
||||
raise ValueError(
|
||||
"`prompt_embeds` and `attention_mask` must have the same batch_size and token length when passed directly, but"
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape[:2]} != `attention_mask`"
|
||||
f" {attention_mask.shape}."
|
||||
)
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
@@ -347,15 +423,117 @@ class Kandinsky3Img2ImgPipeline(DiffusionPipeline, LoraLoaderMixin):
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
negative_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: int = 1,
|
||||
latents=None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
|
||||
`Image`, or tensor representing an image batch, that will be used as the starting point for the
|
||||
process.
|
||||
strength (`float`, *optional*, defaults to 0.8):
|
||||
Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
|
||||
starting point and more noise is added the higher the `strength`. The number of denoising steps depends
|
||||
on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
|
||||
process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
|
||||
essentially ignores `image`.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, *optional*, defaults to 3.0):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
attention_mask (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated attention mask. Must provide if passing `prompt_embeds` directly.
|
||||
negative_attention_mask (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative attention mask. Must provide if passing `negative_prompt_embeds` directly.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
||||
`callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.ImagePipelineOutput`] or `tuple`
|
||||
|
||||
"""
|
||||
callback = kwargs.pop("callback", None)
|
||||
callback_steps = kwargs.pop("callback_steps", None)
|
||||
|
||||
if callback is not None:
|
||||
deprecate(
|
||||
"callback",
|
||||
"1.0.0",
|
||||
"Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
|
||||
)
|
||||
if callback_steps is not None:
|
||||
deprecate(
|
||||
"callback_steps",
|
||||
"1.0.0",
|
||||
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
|
||||
)
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
cut_context = True
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds)
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
callback_steps,
|
||||
negative_prompt,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
attention_mask,
|
||||
negative_attention_mask,
|
||||
)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
@@ -366,24 +544,21 @@ class Kandinsky3Img2ImgPipeline(DiffusionPipeline, LoraLoaderMixin):
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
# 3. Encode input prompt
|
||||
prompt_embeds, negative_prompt_embeds, attention_mask, negative_attention_mask = self.encode_prompt(
|
||||
prompt,
|
||||
do_classifier_free_guidance,
|
||||
self.do_classifier_free_guidance,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
_cut_context=cut_context,
|
||||
attention_mask=attention_mask,
|
||||
negative_attention_mask=negative_attention_mask,
|
||||
)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
if self.do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
attention_mask = torch.cat([negative_attention_mask, attention_mask]).bool()
|
||||
if not isinstance(image, list):
|
||||
@@ -409,11 +584,11 @@ class Kandinsky3Img2ImgPipeline(DiffusionPipeline, LoraLoaderMixin):
|
||||
self.text_encoder_offload_hook.offload()
|
||||
|
||||
# 7. Denoising loop
|
||||
# TODO(Yiyi): Correct the following line and use correctly
|
||||
# num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
self._num_timesteps = len(timesteps)
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(
|
||||
@@ -422,7 +597,7 @@ class Kandinsky3Img2ImgPipeline(DiffusionPipeline, LoraLoaderMixin):
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
encoder_attention_mask=attention_mask,
|
||||
)[0]
|
||||
if do_classifier_free_guidance:
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
|
||||
noise_pred = (guidance_scale + 1.0) * noise_pred_text - guidance_scale * noise_pred_uncond
|
||||
@@ -434,25 +609,44 @@ class Kandinsky3Img2ImgPipeline(DiffusionPipeline, LoraLoaderMixin):
|
||||
latents,
|
||||
generator=generator,
|
||||
).prev_sample
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
step_idx = i // getattr(self.scheduler, "order", 1)
|
||||
callback(step_idx, t, latents)
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
attention_mask = callback_outputs.pop("attention_mask", attention_mask)
|
||||
negative_attention_mask = callback_outputs.pop("negative_attention_mask", negative_attention_mask)
|
||||
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
step_idx = i // getattr(self.scheduler, "order", 1)
|
||||
callback(step_idx, t, latents)
|
||||
|
||||
# post-processing
|
||||
image = self.movq.decode(latents, force_not_quantize=True)["sample"]
|
||||
|
||||
if output_type not in ["pt", "np", "pil"]:
|
||||
if output_type not in ["pt", "np", "pil", "latent"]:
|
||||
raise ValueError(
|
||||
f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}"
|
||||
f"Only the output types `pt`, `pil`, `np` and `latent` are supported not output_type={output_type}"
|
||||
)
|
||||
if not output_type == "latent":
|
||||
image = self.movq.decode(latents, force_not_quantize=True)["sample"]
|
||||
|
||||
if output_type in ["np", "pil"]:
|
||||
image = image * 0.5 + 0.5
|
||||
image = image.clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
if output_type in ["np", "pil"]:
|
||||
image = image * 0.5 + 0.5
|
||||
image = image.clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
else:
|
||||
image = latents
|
||||
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
@@ -758,10 +758,10 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
if torch_dtype is not None:
|
||||
deprecate("torch_dtype", "0.25.0", "")
|
||||
deprecate("torch_dtype", "0.27.0", "")
|
||||
torch_device = kwargs.pop("torch_device", None)
|
||||
if torch_device is not None:
|
||||
deprecate("torch_device", "0.25.0", "")
|
||||
deprecate("torch_device", "0.27.0", "")
|
||||
|
||||
dtype_kwarg = kwargs.pop("dtype", None)
|
||||
device_kwarg = kwargs.pop("device", None)
|
||||
|
||||
@@ -134,6 +134,51 @@ ASPECT_RATIO_512_BIN = {
|
||||
}
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used,
|
||||
`timesteps` must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
|
||||
timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
|
||||
must be `None`.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class PixArtAlphaPipeline(DiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using PixArt-Alpha.
|
||||
@@ -783,8 +828,7 @@ class PixArtAlphaPipeline(DiffusionPipeline):
|
||||
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
||||
|
||||
# 5. Prepare latents.
|
||||
latent_channels = self.transformer.config.in_channels
|
||||
|
||||
@@ -34,7 +34,6 @@ else:
|
||||
_import_structure["pipeline_stable_diffusion"] = ["StableDiffusionPipeline"]
|
||||
_import_structure["pipeline_stable_diffusion_attend_and_excite"] = ["StableDiffusionAttendAndExcitePipeline"]
|
||||
_import_structure["pipeline_stable_diffusion_gligen"] = ["StableDiffusionGLIGENPipeline"]
|
||||
_import_structure["pipeline_stable_diffusion_gligen"] = ["StableDiffusionGLIGENPipeline"]
|
||||
_import_structure["pipeline_stable_diffusion_gligen_text_image"] = ["StableDiffusionGLIGENTextImagePipeline"]
|
||||
_import_structure["pipeline_stable_diffusion_img2img"] = ["StableDiffusionImg2ImgPipeline"]
|
||||
_import_structure["pipeline_stable_diffusion_inpaint"] = ["StableDiffusionInpaintPipeline"]
|
||||
|
||||
@@ -447,7 +447,8 @@ def convert_ldm_unet_checkpoint(
|
||||
|
||||
# Relevant to StableDiffusionUpscalePipeline
|
||||
if "num_class_embeds" in config:
|
||||
new_checkpoint["class_embedding.weight"] = unet_state_dict["label_emb.weight"]
|
||||
if (config["num_class_embeds"] is not None) and ("label_emb.weight" in unet_state_dict):
|
||||
new_checkpoint["class_embedding.weight"] = unet_state_dict["label_emb.weight"]
|
||||
|
||||
new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
|
||||
new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
|
||||
@@ -1480,9 +1481,12 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
config_name = "stabilityai/stable-diffusion-2"
|
||||
config_kwargs = {"subfolder": "text_encoder"}
|
||||
|
||||
text_model = convert_open_clip_checkpoint(
|
||||
checkpoint, config_name, local_files_only=local_files_only, **config_kwargs
|
||||
)
|
||||
if text_encoder is None:
|
||||
text_model = convert_open_clip_checkpoint(
|
||||
checkpoint, config_name, local_files_only=local_files_only, **config_kwargs
|
||||
)
|
||||
else:
|
||||
text_model = text_encoder
|
||||
|
||||
try:
|
||||
tokenizer = CLIPTokenizer.from_pretrained(
|
||||
|
||||
@@ -22,7 +22,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPV
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
|
||||
from ...models.lora import adjust_lora_scale_text_encoder
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import (
|
||||
@@ -489,18 +489,29 @@ class StableDiffusionPipeline(
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
def encode_image(self, image, device, num_images_per_prompt):
|
||||
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
|
||||
dtype = next(self.image_encoder.parameters()).dtype
|
||||
|
||||
if not isinstance(image, torch.Tensor):
|
||||
image = self.feature_extractor(image, return_tensors="pt").pixel_values
|
||||
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
image_embeds = self.image_encoder(image).image_embeds
|
||||
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
if output_hidden_states:
|
||||
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
|
||||
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
uncond_image_enc_hidden_states = self.image_encoder(
|
||||
torch.zeros_like(image), output_hidden_states=True
|
||||
).hidden_states[-2]
|
||||
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
|
||||
num_images_per_prompt, dim=0
|
||||
)
|
||||
return image_enc_hidden_states, uncond_image_enc_hidden_states
|
||||
else:
|
||||
image_embeds = self.image_encoder(image).image_embeds
|
||||
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
uncond_image_embeds = torch.zeros_like(image_embeds)
|
||||
|
||||
uncond_image_embeds = torch.zeros_like(image_embeds)
|
||||
return image_embeds, uncond_image_embeds
|
||||
return image_embeds, uncond_image_embeds
|
||||
|
||||
def run_safety_checker(self, image, device, dtype):
|
||||
if self.safety_checker is None:
|
||||
@@ -871,7 +882,10 @@ class StableDiffusionPipeline(
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
if ip_adapter_image is not None:
|
||||
image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt)
|
||||
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
|
||||
image_embeds, negative_image_embeds = self.encode_image(
|
||||
ip_adapter_image, device, num_images_per_prompt, output_hidden_state
|
||||
)
|
||||
if self.do_classifier_free_guidance:
|
||||
image_embeds = torch.cat([negative_image_embeds, image_embeds])
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPV
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
|
||||
from ...models.lora import adjust_lora_scale_text_encoder
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import (
|
||||
@@ -503,18 +503,29 @@ class StableDiffusionImg2ImgPipeline(
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
|
||||
def encode_image(self, image, device, num_images_per_prompt):
|
||||
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
|
||||
dtype = next(self.image_encoder.parameters()).dtype
|
||||
|
||||
if not isinstance(image, torch.Tensor):
|
||||
image = self.feature_extractor(image, return_tensors="pt").pixel_values
|
||||
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
image_embeds = self.image_encoder(image).image_embeds
|
||||
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
if output_hidden_states:
|
||||
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
|
||||
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
uncond_image_enc_hidden_states = self.image_encoder(
|
||||
torch.zeros_like(image), output_hidden_states=True
|
||||
).hidden_states[-2]
|
||||
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
|
||||
num_images_per_prompt, dim=0
|
||||
)
|
||||
return image_enc_hidden_states, uncond_image_enc_hidden_states
|
||||
else:
|
||||
image_embeds = self.image_encoder(image).image_embeds
|
||||
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
uncond_image_embeds = torch.zeros_like(image_embeds)
|
||||
|
||||
uncond_image_embeds = torch.zeros_like(image_embeds)
|
||||
return image_embeds, uncond_image_embeds
|
||||
return image_embeds, uncond_image_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
|
||||
def run_safety_checker(self, image, device, dtype):
|
||||
@@ -923,7 +934,10 @@ class StableDiffusionImg2ImgPipeline(
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
if ip_adapter_image is not None:
|
||||
image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt)
|
||||
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
|
||||
image_embeds, negative_image_embeds = self.encode_image(
|
||||
ip_adapter_image, device, num_images_per_prompt, output_hidden_state
|
||||
)
|
||||
if self.do_classifier_free_guidance:
|
||||
image_embeds = torch.cat([negative_image_embeds, image_embeds])
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPV
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AsymmetricAutoencoderKL, AutoencoderKL, UNet2DConditionModel
|
||||
from ...models import AsymmetricAutoencoderKL, AutoencoderKL, ImageProjection, UNet2DConditionModel
|
||||
from ...models.lora import adjust_lora_scale_text_encoder
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
|
||||
@@ -574,18 +574,29 @@ class StableDiffusionInpaintPipeline(
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
|
||||
def encode_image(self, image, device, num_images_per_prompt):
|
||||
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
|
||||
dtype = next(self.image_encoder.parameters()).dtype
|
||||
|
||||
if not isinstance(image, torch.Tensor):
|
||||
image = self.feature_extractor(image, return_tensors="pt").pixel_values
|
||||
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
image_embeds = self.image_encoder(image).image_embeds
|
||||
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
if output_hidden_states:
|
||||
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
|
||||
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
uncond_image_enc_hidden_states = self.image_encoder(
|
||||
torch.zeros_like(image), output_hidden_states=True
|
||||
).hidden_states[-2]
|
||||
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
|
||||
num_images_per_prompt, dim=0
|
||||
)
|
||||
return image_enc_hidden_states, uncond_image_enc_hidden_states
|
||||
else:
|
||||
image_embeds = self.image_encoder(image).image_embeds
|
||||
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
uncond_image_embeds = torch.zeros_like(image_embeds)
|
||||
|
||||
uncond_image_embeds = torch.zeros_like(image_embeds)
|
||||
return image_embeds, uncond_image_embeds
|
||||
return image_embeds, uncond_image_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
|
||||
def run_safety_checker(self, image, device, dtype):
|
||||
@@ -1103,7 +1114,10 @@ class StableDiffusionInpaintPipeline(
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
if ip_adapter_image is not None:
|
||||
image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt)
|
||||
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
|
||||
image_embeds, negative_image_embeds = self.encode_image(
|
||||
ip_adapter_image, device, num_images_per_prompt, output_hidden_state
|
||||
)
|
||||
if self.do_classifier_free_guidance:
|
||||
image_embeds = torch.cat([negative_image_embeds, image_embeds])
|
||||
|
||||
|
||||
@@ -31,9 +31,10 @@ from ...loaders import (
|
||||
StableDiffusionXLLoraLoaderMixin,
|
||||
TextualInversionLoaderMixin,
|
||||
)
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
|
||||
from ...models.attention_processor import (
|
||||
AttnProcessor2_0,
|
||||
FusedAttnProcessor2_0,
|
||||
LoRAAttnProcessor2_0,
|
||||
LoRAXFormersAttnProcessor,
|
||||
XFormersAttnProcessor,
|
||||
@@ -524,18 +525,29 @@ class StableDiffusionXLPipeline(
|
||||
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
|
||||
def encode_image(self, image, device, num_images_per_prompt):
|
||||
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
|
||||
dtype = next(self.image_encoder.parameters()).dtype
|
||||
|
||||
if not isinstance(image, torch.Tensor):
|
||||
image = self.feature_extractor(image, return_tensors="pt").pixel_values
|
||||
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
image_embeds = self.image_encoder(image).image_embeds
|
||||
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
if output_hidden_states:
|
||||
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
|
||||
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
uncond_image_enc_hidden_states = self.image_encoder(
|
||||
torch.zeros_like(image), output_hidden_states=True
|
||||
).hidden_states[-2]
|
||||
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
|
||||
num_images_per_prompt, dim=0
|
||||
)
|
||||
return image_enc_hidden_states, uncond_image_enc_hidden_states
|
||||
else:
|
||||
image_embeds = self.image_encoder(image).image_embeds
|
||||
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
uncond_image_embeds = torch.zeros_like(image_embeds)
|
||||
|
||||
uncond_image_embeds = torch.zeros_like(image_embeds)
|
||||
return image_embeds, uncond_image_embeds
|
||||
return image_embeds, uncond_image_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
@@ -670,7 +682,6 @@ class StableDiffusionXLPipeline(
|
||||
add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
|
||||
return add_time_ids
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
|
||||
def upcast_vae(self):
|
||||
dtype = self.vae.dtype
|
||||
self.vae.to(dtype=torch.float32)
|
||||
@@ -681,6 +692,7 @@ class StableDiffusionXLPipeline(
|
||||
XFormersAttnProcessor,
|
||||
LoRAXFormersAttnProcessor,
|
||||
LoRAAttnProcessor2_0,
|
||||
FusedAttnProcessor2_0,
|
||||
),
|
||||
)
|
||||
# if xformers or torch_2_0 is used attention block does not need
|
||||
@@ -718,6 +730,65 @@ class StableDiffusionXLPipeline(
|
||||
"""Disables the FreeU mechanism if enabled."""
|
||||
self.unet.disable_freeu()
|
||||
|
||||
def fuse_qkv_projections(self, unet: bool = True, vae: bool = True):
|
||||
"""
|
||||
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
|
||||
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
unet (`bool`, defaults to `True`): To apply fusion on the UNet.
|
||||
vae (`bool`, defaults to `True`): To apply fusion on the VAE.
|
||||
"""
|
||||
self.fusing_unet = False
|
||||
self.fusing_vae = False
|
||||
|
||||
if unet:
|
||||
self.fusing_unet = True
|
||||
self.unet.fuse_qkv_projections()
|
||||
self.unet.set_attn_processor(FusedAttnProcessor2_0())
|
||||
|
||||
if vae:
|
||||
if not isinstance(self.vae, AutoencoderKL):
|
||||
raise ValueError("`fuse_qkv_projections()` is only supported for the VAE of type `AutoencoderKL`.")
|
||||
|
||||
self.fusing_vae = True
|
||||
self.vae.fuse_qkv_projections()
|
||||
self.vae.set_attn_processor(FusedAttnProcessor2_0())
|
||||
|
||||
def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True):
|
||||
"""Disable QKV projection fusion if enabled.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
unet (`bool`, defaults to `True`): To apply fusion on the UNet.
|
||||
vae (`bool`, defaults to `True`): To apply fusion on the VAE.
|
||||
|
||||
"""
|
||||
if unet:
|
||||
if not self.fusing_unet:
|
||||
logger.warning("The UNet was not initially fused for QKV projections. Doing nothing.")
|
||||
else:
|
||||
self.unet.unfuse_qkv_projections()
|
||||
self.fusing_unet = False
|
||||
|
||||
if vae:
|
||||
if not self.fusing_vae:
|
||||
logger.warning("The VAE was not initially fused for QKV projections. Doing nothing.")
|
||||
else:
|
||||
self.vae.unfuse_qkv_projections()
|
||||
self.fusing_vae = False
|
||||
|
||||
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
|
||||
def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
|
||||
"""
|
||||
@@ -1087,7 +1158,10 @@ class StableDiffusionXLPipeline(
|
||||
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
|
||||
|
||||
if ip_adapter_image is not None:
|
||||
image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt)
|
||||
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
|
||||
image_embeds, negative_image_embeds = self.encode_image(
|
||||
ip_adapter_image, device, num_images_per_prompt, output_hidden_state
|
||||
)
|
||||
if self.do_classifier_free_guidance:
|
||||
image_embeds = torch.cat([negative_image_embeds, image_embeds])
|
||||
image_embeds = image_embeds.to(device)
|
||||
|
||||
+21
-7
@@ -32,7 +32,7 @@ from ...loaders import (
|
||||
StableDiffusionXLLoraLoaderMixin,
|
||||
TextualInversionLoaderMixin,
|
||||
)
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
|
||||
from ...models.attention_processor import (
|
||||
AttnProcessor2_0,
|
||||
LoRAAttnProcessor2_0,
|
||||
@@ -741,18 +741,29 @@ class StableDiffusionXLImg2ImgPipeline(
|
||||
return latents
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
|
||||
def encode_image(self, image, device, num_images_per_prompt):
|
||||
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
|
||||
dtype = next(self.image_encoder.parameters()).dtype
|
||||
|
||||
if not isinstance(image, torch.Tensor):
|
||||
image = self.feature_extractor(image, return_tensors="pt").pixel_values
|
||||
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
image_embeds = self.image_encoder(image).image_embeds
|
||||
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
if output_hidden_states:
|
||||
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
|
||||
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
uncond_image_enc_hidden_states = self.image_encoder(
|
||||
torch.zeros_like(image), output_hidden_states=True
|
||||
).hidden_states[-2]
|
||||
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
|
||||
num_images_per_prompt, dim=0
|
||||
)
|
||||
return image_enc_hidden_states, uncond_image_enc_hidden_states
|
||||
else:
|
||||
image_embeds = self.image_encoder(image).image_embeds
|
||||
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
uncond_image_embeds = torch.zeros_like(image_embeds)
|
||||
|
||||
uncond_image_embeds = torch.zeros_like(image_embeds)
|
||||
return image_embeds, uncond_image_embeds
|
||||
return image_embeds, uncond_image_embeds
|
||||
|
||||
def _get_add_time_ids(
|
||||
self,
|
||||
@@ -1259,7 +1270,10 @@ class StableDiffusionXLImg2ImgPipeline(
|
||||
add_time_ids = add_time_ids.to(device)
|
||||
|
||||
if ip_adapter_image is not None:
|
||||
image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt)
|
||||
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
|
||||
image_embeds, negative_image_embeds = self.encode_image(
|
||||
ip_adapter_image, device, num_images_per_prompt, output_hidden_state
|
||||
)
|
||||
if self.do_classifier_free_guidance:
|
||||
image_embeds = torch.cat([negative_image_embeds, image_embeds])
|
||||
image_embeds = image_embeds.to(device)
|
||||
|
||||
+21
-7
@@ -33,7 +33,7 @@ from ...loaders import (
|
||||
StableDiffusionXLLoraLoaderMixin,
|
||||
TextualInversionLoaderMixin,
|
||||
)
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
|
||||
from ...models.attention_processor import (
|
||||
AttnProcessor2_0,
|
||||
LoRAAttnProcessor2_0,
|
||||
@@ -462,18 +462,29 @@ class StableDiffusionXLInpaintPipeline(
|
||||
self.vae.disable_tiling()
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
|
||||
def encode_image(self, image, device, num_images_per_prompt):
|
||||
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
|
||||
dtype = next(self.image_encoder.parameters()).dtype
|
||||
|
||||
if not isinstance(image, torch.Tensor):
|
||||
image = self.feature_extractor(image, return_tensors="pt").pixel_values
|
||||
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
image_embeds = self.image_encoder(image).image_embeds
|
||||
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
if output_hidden_states:
|
||||
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
|
||||
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
uncond_image_enc_hidden_states = self.image_encoder(
|
||||
torch.zeros_like(image), output_hidden_states=True
|
||||
).hidden_states[-2]
|
||||
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
|
||||
num_images_per_prompt, dim=0
|
||||
)
|
||||
return image_enc_hidden_states, uncond_image_enc_hidden_states
|
||||
else:
|
||||
image_embeds = self.image_encoder(image).image_embeds
|
||||
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
uncond_image_embeds = torch.zeros_like(image_embeds)
|
||||
|
||||
uncond_image_embeds = torch.zeros_like(image_embeds)
|
||||
return image_embeds, uncond_image_embeds
|
||||
return image_embeds, uncond_image_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
|
||||
def encode_prompt(
|
||||
@@ -1568,7 +1579,10 @@ class StableDiffusionXLInpaintPipeline(
|
||||
add_time_ids = add_time_ids.to(device)
|
||||
|
||||
if ip_adapter_image is not None:
|
||||
image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt)
|
||||
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
|
||||
image_embeds, negative_image_embeds = self.encode_image(
|
||||
ip_adapter_image, device, num_images_per_prompt, output_hidden_state
|
||||
)
|
||||
if self.do_classifier_free_guidance:
|
||||
image_embeds = torch.cat([negative_image_embeds, image_embeds])
|
||||
image_embeds = image_embeds.to(device)
|
||||
|
||||
+2
@@ -24,6 +24,7 @@ from ...loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, Te
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...models.attention_processor import (
|
||||
AttnProcessor2_0,
|
||||
FusedAttnProcessor2_0,
|
||||
LoRAAttnProcessor2_0,
|
||||
LoRAXFormersAttnProcessor,
|
||||
XFormersAttnProcessor,
|
||||
@@ -610,6 +611,7 @@ class StableDiffusionXLInstructPix2PixPipeline(
|
||||
XFormersAttnProcessor,
|
||||
LoRAXFormersAttnProcessor,
|
||||
LoRAAttnProcessor2_0,
|
||||
FusedAttnProcessor2_0,
|
||||
),
|
||||
)
|
||||
# if xformers or torch_2_0 is used attention block does not need
|
||||
|
||||
@@ -290,9 +290,7 @@ class StableVideoDiffusionPipeline(DiffusionPipeline):
|
||||
# corresponds to doing no classifier free guidance.
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
if isinstance(self.guidance_scale, (int, float)):
|
||||
return self.guidance_scale
|
||||
return self.guidance_scale.max() > 1
|
||||
return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
@@ -417,10 +415,10 @@ class StableVideoDiffusionPipeline(DiffusionPipeline):
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
self._guidance_scale = max_guidance_scale
|
||||
do_classifier_free_guidance = max_guidance_scale > 1.0
|
||||
|
||||
# 3. Encode input image
|
||||
image_embeddings = self._encode_image(image, device, num_videos_per_prompt, self.do_classifier_free_guidance)
|
||||
image_embeddings = self._encode_image(image, device, num_videos_per_prompt, do_classifier_free_guidance)
|
||||
|
||||
# NOTE: Stable Diffusion Video was conditioned on fps - 1, which
|
||||
# is why it is reduced here.
|
||||
@@ -436,7 +434,7 @@ class StableVideoDiffusionPipeline(DiffusionPipeline):
|
||||
if needs_upcasting:
|
||||
self.vae.to(dtype=torch.float32)
|
||||
|
||||
image_latents = self._encode_vae_image(image, device, num_videos_per_prompt, self.do_classifier_free_guidance)
|
||||
image_latents = self._encode_vae_image(image, device, num_videos_per_prompt, do_classifier_free_guidance)
|
||||
image_latents = image_latents.to(image_embeddings.dtype)
|
||||
|
||||
# cast back to fp16 if needed
|
||||
@@ -455,7 +453,7 @@ class StableVideoDiffusionPipeline(DiffusionPipeline):
|
||||
image_embeddings.dtype,
|
||||
batch_size,
|
||||
num_videos_per_prompt,
|
||||
self.do_classifier_free_guidance,
|
||||
do_classifier_free_guidance,
|
||||
)
|
||||
added_time_ids = added_time_ids.to(device)
|
||||
|
||||
@@ -491,7 +489,7 @@ class StableVideoDiffusionPipeline(DiffusionPipeline):
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# Concatenate image_latents over channels dimention
|
||||
@@ -507,7 +505,7 @@ class StableVideoDiffusionPipeline(DiffusionPipeline):
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
if self.do_classifier_free_guidance:
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
||||
|
||||
|
||||
@@ -10,10 +10,10 @@ from diffusers.utils import deprecate
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...models import ModelMixin
|
||||
from ...models.activations import get_activation
|
||||
from ...models.attention import Attention
|
||||
from ...models.attention_processor import (
|
||||
ADDED_KV_ATTENTION_PROCESSORS,
|
||||
CROSS_ATTENTION_PROCESSORS,
|
||||
Attention,
|
||||
AttentionProcessor,
|
||||
AttnAddedKVProcessor,
|
||||
AttnAddedKVProcessor2_0,
|
||||
@@ -50,6 +50,9 @@ def get_down_block(
|
||||
resnet_eps,
|
||||
resnet_act_fn,
|
||||
num_attention_heads,
|
||||
transformer_layers_per_block,
|
||||
attention_type,
|
||||
attention_head_dim,
|
||||
resnet_groups=None,
|
||||
cross_attention_dim=None,
|
||||
downsample_padding=None,
|
||||
@@ -113,6 +116,10 @@ def get_up_block(
|
||||
resnet_eps,
|
||||
resnet_act_fn,
|
||||
num_attention_heads,
|
||||
transformer_layers_per_block,
|
||||
resolution_idx,
|
||||
attention_type,
|
||||
attention_head_dim,
|
||||
resnet_groups=None,
|
||||
cross_attention_dim=None,
|
||||
dual_cross_attention=False,
|
||||
@@ -993,6 +1000,42 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
|
||||
setattr(upsample_block, k, None)
|
||||
|
||||
def fuse_qkv_projections(self):
|
||||
"""
|
||||
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
|
||||
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
"""
|
||||
self.original_attn_processors = None
|
||||
|
||||
for _, attn_processor in self.attn_processors.items():
|
||||
if "Added" in str(attn_processor.__class__.__name__):
|
||||
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
||||
|
||||
self.original_attn_processors = self.attn_processors
|
||||
|
||||
for module in self.modules():
|
||||
if isinstance(module, Attention):
|
||||
module.fuse_projections(fuse=True)
|
||||
|
||||
def unfuse_qkv_projections(self):
|
||||
"""Disables the fused QKV projection if enabled.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
|
||||
"""
|
||||
if self.original_attn_processors is not None:
|
||||
self.set_attn_processor(self.original_attn_processors)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
|
||||
@@ -162,6 +162,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.alpha_t = torch.sqrt(self.alphas_cumprod)
|
||||
self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
|
||||
self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t)
|
||||
self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
|
||||
|
||||
# standard deviation of the initial noise distribution
|
||||
self.init_noise_sigma = 1.0
|
||||
|
||||
@@ -189,6 +189,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.alpha_t = torch.sqrt(self.alphas_cumprod)
|
||||
self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
|
||||
self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t)
|
||||
self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
|
||||
|
||||
# standard deviation of the initial noise distribution
|
||||
self.init_noise_sigma = 1.0
|
||||
|
||||
@@ -184,6 +184,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.alpha_t = torch.sqrt(self.alphas_cumprod)
|
||||
self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
|
||||
self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t)
|
||||
self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
|
||||
|
||||
# standard deviation of the initial noise distribution
|
||||
self.init_noise_sigma = 1.0
|
||||
|
||||
@@ -172,6 +172,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.alpha_t = torch.sqrt(self.alphas_cumprod)
|
||||
self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
|
||||
self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t)
|
||||
self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
|
||||
|
||||
# standard deviation of the initial noise distribution
|
||||
self.init_noise_sigma = 1.0
|
||||
|
||||
@@ -191,10 +191,11 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
@property
|
||||
def init_noise_sigma(self):
|
||||
# standard deviation of the initial noise distribution
|
||||
max_sigma = max(self.sigmas) if isinstance(self.sigmas, list) else self.sigmas.max()
|
||||
if self.config.timestep_spacing in ["linspace", "trailing"]:
|
||||
return self.sigmas.max()
|
||||
return max_sigma
|
||||
|
||||
return (self.sigmas.max() ** 2 + 1) ** 0.5
|
||||
return (max_sigma**2 + 1) ** 0.5
|
||||
|
||||
@property
|
||||
def step_index(self):
|
||||
@@ -289,6 +290,8 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.timesteps = torch.from_numpy(timesteps.astype(np.float32)).to(device=device)
|
||||
|
||||
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
|
||||
if sigmas.device.type == "cuda":
|
||||
self.sigmas = self.sigmas.tolist()
|
||||
self._step_index = None
|
||||
|
||||
def _sigma_to_t(self, sigma, log_sigmas):
|
||||
|
||||
@@ -175,6 +175,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.alpha_t = torch.sqrt(self.alphas_cumprod)
|
||||
self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
|
||||
self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t)
|
||||
self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
|
||||
|
||||
# standard deviation of the initial noise distribution
|
||||
self.init_noise_sigma = 1.0
|
||||
|
||||
@@ -67,7 +67,7 @@ def unet_lora_state_dict(unet: UNet2DConditionModel) -> Dict[str, torch.Tensor]:
|
||||
current_lora_layer_sd = lora_layer.state_dict()
|
||||
for lora_layer_matrix_name, lora_param in current_lora_layer_sd.items():
|
||||
# The matrix name can either be "down" or "up".
|
||||
lora_state_dict[f"unet.{name}.lora.{lora_layer_matrix_name}"] = lora_param
|
||||
lora_state_dict[f"{name}.lora.{lora_layer_matrix_name}"] = lora_param
|
||||
|
||||
return lora_state_dict
|
||||
|
||||
|
||||
@@ -213,7 +213,7 @@ def remove_handler(handler: logging.Handler) -> None:
|
||||
|
||||
_configure_library_root_logger()
|
||||
|
||||
assert handler is not None and handler not in _get_library_root_logger().handlers
|
||||
assert handler is not None and handler in _get_library_root_logger().handlers
|
||||
_get_library_root_logger().removeHandler(handler)
|
||||
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ from contextlib import contextmanager
|
||||
from distutils.util import strtobool
|
||||
from io import BytesIO, StringIO
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Union
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
@@ -58,6 +58,17 @@ USE_PEFT_BACKEND = _required_peft_version and _required_transformers_version
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
# Set a backend environment variable for any extra module import required for a custom accelerator
|
||||
if "DIFFUSERS_TEST_BACKEND" in os.environ:
|
||||
backend = os.environ["DIFFUSERS_TEST_BACKEND"]
|
||||
try:
|
||||
_ = importlib.import_module(backend)
|
||||
except ModuleNotFoundError as e:
|
||||
raise ModuleNotFoundError(
|
||||
f"Failed to import `DIFFUSERS_TEST_BACKEND` '{backend}'! This should be the name of an installed module \
|
||||
to enable a specified backend.):\n{e}"
|
||||
) from e
|
||||
|
||||
if "DIFFUSERS_TEST_DEVICE" in os.environ:
|
||||
torch_device = os.environ["DIFFUSERS_TEST_DEVICE"]
|
||||
try:
|
||||
@@ -210,6 +221,36 @@ def require_torch_gpu(test_case):
|
||||
)
|
||||
|
||||
|
||||
# These decorators are for accelerator-specific behaviours that are not GPU-specific
|
||||
def require_torch_accelerator(test_case):
|
||||
"""Decorator marking a test that requires an accelerator backend and PyTorch."""
|
||||
return unittest.skipUnless(is_torch_available() and torch_device != "cpu", "test requires accelerator+PyTorch")(
|
||||
test_case
|
||||
)
|
||||
|
||||
|
||||
def require_torch_accelerator_with_fp16(test_case):
|
||||
"""Decorator marking a test that requires an accelerator with support for the FP16 data type."""
|
||||
return unittest.skipUnless(_is_torch_fp16_available(torch_device), "test requires accelerator with fp16 support")(
|
||||
test_case
|
||||
)
|
||||
|
||||
|
||||
def require_torch_accelerator_with_fp64(test_case):
|
||||
"""Decorator marking a test that requires an accelerator with support for the FP64 data type."""
|
||||
return unittest.skipUnless(_is_torch_fp64_available(torch_device), "test requires accelerator with fp64 support")(
|
||||
test_case
|
||||
)
|
||||
|
||||
|
||||
def require_torch_accelerator_with_training(test_case):
|
||||
"""Decorator marking a test that requires an accelerator with support for training."""
|
||||
return unittest.skipUnless(
|
||||
is_torch_available() and backend_supports_training(torch_device),
|
||||
"test requires accelerator with training support",
|
||||
)(test_case)
|
||||
|
||||
|
||||
def skip_mps(test_case):
|
||||
"""Decorator marking a test to skip if torch_device is 'mps'"""
|
||||
return unittest.skipUnless(torch_device != "mps", "test requires non 'mps' device")(test_case)
|
||||
@@ -766,3 +807,139 @@ def disable_full_determinism():
|
||||
os.environ["CUDA_LAUNCH_BLOCKING"] = "0"
|
||||
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ""
|
||||
torch.use_deterministic_algorithms(False)
|
||||
|
||||
|
||||
# Utils for custom and alternative accelerator devices
|
||||
def _is_torch_fp16_available(device):
|
||||
if not is_torch_available():
|
||||
return False
|
||||
|
||||
import torch
|
||||
|
||||
device = torch.device(device)
|
||||
|
||||
try:
|
||||
x = torch.zeros((2, 2), dtype=torch.float16).to(device)
|
||||
_ = x @ x
|
||||
except Exception as e:
|
||||
if device.type == "cuda":
|
||||
raise ValueError(
|
||||
f"You have passed a device of type 'cuda' which should work with 'fp16', but 'cuda' does not seem to be correctly installed on your machine: {e}"
|
||||
)
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _is_torch_fp64_available(device):
|
||||
if not is_torch_available():
|
||||
return False
|
||||
|
||||
import torch
|
||||
|
||||
try:
|
||||
x = torch.zeros((2, 2), dtype=torch.float64).to(device)
|
||||
_ = x @ x
|
||||
except Exception as e:
|
||||
if device.type == "cuda":
|
||||
raise ValueError(
|
||||
f"You have passed a device of type 'cuda' which should work with 'fp64', but 'cuda' does not seem to be correctly installed on your machine: {e}"
|
||||
)
|
||||
|
||||
return False
|
||||
|
||||
|
||||
# Guard these lookups for when Torch is not used - alternative accelerator support is for PyTorch
|
||||
if is_torch_available():
|
||||
# Behaviour flags
|
||||
BACKEND_SUPPORTS_TRAINING = {"cuda": True, "cpu": True, "mps": False, "default": True}
|
||||
|
||||
# Function definitions
|
||||
BACKEND_EMPTY_CACHE = {"cuda": torch.cuda.empty_cache, "cpu": None, "mps": None, "default": None}
|
||||
BACKEND_DEVICE_COUNT = {"cuda": torch.cuda.device_count, "cpu": lambda: 0, "mps": lambda: 0, "default": 0}
|
||||
BACKEND_MANUAL_SEED = {"cuda": torch.cuda.manual_seed, "cpu": torch.manual_seed, "default": torch.manual_seed}
|
||||
|
||||
|
||||
# This dispatches a defined function according to the accelerator from the function definitions.
|
||||
def _device_agnostic_dispatch(device: str, dispatch_table: Dict[str, Callable], *args, **kwargs):
|
||||
if device not in dispatch_table:
|
||||
return dispatch_table["default"](*args, **kwargs)
|
||||
|
||||
fn = dispatch_table[device]
|
||||
|
||||
# Some device agnostic functions return values. Need to guard against 'None' instead at
|
||||
# user level
|
||||
if fn is None:
|
||||
return None
|
||||
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
|
||||
# These are callables which automatically dispatch the function specific to the accelerator
|
||||
def backend_manual_seed(device: str, seed: int):
|
||||
return _device_agnostic_dispatch(device, BACKEND_MANUAL_SEED, seed)
|
||||
|
||||
|
||||
def backend_empty_cache(device: str):
|
||||
return _device_agnostic_dispatch(device, BACKEND_EMPTY_CACHE)
|
||||
|
||||
|
||||
def backend_device_count(device: str):
|
||||
return _device_agnostic_dispatch(device, BACKEND_DEVICE_COUNT)
|
||||
|
||||
|
||||
# These are callables which return boolean behaviour flags and can be used to specify some
|
||||
# device agnostic alternative where the feature is unsupported.
|
||||
def backend_supports_training(device: str):
|
||||
if not is_torch_available():
|
||||
return False
|
||||
|
||||
if device not in BACKEND_SUPPORTS_TRAINING:
|
||||
device = "default"
|
||||
|
||||
return BACKEND_SUPPORTS_TRAINING[device]
|
||||
|
||||
|
||||
# Guard for when Torch is not available
|
||||
if is_torch_available():
|
||||
# Update device function dict mapping
|
||||
def update_mapping_from_spec(device_fn_dict: Dict[str, Callable], attribute_name: str):
|
||||
try:
|
||||
# Try to import the function directly
|
||||
spec_fn = getattr(device_spec_module, attribute_name)
|
||||
device_fn_dict[torch_device] = spec_fn
|
||||
except AttributeError as e:
|
||||
# If the function doesn't exist, and there is no default, throw an error
|
||||
if "default" not in device_fn_dict:
|
||||
raise AttributeError(
|
||||
f"`{attribute_name}` not found in '{device_spec_path}' and no default fallback function found."
|
||||
) from e
|
||||
|
||||
if "DIFFUSERS_TEST_DEVICE_SPEC" in os.environ:
|
||||
device_spec_path = os.environ["DIFFUSERS_TEST_DEVICE_SPEC"]
|
||||
if not Path(device_spec_path).is_file():
|
||||
raise ValueError(f"Specified path to device specification file is not found. Received {device_spec_path}")
|
||||
|
||||
try:
|
||||
import_name = device_spec_path[: device_spec_path.index(".py")]
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Provided device spec file is not a Python file! Received {device_spec_path}") from e
|
||||
|
||||
device_spec_module = importlib.import_module(import_name)
|
||||
|
||||
try:
|
||||
device_name = device_spec_module.DEVICE_NAME
|
||||
except AttributeError:
|
||||
raise AttributeError("Device spec file did not contain `DEVICE_NAME`")
|
||||
|
||||
if "DIFFUSERS_TEST_DEVICE" in os.environ and torch_device != device_name:
|
||||
msg = f"Mismatch between environment variable `DIFFUSERS_TEST_DEVICE` '{torch_device}' and device found in spec '{device_name}'\n"
|
||||
msg += "Either unset `DIFFUSERS_TEST_DEVICE` or ensure it matches device spec name."
|
||||
raise ValueError(msg)
|
||||
|
||||
torch_device = device_name
|
||||
|
||||
# Add one entry here for each `BACKEND_*` dictionary.
|
||||
update_mapping_from_spec(BACKEND_MANUAL_SEED, "MANUAL_SEED_FN")
|
||||
update_mapping_from_spec(BACKEND_EMPTY_CACHE, "EMPTY_CACHE_FN")
|
||||
update_mapping_from_spec(BACKEND_DEVICE_COUNT, "DEVICE_COUNT_FN")
|
||||
update_mapping_from_spec(BACKEND_SUPPORTS_TRAINING, "SUPPORTS_TRAINING")
|
||||
|
||||
@@ -25,7 +25,11 @@ from diffusers.models.embeddings import get_timestep_embedding
|
||||
from diffusers.models.lora import LoRACompatibleLinear
|
||||
from diffusers.models.resnet import Downsample2D, ResnetBlock2D, Upsample2D
|
||||
from diffusers.models.transformer_2d import Transformer2DModel
|
||||
from diffusers.utils.testing_utils import torch_device
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_manual_seed,
|
||||
require_torch_accelerator_with_fp64,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
class EmbeddingsTests(unittest.TestCase):
|
||||
@@ -315,8 +319,7 @@ class ResnetBlock2DTests(unittest.TestCase):
|
||||
class Transformer2DModelTests(unittest.TestCase):
|
||||
def test_spatial_transformer_default(self):
|
||||
torch.manual_seed(0)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(0)
|
||||
backend_manual_seed(torch_device, 0)
|
||||
|
||||
sample = torch.randn(1, 32, 64, 64).to(torch_device)
|
||||
spatial_transformer_block = Transformer2DModel(
|
||||
@@ -339,8 +342,7 @@ class Transformer2DModelTests(unittest.TestCase):
|
||||
|
||||
def test_spatial_transformer_cross_attention_dim(self):
|
||||
torch.manual_seed(0)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(0)
|
||||
backend_manual_seed(torch_device, 0)
|
||||
|
||||
sample = torch.randn(1, 64, 64, 64).to(torch_device)
|
||||
spatial_transformer_block = Transformer2DModel(
|
||||
@@ -363,8 +365,7 @@ class Transformer2DModelTests(unittest.TestCase):
|
||||
|
||||
def test_spatial_transformer_timestep(self):
|
||||
torch.manual_seed(0)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(0)
|
||||
backend_manual_seed(torch_device, 0)
|
||||
|
||||
num_embeds_ada_norm = 5
|
||||
|
||||
@@ -401,8 +402,7 @@ class Transformer2DModelTests(unittest.TestCase):
|
||||
|
||||
def test_spatial_transformer_dropout(self):
|
||||
torch.manual_seed(0)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(0)
|
||||
backend_manual_seed(torch_device, 0)
|
||||
|
||||
sample = torch.randn(1, 32, 64, 64).to(torch_device)
|
||||
spatial_transformer_block = (
|
||||
@@ -427,11 +427,10 @@ class Transformer2DModelTests(unittest.TestCase):
|
||||
)
|
||||
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "MPS does not support float64")
|
||||
@require_torch_accelerator_with_fp64
|
||||
def test_spatial_transformer_discrete(self):
|
||||
torch.manual_seed(0)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(0)
|
||||
backend_manual_seed(torch_device, 0)
|
||||
|
||||
num_embed = 5
|
||||
|
||||
|
||||
@@ -35,6 +35,7 @@ from diffusers.utils.testing_utils import (
|
||||
CaptureLogger,
|
||||
require_python39_or_higher,
|
||||
require_torch_2,
|
||||
require_torch_accelerator_with_training,
|
||||
require_torch_gpu,
|
||||
run_test_in_subprocess,
|
||||
torch_device,
|
||||
@@ -536,7 +537,7 @@ class ModelTesterMixin:
|
||||
|
||||
self.assertEqual(output_1.shape, output_2.shape)
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "Training is not supported in mps")
|
||||
@require_torch_accelerator_with_training
|
||||
def test_training(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
@@ -553,7 +554,7 @@ class ModelTesterMixin:
|
||||
loss = torch.nn.functional.mse_loss(output, noise)
|
||||
loss.backward()
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "Training is not supported in mps")
|
||||
@require_torch_accelerator_with_training
|
||||
def test_ema_training(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
@@ -624,7 +625,7 @@ class ModelTesterMixin:
|
||||
|
||||
recursive_check(outputs_tuple, outputs_dict)
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "Gradient checkpointing skipped on MPS")
|
||||
@require_torch_accelerator_with_training
|
||||
def test_enable_disable_gradient_checkpointing(self):
|
||||
if not self.model_class._supports_gradient_checkpointing:
|
||||
return # Skip test if model does not support gradient checkpointing
|
||||
|
||||
@@ -21,7 +21,14 @@ import torch
|
||||
from parameterized import parameterized
|
||||
|
||||
from diffusers import PriorTransformer
|
||||
from diffusers.utils.testing_utils import enable_full_determinism, floats_tensor, slow, torch_all_close, torch_device
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
enable_full_determinism,
|
||||
floats_tensor,
|
||||
slow,
|
||||
torch_all_close,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from .test_modeling_common import ModelTesterMixin
|
||||
|
||||
@@ -157,7 +164,7 @@ class PriorTransformerIntegrationTests(unittest.TestCase):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
|
||||
@@ -18,7 +18,12 @@ import unittest
|
||||
import torch
|
||||
|
||||
from diffusers import UNet1DModel
|
||||
from diffusers.utils.testing_utils import floats_tensor, slow, torch_device
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_manual_seed,
|
||||
floats_tensor,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from .test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
|
||||
@@ -103,8 +108,7 @@ class UNet1DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
def test_output_pretrained(self):
|
||||
model = UNet1DModel.from_pretrained("bglick13/hopper-medium-v2-value-function-hor32", subfolder="unet")
|
||||
torch.manual_seed(0)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(0)
|
||||
backend_manual_seed(torch_device, 0)
|
||||
|
||||
num_features = model.config.in_channels
|
||||
seq_len = 16
|
||||
@@ -244,8 +248,7 @@ class UNetRLModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
"bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="value_function"
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(0)
|
||||
backend_manual_seed(torch_device, 0)
|
||||
|
||||
num_features = value_function.config.in_channels
|
||||
seq_len = 14
|
||||
|
||||
@@ -24,6 +24,7 @@ from diffusers.utils import logging
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
floats_tensor,
|
||||
require_torch_accelerator,
|
||||
slow,
|
||||
torch_all_close,
|
||||
torch_device,
|
||||
@@ -153,7 +154,7 @@ class UNetLDMModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
|
||||
assert image is not None, "Make sure output is not None"
|
||||
|
||||
@unittest.skipIf(torch_device != "cuda", "This test is supposed to run on GPU")
|
||||
@require_torch_accelerator
|
||||
def test_from_pretrained_accelerate(self):
|
||||
model, _ = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True)
|
||||
model.to(torch_device)
|
||||
@@ -161,7 +162,7 @@ class UNetLDMModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
|
||||
assert image is not None, "Make sure output is not None"
|
||||
|
||||
@unittest.skipIf(torch_device != "cuda", "This test is supposed to run on GPU")
|
||||
@require_torch_accelerator
|
||||
def test_from_pretrained_accelerate_wont_change_results(self):
|
||||
# by defautl model loading will use accelerate as `low_cpu_mem_usage=True`
|
||||
model_accelerate, _ = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True)
|
||||
|
||||
@@ -18,6 +18,7 @@ import gc
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
from parameterized import parameterized
|
||||
@@ -25,14 +26,19 @@ from pytest import mark
|
||||
|
||||
from diffusers import UNet2DConditionModel
|
||||
from diffusers.models.attention_processor import CustomDiffusionAttnProcessor, IPAdapterAttnProcessor
|
||||
from diffusers.models.embeddings import ImageProjection
|
||||
from diffusers.models.embeddings import ImageProjection, Resampler
|
||||
from diffusers.utils import logging
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
enable_full_determinism,
|
||||
floats_tensor,
|
||||
load_hf_numpy,
|
||||
require_torch_accelerator,
|
||||
require_torch_accelerator_with_fp16,
|
||||
require_torch_accelerator_with_training,
|
||||
require_torch_gpu,
|
||||
skip_mps,
|
||||
slow,
|
||||
torch_all_close,
|
||||
torch_device,
|
||||
@@ -97,6 +103,85 @@ def create_ip_adapter_state_dict(model):
|
||||
return ip_state_dict
|
||||
|
||||
|
||||
def create_ip_adapter_plus_state_dict(model):
|
||||
# "ip_adapter" (cross-attention weights)
|
||||
ip_cross_attn_state_dict = {}
|
||||
key_id = 1
|
||||
|
||||
for name in model.attn_processors.keys():
|
||||
cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim
|
||||
if name.startswith("mid_block"):
|
||||
hidden_size = model.config.block_out_channels[-1]
|
||||
elif name.startswith("up_blocks"):
|
||||
block_id = int(name[len("up_blocks.")])
|
||||
hidden_size = list(reversed(model.config.block_out_channels))[block_id]
|
||||
elif name.startswith("down_blocks"):
|
||||
block_id = int(name[len("down_blocks.")])
|
||||
hidden_size = model.config.block_out_channels[block_id]
|
||||
if cross_attention_dim is not None:
|
||||
sd = IPAdapterAttnProcessor(
|
||||
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0
|
||||
).state_dict()
|
||||
ip_cross_attn_state_dict.update(
|
||||
{
|
||||
f"{key_id}.to_k_ip.weight": sd["to_k_ip.weight"],
|
||||
f"{key_id}.to_v_ip.weight": sd["to_v_ip.weight"],
|
||||
}
|
||||
)
|
||||
|
||||
key_id += 2
|
||||
|
||||
# "image_proj" (ImageProjection layer weights)
|
||||
cross_attention_dim = model.config["cross_attention_dim"]
|
||||
image_projection = Resampler(
|
||||
embed_dims=cross_attention_dim, output_dims=cross_attention_dim, dim_head=32, heads=2, num_queries=4
|
||||
)
|
||||
|
||||
ip_image_projection_state_dict = OrderedDict()
|
||||
for k, v in image_projection.state_dict().items():
|
||||
if "2.to" in k:
|
||||
k = k.replace("2.to", "0.to")
|
||||
elif "3.0.weight" in k:
|
||||
k = k.replace("3.0.weight", "1.0.weight")
|
||||
elif "3.0.bias" in k:
|
||||
k = k.replace("3.0.bias", "1.0.bias")
|
||||
elif "3.0.weight" in k:
|
||||
k = k.replace("3.0.weight", "1.0.weight")
|
||||
elif "3.1.net.0.proj.weight" in k:
|
||||
k = k.replace("3.1.net.0.proj.weight", "1.1.weight")
|
||||
elif "3.net.2.weight" in k:
|
||||
k = k.replace("3.net.2.weight", "1.3.weight")
|
||||
elif "layers.0.0" in k:
|
||||
k = k.replace("layers.0.0", "layers.0.0.norm1")
|
||||
elif "layers.0.1" in k:
|
||||
k = k.replace("layers.0.1", "layers.0.0.norm2")
|
||||
elif "layers.1.0" in k:
|
||||
k = k.replace("layers.1.0", "layers.1.0.norm1")
|
||||
elif "layers.1.1" in k:
|
||||
k = k.replace("layers.1.1", "layers.1.0.norm2")
|
||||
elif "layers.2.0" in k:
|
||||
k = k.replace("layers.2.0", "layers.2.0.norm1")
|
||||
elif "layers.2.1" in k:
|
||||
k = k.replace("layers.2.1", "layers.2.0.norm2")
|
||||
|
||||
if "norm_cross" in k:
|
||||
ip_image_projection_state_dict[k.replace("norm_cross", "norm1")] = v
|
||||
elif "layer_norm" in k:
|
||||
ip_image_projection_state_dict[k.replace("layer_norm", "norm2")] = v
|
||||
elif "to_k" in k:
|
||||
ip_image_projection_state_dict[k.replace("to_k", "to_kv")] = torch.cat([v, v], dim=0)
|
||||
elif "to_v" in k:
|
||||
continue
|
||||
elif "to_out.0" in k:
|
||||
ip_image_projection_state_dict[k.replace("to_out.0", "to_out")] = v
|
||||
else:
|
||||
ip_image_projection_state_dict[k] = v
|
||||
|
||||
ip_state_dict = {}
|
||||
ip_state_dict.update({"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict})
|
||||
return ip_state_dict
|
||||
|
||||
|
||||
def create_custom_diffusion_layers(model, mock_weights: bool = True):
|
||||
train_kv = True
|
||||
train_q_out = True
|
||||
@@ -200,7 +285,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
|
||||
== "XFormersAttnProcessor"
|
||||
), "xformers is not enabled"
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "Gradient checkpointing skipped on MPS")
|
||||
@require_torch_accelerator_with_training
|
||||
def test_gradient_checkpointing(self):
|
||||
# enable deterministic behavior for gradient checkpointing
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
@@ -724,6 +809,56 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
|
||||
assert not sample2.allclose(sample3, atol=1e-4, rtol=1e-4)
|
||||
assert sample2.allclose(sample4, atol=1e-4, rtol=1e-4)
|
||||
|
||||
def test_ip_adapter_plus(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["attention_head_dim"] = (8, 16)
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
# forward pass without ip-adapter
|
||||
with torch.no_grad():
|
||||
sample1 = model(**inputs_dict).sample
|
||||
|
||||
# update inputs_dict for ip-adapter
|
||||
batch_size = inputs_dict["encoder_hidden_states"].shape[0]
|
||||
image_embeds = floats_tensor((batch_size, 1, model.cross_attention_dim)).to(torch_device)
|
||||
inputs_dict["added_cond_kwargs"] = {"image_embeds": image_embeds}
|
||||
|
||||
# make ip_adapter_1 and ip_adapter_2
|
||||
ip_adapter_1 = create_ip_adapter_plus_state_dict(model)
|
||||
|
||||
image_proj_state_dict_2 = {k: w + 1.0 for k, w in ip_adapter_1["image_proj"].items()}
|
||||
cross_attn_state_dict_2 = {k: w + 1.0 for k, w in ip_adapter_1["ip_adapter"].items()}
|
||||
ip_adapter_2 = {}
|
||||
ip_adapter_2.update({"image_proj": image_proj_state_dict_2, "ip_adapter": cross_attn_state_dict_2})
|
||||
|
||||
# forward pass ip_adapter_1
|
||||
model._load_ip_adapter_weights(ip_adapter_1)
|
||||
assert model.config.encoder_hid_dim_type == "ip_image_proj"
|
||||
assert model.encoder_hid_proj is not None
|
||||
assert model.down_blocks[0].attentions[0].transformer_blocks[0].attn2.processor.__class__.__name__ in (
|
||||
"IPAdapterAttnProcessor",
|
||||
"IPAdapterAttnProcessor2_0",
|
||||
)
|
||||
with torch.no_grad():
|
||||
sample2 = model(**inputs_dict).sample
|
||||
|
||||
# forward pass with ip_adapter_2
|
||||
model._load_ip_adapter_weights(ip_adapter_2)
|
||||
with torch.no_grad():
|
||||
sample3 = model(**inputs_dict).sample
|
||||
|
||||
# forward pass with ip_adapter_1 again
|
||||
model._load_ip_adapter_weights(ip_adapter_1)
|
||||
with torch.no_grad():
|
||||
sample4 = model(**inputs_dict).sample
|
||||
|
||||
assert not sample1.allclose(sample2, atol=1e-4, rtol=1e-4)
|
||||
assert not sample2.allclose(sample3, atol=1e-4, rtol=1e-4)
|
||||
assert sample2.allclose(sample4, atol=1e-4, rtol=1e-4)
|
||||
|
||||
|
||||
@slow
|
||||
class UNet2DConditionModelIntegrationTests(unittest.TestCase):
|
||||
@@ -734,7 +869,7 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def get_latents(self, seed=0, shape=(4, 4, 64, 64), fp16=False):
|
||||
dtype = torch.float16 if fp16 else torch.float32
|
||||
@@ -752,6 +887,7 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
|
||||
|
||||
return model
|
||||
|
||||
@require_torch_gpu
|
||||
def test_set_attention_slice_auto(self):
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_max_memory_allocated()
|
||||
@@ -771,6 +907,7 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
|
||||
|
||||
assert mem_bytes < 5 * 10**9
|
||||
|
||||
@require_torch_gpu
|
||||
def test_set_attention_slice_max(self):
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_max_memory_allocated()
|
||||
@@ -790,6 +927,7 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
|
||||
|
||||
assert mem_bytes < 5 * 10**9
|
||||
|
||||
@require_torch_gpu
|
||||
def test_set_attention_slice_int(self):
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_max_memory_allocated()
|
||||
@@ -809,6 +947,7 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
|
||||
|
||||
assert mem_bytes < 5 * 10**9
|
||||
|
||||
@require_torch_gpu
|
||||
def test_set_attention_slice_list(self):
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_max_memory_allocated()
|
||||
@@ -845,7 +984,7 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
|
||||
# fmt: on
|
||||
]
|
||||
)
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator_with_fp16
|
||||
def test_compvis_sd_v1_4(self, seed, timestep, expected_slice):
|
||||
model = self.get_unet_model(model_id="CompVis/stable-diffusion-v1-4")
|
||||
latents = self.get_latents(seed)
|
||||
@@ -873,7 +1012,7 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
|
||||
# fmt: on
|
||||
]
|
||||
)
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator_with_fp16
|
||||
def test_compvis_sd_v1_4_fp16(self, seed, timestep, expected_slice):
|
||||
model = self.get_unet_model(model_id="CompVis/stable-diffusion-v1-4", fp16=True)
|
||||
latents = self.get_latents(seed, fp16=True)
|
||||
@@ -901,7 +1040,8 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
|
||||
# fmt: on
|
||||
]
|
||||
)
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator
|
||||
@skip_mps
|
||||
def test_compvis_sd_v1_5(self, seed, timestep, expected_slice):
|
||||
model = self.get_unet_model(model_id="runwayml/stable-diffusion-v1-5")
|
||||
latents = self.get_latents(seed)
|
||||
@@ -929,7 +1069,7 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
|
||||
# fmt: on
|
||||
]
|
||||
)
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator_with_fp16
|
||||
def test_compvis_sd_v1_5_fp16(self, seed, timestep, expected_slice):
|
||||
model = self.get_unet_model(model_id="runwayml/stable-diffusion-v1-5", fp16=True)
|
||||
latents = self.get_latents(seed, fp16=True)
|
||||
@@ -957,7 +1097,8 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
|
||||
# fmt: on
|
||||
]
|
||||
)
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator
|
||||
@skip_mps
|
||||
def test_compvis_sd_inpaint(self, seed, timestep, expected_slice):
|
||||
model = self.get_unet_model(model_id="runwayml/stable-diffusion-inpainting")
|
||||
latents = self.get_latents(seed, shape=(4, 9, 64, 64))
|
||||
@@ -985,7 +1126,7 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
|
||||
# fmt: on
|
||||
]
|
||||
)
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator_with_fp16
|
||||
def test_compvis_sd_inpaint_fp16(self, seed, timestep, expected_slice):
|
||||
model = self.get_unet_model(model_id="runwayml/stable-diffusion-inpainting", fp16=True)
|
||||
latents = self.get_latents(seed, shape=(4, 9, 64, 64), fp16=True)
|
||||
@@ -1013,7 +1154,7 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
|
||||
# fmt: on
|
||||
]
|
||||
)
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator_with_fp16
|
||||
def test_stabilityai_sd_v2_fp16(self, seed, timestep, expected_slice):
|
||||
model = self.get_unet_model(model_id="stabilityai/stable-diffusion-2", fp16=True)
|
||||
latents = self.get_latents(seed, shape=(4, 4, 96, 96), fp16=True)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user