Compare commits
9 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 66f94eaa0c | |||
| c1f2609a40 | |||
| 552634d688 | |||
| 7d52558c15 | |||
| 3efe355d52 | |||
| 08e6558ab8 | |||
| 1547720209 | |||
| 674d43fd68 | |||
| e7a16666ea |
@@ -18,11 +18,11 @@ The abstract from the paper is:
|
||||
|
||||
*Video synthesis has recently made remarkable strides benefiting from the rapid development of diffusion models. However, it still encounters challenges in terms of semantic accuracy, clarity and spatio-temporal continuity. They primarily arise from the scarcity of well-aligned text-video data and the complex inherent structure of videos, making it difficult for the model to simultaneously ensure semantic and qualitative excellence. In this report, we propose a cascaded I2VGen-XL approach that enhances model performance by decoupling these two factors and ensures the alignment of the input data by utilizing static images as a form of crucial guidance. I2VGen-XL consists of two stages: i) the base stage guarantees coherent semantics and preserves content from input images by using two hierarchical encoders, and ii) the refinement stage enhances the video's details by incorporating an additional brief text and improves the resolution to 1280×720. To improve the diversity, we collect around 35 million single-shot text-video pairs and 6 billion text-image pairs to optimize the model. By this means, I2VGen-XL can simultaneously enhance the semantic accuracy, continuity of details and clarity of generated videos. Through extensive experiments, we have investigated the underlying principles of I2VGen-XL and compared it with current top methods, which can demonstrate its effectiveness on diverse data. The source code and models will be publicly available at [this https URL](https://i2vgen-xl.github.io/).*
|
||||
|
||||
The original codebase can be found [here](https://github.com/ali-vilab/i2vgen-xl/). The model checkpoints can be found [here](https://huggingface.co/ali-vilab/).
|
||||
The original codebase can be found [here](https://github.com/ali-vilab/i2vgen-xl/). The model checkpoints can be found [here](https://huggingface.co/ali-vilab/).
|
||||
|
||||
<Tip>
|
||||
|
||||
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines. Also, to know more about reducing the memory usage of this pipeline, refer to the ["Reduce memory usage"] section [here](../../using-diffusers/svd#reduce-memory-usage).
|
||||
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines. Also, to know more about reducing the memory usage of this pipeline, refer to the ["Reduce memory usage"] section [here](../../using-diffusers/svd#reduce-memory-usage).
|
||||
|
||||
</Tip>
|
||||
|
||||
@@ -31,7 +31,7 @@ Sample output with I2VGenXL:
|
||||
<table>
|
||||
<tr>
|
||||
<td><center>
|
||||
library.
|
||||
masterpiece, bestquality, sunset.
|
||||
<br>
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/i2vgen-xl-example.gif"
|
||||
alt="library"
|
||||
@@ -43,9 +43,9 @@ Sample output with I2VGenXL:
|
||||
## Notes
|
||||
|
||||
* I2VGenXL always uses a `clip_skip` value of 1. This means it leverages the penultimate layer representations from the text encoder of CLIP.
|
||||
* It can generate videos of quality that is often on par with [Stable Video Diffusion](../../using-diffusers/svd) (SVD).
|
||||
* Unlike SVD, it additionally accepts text prompts as inputs.
|
||||
* It can generate higher resolution videos.
|
||||
* It can generate videos of quality that is often on par with [Stable Video Diffusion](../../using-diffusers/svd) (SVD).
|
||||
* Unlike SVD, it additionally accepts text prompts as inputs.
|
||||
* It can generate higher resolution videos.
|
||||
* When using the [`DDIMScheduler`] (which is default for this pipeline), less than 50 steps for inference leads to bad results.
|
||||
|
||||
## I2VGenXLPipeline
|
||||
|
||||
@@ -70,7 +70,7 @@ Here are some sample outputs:
|
||||
<table>
|
||||
<tr>
|
||||
<td><center>
|
||||
cat in a field.
|
||||
masterpiece, bestquality, sunset.
|
||||
<br>
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/pia-default-output.gif"
|
||||
alt="cat in a field"
|
||||
@@ -119,7 +119,7 @@ image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/pix2pix/cat_6.png?download=true"
|
||||
)
|
||||
image = image.resize((512, 512))
|
||||
prompt = "cat in a field"
|
||||
prompt = "cat in a hat"
|
||||
negative_prompt = "wrong white balance, dark, sketches,worst quality,low quality"
|
||||
|
||||
generator = torch.Generator("cpu").manual_seed(0)
|
||||
@@ -132,7 +132,7 @@ export_to_gif(frames, "pia-freeinit-animation.gif")
|
||||
<table>
|
||||
<tr>
|
||||
<td><center>
|
||||
cat in a field.
|
||||
masterpiece, bestquality, sunset.
|
||||
<br>
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/pia-freeinit-output-cat.gif"
|
||||
alt="cat in a field"
|
||||
|
||||
@@ -41,7 +41,7 @@ pipe = DiffusionPipeline.from_pretrained("damo-vilab/text-to-video-ms-1.7b", tor
|
||||
pipe = pipe.to("cuda")
|
||||
|
||||
prompt = "Spiderman is surfing"
|
||||
video_frames = pipe(prompt).frames[0]
|
||||
video_frames = pipe(prompt).frames
|
||||
video_path = export_to_video(video_frames)
|
||||
video_path
|
||||
```
|
||||
@@ -64,7 +64,7 @@ pipe.enable_model_cpu_offload()
|
||||
pipe.enable_vae_slicing()
|
||||
|
||||
prompt = "Darth Vader surfing a wave"
|
||||
video_frames = pipe(prompt, num_frames=64).frames[0]
|
||||
video_frames = pipe(prompt, num_frames=64).frames
|
||||
video_path = export_to_video(video_frames)
|
||||
video_path
|
||||
```
|
||||
@@ -83,7 +83,7 @@ pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
prompt = "Spiderman is surfing"
|
||||
video_frames = pipe(prompt, num_inference_steps=25).frames[0]
|
||||
video_frames = pipe(prompt, num_inference_steps=25).frames
|
||||
video_path = export_to_video(video_frames)
|
||||
video_path
|
||||
```
|
||||
@@ -130,7 +130,7 @@ pipe.unet.enable_forward_chunking(chunk_size=1, dim=1)
|
||||
pipe.enable_vae_slicing()
|
||||
|
||||
prompt = "Darth Vader surfing a wave"
|
||||
video_frames = pipe(prompt, num_frames=24).frames[0]
|
||||
video_frames = pipe(prompt, num_frames=24).frames
|
||||
video_path = export_to_video(video_frames)
|
||||
video_path
|
||||
```
|
||||
@@ -148,7 +148,7 @@ pipe.enable_vae_slicing()
|
||||
|
||||
video = [Image.fromarray(frame).resize((1024, 576)) for frame in video_frames]
|
||||
|
||||
video_frames = pipe(prompt, video=video, strength=0.6).frames[0]
|
||||
video_frames = pipe(prompt, video=video, strength=0.6).frames
|
||||
video_path = export_to_video(video_frames)
|
||||
video_path
|
||||
```
|
||||
|
||||
@@ -1,244 +0,0 @@
|
||||
# Advanced diffusion training examples
|
||||
|
||||
## Train Dreambooth LoRA with Stable Diffusion XL
|
||||
> [!TIP]
|
||||
> 💡 This example follows the techniques and recommended practices covered in the blog post: [LoRA training scripts of the world, unite!](https://huggingface.co/blog/sdxl_lora_advanced_script). Make sure to check it out before starting 🤗
|
||||
|
||||
[DreamBooth](https://arxiv.org/abs/2208.12242) is a method to personalize text2image models like stable diffusion given just a few(3~5) images of a subject.
|
||||
|
||||
LoRA - Low-Rank Adaption of Large Language Models, was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*
|
||||
In a nutshell, LoRA allows to adapt pretrained models by adding pairs of rank-decomposition matrices to existing weights and **only** training those newly added weights. This has a couple of advantages:
|
||||
- Previous pretrained weights are kept frozen so that the model is not prone to [catastrophic forgetting](https://www.pnas.org/doi/10.1073/pnas.1611835114)
|
||||
- Rank-decomposition matrices have significantly fewer parameters than the original model, which means that trained LoRA weights are easily portable.
|
||||
- LoRA attention layers allow to control to which extent the model is adapted towards new training images via a `scale` parameter.
|
||||
[cloneofsimo](https://github.com/cloneofsimo) was the first to try out LoRA training for Stable Diffusion in
|
||||
the popular [lora](https://github.com/cloneofsimo/lora) GitHub repository.
|
||||
|
||||
The `train_dreambooth_lora_sdxl_advanced.py` script shows how to implement dreambooth-LoRA, combining the training process shown in `train_dreambooth_lora_sdxl.py`, with
|
||||
advanced features and techniques, inspired and built upon contributions by [Nataniel Ruiz](https://twitter.com/natanielruizg): [Dreambooth](https://dreambooth.github.io), [Rinon Gal](https://twitter.com/RinonGal): [Textual Inversion](https://textual-inversion.github.io), [Ron Mokady](https://twitter.com/MokadyRon): [Pivotal Tuning](https://arxiv.org/abs/2106.05744), [Simo Ryu](https://twitter.com/cloneofsimo): [cog-sdxl](https://github.com/replicate/cog-sdxl),
|
||||
[Kohya](https://twitter.com/kohya_tech/): [sd-scripts](https://github.com/kohya-ss/sd-scripts), [The Last Ben](https://twitter.com/__TheBen): [fast-stable-diffusion](https://github.com/TheLastBen/fast-stable-diffusion) ❤️
|
||||
|
||||
> [!NOTE]
|
||||
> 💡If this is your first time training a Dreambooth LoRA, congrats!🥳
|
||||
> You might want to familiarize yourself more with the techniques: [Dreambooth blog](https://huggingface.co/blog/dreambooth), [Using LoRA for Efficient Stable Diffusion Fine-Tuning blog](https://huggingface.co/blog/lora)
|
||||
|
||||
📚 Read more about the advanced features and best practices in this community derived blog post: [LoRA training scripts of the world, unite!](https://huggingface.co/blog/sdxl_lora_advanced_script)
|
||||
|
||||
|
||||
## Running locally with PyTorch
|
||||
|
||||
### Installing the dependencies
|
||||
|
||||
Before running the scripts, make sure to install the library's training dependencies:
|
||||
|
||||
**Important**
|
||||
|
||||
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
|
||||
```bash
|
||||
git clone https://github.com/huggingface/diffusers
|
||||
cd diffusers
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
Then cd in the `examples/advanced_diffusion_training` folder and run
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
|
||||
|
||||
```bash
|
||||
accelerate config
|
||||
```
|
||||
|
||||
Or for a default accelerate configuration without answering questions about your environment
|
||||
|
||||
```bash
|
||||
accelerate config default
|
||||
```
|
||||
|
||||
Or if your environment doesn't support an interactive shell e.g. a notebook
|
||||
|
||||
```python
|
||||
from accelerate.utils import write_basic_config
|
||||
write_basic_config()
|
||||
```
|
||||
|
||||
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
|
||||
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
|
||||
|
||||
### Pivotal Tuning
|
||||
**Training with text encoder(s)**
|
||||
|
||||
Alongside the UNet, LoRA fine-tuning of the text encoders is also supported. In addition to the text encoder optimization
|
||||
available with `train_dreambooth_lora_sdxl_advanced.py`, in the advanced script **pivotal tuning** is also supported.
|
||||
[pivotal tuning](https://huggingface.co/blog/sdxl_lora_advanced_script#pivotal-tuning) combines Textual Inversion with regular diffusion fine-tuning -
|
||||
we insert new tokens into the text encoders of the model, instead of reusing existing ones.
|
||||
We then optimize the newly-inserted token embeddings to represent the new concept.
|
||||
|
||||
To do so, just specify `--train_text_encoder_ti` while launching training (for regular text encoder optimizations, use `--train_text_encoder`).
|
||||
Please keep the following points in mind:
|
||||
|
||||
* SDXL has two text encoders. So, we fine-tune both using LoRA.
|
||||
* When not fine-tuning the text encoders, we ALWAYS precompute the text embeddings to save memoםהקרry.
|
||||
|
||||
|
||||
### 3D icon example
|
||||
|
||||
Now let's get our dataset. For this example we will use some cool images of 3d rendered icons: https://huggingface.co/datasets/linoyts/3d_icon.
|
||||
|
||||
Let's first download it locally:
|
||||
|
||||
```python
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
local_dir = "./3d_icon"
|
||||
snapshot_download(
|
||||
"LinoyTsaban/3d_icon",
|
||||
local_dir=local_dir, repo_type="dataset",
|
||||
ignore_patterns=".gitattributes",
|
||||
)
|
||||
```
|
||||
|
||||
Let's review some of the advanced features we're going to be using for this example:
|
||||
- **custom captions**:
|
||||
To use custom captioning, first ensure that you have the datasets library installed, otherwise you can install it by
|
||||
```bash
|
||||
pip install datasets
|
||||
```
|
||||
|
||||
Now we'll simply specify the name of the dataset and caption column (in this case it's "prompt")
|
||||
|
||||
```
|
||||
--dataset_name=./3d_icon
|
||||
--caption_column=prompt
|
||||
```
|
||||
|
||||
You can also load a dataset straight from by specifying it's name in `dataset_name`.
|
||||
Look [here](https://huggingface.co/blog/sdxl_lora_advanced_script#custom-captioning) for more info on creating/loadin your own caption dataset.
|
||||
|
||||
- **optimizer**: for this example, we'll use [prodigy](https://huggingface.co/blog/sdxl_lora_advanced_script#adaptive-optimizers) - an adaptive optimizer
|
||||
- **pivotal tuning**
|
||||
- **min SNR gamma**
|
||||
|
||||
**Now, we can launch training:**
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
|
||||
export DATASET_NAME="./3d_icon"
|
||||
export OUTPUT_DIR="3d-icon-SDXL-LoRA"
|
||||
export VAE_PATH="madebyollin/sdxl-vae-fp16-fix"
|
||||
|
||||
accelerate launch train_dreambooth_lora_sdxl_advanced.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--pretrained_vae_model_name_or_path=$VAE_PATH \
|
||||
--dataset_name=$DATASET_NAME \
|
||||
--instance_prompt="3d icon in the style of TOK" \
|
||||
--validation_prompt="a TOK icon of an astronaut riding a horse, in the style of TOK" \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--caption_column="prompt" \
|
||||
--mixed_precision="bf16" \
|
||||
--resolution=1024 \
|
||||
--train_batch_size=3 \
|
||||
--repeats=1 \
|
||||
--report_to="wandb"\
|
||||
--gradient_accumulation_steps=1 \
|
||||
--gradient_checkpointing \
|
||||
--learning_rate=1.0 \
|
||||
--text_encoder_lr=1.0 \
|
||||
--optimizer="prodigy"\
|
||||
--train_text_encoder_ti\
|
||||
--train_text_encoder_ti_frac=0.5\
|
||||
--snr_gamma=5.0 \
|
||||
--lr_scheduler="constant" \
|
||||
--lr_warmup_steps=0 \
|
||||
--rank=8 \
|
||||
--max_train_steps=1000 \
|
||||
--checkpointing_steps=2000 \
|
||||
--seed="0" \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
To better track our training experiments, we're using the following flags in the command above:
|
||||
|
||||
* `report_to="wandb` will ensure the training runs are tracked on Weights and Biases. To use it, be sure to install `wandb` with `pip install wandb`.
|
||||
* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
|
||||
|
||||
Our experiments were conducted on a single 40GB A100 GPU.
|
||||
|
||||
|
||||
### Inference
|
||||
|
||||
Once training is done, we can perform inference like so:
|
||||
1. starting with loading the unet lora weights
|
||||
```python
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download, upload_file
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers.models import AutoencoderKL
|
||||
from safetensors.torch import load_file
|
||||
|
||||
username = "linoyts"
|
||||
repo_id = f"{username}/3d-icon-SDXL-LoRA"
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
torch_dtype=torch.float16,
|
||||
variant="fp16",
|
||||
).to("cuda")
|
||||
|
||||
|
||||
pipe.load_lora_weights(repo_id, weight_name="pytorch_lora_weights.safetensors")
|
||||
```
|
||||
2. now we load the pivotal tuning embeddings
|
||||
|
||||
```python
|
||||
text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
|
||||
tokenizers = [pipe.tokenizer, pipe.tokenizer_2]
|
||||
|
||||
embedding_path = hf_hub_download(repo_id=repo_id, filename="3d-icon-SDXL-LoRA_emb.safetensors", repo_type="model")
|
||||
|
||||
state_dict = load_file(embedding_path)
|
||||
# load embeddings of text_encoder 1 (CLIP ViT-L/14)
|
||||
pipe.load_textual_inversion(state_dict["clip_l"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer)
|
||||
# load embeddings of text_encoder 2 (CLIP ViT-G/14)
|
||||
pipe.load_textual_inversion(state_dict["clip_g"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder_2, tokenizer=pipe.tokenizer_2)
|
||||
```
|
||||
|
||||
3. let's generate images
|
||||
|
||||
```python
|
||||
instance_token = "<s0><s1>"
|
||||
prompt = f"a {instance_token} icon of an orange llama eating ramen, in the style of {instance_token}"
|
||||
|
||||
image = pipe(prompt=prompt, num_inference_steps=25, cross_attention_kwargs={"scale": 1.0}).images[0]
|
||||
image.save("llama.png")
|
||||
```
|
||||
|
||||
### Comfy UI / AUTOMATIC1111 Inference
|
||||
The new script fully supports textual inversion loading with Comfy UI and AUTOMATIC1111 formats!
|
||||
|
||||
**AUTOMATIC1111 / SD.Next** \
|
||||
In AUTOMATIC1111/SD.Next we will load a LoRA and a textual embedding at the same time.
|
||||
- *LoRA*: Besides the diffusers format, the script will also train a WebUI compatible LoRA. It is generated as `{your_lora_name}.safetensors`. You can then include it in your `models/Lora` directory.
|
||||
- *Embedding*: the embedding is the same for diffusers and WebUI. You can download your `{lora_name}_emb.safetensors` file from a trained model, and include it in your `embeddings` directory.
|
||||
|
||||
You can then run inference by prompting `a y2k_emb webpage about the movie Mean Girls <lora:y2k:0.9>`. You can use the `y2k_emb` token normally, including increasing its weight by doing `(y2k_emb:1.2)`.
|
||||
|
||||
**ComfyUI** \
|
||||
In ComfyUI we will load a LoRA and a textual embedding at the same time.
|
||||
- *LoRA*: Besides the diffusers format, the script will also train a ComfyUI compatible LoRA. It is generated as `{your_lora_name}.safetensors`. You can then include it in your `models/Lora` directory. Then you will load the LoRALoader node and hook that up with your model and CLIP. [Official guide for loading LoRAs](https://comfyanonymous.github.io/ComfyUI_examples/lora/)
|
||||
- *Embedding*: the embedding is the same for diffusers and WebUI. You can download your `{lora_name}_emb.safetensors` file from a trained model, and include it in your `models/embeddings` directory and use it in your prompts like `embedding:y2k_emb`. [Official guide for loading embeddings](https://comfyanonymous.github.io/ComfyUI_examples/textual_inversion_embeddings/).
|
||||
-
|
||||
### Specifying a better VAE
|
||||
|
||||
SDXL's VAE is known to suffer from numerical instability issues. This is why we also expose a CLI argument namely `--pretrained_vae_model_name_or_path` that lets you specify the location of a better VAE (such as [this one](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)).
|
||||
|
||||
|
||||
### Tips and Tricks
|
||||
Check out [these recommended practices](https://huggingface.co/blog/sdxl_lora_advanced_script#additional-good-practices)
|
||||
|
||||
## Running on Colab Notebook
|
||||
Check out [this notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/SDXL_DreamBooth_LoRA_advanced_example.ipynb).
|
||||
to train using the advanced features (including pivotal tuning), and [this notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/SDXL_DreamBooth_LoRA_.ipynb) to train on a free colab, using some of the advanced features (excluding pivotal tuning)
|
||||
|
||||
@@ -1,7 +0,0 @@
|
||||
accelerate>=0.16.0
|
||||
torchvision
|
||||
transformers>=4.25.1
|
||||
ftfy
|
||||
tensorboard
|
||||
Jinja2
|
||||
peft==0.7.0
|
||||
@@ -70,7 +70,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.25.0.dev0")
|
||||
check_min_version("0.26.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -119,9 +119,10 @@ def save_model_card(
|
||||
diffusers_imports_pivotal = """from huggingface_hub import hf_hub_download
|
||||
from safetensors.torch import load_file
|
||||
"""
|
||||
diffusers_example_pivotal = f"""embedding_path = hf_hub_download(repo_id='{repo_id}', filename='{embeddings_filename}.safetensors', repo_type="model")
|
||||
diffusers_example_pivotal = f"""embedding_path = hf_hub_download(repo_id='{repo_id}', filename='{embeddings_filename}.safetensors' repo_type="model")
|
||||
state_dict = load_file(embedding_path)
|
||||
pipeline.load_textual_inversion(state_dict["clip_l"], token=[{ti_keys}], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer)
|
||||
pipeline.load_textual_inversion(state_dict["clip_g"], token=[{ti_keys}], text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2)
|
||||
"""
|
||||
webui_example_pivotal = f"""- *Embeddings*: download **[`{embeddings_filename}.safetensors` here 💾](/{repo_id}/blob/main/{embeddings_filename}.safetensors)**.
|
||||
- Place it on it on your `embeddings` folder
|
||||
@@ -388,7 +389,7 @@ def parse_args(input_args=None):
|
||||
parser.add_argument(
|
||||
"--resolution",
|
||||
type=int,
|
||||
default=512,
|
||||
default=1024,
|
||||
help=(
|
||||
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
||||
" resolution"
|
||||
@@ -644,7 +645,6 @@ def parse_args(input_args=None):
|
||||
parser.add_argument(
|
||||
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
|
||||
)
|
||||
parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
|
||||
parser.add_argument(
|
||||
"--rank",
|
||||
type=int,
|
||||
@@ -745,11 +745,10 @@ class TokenEmbeddingsHandler:
|
||||
|
||||
idx += 1
|
||||
|
||||
# copied from train_dreambooth_lora_sdxl_advanced.py
|
||||
def save_embeddings(self, file_path: str):
|
||||
assert self.train_ids is not None, "Initialize new tokens before saving embeddings."
|
||||
tensors = {}
|
||||
# text_encoder_0 - CLIP ViT-L/14, text_encoder_1 - CLIP ViT-G/14 - TODO - change for sd
|
||||
# text_encoder_0 - CLIP ViT-L/14, text_encoder_1 - CLIP ViT-G/14
|
||||
idx_to_text_encoder_name = {0: "clip_l", 1: "clip_g"}
|
||||
for idx, text_encoder in enumerate(self.text_encoders):
|
||||
assert text_encoder.text_model.embeddings.token_embedding.weight.data.shape[0] == len(
|
||||
@@ -1635,11 +1634,6 @@ def main(args):
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(model_input)
|
||||
if args.noise_offset:
|
||||
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
|
||||
noise += args.noise_offset * torch.randn(
|
||||
(model_input.shape[0], model_input.shape[1], 1, 1), device=model_input.device
|
||||
)
|
||||
bsz = model_input.shape[0]
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.randint(
|
||||
@@ -1794,7 +1788,6 @@ def main(args):
|
||||
pipeline = StableDiffusionPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
vae=vae,
|
||||
tokenizer=tokenizer_one,
|
||||
text_encoder=accelerator.unwrap_model(text_encoder_one),
|
||||
unet=accelerator.unwrap_model(unet),
|
||||
revision=args.revision,
|
||||
@@ -1867,11 +1860,6 @@ def main(args):
|
||||
unet_lora_layers=unet_lora_layers,
|
||||
text_encoder_lora_layers=text_encoder_lora_layers,
|
||||
)
|
||||
|
||||
if args.train_text_encoder_ti:
|
||||
embeddings_path = f"{args.output_dir}/{args.output_dir}_emb.safetensors"
|
||||
embedding_handler.save_embeddings(embeddings_path)
|
||||
|
||||
images = []
|
||||
if args.validation_prompt and args.num_validation_images > 0:
|
||||
# Final inference
|
||||
@@ -1907,18 +1895,6 @@ def main(args):
|
||||
# load attention processors
|
||||
pipeline.load_lora_weights(args.output_dir)
|
||||
|
||||
# load new tokens
|
||||
if args.train_text_encoder_ti:
|
||||
state_dict = load_file(embeddings_path)
|
||||
all_new_tokens = []
|
||||
for key, value in token_abstraction_dict.items():
|
||||
all_new_tokens.extend(value)
|
||||
pipeline.load_textual_inversion(
|
||||
state_dict["clip_l"],
|
||||
token=all_new_tokens,
|
||||
text_encoder=pipeline.text_encoder,
|
||||
tokenizer=pipeline.tokenizer,
|
||||
)
|
||||
# run inference
|
||||
pipeline = pipeline.to(accelerator.device)
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
||||
@@ -1941,6 +1917,11 @@ def main(args):
|
||||
}
|
||||
)
|
||||
|
||||
if args.train_text_encoder_ti:
|
||||
embedding_handler.save_embeddings(
|
||||
f"{args.output_dir}/{args.output_dir}_emb.safetensors",
|
||||
)
|
||||
|
||||
# Conver to WebUI format
|
||||
lora_state_dict = load_file(f"{args.output_dir}/pytorch_lora_weights.safetensors")
|
||||
peft_state_dict = convert_all_state_dict_to_peft(lora_state_dict)
|
||||
|
||||
@@ -20,7 +20,6 @@ import itertools
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import shutil
|
||||
import warnings
|
||||
@@ -46,7 +45,6 @@ from PIL.ImageOps import exif_transpose
|
||||
from safetensors.torch import load_file, save_file
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision import transforms
|
||||
from torchvision.transforms.functional import crop
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import AutoTokenizer, PretrainedConfig
|
||||
|
||||
@@ -74,7 +72,7 @@ from diffusers.utils.torch_utils import is_compiled_module
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.26.0.dev0")
|
||||
check_min_version("0.26.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -123,7 +121,7 @@ def save_model_card(
|
||||
diffusers_imports_pivotal = """from huggingface_hub import hf_hub_download
|
||||
from safetensors.torch import load_file
|
||||
"""
|
||||
diffusers_example_pivotal = f"""embedding_path = hf_hub_download(repo_id='{repo_id}', filename='{embeddings_filename}.safetensors', repo_type="model")
|
||||
diffusers_example_pivotal = f"""embedding_path = hf_hub_download(repo_id='{repo_id}', filename='{embeddings_filename}.safetensors' repo_type="model")
|
||||
state_dict = load_file(embedding_path)
|
||||
pipeline.load_textual_inversion(state_dict["clip_l"], token=[{ti_keys}], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer)
|
||||
pipeline.load_textual_inversion(state_dict["clip_g"], token=[{ti_keys}], text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2)
|
||||
@@ -399,6 +397,18 @@ def parse_args(input_args=None):
|
||||
" resolution"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--crops_coords_top_left_h",
|
||||
type=int,
|
||||
default=0,
|
||||
help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--crops_coords_top_left_w",
|
||||
type=int,
|
||||
default=0,
|
||||
help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--center_crop",
|
||||
default=False,
|
||||
@@ -408,11 +418,6 @@ def parse_args(input_args=None):
|
||||
" cropped. The images will be resized to the resolution first before cropping."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--random_flip",
|
||||
action="store_true",
|
||||
help="whether to randomly flip images horizontally",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_text_encoder",
|
||||
action="store_true",
|
||||
@@ -654,7 +659,6 @@ def parse_args(input_args=None):
|
||||
parser.add_argument(
|
||||
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
|
||||
)
|
||||
parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
|
||||
parser.add_argument(
|
||||
"--rank",
|
||||
type=int,
|
||||
@@ -897,41 +901,6 @@ class DreamBoothDataset(Dataset):
|
||||
self.instance_images = []
|
||||
for img in instance_images:
|
||||
self.instance_images.extend(itertools.repeat(img, repeats))
|
||||
|
||||
# image processing to prepare for using SD-XL micro-conditioning
|
||||
self.original_sizes = []
|
||||
self.crop_top_lefts = []
|
||||
self.pixel_values = []
|
||||
train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)
|
||||
train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size)
|
||||
train_flip = transforms.RandomHorizontalFlip(p=1.0)
|
||||
train_transforms = transforms.Compose(
|
||||
[
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5]),
|
||||
]
|
||||
)
|
||||
for image in self.instance_images:
|
||||
image = exif_transpose(image)
|
||||
if not image.mode == "RGB":
|
||||
image = image.convert("RGB")
|
||||
self.original_sizes.append((image.height, image.width))
|
||||
image = train_resize(image)
|
||||
if args.random_flip and random.random() < 0.5:
|
||||
# flip
|
||||
image = train_flip(image)
|
||||
if args.center_crop:
|
||||
y1 = max(0, int(round((image.height - args.resolution) / 2.0)))
|
||||
x1 = max(0, int(round((image.width - args.resolution) / 2.0)))
|
||||
image = train_crop(image)
|
||||
else:
|
||||
y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))
|
||||
image = crop(image, y1, x1, h, w)
|
||||
crop_top_left = (y1, x1)
|
||||
self.crop_top_lefts.append(crop_top_left)
|
||||
image = train_transforms(image)
|
||||
self.pixel_values.append(image)
|
||||
|
||||
self.num_instance_images = len(self.instance_images)
|
||||
self._length = self.num_instance_images
|
||||
|
||||
@@ -961,12 +930,12 @@ class DreamBoothDataset(Dataset):
|
||||
|
||||
def __getitem__(self, index):
|
||||
example = {}
|
||||
instance_image = self.pixel_values[index % self.num_instance_images]
|
||||
original_size = self.original_sizes[index % self.num_instance_images]
|
||||
crop_top_left = self.crop_top_lefts[index % self.num_instance_images]
|
||||
example["instance_images"] = instance_image
|
||||
example["original_size"] = original_size
|
||||
example["crop_top_left"] = crop_top_left
|
||||
instance_image = self.instance_images[index % self.num_instance_images]
|
||||
instance_image = exif_transpose(instance_image)
|
||||
|
||||
if not instance_image.mode == "RGB":
|
||||
instance_image = instance_image.convert("RGB")
|
||||
example["instance_images"] = self.image_transforms(instance_image)
|
||||
|
||||
if self.custom_instance_prompts:
|
||||
caption = self.custom_instance_prompts[index % self.num_instance_images]
|
||||
@@ -997,8 +966,6 @@ class DreamBoothDataset(Dataset):
|
||||
def collate_fn(examples, with_prior_preservation=False):
|
||||
pixel_values = [example["instance_images"] for example in examples]
|
||||
prompts = [example["instance_prompt"] for example in examples]
|
||||
original_sizes = [example["original_size"] for example in examples]
|
||||
crop_top_lefts = [example["crop_top_left"] for example in examples]
|
||||
|
||||
# Concat class and instance examples for prior preservation.
|
||||
# We do this to avoid doing two forward passes.
|
||||
@@ -1009,12 +976,7 @@ def collate_fn(examples, with_prior_preservation=False):
|
||||
pixel_values = torch.stack(pixel_values)
|
||||
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
|
||||
|
||||
batch = {
|
||||
"pixel_values": pixel_values,
|
||||
"prompts": prompts,
|
||||
"original_sizes": original_sizes,
|
||||
"crop_top_lefts": crop_top_lefts,
|
||||
}
|
||||
batch = {"pixel_values": pixel_values, "prompts": prompts}
|
||||
return batch
|
||||
|
||||
|
||||
@@ -1236,9 +1198,7 @@ def main(args):
|
||||
args.instance_prompt = args.instance_prompt.replace(token_abs, "".join(token_replacement))
|
||||
if args.with_prior_preservation:
|
||||
args.class_prompt = args.class_prompt.replace(token_abs, "".join(token_replacement))
|
||||
if args.validation_prompt:
|
||||
args.validation_prompt = args.validation_prompt.replace(token_abs, "".join(token_replacement))
|
||||
print("validation prompt:", args.validation_prompt)
|
||||
|
||||
# initialize the new tokens for textual inversion
|
||||
embedding_handler = TokenEmbeddingsHandler(
|
||||
[text_encoder_one, text_encoder_two], [tokenizer_one, tokenizer_two]
|
||||
@@ -1579,11 +1539,11 @@ def main(args):
|
||||
# pooled text embeddings
|
||||
# time ids
|
||||
|
||||
def compute_time_ids(crops_coords_top_left, original_size=None):
|
||||
def compute_time_ids():
|
||||
# Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
|
||||
if original_size is None:
|
||||
original_size = (args.resolution, args.resolution)
|
||||
original_size = (args.resolution, args.resolution)
|
||||
target_size = (args.resolution, args.resolution)
|
||||
crops_coords_top_left = (args.crops_coords_top_left_h, args.crops_coords_top_left_w)
|
||||
add_time_ids = list(original_size + crops_coords_top_left + target_size)
|
||||
add_time_ids = torch.tensor([add_time_ids])
|
||||
add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)
|
||||
@@ -1600,6 +1560,9 @@ def main(args):
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)
|
||||
return prompt_embeds, pooled_prompt_embeds
|
||||
|
||||
# Handle instance prompt.
|
||||
instance_time_ids = compute_time_ids()
|
||||
|
||||
# If no type of tuning is done on the text_encoder and custom instance prompts are NOT
|
||||
# provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid
|
||||
# the redundant encoding.
|
||||
@@ -1610,6 +1573,7 @@ def main(args):
|
||||
|
||||
# Handle class prompt for prior-preservation.
|
||||
if args.with_prior_preservation:
|
||||
class_time_ids = compute_time_ids()
|
||||
if freeze_text_encoder:
|
||||
class_prompt_hidden_states, class_pooled_prompt_embeds = compute_text_embeddings(
|
||||
args.class_prompt, text_encoders, tokenizers
|
||||
@@ -1624,6 +1588,9 @@ def main(args):
|
||||
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
|
||||
# pack the statically computed variables appropriately here. This is so that we don't
|
||||
# have to pass them to the dataloader.
|
||||
add_time_ids = instance_time_ids
|
||||
if args.with_prior_preservation:
|
||||
add_time_ids = torch.cat([add_time_ids, class_time_ids], dim=0)
|
||||
|
||||
# if --train_text_encoder_ti we need add_special_tokens to be True fo textual inversion
|
||||
add_special_tokens = True if args.train_text_encoder_ti else False
|
||||
@@ -1646,6 +1613,12 @@ def main(args):
|
||||
tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0)
|
||||
tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)
|
||||
|
||||
if args.train_text_encoder_ti and args.validation_prompt:
|
||||
# replace instances of --token_abstraction in validation prompt with the new tokens: "<si><si+1>" etc.
|
||||
for token_abs, token_replacement in train_dataset.token_abstraction_dict.items():
|
||||
args.validation_prompt = args.validation_prompt.replace(token_abs, "".join(token_replacement))
|
||||
print("validation prompt:", args.validation_prompt)
|
||||
|
||||
if args.cache_latents:
|
||||
latents_cache = []
|
||||
for batch in tqdm(train_dataloader, desc="Caching latents"):
|
||||
@@ -1805,12 +1778,6 @@ def main(args):
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(model_input)
|
||||
if args.noise_offset:
|
||||
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
|
||||
noise += args.noise_offset * torch.randn(
|
||||
(model_input.shape[0], model_input.shape[1], 1, 1), device=model_input.device
|
||||
)
|
||||
|
||||
bsz = model_input.shape[0]
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.randint(
|
||||
@@ -1822,26 +1789,19 @@ def main(args):
|
||||
# (this is the forward diffusion process)
|
||||
noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
|
||||
|
||||
# time ids
|
||||
add_time_ids = torch.cat(
|
||||
[
|
||||
compute_time_ids(original_size=s, crops_coords_top_left=c)
|
||||
for s, c in zip(batch["original_sizes"], batch["crop_top_lefts"])
|
||||
]
|
||||
)
|
||||
|
||||
# Calculate the elements to repeat depending on the use of prior-preservation and custom captions.
|
||||
if not train_dataset.custom_instance_prompts:
|
||||
elems_to_repeat_text_embeds = bsz // 2 if args.with_prior_preservation else bsz
|
||||
elems_to_repeat_time_ids = bsz // 2 if args.with_prior_preservation else bsz
|
||||
|
||||
else:
|
||||
elems_to_repeat_text_embeds = 1
|
||||
elems_to_repeat_time_ids = bsz // 2 if args.with_prior_preservation else bsz
|
||||
|
||||
# Predict the noise residual
|
||||
if freeze_text_encoder:
|
||||
unet_added_conditions = {
|
||||
"time_ids": add_time_ids,
|
||||
# "time_ids": add_time_ids.repeat(elems_to_repeat_time_ids, 1),
|
||||
"time_ids": add_time_ids.repeat(elems_to_repeat_time_ids, 1),
|
||||
"text_embeds": unet_add_text_embeds.repeat(elems_to_repeat_text_embeds, 1),
|
||||
}
|
||||
prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1)
|
||||
@@ -1852,7 +1812,7 @@ def main(args):
|
||||
added_cond_kwargs=unet_added_conditions,
|
||||
).sample
|
||||
else:
|
||||
unet_added_conditions = {"time_ids": add_time_ids}
|
||||
unet_added_conditions = {"time_ids": add_time_ids.repeat(elems_to_repeat_time_ids, 1)}
|
||||
prompt_embeds, pooled_prompt_embeds = encode_prompt(
|
||||
text_encoders=[text_encoder_one, text_encoder_two],
|
||||
tokenizers=None,
|
||||
@@ -1994,8 +1954,6 @@ def main(args):
|
||||
pipeline = StableDiffusionXLPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
vae=vae,
|
||||
tokenizer=tokenizer_one,
|
||||
tokenizer_2=tokenizer_two,
|
||||
text_encoder=accelerator.unwrap_model(text_encoder_one),
|
||||
text_encoder_2=accelerator.unwrap_model(text_encoder_two),
|
||||
unet=accelerator.unwrap_model(unet),
|
||||
@@ -2075,11 +2033,6 @@ def main(args):
|
||||
text_encoder_lora_layers=text_encoder_lora_layers,
|
||||
text_encoder_2_lora_layers=text_encoder_2_lora_layers,
|
||||
)
|
||||
|
||||
if args.train_text_encoder_ti:
|
||||
embeddings_path = f"{args.output_dir}/{args.output_dir}_emb.safetensors"
|
||||
embedding_handler.save_embeddings(embeddings_path)
|
||||
|
||||
images = []
|
||||
if args.validation_prompt and args.num_validation_images > 0:
|
||||
# Final inference
|
||||
@@ -2115,25 +2068,6 @@ def main(args):
|
||||
# load attention processors
|
||||
pipeline.load_lora_weights(args.output_dir)
|
||||
|
||||
# load new tokens
|
||||
if args.train_text_encoder_ti:
|
||||
state_dict = load_file(embeddings_path)
|
||||
all_new_tokens = []
|
||||
for key, value in token_abstraction_dict.items():
|
||||
all_new_tokens.extend(value)
|
||||
pipeline.load_textual_inversion(
|
||||
state_dict["clip_l"],
|
||||
token=all_new_tokens,
|
||||
text_encoder=pipeline.text_encoder,
|
||||
tokenizer=pipeline.tokenizer,
|
||||
)
|
||||
pipeline.load_textual_inversion(
|
||||
state_dict["clip_g"],
|
||||
token=all_new_tokens,
|
||||
text_encoder=pipeline.text_encoder_2,
|
||||
tokenizer=pipeline.tokenizer_2,
|
||||
)
|
||||
|
||||
# run inference
|
||||
pipeline = pipeline.to(accelerator.device)
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
||||
@@ -2156,6 +2090,11 @@ def main(args):
|
||||
}
|
||||
)
|
||||
|
||||
if args.train_text_encoder_ti:
|
||||
embedding_handler.save_embeddings(
|
||||
f"{args.output_dir}/{args.output_dir}_emb.safetensors",
|
||||
)
|
||||
|
||||
# Conver to WebUI format
|
||||
lora_state_dict = load_file(f"{args.output_dir}/pytorch_lora_weights.safetensors")
|
||||
peft_state_dict = convert_all_state_dict_to_peft(lora_state_dict)
|
||||
|
||||
@@ -104,22 +104,6 @@ class LoRAIPAdapterAttnProcessor(nn.Module):
|
||||
):
|
||||
residual = hidden_states
|
||||
|
||||
# separate ip_hidden_states from encoder_hidden_states
|
||||
if encoder_hidden_states is not None:
|
||||
if isinstance(encoder_hidden_states, tuple):
|
||||
encoder_hidden_states, ip_hidden_states = encoder_hidden_states
|
||||
else:
|
||||
deprecation_message = (
|
||||
"You have passed a tensor as `encoder_hidden_states`.This is deprecated and will be removed in a future release."
|
||||
" Please make sure to update your script to pass `encoder_hidden_states` as a tuple to supress this warning."
|
||||
)
|
||||
deprecate("encoder_hidden_states not a tuple", "1.0.0", deprecation_message, standard_warn=False)
|
||||
end_pos = encoder_hidden_states.shape[1] - self.num_tokens[0]
|
||||
encoder_hidden_states, ip_hidden_states = (
|
||||
encoder_hidden_states[:, :end_pos, :],
|
||||
[encoder_hidden_states[:, end_pos:, :]],
|
||||
)
|
||||
|
||||
if attn.spatial_norm is not None:
|
||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||
|
||||
@@ -141,8 +125,15 @@ class LoRAIPAdapterAttnProcessor(nn.Module):
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
else:
|
||||
# get encoder_hidden_states, ip_hidden_states
|
||||
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
|
||||
encoder_hidden_states, ip_hidden_states = (
|
||||
encoder_hidden_states[:, :end_pos, :],
|
||||
encoder_hidden_states[:, end_pos:, :],
|
||||
)
|
||||
if attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states)
|
||||
@@ -242,22 +233,6 @@ class LoRAIPAdapterAttnProcessor2_0(nn.Module):
|
||||
):
|
||||
residual = hidden_states
|
||||
|
||||
# separate ip_hidden_states from encoder_hidden_states
|
||||
if encoder_hidden_states is not None:
|
||||
if isinstance(encoder_hidden_states, tuple):
|
||||
encoder_hidden_states, ip_hidden_states = encoder_hidden_states
|
||||
else:
|
||||
deprecation_message = (
|
||||
"You have passed a tensor as `encoder_hidden_states`.This is deprecated and will be removed in a future release."
|
||||
" Please make sure to update your script to pass `encoder_hidden_states` as a tuple to supress this warning."
|
||||
)
|
||||
deprecate("encoder_hidden_states not a tuple", "1.0.0", deprecation_message, standard_warn=False)
|
||||
end_pos = encoder_hidden_states.shape[1] - self.num_tokens[0]
|
||||
encoder_hidden_states, ip_hidden_states = (
|
||||
encoder_hidden_states[:, :end_pos, :],
|
||||
[encoder_hidden_states[:, end_pos:, :]],
|
||||
)
|
||||
|
||||
if attn.spatial_norm is not None:
|
||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||
|
||||
@@ -284,8 +259,15 @@ class LoRAIPAdapterAttnProcessor2_0(nn.Module):
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
else:
|
||||
# get encoder_hidden_states, ip_hidden_states
|
||||
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
|
||||
encoder_hidden_states, ip_hidden_states = (
|
||||
encoder_hidden_states[:, :end_pos, :],
|
||||
encoder_hidden_states[:, end_pos:, :],
|
||||
)
|
||||
if attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states)
|
||||
@@ -969,6 +951,30 @@ class IPAdapterFaceIDStableDiffusionPipeline(
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
|
||||
dtype = next(self.image_encoder.parameters()).dtype
|
||||
|
||||
if not isinstance(image, torch.Tensor):
|
||||
image = self.feature_extractor(image, return_tensors="pt").pixel_values
|
||||
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
if output_hidden_states:
|
||||
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
|
||||
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
uncond_image_enc_hidden_states = self.image_encoder(
|
||||
torch.zeros_like(image), output_hidden_states=True
|
||||
).hidden_states[-2]
|
||||
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
|
||||
num_images_per_prompt, dim=0
|
||||
)
|
||||
return image_enc_hidden_states, uncond_image_enc_hidden_states
|
||||
else:
|
||||
image_embeds = self.image_encoder(image).image_embeds
|
||||
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
uncond_image_embeds = torch.zeros_like(image_embeds)
|
||||
|
||||
return image_embeds, uncond_image_embeds
|
||||
|
||||
def run_safety_checker(self, image, device, dtype):
|
||||
if self.safety_checker is None:
|
||||
has_nsfw_concept = None
|
||||
@@ -1296,6 +1302,7 @@ class IPAdapterFaceIDStableDiffusionPipeline(
|
||||
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
|
||||
image_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated image embeddings.
|
||||
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
@@ -1404,7 +1411,7 @@ class IPAdapterFaceIDStableDiffusionPipeline(
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
if image_embeds is not None:
|
||||
image_embeds = torch.stack([image_embeds] * num_images_per_prompt, dim=0).to(
|
||||
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0).to(
|
||||
device=device, dtype=prompt_embeds.dtype
|
||||
)
|
||||
negative_image_embeds = torch.zeros_like(image_embeds)
|
||||
|
||||
@@ -40,8 +40,7 @@ from diffusers.utils import BaseOutput, check_min_version
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.26.0.dev0")
|
||||
|
||||
check_min_version("0.26.0")
|
||||
|
||||
class MarigoldDepthOutput(BaseOutput):
|
||||
"""
|
||||
|
||||
@@ -72,7 +72,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.26.0.dev0")
|
||||
check_min_version("0.26.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.26.0.dev0")
|
||||
check_min_version("0.26.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -78,7 +78,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.26.0.dev0")
|
||||
check_min_version("0.26.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -71,7 +71,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.26.0.dev0")
|
||||
check_min_version("0.26.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -77,7 +77,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.26.0.dev0")
|
||||
check_min_version("0.26.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -57,7 +57,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.26.0.dev0")
|
||||
check_min_version("0.26.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -59,7 +59,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.26.0.dev0")
|
||||
check_min_version("0.26.0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -59,7 +59,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.26.0.dev0")
|
||||
check_min_version("0.26.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -62,7 +62,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.26.0.dev0")
|
||||
check_min_version("0.26.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -62,7 +62,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.26.0.dev0")
|
||||
check_min_version("0.26.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -35,7 +35,7 @@ from diffusers.utils import check_min_version
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.26.0.dev0")
|
||||
check_min_version("0.26.0")
|
||||
|
||||
# Cache compiled models across invocations of this script.
|
||||
cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache"))
|
||||
|
||||
@@ -66,7 +66,7 @@ from diffusers.utils.torch_utils import is_compiled_module
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.26.0.dev0")
|
||||
check_min_version("0.26.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -19,7 +19,6 @@ import itertools
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
@@ -41,7 +40,6 @@ from PIL import Image
|
||||
from PIL.ImageOps import exif_transpose
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision import transforms
|
||||
from torchvision.transforms.functional import crop
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import AutoTokenizer, PretrainedConfig
|
||||
|
||||
@@ -67,7 +65,7 @@ from diffusers.utils.torch_utils import is_compiled_module
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.26.0.dev0")
|
||||
check_min_version("0.26.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -306,6 +304,18 @@ def parse_args(input_args=None):
|
||||
" resolution"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--crops_coords_top_left_h",
|
||||
type=int,
|
||||
default=0,
|
||||
help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--crops_coords_top_left_w",
|
||||
type=int,
|
||||
default=0,
|
||||
help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--center_crop",
|
||||
default=False,
|
||||
@@ -315,11 +325,6 @@ def parse_args(input_args=None):
|
||||
" cropped. The images will be resized to the resolution first before cropping."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--random_flip",
|
||||
action="store_true",
|
||||
help="whether to randomly flip images horizontally",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_text_encoder",
|
||||
action="store_true",
|
||||
@@ -664,41 +669,6 @@ class DreamBoothDataset(Dataset):
|
||||
self.instance_images = []
|
||||
for img in instance_images:
|
||||
self.instance_images.extend(itertools.repeat(img, repeats))
|
||||
|
||||
# image processing to prepare for using SD-XL micro-conditioning
|
||||
self.original_sizes = []
|
||||
self.crop_top_lefts = []
|
||||
self.pixel_values = []
|
||||
train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)
|
||||
train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size)
|
||||
train_flip = transforms.RandomHorizontalFlip(p=1.0)
|
||||
train_transforms = transforms.Compose(
|
||||
[
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5]),
|
||||
]
|
||||
)
|
||||
for image in self.instance_images:
|
||||
image = exif_transpose(image)
|
||||
if not image.mode == "RGB":
|
||||
image = image.convert("RGB")
|
||||
self.original_sizes.append((image.height, image.width))
|
||||
image = train_resize(image)
|
||||
if args.random_flip and random.random() < 0.5:
|
||||
# flip
|
||||
image = train_flip(image)
|
||||
if args.center_crop:
|
||||
y1 = max(0, int(round((image.height - args.resolution) / 2.0)))
|
||||
x1 = max(0, int(round((image.width - args.resolution) / 2.0)))
|
||||
image = train_crop(image)
|
||||
else:
|
||||
y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))
|
||||
image = crop(image, y1, x1, h, w)
|
||||
crop_top_left = (y1, x1)
|
||||
self.crop_top_lefts.append(crop_top_left)
|
||||
image = train_transforms(image)
|
||||
self.pixel_values.append(image)
|
||||
|
||||
self.num_instance_images = len(self.instance_images)
|
||||
self._length = self.num_instance_images
|
||||
|
||||
@@ -728,12 +698,12 @@ class DreamBoothDataset(Dataset):
|
||||
|
||||
def __getitem__(self, index):
|
||||
example = {}
|
||||
instance_image = self.pixel_values[index % self.num_instance_images]
|
||||
original_size = self.original_sizes[index % self.num_instance_images]
|
||||
crop_top_left = self.crop_top_lefts[index % self.num_instance_images]
|
||||
example["instance_images"] = instance_image
|
||||
example["original_size"] = original_size
|
||||
example["crop_top_left"] = crop_top_left
|
||||
instance_image = self.instance_images[index % self.num_instance_images]
|
||||
instance_image = exif_transpose(instance_image)
|
||||
|
||||
if not instance_image.mode == "RGB":
|
||||
instance_image = instance_image.convert("RGB")
|
||||
example["instance_images"] = self.image_transforms(instance_image)
|
||||
|
||||
if self.custom_instance_prompts:
|
||||
caption = self.custom_instance_prompts[index % self.num_instance_images]
|
||||
@@ -760,8 +730,6 @@ class DreamBoothDataset(Dataset):
|
||||
def collate_fn(examples, with_prior_preservation=False):
|
||||
pixel_values = [example["instance_images"] for example in examples]
|
||||
prompts = [example["instance_prompt"] for example in examples]
|
||||
original_sizes = [example["original_size"] for example in examples]
|
||||
crop_top_lefts = [example["crop_top_left"] for example in examples]
|
||||
|
||||
# Concat class and instance examples for prior preservation.
|
||||
# We do this to avoid doing two forward passes.
|
||||
@@ -772,12 +740,7 @@ def collate_fn(examples, with_prior_preservation=False):
|
||||
pixel_values = torch.stack(pixel_values)
|
||||
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
|
||||
|
||||
batch = {
|
||||
"pixel_values": pixel_values,
|
||||
"prompts": prompts,
|
||||
"original_sizes": original_sizes,
|
||||
"crop_top_lefts": crop_top_lefts,
|
||||
}
|
||||
batch = {"pixel_values": pixel_values, "prompts": prompts}
|
||||
return batch
|
||||
|
||||
|
||||
@@ -1270,9 +1233,11 @@ def main(args):
|
||||
# pooled text embeddings
|
||||
# time ids
|
||||
|
||||
def compute_time_ids(original_size, crops_coords_top_left):
|
||||
def compute_time_ids():
|
||||
# Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
|
||||
original_size = (args.resolution, args.resolution)
|
||||
target_size = (args.resolution, args.resolution)
|
||||
crops_coords_top_left = (args.crops_coords_top_left_h, args.crops_coords_top_left_w)
|
||||
add_time_ids = list(original_size + crops_coords_top_left + target_size)
|
||||
add_time_ids = torch.tensor([add_time_ids])
|
||||
add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)
|
||||
@@ -1289,6 +1254,9 @@ def main(args):
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)
|
||||
return prompt_embeds, pooled_prompt_embeds
|
||||
|
||||
# Handle instance prompt.
|
||||
instance_time_ids = compute_time_ids()
|
||||
|
||||
# If no type of tuning is done on the text_encoder and custom instance prompts are NOT
|
||||
# provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid
|
||||
# the redundant encoding.
|
||||
@@ -1299,6 +1267,7 @@ def main(args):
|
||||
|
||||
# Handle class prompt for prior-preservation.
|
||||
if args.with_prior_preservation:
|
||||
class_time_ids = compute_time_ids()
|
||||
if not args.train_text_encoder:
|
||||
class_prompt_hidden_states, class_pooled_prompt_embeds = compute_text_embeddings(
|
||||
args.class_prompt, text_encoders, tokenizers
|
||||
@@ -1313,6 +1282,9 @@ def main(args):
|
||||
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
|
||||
# pack the statically computed variables appropriately here. This is so that we don't
|
||||
# have to pass them to the dataloader.
|
||||
add_time_ids = instance_time_ids
|
||||
if args.with_prior_preservation:
|
||||
add_time_ids = torch.cat([add_time_ids, class_time_ids], dim=0)
|
||||
|
||||
if not train_dataset.custom_instance_prompts:
|
||||
if not args.train_text_encoder:
|
||||
@@ -1427,8 +1399,8 @@ def main(args):
|
||||
text_encoder_two.train()
|
||||
|
||||
# set top parameter requires_grad = True for gradient checkpointing works
|
||||
accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)
|
||||
accelerator.unwrap_model(text_encoder_two).text_model.embeddings.requires_grad_(True)
|
||||
text_encoder_one.text_model.embeddings.requires_grad_(True)
|
||||
text_encoder_two.text_model.embeddings.requires_grad_(True)
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
with accelerator.accumulate(unet):
|
||||
@@ -1464,24 +1436,18 @@ def main(args):
|
||||
# (this is the forward diffusion process)
|
||||
noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
|
||||
|
||||
# time ids
|
||||
add_time_ids = torch.cat(
|
||||
[
|
||||
compute_time_ids(original_size=s, crops_coords_top_left=c)
|
||||
for s, c in zip(batch["original_sizes"], batch["crop_top_lefts"])
|
||||
]
|
||||
)
|
||||
|
||||
# Calculate the elements to repeat depending on the use of prior-preservation and custom captions.
|
||||
if not train_dataset.custom_instance_prompts:
|
||||
elems_to_repeat_text_embeds = bsz // 2 if args.with_prior_preservation else bsz
|
||||
elems_to_repeat_time_ids = bsz // 2 if args.with_prior_preservation else bsz
|
||||
else:
|
||||
elems_to_repeat_text_embeds = 1
|
||||
elems_to_repeat_time_ids = bsz // 2 if args.with_prior_preservation else bsz
|
||||
|
||||
# Predict the noise residual
|
||||
if not args.train_text_encoder:
|
||||
unet_added_conditions = {
|
||||
"time_ids": add_time_ids,
|
||||
"time_ids": add_time_ids.repeat(elems_to_repeat_time_ids, 1),
|
||||
"text_embeds": unet_add_text_embeds.repeat(elems_to_repeat_text_embeds, 1),
|
||||
}
|
||||
prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1)
|
||||
@@ -1493,7 +1459,7 @@ def main(args):
|
||||
return_dict=False,
|
||||
)[0]
|
||||
else:
|
||||
unet_added_conditions = {"time_ids": add_time_ids}
|
||||
unet_added_conditions = {"time_ids": add_time_ids.repeat(elems_to_repeat_time_ids, 1)}
|
||||
prompt_embeds, pooled_prompt_embeds = encode_prompt(
|
||||
text_encoders=[text_encoder_one, text_encoder_two],
|
||||
tokenizers=None,
|
||||
|
||||
@@ -53,7 +53,7 @@ from diffusers.utils.torch_utils import is_compiled_module
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.26.0.dev0")
|
||||
check_min_version("0.26.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -59,7 +59,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.26.0.dev0")
|
||||
check_min_version("0.26.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -52,7 +52,7 @@ if is_wandb_available():
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.26.0.dev0")
|
||||
check_min_version("0.26.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -46,7 +46,7 @@ from diffusers.utils import check_min_version, is_wandb_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.26.0.dev0")
|
||||
check_min_version("0.26.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -46,7 +46,7 @@ from diffusers.utils import check_min_version, is_wandb_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.26.0.dev0")
|
||||
check_min_version("0.26.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -51,7 +51,7 @@ if is_wandb_available():
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.26.0.dev0")
|
||||
check_min_version("0.26.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -59,7 +59,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.26.0.dev0")
|
||||
check_min_version("0.26.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -54,7 +54,7 @@ if is_wandb_available():
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.26.0.dev0")
|
||||
check_min_version("0.26.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -33,7 +33,7 @@ from diffusers.utils import check_min_version
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.26.0.dev0")
|
||||
check_min_version("0.26.0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -50,7 +50,7 @@ from diffusers.utils.torch_utils import is_compiled_module
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.26.0.dev0")
|
||||
check_min_version("0.26.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -63,7 +63,7 @@ from diffusers.utils.torch_utils import is_compiled_module
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.26.0.dev0")
|
||||
check_min_version("0.26.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -53,7 +53,7 @@ from diffusers.utils.torch_utils import is_compiled_module
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.26.0.dev0")
|
||||
check_min_version("0.26.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -79,7 +79,7 @@ else:
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.26.0.dev0")
|
||||
check_min_version("0.26.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -56,7 +56,7 @@ else:
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.26.0.dev0")
|
||||
check_min_version("0.26.0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -77,7 +77,7 @@ else:
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.25.0.dev0")
|
||||
check_min_version("0.26.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.26.0.dev0")
|
||||
check_min_version("0.26.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -50,7 +50,7 @@ if is_wandb_available():
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.26.0.dev0")
|
||||
check_min_version("0.26.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -51,7 +51,7 @@ if is_wandb_available():
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.26.0.dev0")
|
||||
check_min_version("0.26.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -249,7 +249,7 @@ version_range_max = max(sys.version_info[1], 10) + 1
|
||||
|
||||
setup(
|
||||
name="diffusers",
|
||||
version="0.26.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
version="0.26.3", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
description="State-of-the-art diffusion in PyTorch and JAX.",
|
||||
long_description=open("README.md", "r", encoding="utf-8").read(),
|
||||
long_description_content_type="text/markdown",
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
__version__ = "0.26.0.dev0"
|
||||
__version__ = "0.26.3"
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
|
||||
@@ -38,6 +38,9 @@ class FromOriginalVAEMixin:
|
||||
- A link to the `.ckpt` file (for example
|
||||
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.ckpt"`) on the Hub.
|
||||
- A path to a *file* containing all pipeline weights.
|
||||
config_file (`str`, *optional*):
|
||||
Filepath to the configuration YAML file associated with the model. If not provided it will default to:
|
||||
https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml
|
||||
torch_dtype (`str` or `torch.dtype`, *optional*):
|
||||
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
|
||||
dtype is automatically derived from the model's weights.
|
||||
@@ -65,6 +68,13 @@ class FromOriginalVAEMixin:
|
||||
image_size (`int`, *optional*, defaults to 512):
|
||||
The image size the model was trained on. Use 512 for all Stable Diffusion v1 models and the Stable
|
||||
Diffusion v2 base model. Use 768 for Stable Diffusion v2.
|
||||
scaling_factor (`float`, *optional*, defaults to 0.18215):
|
||||
The component-wise standard deviation of the trained latent space computed using the first batch of the
|
||||
training set. This is used to scale the latent space to have unit variance when training the diffusion
|
||||
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
|
||||
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z
|
||||
= 1 / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution
|
||||
Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
|
||||
use_safetensors (`bool`, *optional*, defaults to `None`):
|
||||
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
|
||||
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
|
||||
@@ -92,6 +102,7 @@ class FromOriginalVAEMixin:
|
||||
"""
|
||||
|
||||
original_config_file = kwargs.pop("original_config_file", None)
|
||||
config_file = kwargs.pop("config_file", None)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
@@ -103,6 +114,13 @@ class FromOriginalVAEMixin:
|
||||
use_safetensors = kwargs.pop("use_safetensors", True)
|
||||
|
||||
class_name = cls.__name__
|
||||
|
||||
if (config_file is not None) and (original_config_file is not None):
|
||||
raise ValueError(
|
||||
"You cannot pass both `config_file` and `original_config_file` to `from_single_file`. Please use only one of these arguments."
|
||||
)
|
||||
|
||||
original_config_file = original_config_file or config_file
|
||||
original_config, checkpoint = fetch_ldm_config_and_checkpoint(
|
||||
pretrained_model_link_or_path=pretrained_model_link_or_path,
|
||||
class_name=class_name,
|
||||
@@ -118,7 +136,10 @@ class FromOriginalVAEMixin:
|
||||
)
|
||||
|
||||
image_size = kwargs.pop("image_size", None)
|
||||
component = create_diffusers_vae_model_from_ldm(class_name, original_config, checkpoint, image_size=image_size)
|
||||
scaling_factor = kwargs.pop("scaling_factor", None)
|
||||
component = create_diffusers_vae_model_from_ldm(
|
||||
class_name, original_config, checkpoint, image_size=image_size, scaling_factor=scaling_factor
|
||||
)
|
||||
vae = component["vae"]
|
||||
if torch_dtype is not None:
|
||||
vae = vae.to(torch_dtype)
|
||||
|
||||
@@ -175,6 +175,7 @@ DIFFUSERS_TO_LDM_MAPPING = {
|
||||
}
|
||||
|
||||
LDM_VAE_KEY = "first_stage_model."
|
||||
LDM_VAE_DEFAULT_SCALING_FACTOR = 0.18215
|
||||
LDM_UNET_KEY = "model.diffusion_model."
|
||||
LDM_CONTROLNET_KEY = "control_model."
|
||||
LDM_CLIP_PREFIX_TO_REMOVE = ["cond_stage_model.transformer.", "conditioner.embedders.0.transformer."]
|
||||
@@ -518,7 +519,10 @@ def create_vae_diffusers_config(original_config, image_size, scaling_factor=None
|
||||
Creates a config for the diffusers based on the config of the LDM model.
|
||||
"""
|
||||
vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"]
|
||||
scaling_factor = scaling_factor or original_config["model"]["params"]["scale_factor"]
|
||||
if scaling_factor is None and "scale_factor" in original_config["model"]["params"]:
|
||||
scaling_factor = original_config["model"]["params"]["scale_factor"]
|
||||
elif scaling_factor is None:
|
||||
scaling_factor = LDM_VAE_DEFAULT_SCALING_FACTOR
|
||||
|
||||
block_out_channels = [vae_params["ch"] * mult for mult in vae_params["ch_mult"]]
|
||||
down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
|
||||
@@ -1174,7 +1178,7 @@ def create_diffusers_unet_model_from_ldm(
|
||||
|
||||
|
||||
def create_diffusers_vae_model_from_ldm(
|
||||
pipeline_class_name, original_config, checkpoint, image_size=None, scaling_factor=0.18125
|
||||
pipeline_class_name, original_config, checkpoint, image_size=None, scaling_factor=None
|
||||
):
|
||||
# import here to avoid circular imports
|
||||
from ..models import AutoencoderKL
|
||||
|
||||
@@ -1031,10 +1031,16 @@ class DownBlockMotion(nn.Module):
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb, scale
|
||||
)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(motion_module),
|
||||
hidden_states.requires_grad_(),
|
||||
temb,
|
||||
num_frames,
|
||||
)
|
||||
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb, scale=scale)
|
||||
hidden_states = motion_module(hidden_states, num_frames=num_frames)[0]
|
||||
hidden_states = motion_module(hidden_states, num_frames=num_frames)[0]
|
||||
|
||||
output_states = output_states + (hidden_states,)
|
||||
|
||||
@@ -1215,10 +1221,10 @@ class CrossAttnDownBlockMotion(nn.Module):
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
hidden_states = motion_module(
|
||||
hidden_states,
|
||||
num_frames=num_frames,
|
||||
)[0]
|
||||
hidden_states = motion_module(
|
||||
hidden_states,
|
||||
num_frames=num_frames,
|
||||
)[0]
|
||||
|
||||
# apply additional residuals to the output of the last pair of resnet and attention blocks
|
||||
if i == len(blocks) - 1 and additional_residuals is not None:
|
||||
@@ -1419,10 +1425,10 @@ class CrossAttnUpBlockMotion(nn.Module):
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
hidden_states = motion_module(
|
||||
hidden_states,
|
||||
num_frames=num_frames,
|
||||
)[0]
|
||||
hidden_states = motion_module(
|
||||
hidden_states,
|
||||
num_frames=num_frames,
|
||||
)[0]
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
@@ -1557,10 +1563,15 @@ class UpBlockMotion(nn.Module):
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb
|
||||
)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet),
|
||||
hidden_states,
|
||||
temb,
|
||||
)
|
||||
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb, scale=scale)
|
||||
hidden_states = motion_module(hidden_states, num_frames=num_frames)[0]
|
||||
hidden_states = motion_module(hidden_states, num_frames=num_frames)[0]
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
|
||||
@@ -792,7 +792,6 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
||||
|
||||
emb = self.time_embedding(t_emb, timestep_cond)
|
||||
emb = emb.repeat_interleave(repeats=num_frames, dim=0)
|
||||
encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0)
|
||||
|
||||
if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
|
||||
if "image_embeds" not in added_cond_kwargs:
|
||||
@@ -800,9 +799,10 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
||||
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
|
||||
)
|
||||
image_embeds = added_cond_kwargs.get("image_embeds")
|
||||
image_embeds = self.encoder_hid_proj(image_embeds)
|
||||
image_embeds = [image_embed.repeat_interleave(repeats=num_frames, dim=0) for image_embed in image_embeds]
|
||||
encoder_hidden_states = (encoder_hidden_states, image_embeds)
|
||||
image_embeds = self.encoder_hid_proj(image_embeds).to(encoder_hidden_states.dtype)
|
||||
encoder_hidden_states = torch.cat([encoder_hidden_states, image_embeds], dim=1)
|
||||
|
||||
encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0)
|
||||
|
||||
# 2. pre-process
|
||||
sample = sample.permute(0, 2, 1, 3, 4).reshape((sample.shape[0] * num_frames, -1) + sample.shape[3:])
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,205 +0,0 @@
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ..attention_processor import (
|
||||
ADDED_KV_ATTENTION_PROCESSORS,
|
||||
CROSS_ATTENTION_PROCESSORS,
|
||||
Attention,
|
||||
AttentionProcessor,
|
||||
AttnAddedKVProcessor,
|
||||
AttnProcessor,
|
||||
)
|
||||
|
||||
|
||||
class UNet2DConditionModelUtilsMixin:
|
||||
def _check_config(
|
||||
self,
|
||||
down_block_types: Tuple[str],
|
||||
up_block_types: Tuple[str],
|
||||
only_cross_attention: Union[bool, Tuple[bool]],
|
||||
block_out_channels: Tuple[int],
|
||||
layers_per_block: [int, Tuple[int]],
|
||||
cross_attention_dim: Union[int, Tuple[int]],
|
||||
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]],
|
||||
reverse_transformer_layers_per_block: bool,
|
||||
num_attention_heads: Optional[Union[int, Tuple[int]]],
|
||||
):
|
||||
if len(down_block_types) != len(up_block_types):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
|
||||
)
|
||||
|
||||
if len(block_out_channels) != len(down_block_types):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
||||
)
|
||||
|
||||
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
|
||||
)
|
||||
|
||||
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
|
||||
)
|
||||
|
||||
if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
|
||||
)
|
||||
|
||||
if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
|
||||
)
|
||||
if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None:
|
||||
for layer_number_per_block in transformer_layers_per_block:
|
||||
if isinstance(layer_number_per_block, list):
|
||||
raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.")
|
||||
|
||||
@property
|
||||
def attn_processors(self):
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
Parameters:
|
||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
for **all** `Attention` layers.
|
||||
|
||||
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||
processor. This is strongly recommended when setting trainable attention processors.
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
def set_default_attn_processor(self):
|
||||
"""
|
||||
Disables custom attention processors and sets the default attention implementation.
|
||||
"""
|
||||
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
||||
processor = AttnAddedKVProcessor()
|
||||
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
||||
processor = AttnProcessor()
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
||||
)
|
||||
|
||||
self.set_attn_processor(processor)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def enable_freeu(self, s1, s2, b1, b2):
|
||||
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
|
||||
|
||||
The suffixes after the scaling factors represent the stage blocks where they are being applied.
|
||||
|
||||
Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that
|
||||
are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
|
||||
|
||||
Args:
|
||||
s1 (`float`):
|
||||
Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
|
||||
mitigate the "oversmoothing effect" in the enhanced denoising process.
|
||||
s2 (`float`):
|
||||
Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
|
||||
mitigate the "oversmoothing effect" in the enhanced denoising process.
|
||||
b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
|
||||
b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
|
||||
"""
|
||||
for i, upsample_block in enumerate(self.up_blocks):
|
||||
setattr(upsample_block, "s1", s1)
|
||||
setattr(upsample_block, "s2", s2)
|
||||
setattr(upsample_block, "b1", b1)
|
||||
setattr(upsample_block, "b2", b2)
|
||||
|
||||
def disable_freeu(self):
|
||||
"""Disables the FreeU mechanism."""
|
||||
freeu_keys = {"s1", "s2", "b1", "b2"}
|
||||
for i, upsample_block in enumerate(self.up_blocks):
|
||||
for k in freeu_keys:
|
||||
if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
|
||||
setattr(upsample_block, k, None)
|
||||
|
||||
def fuse_qkv_projections(self):
|
||||
"""
|
||||
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
|
||||
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
"""
|
||||
self.original_attn_processors = None
|
||||
|
||||
for _, attn_processor in self.attn_processors.items():
|
||||
if "Added" in str(attn_processor.__class__.__name__):
|
||||
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
||||
|
||||
self.original_attn_processors = self.attn_processors
|
||||
|
||||
for module in self.modules():
|
||||
if isinstance(module, Attention):
|
||||
module.fuse_projections(fuse=True)
|
||||
|
||||
def unfuse_qkv_projections(self):
|
||||
"""Disables the fused QKV projection if enabled.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
|
||||
"""
|
||||
if self.original_attn_processors is not None:
|
||||
self.set_attn_processor(self.original_attn_processors)
|
||||
@@ -11,13 +11,12 @@ from ...utils import BaseOutput
|
||||
@dataclass
|
||||
class AnimateDiffPipelineOutput(BaseOutput):
|
||||
r"""
|
||||
Output class for AnimateDiff pipelines.
|
||||
Output class for AnimateDiff pipelines.
|
||||
|
||||
Args:
|
||||
frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
|
||||
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing denoised
|
||||
PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
|
||||
`(batch_size, num_frames, channels, height, width)`
|
||||
Args:
|
||||
frames (`List[List[PIL.Image.Image]]` or `torch.Tensor` or `np.ndarray`):
|
||||
List of PIL Images of length `batch_size` or torch.Tensor or np.ndarray of shape
|
||||
`(batch_size, num_frames, height, width, num_channels)`.
|
||||
"""
|
||||
|
||||
frames: Union[torch.Tensor, np.ndarray, List[List[PIL.Image.Image]]]
|
||||
frames: Union[List[List[PIL.Image.Image]], torch.Tensor, np.ndarray]
|
||||
|
||||
@@ -789,8 +789,6 @@ class StableDiffusionControlNetImg2ImgPipeline(
|
||||
|
||||
t_start = max(num_inference_steps - init_timestep, 0)
|
||||
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
||||
if hasattr(self.scheduler, "set_begin_index"):
|
||||
self.scheduler.set_begin_index(t_start * self.scheduler.order)
|
||||
|
||||
return timesteps, num_inference_steps - t_start
|
||||
|
||||
|
||||
@@ -705,8 +705,6 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
|
||||
t_start = max(num_inference_steps - init_timestep, 0)
|
||||
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
||||
if hasattr(self.scheduler, "set_begin_index"):
|
||||
self.scheduler.set_begin_index(t_start * self.scheduler.order)
|
||||
|
||||
return timesteps, num_inference_steps - t_start
|
||||
|
||||
|
||||
@@ -871,8 +871,6 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
|
||||
|
||||
t_start = max(num_inference_steps - init_timestep, 0)
|
||||
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
||||
if hasattr(self.scheduler, "set_begin_index"):
|
||||
self.scheduler.set_begin_index(t_start * self.scheduler.order)
|
||||
|
||||
return timesteps, num_inference_steps - t_start
|
||||
|
||||
|
||||
-2
@@ -566,8 +566,6 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
|
||||
|
||||
t_start = max(num_inference_steps - init_timestep, 0)
|
||||
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
||||
if hasattr(self.scheduler, "set_begin_index"):
|
||||
self.scheduler.set_begin_index(t_start * self.scheduler.order)
|
||||
|
||||
return timesteps, num_inference_steps - t_start
|
||||
|
||||
|
||||
-2
@@ -536,8 +536,6 @@ class StableDiffusionInpaintPipelineLegacy(
|
||||
|
||||
t_start = max(num_inference_steps - init_timestep, 0)
|
||||
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
||||
if hasattr(self.scheduler, "set_begin_index"):
|
||||
self.scheduler.set_begin_index(t_start * self.scheduler.order)
|
||||
|
||||
return timesteps, num_inference_steps - t_start
|
||||
|
||||
|
||||
@@ -46,7 +46,6 @@ EXAMPLE_DOC_STRING = """
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from diffusers import I2VGenXLPipeline
|
||||
>>> from diffusers.utils import export_to_gif, load_image
|
||||
|
||||
>>> pipeline = I2VGenXLPipeline.from_pretrained("ali-vilab/i2vgen-xl", torch_dtype=torch.float16, variant="fp16")
|
||||
>>> pipeline.enable_model_cpu_offload()
|
||||
@@ -96,16 +95,15 @@ def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type:
|
||||
@dataclass
|
||||
class I2VGenXLPipelineOutput(BaseOutput):
|
||||
r"""
|
||||
Output class for image-to-video pipeline.
|
||||
Output class for image-to-video pipeline.
|
||||
|
||||
Args:
|
||||
frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
|
||||
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing denoised
|
||||
PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
|
||||
`(batch_size, num_frames, channels, height, width)`
|
||||
Args:
|
||||
frames (`List[np.ndarray]` or `torch.FloatTensor`)
|
||||
List of denoised frames (essentially images) as NumPy arrays of shape `(height, width, num_channels)` or as
|
||||
a `torch` tensor. The length of the list denotes the video length (the number of frames).
|
||||
"""
|
||||
|
||||
frames: Union[torch.Tensor, np.ndarray, List[List[PIL.Image.Image]]]
|
||||
frames: Union[List[np.ndarray], torch.FloatTensor]
|
||||
|
||||
|
||||
class I2VGenXLPipeline(DiffusionPipeline):
|
||||
|
||||
-2
@@ -634,8 +634,6 @@ class LatentConsistencyModelImg2ImgPipeline(
|
||||
|
||||
t_start = max(num_inference_steps - init_timestep, 0)
|
||||
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
||||
if hasattr(self.scheduler, "set_begin_index"):
|
||||
self.scheduler.set_begin_index(t_start * self.scheduler.order)
|
||||
|
||||
return timesteps, num_inference_steps - t_start
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ import torch.fft as fft
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
||||
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel
|
||||
from ...models.lora import adjust_lora_scale_text_encoder
|
||||
from ...models.unets.unet_motion_model import MotionAdapter
|
||||
@@ -200,18 +200,16 @@ class PIAPipelineOutput(BaseOutput):
|
||||
Output class for PIAPipeline.
|
||||
|
||||
Args:
|
||||
frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
|
||||
frames (`torch.Tensor`, `np.ndarray`, or List[PIL.Image.Image]):
|
||||
Nested list of length `batch_size` with denoised PIL image sequences of length `num_frames`,
|
||||
NumPy array of shape `(batch_size, num_frames, channels, height, width,
|
||||
Torch tensor of shape `(batch_size, num_frames, channels, height, width)`.
|
||||
"""
|
||||
|
||||
frames: Union[torch.Tensor, np.ndarray, List[List[PIL.Image.Image]]]
|
||||
frames: Union[torch.Tensor, np.ndarray, PIL.Image.Image]
|
||||
|
||||
|
||||
class PIAPipeline(
|
||||
DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin, FromSingleFileMixin
|
||||
):
|
||||
class PIAPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin):
|
||||
r"""
|
||||
Pipeline for text-to-video generation.
|
||||
|
||||
@@ -687,35 +685,6 @@ class PIAPipeline(
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
|
||||
def prepare_ip_adapter_image_embeds(self, ip_adapter_image, device, num_images_per_prompt):
|
||||
if not isinstance(ip_adapter_image, list):
|
||||
ip_adapter_image = [ip_adapter_image]
|
||||
|
||||
if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
|
||||
raise ValueError(
|
||||
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
|
||||
)
|
||||
|
||||
image_embeds = []
|
||||
for single_ip_adapter_image, image_proj_layer in zip(
|
||||
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
|
||||
):
|
||||
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
|
||||
single_image_embeds, single_negative_image_embeds = self.encode_image(
|
||||
single_ip_adapter_image, device, 1, output_hidden_state
|
||||
)
|
||||
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
|
||||
single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0)
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
|
||||
single_image_embeds = single_image_embeds.to(device)
|
||||
|
||||
image_embeds.append(single_image_embeds)
|
||||
|
||||
return image_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents
|
||||
def prepare_latents(
|
||||
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
|
||||
@@ -937,8 +906,6 @@ class PIAPipeline(
|
||||
|
||||
t_start = max(num_inference_steps - init_timestep, 0)
|
||||
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
||||
if hasattr(self.scheduler, "set_begin_index"):
|
||||
self.scheduler.set_begin_index(t_start * self.scheduler.order)
|
||||
|
||||
return timesteps, num_inference_steps - t_start
|
||||
|
||||
@@ -1138,9 +1105,12 @@ class PIAPipeline(
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
if ip_adapter_image is not None:
|
||||
image_embeds = self.prepare_ip_adapter_image_embeds(
|
||||
ip_adapter_image, device, batch_size * num_videos_per_prompt
|
||||
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
|
||||
image_embeds, negative_image_embeds = self.encode_image(
|
||||
ip_adapter_image, device, num_videos_per_prompt, output_hidden_state
|
||||
)
|
||||
if self.do_classifier_free_guidance:
|
||||
image_embeds = torch.cat([negative_image_embeds, image_embeds])
|
||||
|
||||
# 4. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
|
||||
@@ -467,8 +467,6 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
|
||||
|
||||
t_start = max(num_inference_steps - init_timestep, 0)
|
||||
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
||||
if hasattr(self.scheduler, "set_begin_index"):
|
||||
self.scheduler.set_begin_index(t_start * self.scheduler.order)
|
||||
|
||||
return timesteps, num_inference_steps - t_start
|
||||
|
||||
|
||||
@@ -659,8 +659,6 @@ class StableDiffusionImg2ImgPipeline(
|
||||
|
||||
t_start = max(num_inference_steps - init_timestep, 0)
|
||||
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
||||
if hasattr(self.scheduler, "set_begin_index"):
|
||||
self.scheduler.set_begin_index(t_start * self.scheduler.order)
|
||||
|
||||
return timesteps, num_inference_steps - t_start
|
||||
|
||||
|
||||
@@ -859,8 +859,6 @@ class StableDiffusionInpaintPipeline(
|
||||
|
||||
t_start = max(num_inference_steps - init_timestep, 0)
|
||||
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
||||
if hasattr(self.scheduler, "set_begin_index"):
|
||||
self.scheduler.set_begin_index(t_start * self.scheduler.order)
|
||||
|
||||
return timesteps, num_inference_steps - t_start
|
||||
|
||||
|
||||
-2
@@ -754,8 +754,6 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM
|
||||
|
||||
t_start = max(num_inference_steps - init_timestep, 0)
|
||||
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
||||
if hasattr(self.scheduler, "set_begin_index"):
|
||||
self.scheduler.set_begin_index(t_start * self.scheduler.order)
|
||||
|
||||
return timesteps, num_inference_steps - t_start
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@ from dataclasses import dataclass
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
|
||||
from ...utils import (
|
||||
@@ -13,13 +12,12 @@ from ...utils import (
|
||||
@dataclass
|
||||
class TextToVideoSDPipelineOutput(BaseOutput):
|
||||
"""
|
||||
Output class for text-to-video pipelines.
|
||||
Output class for text-to-video pipelines.
|
||||
|
||||
Args:
|
||||
frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
|
||||
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing denoised
|
||||
PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
|
||||
`(batch_size, num_frames, channels, height, width)`
|
||||
Args:
|
||||
frames (`List[np.ndarray]` or `torch.FloatTensor`)
|
||||
List of denoised frames (essentially images) as NumPy arrays of shape `(height, width, num_channels)` or as
|
||||
a `torch` tensor. The length of the list denotes the video length (the number of frames).
|
||||
"""
|
||||
|
||||
frames: Union[torch.Tensor, np.ndarray, List[List[PIL.Image.Image]]]
|
||||
frames: Union[List[np.ndarray], torch.FloatTensor]
|
||||
|
||||
-2
@@ -554,8 +554,6 @@ class VideoToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
|
||||
|
||||
t_start = max(num_inference_steps - init_timestep, 0)
|
||||
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
||||
if hasattr(self.scheduler, "set_begin_index"):
|
||||
self.scheduler.set_begin_index(t_start * self.scheduler.order)
|
||||
|
||||
return timesteps, num_inference_steps - t_start
|
||||
|
||||
|
||||
@@ -98,9 +98,15 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.custom_timesteps = False
|
||||
self.is_scale_input_called = False
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
|
||||
indices = (schedule_timesteps == timestep).nonzero()
|
||||
return indices.item()
|
||||
|
||||
@property
|
||||
def step_index(self):
|
||||
"""
|
||||
@@ -108,24 +114,6 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
return self._step_index
|
||||
|
||||
@property
|
||||
def begin_index(self):
|
||||
"""
|
||||
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
||||
"""
|
||||
return self._begin_index
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
|
||||
def set_begin_index(self, begin_index: int = 0):
|
||||
"""
|
||||
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
||||
|
||||
Args:
|
||||
begin_index (`int`):
|
||||
The begin index for the scheduler.
|
||||
"""
|
||||
self._begin_index = begin_index
|
||||
|
||||
def scale_model_input(
|
||||
self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
|
||||
) -> torch.FloatTensor:
|
||||
@@ -243,7 +231,6 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.timesteps = torch.from_numpy(timesteps).to(device=device)
|
||||
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
# Modified _convert_to_karras implementation that takes in ramp as argument
|
||||
@@ -293,29 +280,23 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
|
||||
c_out = (sigma - sigma_min) * sigma_data / (sigma**2 + sigma_data**2) ** 0.5
|
||||
return c_skip, c_out
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
|
||||
def _init_step_index(self, timestep):
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
|
||||
indices = (schedule_timesteps == timestep).nonzero()
|
||||
index_candidates = (self.timesteps == timestep).nonzero()
|
||||
|
||||
# The sigma index that is taken for the **very** first `step`
|
||||
# is always the second index (or the last index if there is only 1)
|
||||
# This way we can ensure we don't accidentally skip a sigma in
|
||||
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
||||
pos = 1 if len(indices) > 1 else 0
|
||||
|
||||
return indices[pos].item()
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
|
||||
def _init_step_index(self, timestep):
|
||||
if self.begin_index is None:
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
self._step_index = self.index_for_timestep(timestep)
|
||||
if len(index_candidates) > 1:
|
||||
step_index = index_candidates[1]
|
||||
else:
|
||||
self._step_index = self._begin_index
|
||||
step_index = index_candidates[0]
|
||||
|
||||
self._step_index = step_index.item()
|
||||
|
||||
def step(
|
||||
self,
|
||||
@@ -431,11 +412,7 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
|
||||
schedule_timesteps = self.timesteps.to(original_samples.device)
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
|
||||
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
|
||||
if self.begin_index is None:
|
||||
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
||||
else:
|
||||
step_indices = [self.begin_index] * timesteps.shape[0]
|
||||
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
||||
|
||||
sigma = sigmas[step_indices].flatten()
|
||||
while len(sigma.shape) < len(original_samples.shape):
|
||||
|
||||
@@ -187,7 +187,6 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.model_outputs = [None] * solver_order
|
||||
self.lower_order_nums = 0
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
@property
|
||||
@@ -197,24 +196,6 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
return self._step_index
|
||||
|
||||
@property
|
||||
def begin_index(self):
|
||||
"""
|
||||
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
||||
"""
|
||||
return self._begin_index
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
|
||||
def set_begin_index(self, begin_index: int = 0):
|
||||
"""
|
||||
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
||||
|
||||
Args:
|
||||
begin_index (`int`):
|
||||
The begin index for the scheduler.
|
||||
"""
|
||||
self._begin_index = begin_index
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||
@@ -274,7 +255,6 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# add an index counter for schedulers that allow duplicated timesteps
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
||||
@@ -640,12 +620,11 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
else:
|
||||
raise NotImplementedError("only support log-rho multistep deis now")
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
def _init_step_index(self, timestep):
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
|
||||
index_candidates = (schedule_timesteps == timestep).nonzero()
|
||||
index_candidates = (self.timesteps == timestep).nonzero()
|
||||
|
||||
if len(index_candidates) == 0:
|
||||
step_index = len(self.timesteps) - 1
|
||||
@@ -658,20 +637,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
else:
|
||||
step_index = index_candidates[0].item()
|
||||
|
||||
return step_index
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
|
||||
def _init_step_index(self, timestep):
|
||||
"""
|
||||
Initialize the step_index counter for the scheduler.
|
||||
"""
|
||||
|
||||
if self.begin_index is None:
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
self._step_index = self.index_for_timestep(timestep)
|
||||
else:
|
||||
self._step_index = self._begin_index
|
||||
self._step_index = step_index
|
||||
|
||||
def step(
|
||||
self,
|
||||
@@ -770,11 +736,16 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
schedule_timesteps = self.timesteps.to(original_samples.device)
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
|
||||
# begin_index is None when the scheduler is used for training
|
||||
if self.begin_index is None:
|
||||
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
||||
else:
|
||||
step_indices = [self.begin_index] * timesteps.shape[0]
|
||||
step_indices = []
|
||||
for timestep in timesteps:
|
||||
index_candidates = (schedule_timesteps == timestep).nonzero()
|
||||
if len(index_candidates) == 0:
|
||||
step_index = len(schedule_timesteps) - 1
|
||||
elif len(index_candidates) > 1:
|
||||
step_index = index_candidates[1].item()
|
||||
else:
|
||||
step_index = index_candidates[0].item()
|
||||
step_indices.append(step_index)
|
||||
|
||||
sigma = sigmas[step_indices].flatten()
|
||||
while len(sigma.shape) < len(original_samples.shape):
|
||||
|
||||
@@ -227,7 +227,6 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.model_outputs = [None] * solver_order
|
||||
self.lower_order_nums = 0
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
@property
|
||||
@@ -237,23 +236,6 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
return self._step_index
|
||||
|
||||
@property
|
||||
def begin_index(self):
|
||||
"""
|
||||
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
||||
"""
|
||||
return self._begin_index
|
||||
|
||||
def set_begin_index(self, begin_index: int = 0):
|
||||
"""
|
||||
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
||||
|
||||
Args:
|
||||
begin_index (`int`):
|
||||
The begin index for the scheduler.
|
||||
"""
|
||||
self._begin_index = begin_index
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None):
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||
@@ -329,7 +311,6 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# add an index counter for schedulers that allow duplicated timesteps
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
||||
@@ -811,11 +792,11 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
)
|
||||
return x_t
|
||||
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
def _init_step_index(self, timestep):
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
|
||||
index_candidates = (schedule_timesteps == timestep).nonzero()
|
||||
index_candidates = (self.timesteps == timestep).nonzero()
|
||||
|
||||
if len(index_candidates) == 0:
|
||||
step_index = len(self.timesteps) - 1
|
||||
@@ -828,19 +809,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
else:
|
||||
step_index = index_candidates[0].item()
|
||||
|
||||
return step_index
|
||||
|
||||
def _init_step_index(self, timestep):
|
||||
"""
|
||||
Initialize the step_index counter for the scheduler.
|
||||
"""
|
||||
|
||||
if self.begin_index is None:
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
self._step_index = self.index_for_timestep(timestep)
|
||||
else:
|
||||
self._step_index = self._begin_index
|
||||
self._step_index = step_index
|
||||
|
||||
def step(
|
||||
self,
|
||||
@@ -951,11 +920,16 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
schedule_timesteps = self.timesteps.to(original_samples.device)
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
|
||||
# begin_index is None when the scheduler is used for training
|
||||
if self.begin_index is None:
|
||||
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
||||
else:
|
||||
step_indices = [self.begin_index] * timesteps.shape[0]
|
||||
step_indices = []
|
||||
for timestep in timesteps:
|
||||
index_candidates = (schedule_timesteps == timestep).nonzero()
|
||||
if len(index_candidates) == 0:
|
||||
step_index = len(schedule_timesteps) - 1
|
||||
elif len(index_candidates) > 1:
|
||||
step_index = index_candidates[1].item()
|
||||
else:
|
||||
step_index = index_candidates[0].item()
|
||||
step_indices.append(step_index)
|
||||
|
||||
sigma = sigmas[step_indices].flatten()
|
||||
while len(sigma.shape) < len(original_samples.shape):
|
||||
|
||||
@@ -767,6 +767,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
)
|
||||
return x_t
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
|
||||
def _init_step_index(self, timestep):
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
@@ -878,6 +879,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
return sample
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise
|
||||
def add_noise(
|
||||
self,
|
||||
original_samples: torch.FloatTensor,
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -197,10 +198,9 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.noise_sampler = None
|
||||
self.noise_sampler_seed = noise_sampler_seed
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
|
||||
# Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
@@ -211,18 +211,31 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
|
||||
# is always the second index (or the last index if there is only 1)
|
||||
# This way we can ensure we don't accidentally skip a sigma in
|
||||
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
||||
pos = 1 if len(indices) > 1 else 0
|
||||
if len(self._index_counter) == 0:
|
||||
pos = 1 if len(indices) > 1 else 0
|
||||
else:
|
||||
timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep
|
||||
pos = self._index_counter[timestep_int]
|
||||
|
||||
return indices[pos].item()
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
|
||||
def _init_step_index(self, timestep):
|
||||
if self.begin_index is None:
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
self._step_index = self.index_for_timestep(timestep)
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
|
||||
index_candidates = (self.timesteps == timestep).nonzero()
|
||||
|
||||
# The sigma index that is taken for the **very** first `step`
|
||||
# is always the second index (or the last index if there is only 1)
|
||||
# This way we can ensure we don't accidentally skip a sigma in
|
||||
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
||||
if len(index_candidates) > 1:
|
||||
step_index = index_candidates[1]
|
||||
else:
|
||||
self._step_index = self._begin_index
|
||||
step_index = index_candidates[0]
|
||||
|
||||
self._step_index = step_index.item()
|
||||
|
||||
@property
|
||||
def init_noise_sigma(self):
|
||||
@@ -239,24 +252,6 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
return self._step_index
|
||||
|
||||
@property
|
||||
def begin_index(self):
|
||||
"""
|
||||
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
||||
"""
|
||||
return self._begin_index
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
|
||||
def set_begin_index(self, begin_index: int = 0):
|
||||
"""
|
||||
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
||||
|
||||
Args:
|
||||
begin_index (`int`):
|
||||
The begin index for the scheduler.
|
||||
"""
|
||||
self._begin_index = begin_index
|
||||
|
||||
def scale_model_input(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
@@ -353,10 +348,13 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.mid_point_sigma = None
|
||||
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
self.noise_sampler = None
|
||||
|
||||
# for exp beta schedules, such as the one for `pipeline_shap_e.py`
|
||||
# we need an index counter
|
||||
self._index_counter = defaultdict(int)
|
||||
|
||||
def _second_order_timesteps(self, sigmas, log_sigmas):
|
||||
def sigma_fn(_t):
|
||||
return np.exp(-_t)
|
||||
@@ -446,6 +444,10 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
|
||||
if self.step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
|
||||
# advance index counter by 1
|
||||
timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep
|
||||
self._index_counter[timestep_int] += 1
|
||||
|
||||
# Create a noise sampler if it hasn't been created yet
|
||||
if self.noise_sampler is None:
|
||||
min_sigma, max_sigma = self.sigmas[self.sigmas > 0].min(), self.sigmas.max()
|
||||
@@ -525,7 +527,7 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
return SchedulerOutput(prev_sample=prev_sample)
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
|
||||
# Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.add_noise
|
||||
def add_noise(
|
||||
self,
|
||||
original_samples: torch.FloatTensor,
|
||||
@@ -542,11 +544,7 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
|
||||
schedule_timesteps = self.timesteps.to(original_samples.device)
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
|
||||
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
|
||||
if self.begin_index is None:
|
||||
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
||||
else:
|
||||
step_indices = [self.begin_index] * timesteps.shape[0]
|
||||
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
||||
|
||||
sigma = sigmas[step_indices].flatten()
|
||||
while len(sigma.shape) < len(original_samples.shape):
|
||||
|
||||
@@ -151,7 +151,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
sample_max_value: float = 1.0,
|
||||
algorithm_type: str = "dpmsolver++",
|
||||
solver_type: str = "midpoint",
|
||||
lower_order_final: bool = True,
|
||||
lower_order_final: bool = False,
|
||||
use_karras_sigmas: Optional[bool] = False,
|
||||
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
|
||||
lambda_min_clipped: float = -float("inf"),
|
||||
@@ -210,7 +210,6 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.sample = None
|
||||
self.order_list = self.get_order_list(num_train_timesteps)
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
def get_order_list(self, num_inference_steps: int) -> List[int]:
|
||||
@@ -233,7 +232,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
orders = [1, 2, 3] * (steps // 3) + [1, 2]
|
||||
elif order == 2:
|
||||
if steps % 2 == 0:
|
||||
orders = [1, 2] * (steps // 2)
|
||||
orders = [1, 2] * (steps // 2 - 1) + [1, 1]
|
||||
else:
|
||||
orders = [1, 2] * (steps // 2) + [1]
|
||||
elif order == 1:
|
||||
@@ -254,24 +253,6 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
return self._step_index
|
||||
|
||||
@property
|
||||
def begin_index(self):
|
||||
"""
|
||||
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
||||
"""
|
||||
return self._begin_index
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
|
||||
def set_begin_index(self, begin_index: int = 0):
|
||||
"""
|
||||
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
||||
|
||||
Args:
|
||||
begin_index (`int`):
|
||||
The begin index for the scheduler.
|
||||
"""
|
||||
self._begin_index = begin_index
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||
@@ -320,7 +301,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
if not self.config.lower_order_final and num_inference_steps % self.config.solver_order != 0:
|
||||
logger.warn(
|
||||
"Changing scheduler {self.config} to have `lower_order_final` set to True to handle uneven amount of inference steps. Please make sure to always use an even number of `num_inference steps when using `lower_order_final=True`."
|
||||
"Changing scheduler {self.config} to have `lower_order_final` set to True to handle uneven amount of inference steps. Please make sure to always use an even number of `num_inference steps when using `lower_order_final=False`."
|
||||
)
|
||||
self.register_to_config(lower_order_final=True)
|
||||
|
||||
@@ -334,7 +315,6 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# add an index counter for schedulers that allow duplicated timesteps
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
||||
@@ -833,12 +813,11 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
else:
|
||||
raise ValueError(f"Order must be 1, 2, 3, got {order}")
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
def _init_step_index(self, timestep):
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
|
||||
index_candidates = (schedule_timesteps == timestep).nonzero()
|
||||
index_candidates = (self.timesteps == timestep).nonzero()
|
||||
|
||||
if len(index_candidates) == 0:
|
||||
step_index = len(self.timesteps) - 1
|
||||
@@ -851,20 +830,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
else:
|
||||
step_index = index_candidates[0].item()
|
||||
|
||||
return step_index
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
|
||||
def _init_step_index(self, timestep):
|
||||
"""
|
||||
Initialize the step_index counter for the scheduler.
|
||||
"""
|
||||
|
||||
if self.begin_index is None:
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
self._step_index = self.index_for_timestep(timestep)
|
||||
else:
|
||||
self._step_index = self._begin_index
|
||||
self._step_index = step_index
|
||||
|
||||
def step(
|
||||
self,
|
||||
@@ -959,11 +925,16 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
schedule_timesteps = self.timesteps.to(original_samples.device)
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
|
||||
# begin_index is None when the scheduler is used for training
|
||||
if self.begin_index is None:
|
||||
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
||||
else:
|
||||
step_indices = [self.begin_index] * timesteps.shape[0]
|
||||
step_indices = []
|
||||
for timestep in timesteps:
|
||||
index_candidates = (schedule_timesteps == timestep).nonzero()
|
||||
if len(index_candidates) == 0:
|
||||
step_index = len(schedule_timesteps) - 1
|
||||
elif len(index_candidates) > 1:
|
||||
step_index = index_candidates[1].item()
|
||||
else:
|
||||
step_index = index_candidates[0].item()
|
||||
step_indices.append(step_index)
|
||||
|
||||
sigma = sigmas[step_indices].flatten()
|
||||
while len(sigma.shape) < len(original_samples.shape):
|
||||
|
||||
@@ -216,7 +216,6 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.is_scale_input_called = False
|
||||
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
@property
|
||||
@@ -234,24 +233,6 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
return self._step_index
|
||||
|
||||
@property
|
||||
def begin_index(self):
|
||||
"""
|
||||
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
||||
"""
|
||||
return self._begin_index
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
|
||||
def set_begin_index(self, begin_index: int = 0):
|
||||
"""
|
||||
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
||||
|
||||
Args:
|
||||
begin_index (`int`):
|
||||
The begin index for the scheduler.
|
||||
"""
|
||||
self._begin_index = begin_index
|
||||
|
||||
def scale_model_input(
|
||||
self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
|
||||
) -> torch.FloatTensor:
|
||||
@@ -319,32 +300,25 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
self.timesteps = torch.from_numpy(timesteps).to(device=device)
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
|
||||
def _init_step_index(self, timestep):
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
|
||||
indices = (schedule_timesteps == timestep).nonzero()
|
||||
index_candidates = (self.timesteps == timestep).nonzero()
|
||||
|
||||
# The sigma index that is taken for the **very** first `step`
|
||||
# is always the second index (or the last index if there is only 1)
|
||||
# This way we can ensure we don't accidentally skip a sigma in
|
||||
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
||||
pos = 1 if len(indices) > 1 else 0
|
||||
|
||||
return indices[pos].item()
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
|
||||
def _init_step_index(self, timestep):
|
||||
if self.begin_index is None:
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
self._step_index = self.index_for_timestep(timestep)
|
||||
if len(index_candidates) > 1:
|
||||
step_index = index_candidates[1]
|
||||
else:
|
||||
self._step_index = self._begin_index
|
||||
step_index = index_candidates[0]
|
||||
|
||||
self._step_index = step_index.item()
|
||||
|
||||
def step(
|
||||
self,
|
||||
@@ -466,11 +440,7 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
schedule_timesteps = self.timesteps.to(original_samples.device)
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
|
||||
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
|
||||
if self.begin_index is None:
|
||||
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
||||
else:
|
||||
step_indices = [self.begin_index] * timesteps.shape[0]
|
||||
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
||||
|
||||
sigma = sigmas[step_indices].flatten()
|
||||
while len(sigma.shape) < len(original_samples.shape):
|
||||
|
||||
@@ -237,7 +237,6 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.use_karras_sigmas = use_karras_sigmas
|
||||
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
@property
|
||||
@@ -256,24 +255,6 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
return self._step_index
|
||||
|
||||
@property
|
||||
def begin_index(self):
|
||||
"""
|
||||
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
||||
"""
|
||||
return self._begin_index
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
|
||||
def set_begin_index(self, begin_index: int = 0):
|
||||
"""
|
||||
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
||||
|
||||
Args:
|
||||
begin_index (`int`):
|
||||
The begin index for the scheduler.
|
||||
"""
|
||||
self._begin_index = begin_index
|
||||
|
||||
def scale_model_input(
|
||||
self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
|
||||
) -> torch.FloatTensor:
|
||||
@@ -361,7 +342,6 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
def _sigma_to_t(self, sigma, log_sigmas):
|
||||
@@ -413,27 +393,22 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
||||
return sigmas
|
||||
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
def _init_step_index(self, timestep):
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
|
||||
indices = (schedule_timesteps == timestep).nonzero()
|
||||
index_candidates = (self.timesteps == timestep).nonzero()
|
||||
|
||||
# The sigma index that is taken for the **very** first `step`
|
||||
# is always the second index (or the last index if there is only 1)
|
||||
# This way we can ensure we don't accidentally skip a sigma in
|
||||
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
||||
pos = 1 if len(indices) > 1 else 0
|
||||
|
||||
return indices[pos].item()
|
||||
|
||||
def _init_step_index(self, timestep):
|
||||
if self.begin_index is None:
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
self._step_index = self.index_for_timestep(timestep)
|
||||
if len(index_candidates) > 1:
|
||||
step_index = index_candidates[1]
|
||||
else:
|
||||
self._step_index = self._begin_index
|
||||
step_index = index_candidates[0]
|
||||
|
||||
self._step_index = step_index.item()
|
||||
|
||||
def step(
|
||||
self,
|
||||
@@ -563,11 +538,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
schedule_timesteps = self.timesteps.to(original_samples.device)
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
|
||||
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
|
||||
if self.begin_index is None:
|
||||
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
||||
else:
|
||||
step_indices = [self.begin_index] * timesteps.shape[0]
|
||||
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
||||
|
||||
sigma = sigmas[step_indices].flatten()
|
||||
while len(sigma.shape) < len(original_samples.shape):
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -147,10 +148,8 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.use_karras_sigmas = use_karras_sigmas
|
||||
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
@@ -161,7 +160,11 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
# is always the second index (or the last index if there is only 1)
|
||||
# This way we can ensure we don't accidentally skip a sigma in
|
||||
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
||||
pos = 1 if len(indices) > 1 else 0
|
||||
if len(self._index_counter) == 0:
|
||||
pos = 1 if len(indices) > 1 else 0
|
||||
else:
|
||||
timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep
|
||||
pos = self._index_counter[timestep_int]
|
||||
|
||||
return indices[pos].item()
|
||||
|
||||
@@ -180,24 +183,6 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
return self._step_index
|
||||
|
||||
@property
|
||||
def begin_index(self):
|
||||
"""
|
||||
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
||||
"""
|
||||
return self._begin_index
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
|
||||
def set_begin_index(self, begin_index: int = 0):
|
||||
"""
|
||||
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
||||
|
||||
Args:
|
||||
begin_index (`int`):
|
||||
The begin index for the scheduler.
|
||||
"""
|
||||
self._begin_index = begin_index
|
||||
|
||||
def scale_model_input(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
@@ -285,9 +270,13 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.dt = None
|
||||
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
# (YiYi Notes: keep this for now since we are keeping add_noise function which use index_for_timestep)
|
||||
# for exp beta schedules, such as the one for `pipeline_shap_e.py`
|
||||
# we need an index counter
|
||||
self._index_counter = defaultdict(int)
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
|
||||
def _sigma_to_t(self, sigma, log_sigmas):
|
||||
# get log sigma
|
||||
@@ -344,12 +333,21 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
|
||||
def _init_step_index(self, timestep):
|
||||
if self.begin_index is None:
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
self._step_index = self.index_for_timestep(timestep)
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
|
||||
index_candidates = (self.timesteps == timestep).nonzero()
|
||||
|
||||
# The sigma index that is taken for the **very** first `step`
|
||||
# is always the second index (or the last index if there is only 1)
|
||||
# This way we can ensure we don't accidentally skip a sigma in
|
||||
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
||||
if len(index_candidates) > 1:
|
||||
step_index = index_candidates[1]
|
||||
else:
|
||||
self._step_index = self._begin_index
|
||||
step_index = index_candidates[0]
|
||||
|
||||
self._step_index = step_index.item()
|
||||
|
||||
def step(
|
||||
self,
|
||||
@@ -380,6 +378,11 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
if self.step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
|
||||
# (YiYi notes: keep this for now since we are keeping the add_noise method)
|
||||
# advance index counter by 1
|
||||
timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep
|
||||
self._index_counter[timestep_int] += 1
|
||||
|
||||
if self.state_in_first_order:
|
||||
sigma = self.sigmas[self.step_index]
|
||||
sigma_next = self.sigmas[self.step_index + 1]
|
||||
@@ -450,7 +453,6 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
return SchedulerOutput(prev_sample=prev_sample)
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
|
||||
def add_noise(
|
||||
self,
|
||||
original_samples: torch.FloatTensor,
|
||||
@@ -467,11 +469,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
schedule_timesteps = self.timesteps.to(original_samples.device)
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
|
||||
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
|
||||
if self.begin_index is None:
|
||||
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
||||
else:
|
||||
step_indices = [self.begin_index] * timesteps.shape[0]
|
||||
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
||||
|
||||
sigma = sigmas[step_indices].flatten()
|
||||
while len(sigma.shape) < len(original_samples.shape):
|
||||
|
||||
@@ -56,7 +56,6 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
# running values
|
||||
self.ets = []
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
|
||||
@property
|
||||
def step_index(self):
|
||||
@@ -65,24 +64,6 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
return self._step_index
|
||||
|
||||
@property
|
||||
def begin_index(self):
|
||||
"""
|
||||
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
||||
"""
|
||||
return self._begin_index
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
|
||||
def set_begin_index(self, begin_index: int = 0):
|
||||
"""
|
||||
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
||||
|
||||
Args:
|
||||
begin_index (`int`):
|
||||
The begin index for the scheduler.
|
||||
"""
|
||||
self._begin_index = begin_index
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||
@@ -109,31 +90,24 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
self.ets = []
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
|
||||
def _init_step_index(self, timestep):
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
|
||||
indices = (schedule_timesteps == timestep).nonzero()
|
||||
index_candidates = (self.timesteps == timestep).nonzero()
|
||||
|
||||
# The sigma index that is taken for the **very** first `step`
|
||||
# is always the second index (or the last index if there is only 1)
|
||||
# This way we can ensure we don't accidentally skip a sigma in
|
||||
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
||||
pos = 1 if len(indices) > 1 else 0
|
||||
|
||||
return indices[pos].item()
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
|
||||
def _init_step_index(self, timestep):
|
||||
if self.begin_index is None:
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
self._step_index = self.index_for_timestep(timestep)
|
||||
if len(index_candidates) > 1:
|
||||
step_index = index_candidates[1]
|
||||
else:
|
||||
self._step_index = self._begin_index
|
||||
step_index = index_candidates[0]
|
||||
|
||||
self._step_index = step_index.item()
|
||||
|
||||
def step(
|
||||
self,
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -139,9 +140,27 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
# set all values
|
||||
self.set_timesteps(num_train_timesteps, None, num_train_timesteps)
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
|
||||
indices = (schedule_timesteps == timestep).nonzero()
|
||||
|
||||
# The sigma index that is taken for the **very** first `step`
|
||||
# is always the second index (or the last index if there is only 1)
|
||||
# This way we can ensure we don't accidentally skip a sigma in
|
||||
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
||||
if len(self._index_counter) == 0:
|
||||
pos = 1 if len(indices) > 1 else 0
|
||||
else:
|
||||
timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep
|
||||
pos = self._index_counter[timestep_int]
|
||||
|
||||
return indices[pos].item()
|
||||
|
||||
@property
|
||||
def init_noise_sigma(self):
|
||||
# standard deviation of the initial noise distribution
|
||||
@@ -157,24 +176,6 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
return self._step_index
|
||||
|
||||
@property
|
||||
def begin_index(self):
|
||||
"""
|
||||
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
||||
"""
|
||||
return self._begin_index
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
|
||||
def set_begin_index(self, begin_index: int = 0):
|
||||
"""
|
||||
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
||||
|
||||
Args:
|
||||
begin_index (`int`):
|
||||
The begin index for the scheduler.
|
||||
"""
|
||||
self._begin_index = begin_index
|
||||
|
||||
def scale_model_input(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
@@ -294,8 +295,11 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
self.sample = None
|
||||
|
||||
# for exp beta schedules, such as the one for `pipeline_shap_e.py`
|
||||
# we need an index counter
|
||||
self._index_counter = defaultdict(int)
|
||||
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
|
||||
@@ -352,29 +356,23 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
def state_in_first_order(self):
|
||||
return self.sample is None
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
|
||||
def _init_step_index(self, timestep):
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
|
||||
indices = (schedule_timesteps == timestep).nonzero()
|
||||
index_candidates = (self.timesteps == timestep).nonzero()
|
||||
|
||||
# The sigma index that is taken for the **very** first `step`
|
||||
# is always the second index (or the last index if there is only 1)
|
||||
# This way we can ensure we don't accidentally skip a sigma in
|
||||
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
||||
pos = 1 if len(indices) > 1 else 0
|
||||
|
||||
return indices[pos].item()
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
|
||||
def _init_step_index(self, timestep):
|
||||
if self.begin_index is None:
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
self._step_index = self.index_for_timestep(timestep)
|
||||
if len(index_candidates) > 1:
|
||||
step_index = index_candidates[1]
|
||||
else:
|
||||
self._step_index = self._begin_index
|
||||
step_index = index_candidates[0]
|
||||
|
||||
self._step_index = step_index.item()
|
||||
|
||||
def step(
|
||||
self,
|
||||
@@ -408,6 +406,10 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
if self.step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
|
||||
# advance index counter by 1
|
||||
timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep
|
||||
self._index_counter[timestep_int] += 1
|
||||
|
||||
if self.state_in_first_order:
|
||||
sigma = self.sigmas[self.step_index]
|
||||
sigma_interpol = self.sigmas_interpol[self.step_index]
|
||||
@@ -476,7 +478,7 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
return SchedulerOutput(prev_sample=prev_sample)
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
|
||||
# Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.add_noise
|
||||
def add_noise(
|
||||
self,
|
||||
original_samples: torch.FloatTensor,
|
||||
@@ -493,11 +495,7 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
schedule_timesteps = self.timesteps.to(original_samples.device)
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
|
||||
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
|
||||
if self.begin_index is None:
|
||||
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
||||
else:
|
||||
step_indices = [self.begin_index] * timesteps.shape[0]
|
||||
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
||||
|
||||
sigma = sigmas[step_indices].flatten()
|
||||
while len(sigma.shape) < len(original_samples.shape):
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -139,9 +140,27 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.set_timesteps(num_train_timesteps, None, num_train_timesteps)
|
||||
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
|
||||
indices = (schedule_timesteps == timestep).nonzero()
|
||||
|
||||
# The sigma index that is taken for the **very** first `step`
|
||||
# is always the second index (or the last index if there is only 1)
|
||||
# This way we can ensure we don't accidentally skip a sigma in
|
||||
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
||||
if len(self._index_counter) == 0:
|
||||
pos = 1 if len(indices) > 1 else 0
|
||||
else:
|
||||
timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep
|
||||
pos = self._index_counter[timestep_int]
|
||||
|
||||
return indices[pos].item()
|
||||
|
||||
@property
|
||||
def init_noise_sigma(self):
|
||||
# standard deviation of the initial noise distribution
|
||||
@@ -157,24 +176,6 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
return self._step_index
|
||||
|
||||
@property
|
||||
def begin_index(self):
|
||||
"""
|
||||
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
||||
"""
|
||||
return self._begin_index
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
|
||||
def set_begin_index(self, begin_index: int = 0):
|
||||
"""
|
||||
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
||||
|
||||
Args:
|
||||
begin_index (`int`):
|
||||
The begin index for the scheduler.
|
||||
"""
|
||||
self._begin_index = begin_index
|
||||
|
||||
def scale_model_input(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
@@ -279,37 +280,34 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
self.sample = None
|
||||
|
||||
# for exp beta schedules, such as the one for `pipeline_shap_e.py`
|
||||
# we need an index counter
|
||||
self._index_counter = defaultdict(int)
|
||||
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
@property
|
||||
def state_in_first_order(self):
|
||||
return self.sample is None
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
|
||||
def _init_step_index(self, timestep):
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
|
||||
indices = (schedule_timesteps == timestep).nonzero()
|
||||
index_candidates = (self.timesteps == timestep).nonzero()
|
||||
|
||||
# The sigma index that is taken for the **very** first `step`
|
||||
# is always the second index (or the last index if there is only 1)
|
||||
# This way we can ensure we don't accidentally skip a sigma in
|
||||
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
||||
pos = 1 if len(indices) > 1 else 0
|
||||
|
||||
return indices[pos].item()
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
|
||||
def _init_step_index(self, timestep):
|
||||
if self.begin_index is None:
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
self._step_index = self.index_for_timestep(timestep)
|
||||
if len(index_candidates) > 1:
|
||||
step_index = index_candidates[1]
|
||||
else:
|
||||
self._step_index = self._begin_index
|
||||
step_index = index_candidates[0]
|
||||
|
||||
self._step_index = step_index.item()
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
|
||||
def _sigma_to_t(self, sigma, log_sigmas):
|
||||
@@ -390,6 +388,10 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
if self.step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
|
||||
# advance index counter by 1
|
||||
timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep
|
||||
self._index_counter[timestep_int] += 1
|
||||
|
||||
if self.state_in_first_order:
|
||||
sigma = self.sigmas[self.step_index]
|
||||
sigma_interpol = self.sigmas_interpol[self.step_index + 1]
|
||||
@@ -451,7 +453,7 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
return SchedulerOutput(prev_sample=prev_sample)
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
|
||||
# Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.add_noise
|
||||
def add_noise(
|
||||
self,
|
||||
original_samples: torch.FloatTensor,
|
||||
@@ -468,11 +470,7 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
schedule_timesteps = self.timesteps.to(original_samples.device)
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
|
||||
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
|
||||
if self.begin_index is None:
|
||||
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
||||
else:
|
||||
step_indices = [self.begin_index] * timesteps.shape[0]
|
||||
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
||||
|
||||
sigma = sigmas[step_indices].flatten()
|
||||
while len(sigma.shape) < len(original_samples.shape):
|
||||
|
||||
@@ -250,54 +250,29 @@ class LCMScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.custom_timesteps = False
|
||||
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
|
||||
def _init_step_index(self, timestep):
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
|
||||
indices = (schedule_timesteps == timestep).nonzero()
|
||||
index_candidates = (self.timesteps == timestep).nonzero()
|
||||
|
||||
# The sigma index that is taken for the **very** first `step`
|
||||
# is always the second index (or the last index if there is only 1)
|
||||
# This way we can ensure we don't accidentally skip a sigma in
|
||||
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
||||
pos = 1 if len(indices) > 1 else 0
|
||||
|
||||
return indices[pos].item()
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
|
||||
def _init_step_index(self, timestep):
|
||||
if self.begin_index is None:
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
self._step_index = self.index_for_timestep(timestep)
|
||||
if len(index_candidates) > 1:
|
||||
step_index = index_candidates[1]
|
||||
else:
|
||||
self._step_index = self._begin_index
|
||||
step_index = index_candidates[0]
|
||||
|
||||
self._step_index = step_index.item()
|
||||
|
||||
@property
|
||||
def step_index(self):
|
||||
return self._step_index
|
||||
|
||||
@property
|
||||
def begin_index(self):
|
||||
"""
|
||||
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
||||
"""
|
||||
return self._begin_index
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
|
||||
def set_begin_index(self, begin_index: int = 0):
|
||||
"""
|
||||
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
||||
|
||||
Args:
|
||||
begin_index (`int`):
|
||||
The begin index for the scheduler.
|
||||
"""
|
||||
self._begin_index = begin_index
|
||||
|
||||
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
|
||||
"""
|
||||
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
||||
@@ -487,7 +462,6 @@ class LCMScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.long)
|
||||
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
|
||||
def get_scalings_for_boundary_condition_discrete(self, timestep):
|
||||
self.sigma_data = 0.5 # Default: 0.5
|
||||
|
||||
@@ -168,7 +168,6 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.is_scale_input_called = False
|
||||
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
@property
|
||||
@@ -186,24 +185,6 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
return self._step_index
|
||||
|
||||
@property
|
||||
def begin_index(self):
|
||||
"""
|
||||
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
||||
"""
|
||||
return self._begin_index
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
|
||||
def set_begin_index(self, begin_index: int = 0):
|
||||
"""
|
||||
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
||||
|
||||
Args:
|
||||
begin_index (`int`):
|
||||
The begin index for the scheduler.
|
||||
"""
|
||||
self._begin_index = begin_index
|
||||
|
||||
def scale_model_input(
|
||||
self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
|
||||
) -> torch.FloatTensor:
|
||||
@@ -299,34 +280,27 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.sigmas = torch.from_numpy(sigmas).to(device=device)
|
||||
self.timesteps = torch.from_numpy(timesteps).to(device=device)
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
self.derivatives = []
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
|
||||
def _init_step_index(self, timestep):
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
|
||||
indices = (schedule_timesteps == timestep).nonzero()
|
||||
index_candidates = (self.timesteps == timestep).nonzero()
|
||||
|
||||
# The sigma index that is taken for the **very** first `step`
|
||||
# is always the second index (or the last index if there is only 1)
|
||||
# This way we can ensure we don't accidentally skip a sigma in
|
||||
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
||||
pos = 1 if len(indices) > 1 else 0
|
||||
|
||||
return indices[pos].item()
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
|
||||
def _init_step_index(self, timestep):
|
||||
if self.begin_index is None:
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
self._step_index = self.index_for_timestep(timestep)
|
||||
if len(index_candidates) > 1:
|
||||
step_index = index_candidates[1]
|
||||
else:
|
||||
self._step_index = self._begin_index
|
||||
step_index = index_candidates[0]
|
||||
|
||||
self._step_index = step_index.item()
|
||||
|
||||
# copied from diffusers.schedulers.scheduling_euler_discrete._sigma_to_t
|
||||
def _sigma_to_t(self, sigma, log_sigmas):
|
||||
@@ -460,11 +434,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
schedule_timesteps = self.timesteps.to(original_samples.device)
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
|
||||
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
|
||||
if self.begin_index is None:
|
||||
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
||||
else:
|
||||
step_indices = [self.begin_index] * timesteps.shape[0]
|
||||
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
||||
|
||||
sigma = sigmas[step_indices].flatten()
|
||||
while len(sigma.shape) < len(original_samples.shape):
|
||||
|
||||
@@ -212,7 +212,6 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.lower_order_nums = 0
|
||||
self.last_sample = None
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
@property
|
||||
@@ -222,24 +221,6 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
return self._step_index
|
||||
|
||||
@property
|
||||
def begin_index(self):
|
||||
"""
|
||||
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
||||
"""
|
||||
return self._begin_index
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
|
||||
def set_begin_index(self, begin_index: int = 0):
|
||||
"""
|
||||
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
||||
|
||||
Args:
|
||||
begin_index (`int`):
|
||||
The begin index for the scheduler.
|
||||
"""
|
||||
self._begin_index = begin_index
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None):
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||
@@ -302,7 +283,6 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# add an index counter for schedulers that allow duplicated timesteps
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
||||
@@ -945,12 +925,11 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
|
||||
x_t = x_t.to(x.dtype)
|
||||
return x_t
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
def _init_step_index(self, timestep):
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
|
||||
index_candidates = (schedule_timesteps == timestep).nonzero()
|
||||
index_candidates = (self.timesteps == timestep).nonzero()
|
||||
|
||||
if len(index_candidates) == 0:
|
||||
step_index = len(self.timesteps) - 1
|
||||
@@ -963,20 +942,7 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
|
||||
else:
|
||||
step_index = index_candidates[0].item()
|
||||
|
||||
return step_index
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
|
||||
def _init_step_index(self, timestep):
|
||||
"""
|
||||
Initialize the step_index counter for the scheduler.
|
||||
"""
|
||||
|
||||
if self.begin_index is None:
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
self._step_index = self.index_for_timestep(timestep)
|
||||
else:
|
||||
self._step_index = self._begin_index
|
||||
self._step_index = step_index
|
||||
|
||||
def step(
|
||||
self,
|
||||
|
||||
@@ -198,7 +198,6 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.solver_p = solver_p
|
||||
self.last_sample = None
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
@property
|
||||
@@ -208,24 +207,6 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
return self._step_index
|
||||
|
||||
@property
|
||||
def begin_index(self):
|
||||
"""
|
||||
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
||||
"""
|
||||
return self._begin_index
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
|
||||
def set_begin_index(self, begin_index: int = 0):
|
||||
"""
|
||||
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
||||
|
||||
Args:
|
||||
begin_index (`int`):
|
||||
The begin index for the scheduler.
|
||||
"""
|
||||
self._begin_index = begin_index
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||
@@ -288,7 +269,6 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# add an index counter for schedulers that allow duplicated timesteps
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
||||
@@ -718,12 +698,11 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
x_t = x_t.to(x.dtype)
|
||||
return x_t
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
def _init_step_index(self, timestep):
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
|
||||
index_candidates = (schedule_timesteps == timestep).nonzero()
|
||||
index_candidates = (self.timesteps == timestep).nonzero()
|
||||
|
||||
if len(index_candidates) == 0:
|
||||
step_index = len(self.timesteps) - 1
|
||||
@@ -736,20 +715,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
else:
|
||||
step_index = index_candidates[0].item()
|
||||
|
||||
return step_index
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
|
||||
def _init_step_index(self, timestep):
|
||||
"""
|
||||
Initialize the step_index counter for the scheduler.
|
||||
"""
|
||||
|
||||
if self.begin_index is None:
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
self._step_index = self.index_for_timestep(timestep)
|
||||
else:
|
||||
self._step_index = self._begin_index
|
||||
self._step_index = step_index
|
||||
|
||||
def step(
|
||||
self,
|
||||
@@ -864,11 +830,16 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
schedule_timesteps = self.timesteps.to(original_samples.device)
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
|
||||
# begin_index is None when the scheduler is used for training
|
||||
if self.begin_index is None:
|
||||
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
||||
else:
|
||||
step_indices = [self.begin_index] * timesteps.shape[0]
|
||||
step_indices = []
|
||||
for timestep in timesteps:
|
||||
index_candidates = (schedule_timesteps == timestep).nonzero()
|
||||
if len(index_candidates) == 0:
|
||||
step_index = len(schedule_timesteps) - 1
|
||||
elif len(index_candidates) > 1:
|
||||
step_index = index_candidates[1].item()
|
||||
else:
|
||||
step_index = index_candidates[0].item()
|
||||
step_indices.append(step_index)
|
||||
|
||||
sigma = sigmas[step_indices].flatten()
|
||||
while len(sigma.shape) < len(original_samples.shape):
|
||||
|
||||
@@ -854,8 +854,6 @@ def _is_torch_fp64_available(device):
|
||||
|
||||
import torch
|
||||
|
||||
device = torch.device(device)
|
||||
|
||||
try:
|
||||
x = torch.zeros((2, 2), dtype=torch.float64).to(device)
|
||||
_ = torch.mul(x, x)
|
||||
|
||||
@@ -25,7 +25,7 @@ from diffusers.utils.testing_utils import (
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
from .test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
|
||||
|
||||
class UNet1DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
@@ -30,7 +30,7 @@ from diffusers.utils.testing_utils import (
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
from .test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
+1
-1
@@ -48,7 +48,7 @@ from diffusers.utils.testing_utils import (
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
from .test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
+1
-1
@@ -23,7 +23,7 @@ from diffusers.utils import logging
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.testing_utils import enable_full_determinism, floats_tensor, skip_mps, torch_device
|
||||
|
||||
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
from .test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
+1
-1
@@ -30,7 +30,7 @@ from diffusers.utils.testing_utils import (
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
from .test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
+1
-1
@@ -28,7 +28,7 @@ from diffusers.utils.testing_utils import (
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
from .test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@@ -46,7 +46,7 @@ from diffusers.utils.testing_utils import (
|
||||
)
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
from .test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
+1
-1
@@ -4,7 +4,7 @@ from diffusers import FlaxAutoencoderKL
|
||||
from diffusers.utils import is_flax_available
|
||||
from diffusers.utils.testing_utils import require_flax
|
||||
|
||||
from ..test_modeling_common_flax import FlaxModelTesterMixin
|
||||
from .test_modeling_common_flax import FlaxModelTesterMixin
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
@@ -25,7 +25,7 @@ from diffusers.utils.testing_utils import (
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
from .test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
@@ -35,7 +35,6 @@ from diffusers.models.attention_processor import AttnProcessor, AttnProcessor2_0
|
||||
from diffusers.utils import load_image
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
torch_device,
|
||||
@@ -120,8 +119,7 @@ class IPAdapterSDIntegrationTests(IPAdapterNightlyTestsMixin):
|
||||
|
||||
expected_slice = np.array([0.80810547, 0.88183594, 0.9296875, 0.9189453, 0.9848633, 1.0, 0.97021484, 1.0, 1.0])
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice)
|
||||
assert max_diff < 5e-4
|
||||
assert np.allclose(image_slice, expected_slice, atol=1e-3)
|
||||
|
||||
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-plus_sd15.bin")
|
||||
|
||||
@@ -133,8 +131,7 @@ class IPAdapterSDIntegrationTests(IPAdapterNightlyTestsMixin):
|
||||
[0.30444336, 0.26513672, 0.22436523, 0.2758789, 0.25585938, 0.20751953, 0.25390625, 0.24633789, 0.21923828]
|
||||
)
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice)
|
||||
assert max_diff < 5e-4
|
||||
assert np.allclose(image_slice, expected_slice, atol=1e-3)
|
||||
|
||||
def test_image_to_image(self):
|
||||
image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder")
|
||||
@@ -152,8 +149,7 @@ class IPAdapterSDIntegrationTests(IPAdapterNightlyTestsMixin):
|
||||
[0.22167969, 0.21875, 0.21728516, 0.22607422, 0.21948242, 0.23925781, 0.22387695, 0.25268555, 0.2722168]
|
||||
)
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice)
|
||||
assert max_diff < 5e-4
|
||||
assert np.allclose(image_slice, expected_slice, atol=1e-3)
|
||||
|
||||
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-plus_sd15.bin")
|
||||
|
||||
@@ -165,8 +161,7 @@ class IPAdapterSDIntegrationTests(IPAdapterNightlyTestsMixin):
|
||||
[0.35913086, 0.265625, 0.26367188, 0.24658203, 0.19750977, 0.39990234, 0.15258789, 0.20336914, 0.5517578]
|
||||
)
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice)
|
||||
assert max_diff < 5e-4
|
||||
assert np.allclose(image_slice, expected_slice, atol=1e-3)
|
||||
|
||||
def test_inpainting(self):
|
||||
image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder")
|
||||
@@ -184,8 +179,7 @@ class IPAdapterSDIntegrationTests(IPAdapterNightlyTestsMixin):
|
||||
[0.27148438, 0.24047852, 0.22167969, 0.23217773, 0.21118164, 0.21142578, 0.21875, 0.20751953, 0.20019531]
|
||||
)
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice)
|
||||
assert max_diff < 5e-4
|
||||
assert np.allclose(image_slice, expected_slice, atol=1e-3)
|
||||
|
||||
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-plus_sd15.bin")
|
||||
|
||||
@@ -193,8 +187,11 @@ class IPAdapterSDIntegrationTests(IPAdapterNightlyTestsMixin):
|
||||
images = pipeline(**inputs).images
|
||||
image_slice = images[0, :3, :3, -1].flatten()
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice)
|
||||
assert max_diff < 5e-4
|
||||
expected_slice = np.array(
|
||||
[0.27294922, 0.24023438, 0.21948242, 0.23242188, 0.20825195, 0.2055664, 0.21679688, 0.20336914, 0.19360352]
|
||||
)
|
||||
|
||||
assert np.allclose(image_slice, expected_slice, atol=1e-3)
|
||||
|
||||
def test_text_to_image_model_cpu_offload(self):
|
||||
image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder")
|
||||
@@ -236,10 +233,11 @@ class IPAdapterSDIntegrationTests(IPAdapterNightlyTestsMixin):
|
||||
images = pipeline(**inputs).images
|
||||
image_slice = images[0, :3, :3, -1].flatten()
|
||||
|
||||
expected_slice = np.array([0.1958, 0.1475, 0.1396, 0.2412, 0.1658, 0.1533, 0.3997, 0.4055, 0.4128])
|
||||
expected_slice = np.array(
|
||||
[0.18115234, 0.13500977, 0.13427734, 0.24194336, 0.17138672, 0.16625977, 0.4260254, 0.43359375, 0.4416504]
|
||||
)
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice)
|
||||
assert max_diff < 5e-4
|
||||
assert np.allclose(image_slice, expected_slice, atol=1e-3)
|
||||
|
||||
def test_unload(self):
|
||||
image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder")
|
||||
@@ -279,9 +277,7 @@ class IPAdapterSDIntegrationTests(IPAdapterNightlyTestsMixin):
|
||||
expected_slice = np.array(
|
||||
[0.5234375, 0.53515625, 0.5629883, 0.57128906, 0.59521484, 0.62109375, 0.57910156, 0.6201172, 0.6508789]
|
||||
)
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice)
|
||||
assert max_diff < 5e-4
|
||||
assert np.allclose(image_slice, expected_slice, atol=1e-3)
|
||||
|
||||
|
||||
@slow
|
||||
@@ -318,8 +314,7 @@ class IPAdapterSDXLIntegrationTests(IPAdapterNightlyTestsMixin):
|
||||
]
|
||||
)
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice)
|
||||
assert max_diff < 5e-4
|
||||
assert np.allclose(image_slice, expected_slice, atol=1e-3)
|
||||
|
||||
image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder")
|
||||
|
||||
@@ -344,8 +339,7 @@ class IPAdapterSDXLIntegrationTests(IPAdapterNightlyTestsMixin):
|
||||
[0.0576596, 0.05600825, 0.04479006, 0.05288461, 0.05461192, 0.05137569, 0.04867965, 0.05301541, 0.04939842]
|
||||
)
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice)
|
||||
assert max_diff < 5e-4
|
||||
assert np.allclose(image_slice, expected_slice, atol=1e-3)
|
||||
|
||||
def test_image_to_image_sdxl(self):
|
||||
image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="sdxl_models/image_encoder")
|
||||
@@ -438,8 +432,7 @@ class IPAdapterSDXLIntegrationTests(IPAdapterNightlyTestsMixin):
|
||||
[0.14181179, 0.1493012, 0.14283323, 0.14602411, 0.14915377, 0.15015268, 0.14725655, 0.15009224, 0.15164584]
|
||||
)
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice)
|
||||
assert max_diff < 5e-4
|
||||
assert np.allclose(image_slice, expected_slice, atol=1e-3)
|
||||
|
||||
image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder")
|
||||
feature_extractor = self.get_image_processor("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")
|
||||
@@ -464,5 +457,4 @@ class IPAdapterSDXLIntegrationTests(IPAdapterNightlyTestsMixin):
|
||||
|
||||
expected_slice = np.array([0.1398, 0.1476, 0.1407, 0.1442, 0.1470, 0.1480, 0.1449, 0.1481, 0.1494])
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice)
|
||||
assert max_diff < 5e-4
|
||||
assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4)
|
||||
|
||||
@@ -38,7 +38,7 @@ from diffusers.utils.testing_utils import (
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ..models.autoencoders.test_models_vae import (
|
||||
from ..models.test_models_vae import (
|
||||
get_asym_autoencoder_kl_config,
|
||||
get_autoencoder_kl_config,
|
||||
get_autoencoder_tiny_config,
|
||||
|
||||
Reference in New Issue
Block a user