Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| abf4a9271e | |||
| 0e1fb0d916 | |||
| f77b7a0f27 | |||
| eae1371983 |
@@ -20,7 +20,7 @@ jobs:
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.8"
|
||||
python-version: "3.7"
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
|
||||
@@ -20,7 +20,7 @@ jobs:
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.8"
|
||||
python-version: "3.7"
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
@@ -38,7 +38,7 @@ jobs:
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.8"
|
||||
python-version: "3.7"
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
|
||||
@@ -17,7 +17,7 @@ jobs:
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v1
|
||||
with:
|
||||
python-version: 3.8
|
||||
python-version: 3.7
|
||||
|
||||
- name: Install requirements
|
||||
run: |
|
||||
|
||||
@@ -216,8 +216,6 @@
|
||||
title: AudioLDM 2
|
||||
- local: api/pipelines/auto_pipeline
|
||||
title: AutoPipeline
|
||||
- local: api/pipelines/blip_diffusion
|
||||
title: BLIP Diffusion
|
||||
- local: api/pipelines/consistency_models
|
||||
title: Consistency Models
|
||||
- local: api/pipelines/controlnet
|
||||
|
||||
@@ -1,29 +0,0 @@
|
||||
# Blip Diffusion
|
||||
|
||||
Blip Diffusion was proposed in [BLIP-Diffusion: Pre-trained Subject Representation for Controllable Text-to-Image Generation and Editing](https://arxiv.org/abs/2305.14720). It enables zero-shot subject-driven generation and control-guided zero-shot generation.
|
||||
|
||||
|
||||
The abstract from the paper is:
|
||||
|
||||
*Subject-driven text-to-image generation models create novel renditions of an input subject based on text prompts. Existing models suffer from lengthy fine-tuning and difficulties preserving the subject fidelity. To overcome these limitations, we introduce BLIP-Diffusion, a new subject-driven image generation model that supports multimodal control which consumes inputs of subject images and text prompts. Unlike other subject-driven generation models, BLIP-Diffusion introduces a new multimodal encoder which is pre-trained to provide subject representation. We first pre-train the multimodal encoder following BLIP-2 to produce visual representation aligned with the text. Then we design a subject representation learning task which enables a diffusion model to leverage such visual representation and generates new subject renditions. Compared with previous methods such as DreamBooth, our model enables zero-shot subject-driven generation, and efficient fine-tuning for customized subject with up to 20x speedup. We also demonstrate that BLIP-Diffusion can be flexibly combined with existing techniques such as ControlNet and prompt-to-prompt to enable novel subject-driven generation and editing applications.*
|
||||
|
||||
The original codebase can be found at [salesforce/LAVIS](https://github.com/salesforce/LAVIS/tree/main/projects/blip-diffusion). You can find the official BLIP Diffusion checkpoints under the [hf.co/SalesForce](https://hf.co/SalesForce) organization.
|
||||
|
||||
`BlipDiffusionPipeline` and `BlipDiffusionControlNetPipeline` were contributed by [`ayushtues`](https://github.com/ayushtues/).
|
||||
|
||||
<Tip>
|
||||
|
||||
Make sure to check out the Schedulers [guide](/using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](/using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
|
||||
</Tip>
|
||||
|
||||
|
||||
## BlipDiffusionPipeline
|
||||
[[autodoc]] BlipDiffusionPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## BlipDiffusionControlNetPipeline
|
||||
[[autodoc]] BlipDiffusionControlNetPipeline
|
||||
- all
|
||||
- __call__
|
||||
@@ -14,7 +14,7 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
Install 🤗 Diffusers for whichever deep learning library you're working with.
|
||||
|
||||
🤗 Diffusers is tested on Python 3.8+, PyTorch 1.7.0+ and Flax. Follow the installation instructions below for the deep learning library you are using:
|
||||
🤗 Diffusers is tested on Python 3.7+, PyTorch 1.7.0+ and Flax. Follow the installation instructions below for the deep learning library you are using:
|
||||
|
||||
- [PyTorch](https://pytorch.org/get-started/locally/) installation instructions.
|
||||
- [Flax](https://flax.readthedocs.io/en/latest/) installation instructions.
|
||||
@@ -106,7 +106,7 @@ pip install -e ".[flax]"
|
||||
|
||||
These commands will link the folder you cloned the repository to and your Python library paths.
|
||||
Python will now look inside the folder you cloned to in addition to the normal library paths.
|
||||
For example, if your Python packages are typically installed in `~/anaconda3/envs/main/lib/python3.8/site-packages/`, Python will also search the `~/diffusers/` folder you cloned to.
|
||||
For example, if your Python packages are typically installed in `~/anaconda3/envs/main/lib/python3.7/site-packages/`, Python will also search the `~/diffusers/` folder you cloned to.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
사용하시는 라이브러리에 맞는 🤗 Diffusers를 설치하세요.
|
||||
|
||||
🤗 Diffusers는 Python 3.8+, PyTorch 1.7.0+ 및 flax에서 테스트되었습니다. 사용중인 딥러닝 라이브러리에 대한 아래의 설치 안내를 따르세요.
|
||||
🤗 Diffusers는 Python 3.7+, PyTorch 1.7.0+ 및 flax에서 테스트되었습니다. 사용중인 딥러닝 라이브러리에 대한 아래의 설치 안내를 따르세요.
|
||||
|
||||
- [PyTorch 설치 안내](https://pytorch.org/get-started/locally/)
|
||||
- [Flax 설치 안내](https://flax.readthedocs.io/en/latest/)
|
||||
@@ -105,7 +105,7 @@ pip install -e ".[flax]"
|
||||
|
||||
이러한 명령어들은 저장소를 복제한 폴더와 Python 라이브러리 경로를 연결합니다.
|
||||
Python은 이제 일반 라이브러리 경로에 더하여 복제한 폴더 내부를 살펴봅니다.
|
||||
예를들어 Python 패키지가 `~/anaconda3/envs/main/lib/python3.8/site-packages/`에 설치되어 있는 경우 Python은 복제한 폴더인 `~/diffusers/`도 검색합니다.
|
||||
예를들어 Python 패키지가 `~/anaconda3/envs/main/lib/python3.7/site-packages/`에 설치되어 있는 경우 Python은 복제한 폴더인 `~/diffusers/`도 검색합니다.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
在你正在使用的任意深度学习框架中安装 🤗 Diffusers 。
|
||||
|
||||
🤗 Diffusers已在Python 3.8+、PyTorch 1.7.0+和Flax上进行了测试。按照下面的安装说明,针对你正在使用的深度学习框架进行安装:
|
||||
🤗 Diffusers已在Python 3.7+、PyTorch 1.7.0+和Flax上进行了测试。按照下面的安装说明,针对你正在使用的深度学习框架进行安装:
|
||||
|
||||
- [PyTorch](https://pytorch.org/get-started/locally/) installation instructions.
|
||||
- [Flax](https://flax.readthedocs.io/en/latest/) installation instructions.
|
||||
@@ -107,7 +107,7 @@ pip install -e ".[flax]"
|
||||
|
||||
这些命令将连接到你克隆的版本库和你的 Python 库路径。
|
||||
现在,不只是在通常的库路径,Python 还会在你克隆的文件夹内寻找包。
|
||||
例如,如果你的 Python 包通常安装在 `~/anaconda3/envs/main/lib/python3.8/Site-packages/`,Python 也会搜索你克隆到的文件夹。`~/diffusers/`。
|
||||
例如,如果你的 Python 包通常安装在 `~/anaconda3/envs/main/lib/python3.7/Site-packages/`,Python 也会搜索你克隆到的文件夹。`~/diffusers/`。
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
|
||||
@@ -908,9 +908,6 @@ def main():
|
||||
if args.snr_gamma is not None:
|
||||
snr = jnp.array(compute_snr(timesteps))
|
||||
snr_loss_weights = jnp.where(snr < args.snr_gamma, snr, jnp.ones_like(snr) * args.snr_gamma) / snr
|
||||
if noise_scheduler.config.prediction_type == "v_prediction":
|
||||
# velocity objective prediction requires SNR weights to be floored to a min value of 1.
|
||||
snr_loss_weights = snr_loss_weights + 1
|
||||
loss = loss * snr_loss_weights
|
||||
|
||||
loss = loss.mean()
|
||||
|
||||
@@ -224,30 +224,6 @@ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: st
|
||||
raise ValueError(f"{model_class} is not supported.")
|
||||
|
||||
|
||||
def compute_snr(timesteps, noise_scheduler):
|
||||
"""
|
||||
Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
|
||||
"""
|
||||
alphas_cumprod = noise_scheduler.alphas_cumprod
|
||||
sqrt_alphas_cumprod = alphas_cumprod**0.5
|
||||
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
|
||||
# Expand the tensors.
|
||||
# Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
|
||||
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
|
||||
while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
|
||||
sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
|
||||
alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
|
||||
|
||||
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
|
||||
while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
|
||||
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
|
||||
sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
|
||||
|
||||
# Compute SNR
|
||||
snr = (alpha / sigma) ** 2
|
||||
return snr
|
||||
|
||||
|
||||
def parse_args(input_args=None):
|
||||
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||
parser.add_argument(
|
||||
@@ -548,13 +524,6 @@ def parse_args(input_args=None):
|
||||
" See: https://www.crosslabs.org//blog/diffusion-with-offset-noise for more information."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--snr_gamma",
|
||||
type=float,
|
||||
default=None,
|
||||
help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
|
||||
"More details here: https://arxiv.org/abs/2303.09556.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pre_compute_text_embeddings",
|
||||
action="store_true",
|
||||
@@ -1292,34 +1261,17 @@ def main(args):
|
||||
# Chunk the noise and model_pred into two parts and compute the loss on each part separately.
|
||||
model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
|
||||
target, target_prior = torch.chunk(target, 2, dim=0)
|
||||
|
||||
# Compute instance loss
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||
|
||||
# Compute prior loss
|
||||
prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
|
||||
|
||||
# Compute instance loss
|
||||
if args.snr_gamma is None:
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||
else:
|
||||
# Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
|
||||
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
|
||||
# This is discussed in Section 4.2 of the same paper.
|
||||
snr = compute_snr(timesteps, noise_scheduler)
|
||||
base_weight = (
|
||||
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
|
||||
)
|
||||
|
||||
if noise_scheduler.config.prediction_type == "v_prediction":
|
||||
# Velocity objective needs to be floored to an SNR weight of one.
|
||||
mse_loss_weights = base_weight + 1
|
||||
else:
|
||||
# Epsilon and sample both use the same loss weights.
|
||||
mse_loss_weights = base_weight
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
|
||||
loss = loss.mean()
|
||||
|
||||
if args.with_prior_preservation:
|
||||
# Add the prior loss to the instance loss.
|
||||
loss = loss + args.prior_loss_weight * prior_loss
|
||||
else:
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients:
|
||||
|
||||
@@ -875,9 +875,6 @@ def main():
|
||||
mse_loss_weights = (
|
||||
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
|
||||
)
|
||||
if noise_scheduler.config.prediction_type == "v_prediction":
|
||||
# velocity objective prediction requires SNR weights to be floored to a min value of 1.
|
||||
mse_loss_weights = mse_loss_weights + 1
|
||||
# We first calculate the original loss. Then we mean over the non-batch dimensions and
|
||||
# rebalance the sample-wise losses with their respective loss weights.
|
||||
# Finally, we take the mean of the rebalanced loss.
|
||||
|
||||
@@ -955,9 +955,6 @@ def main():
|
||||
mse_loss_weights = (
|
||||
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
|
||||
)
|
||||
if noise_scheduler.config.prediction_type == "v_prediction":
|
||||
# velocity objective prediction requires SNR weights to be floored to a min value of 1.
|
||||
mse_loss_weights = mse_loss_weights + 1
|
||||
# We first calculate the original loss. Then we mean over the non-batch dimensions and
|
||||
# rebalance the sample-wise losses with their respective loss weights.
|
||||
# Finally, we take the mean of the rebalanced loss.
|
||||
|
||||
@@ -786,9 +786,6 @@ def main():
|
||||
mse_loss_weights = (
|
||||
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
|
||||
)
|
||||
if noise_scheduler.config.prediction_type == "v_prediction":
|
||||
# velocity objective prediction requires SNR weights to be floored to a min value of 1.
|
||||
mse_loss_weights = mse_loss_weights + 1
|
||||
# We first calculate the original loss. Then we mean over the non-batch dimensions and
|
||||
# rebalance the sample-wise losses with their respective loss weights.
|
||||
# Finally, we take the mean of the rebalanced loss.
|
||||
|
||||
@@ -1075,9 +1075,6 @@ def main(args):
|
||||
mse_loss_weights = (
|
||||
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
|
||||
)
|
||||
if noise_scheduler.config.prediction_type == "v_prediction":
|
||||
# velocity objective prediction requires SNR weights to be floored to a min value of 1.
|
||||
mse_loss_weights = mse_loss_weights + 1
|
||||
# We first calculate the original loss. Then we mean over the non-batch dimensions and
|
||||
# rebalance the sample-wise losses with their respective loss weights.
|
||||
# Finally, we take the mean of the rebalanced loss.
|
||||
|
||||
@@ -332,6 +332,15 @@ def parse_args(input_args=None):
|
||||
help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
|
||||
"More details here: https://arxiv.org/abs/2303.09556.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--force_snr_gamma",
|
||||
action="store_true",
|
||||
help=(
|
||||
"When using SNR gamma with rescaled betas for zero terminal SNR, a divide-by-zero error can cause NaN"
|
||||
" condition when computing the SNR with a sigma value of zero. This parameter overrides the check,"
|
||||
" allowing the use of SNR gamma with a terminal SNR model. Use with caution, and closely monitor results."
|
||||
),
|
||||
)
|
||||
parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
|
||||
parser.add_argument(
|
||||
"--allow_tf32",
|
||||
@@ -545,6 +554,18 @@ def main(args):
|
||||
# Load scheduler and models
|
||||
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
||||
# Check for terminal SNR in combination with SNR Gamma
|
||||
if (
|
||||
args.snr_gamma
|
||||
and not args.force_snr_gamma
|
||||
and (
|
||||
hasattr(noise_scheduler.config, "rescale_betas_zero_snr") and noise_scheduler.config.rescale_betas_zero_snr
|
||||
)
|
||||
):
|
||||
raise ValueError(
|
||||
f"The selected noise scheduler for the model {args.pretrained_model_name_or_path} uses rescaled betas for zero SNR.\n"
|
||||
"When this configuration is present, the parameter --snr_gamma may not be used without parameter --force_snr_gamma.\n"
|
||||
"This is due to a mathematical incompatibility between our current SNR gamma implementation, and a sigma value of zero."
|
||||
)
|
||||
text_encoder_one = text_encoder_cls_one.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
|
||||
)
|
||||
@@ -977,11 +998,6 @@ def main(args):
|
||||
target = noise
|
||||
elif noise_scheduler.config.prediction_type == "v_prediction":
|
||||
target = noise_scheduler.get_velocity(model_input, noise, timesteps)
|
||||
elif noise_scheduler.config.prediction_type == "sample":
|
||||
# We set the target to latents here, but the model_pred will return the noise sample prediction.
|
||||
target = model_input
|
||||
# We will have to subtract the noise residual from the prediction to get the target sample.
|
||||
model_pred = model_pred - noise
|
||||
else:
|
||||
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
||||
|
||||
@@ -992,17 +1008,9 @@ def main(args):
|
||||
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
|
||||
# This is discussed in Section 4.2 of the same paper.
|
||||
snr = compute_snr(timesteps)
|
||||
base_weight = (
|
||||
mse_loss_weights = (
|
||||
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
|
||||
)
|
||||
|
||||
if noise_scheduler.config.prediction_type == "v_prediction":
|
||||
# Velocity objective needs to be floored to an SNR weight of one.
|
||||
mse_loss_weights = base_weight + 1
|
||||
else:
|
||||
# Epsilon and sample both use the same loss weights.
|
||||
mse_loss_weights = base_weight
|
||||
|
||||
# We first calculate the original loss. Then we mean over the non-batch dimensions and
|
||||
# rebalance the sample-wise losses with their respective loss weights.
|
||||
# Finally, we take the mean of the rebalanced loss.
|
||||
|
||||
@@ -1,343 +0,0 @@
|
||||
"""
|
||||
This script requires you to build `LAVIS` from source, since the pip version doesn't have BLIP Diffusion. Follow instructions here: https://github.com/salesforce/LAVIS/tree/main.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import torch
|
||||
from lavis.models import load_model_and_preprocess
|
||||
from transformers import CLIPTokenizer
|
||||
from transformers.models.blip_2.configuration_blip_2 import Blip2Config
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
PNDMScheduler,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.pipelines import BlipDiffusionPipeline
|
||||
from diffusers.pipelines.blip_diffusion.blip_image_processing import BlipImageProcessor
|
||||
from diffusers.pipelines.blip_diffusion.modeling_blip2 import Blip2QFormerModel
|
||||
from diffusers.pipelines.blip_diffusion.modeling_ctx_clip import ContextCLIPTextModel
|
||||
|
||||
|
||||
BLIP2_CONFIG = {
|
||||
"vision_config": {
|
||||
"hidden_size": 1024,
|
||||
"num_hidden_layers": 23,
|
||||
"num_attention_heads": 16,
|
||||
"image_size": 224,
|
||||
"patch_size": 14,
|
||||
"intermediate_size": 4096,
|
||||
"hidden_act": "quick_gelu",
|
||||
},
|
||||
"qformer_config": {
|
||||
"cross_attention_frequency": 1,
|
||||
"encoder_hidden_size": 1024,
|
||||
"vocab_size": 30523,
|
||||
},
|
||||
"num_query_tokens": 16,
|
||||
}
|
||||
blip2config = Blip2Config(**BLIP2_CONFIG)
|
||||
|
||||
|
||||
def qformer_model_from_original_config():
|
||||
qformer = Blip2QFormerModel(blip2config)
|
||||
return qformer
|
||||
|
||||
|
||||
def embeddings_from_original_checkpoint(model, diffuser_embeddings_prefix, original_embeddings_prefix):
|
||||
embeddings = {}
|
||||
embeddings.update(
|
||||
{
|
||||
f"{diffuser_embeddings_prefix}.word_embeddings.weight": model[
|
||||
f"{original_embeddings_prefix}.word_embeddings.weight"
|
||||
]
|
||||
}
|
||||
)
|
||||
embeddings.update(
|
||||
{
|
||||
f"{diffuser_embeddings_prefix}.position_embeddings.weight": model[
|
||||
f"{original_embeddings_prefix}.position_embeddings.weight"
|
||||
]
|
||||
}
|
||||
)
|
||||
embeddings.update(
|
||||
{f"{diffuser_embeddings_prefix}.LayerNorm.weight": model[f"{original_embeddings_prefix}.LayerNorm.weight"]}
|
||||
)
|
||||
embeddings.update(
|
||||
{f"{diffuser_embeddings_prefix}.LayerNorm.bias": model[f"{original_embeddings_prefix}.LayerNorm.bias"]}
|
||||
)
|
||||
return embeddings
|
||||
|
||||
|
||||
def proj_layer_from_original_checkpoint(model, diffuser_proj_prefix, original_proj_prefix):
|
||||
proj_layer = {}
|
||||
proj_layer.update({f"{diffuser_proj_prefix}.dense1.weight": model[f"{original_proj_prefix}.dense1.weight"]})
|
||||
proj_layer.update({f"{diffuser_proj_prefix}.dense1.bias": model[f"{original_proj_prefix}.dense1.bias"]})
|
||||
proj_layer.update({f"{diffuser_proj_prefix}.dense2.weight": model[f"{original_proj_prefix}.dense2.weight"]})
|
||||
proj_layer.update({f"{diffuser_proj_prefix}.dense2.bias": model[f"{original_proj_prefix}.dense2.bias"]})
|
||||
proj_layer.update({f"{diffuser_proj_prefix}.LayerNorm.weight": model[f"{original_proj_prefix}.LayerNorm.weight"]})
|
||||
proj_layer.update({f"{diffuser_proj_prefix}.LayerNorm.bias": model[f"{original_proj_prefix}.LayerNorm.bias"]})
|
||||
return proj_layer
|
||||
|
||||
|
||||
def attention_from_original_checkpoint(model, diffuser_attention_prefix, original_attention_prefix):
|
||||
attention = {}
|
||||
attention.update(
|
||||
{
|
||||
f"{diffuser_attention_prefix}.attention.query.weight": model[
|
||||
f"{original_attention_prefix}.self.query.weight"
|
||||
]
|
||||
}
|
||||
)
|
||||
attention.update(
|
||||
{f"{diffuser_attention_prefix}.attention.query.bias": model[f"{original_attention_prefix}.self.query.bias"]}
|
||||
)
|
||||
attention.update(
|
||||
{f"{diffuser_attention_prefix}.attention.key.weight": model[f"{original_attention_prefix}.self.key.weight"]}
|
||||
)
|
||||
attention.update(
|
||||
{f"{diffuser_attention_prefix}.attention.key.bias": model[f"{original_attention_prefix}.self.key.bias"]}
|
||||
)
|
||||
attention.update(
|
||||
{
|
||||
f"{diffuser_attention_prefix}.attention.value.weight": model[
|
||||
f"{original_attention_prefix}.self.value.weight"
|
||||
]
|
||||
}
|
||||
)
|
||||
attention.update(
|
||||
{f"{diffuser_attention_prefix}.attention.value.bias": model[f"{original_attention_prefix}.self.value.bias"]}
|
||||
)
|
||||
attention.update(
|
||||
{f"{diffuser_attention_prefix}.output.dense.weight": model[f"{original_attention_prefix}.output.dense.weight"]}
|
||||
)
|
||||
attention.update(
|
||||
{f"{diffuser_attention_prefix}.output.dense.bias": model[f"{original_attention_prefix}.output.dense.bias"]}
|
||||
)
|
||||
attention.update(
|
||||
{
|
||||
f"{diffuser_attention_prefix}.output.LayerNorm.weight": model[
|
||||
f"{original_attention_prefix}.output.LayerNorm.weight"
|
||||
]
|
||||
}
|
||||
)
|
||||
attention.update(
|
||||
{
|
||||
f"{diffuser_attention_prefix}.output.LayerNorm.bias": model[
|
||||
f"{original_attention_prefix}.output.LayerNorm.bias"
|
||||
]
|
||||
}
|
||||
)
|
||||
return attention
|
||||
|
||||
|
||||
def output_layers_from_original_checkpoint(model, diffuser_output_prefix, original_output_prefix):
|
||||
output_layers = {}
|
||||
output_layers.update({f"{diffuser_output_prefix}.dense.weight": model[f"{original_output_prefix}.dense.weight"]})
|
||||
output_layers.update({f"{diffuser_output_prefix}.dense.bias": model[f"{original_output_prefix}.dense.bias"]})
|
||||
output_layers.update(
|
||||
{f"{diffuser_output_prefix}.LayerNorm.weight": model[f"{original_output_prefix}.LayerNorm.weight"]}
|
||||
)
|
||||
output_layers.update(
|
||||
{f"{diffuser_output_prefix}.LayerNorm.bias": model[f"{original_output_prefix}.LayerNorm.bias"]}
|
||||
)
|
||||
return output_layers
|
||||
|
||||
|
||||
def encoder_from_original_checkpoint(model, diffuser_encoder_prefix, original_encoder_prefix):
|
||||
encoder = {}
|
||||
for i in range(blip2config.qformer_config.num_hidden_layers):
|
||||
encoder.update(
|
||||
attention_from_original_checkpoint(
|
||||
model, f"{diffuser_encoder_prefix}.{i}.attention", f"{original_encoder_prefix}.{i}.attention"
|
||||
)
|
||||
)
|
||||
encoder.update(
|
||||
attention_from_original_checkpoint(
|
||||
model, f"{diffuser_encoder_prefix}.{i}.crossattention", f"{original_encoder_prefix}.{i}.crossattention"
|
||||
)
|
||||
)
|
||||
|
||||
encoder.update(
|
||||
{
|
||||
f"{diffuser_encoder_prefix}.{i}.intermediate.dense.weight": model[
|
||||
f"{original_encoder_prefix}.{i}.intermediate.dense.weight"
|
||||
]
|
||||
}
|
||||
)
|
||||
encoder.update(
|
||||
{
|
||||
f"{diffuser_encoder_prefix}.{i}.intermediate.dense.bias": model[
|
||||
f"{original_encoder_prefix}.{i}.intermediate.dense.bias"
|
||||
]
|
||||
}
|
||||
)
|
||||
encoder.update(
|
||||
{
|
||||
f"{diffuser_encoder_prefix}.{i}.intermediate_query.dense.weight": model[
|
||||
f"{original_encoder_prefix}.{i}.intermediate_query.dense.weight"
|
||||
]
|
||||
}
|
||||
)
|
||||
encoder.update(
|
||||
{
|
||||
f"{diffuser_encoder_prefix}.{i}.intermediate_query.dense.bias": model[
|
||||
f"{original_encoder_prefix}.{i}.intermediate_query.dense.bias"
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
encoder.update(
|
||||
output_layers_from_original_checkpoint(
|
||||
model, f"{diffuser_encoder_prefix}.{i}.output", f"{original_encoder_prefix}.{i}.output"
|
||||
)
|
||||
)
|
||||
encoder.update(
|
||||
output_layers_from_original_checkpoint(
|
||||
model, f"{diffuser_encoder_prefix}.{i}.output_query", f"{original_encoder_prefix}.{i}.output_query"
|
||||
)
|
||||
)
|
||||
return encoder
|
||||
|
||||
|
||||
def visual_encoder_layer_from_original_checkpoint(model, diffuser_prefix, original_prefix):
|
||||
visual_encoder_layer = {}
|
||||
|
||||
visual_encoder_layer.update({f"{diffuser_prefix}.layer_norm1.weight": model[f"{original_prefix}.ln_1.weight"]})
|
||||
visual_encoder_layer.update({f"{diffuser_prefix}.layer_norm1.bias": model[f"{original_prefix}.ln_1.bias"]})
|
||||
visual_encoder_layer.update({f"{diffuser_prefix}.layer_norm2.weight": model[f"{original_prefix}.ln_2.weight"]})
|
||||
visual_encoder_layer.update({f"{diffuser_prefix}.layer_norm2.bias": model[f"{original_prefix}.ln_2.bias"]})
|
||||
visual_encoder_layer.update(
|
||||
{f"{diffuser_prefix}.self_attn.qkv.weight": model[f"{original_prefix}.attn.in_proj_weight"]}
|
||||
)
|
||||
visual_encoder_layer.update(
|
||||
{f"{diffuser_prefix}.self_attn.qkv.bias": model[f"{original_prefix}.attn.in_proj_bias"]}
|
||||
)
|
||||
visual_encoder_layer.update(
|
||||
{f"{diffuser_prefix}.self_attn.projection.weight": model[f"{original_prefix}.attn.out_proj.weight"]}
|
||||
)
|
||||
visual_encoder_layer.update(
|
||||
{f"{diffuser_prefix}.self_attn.projection.bias": model[f"{original_prefix}.attn.out_proj.bias"]}
|
||||
)
|
||||
visual_encoder_layer.update({f"{diffuser_prefix}.mlp.fc1.weight": model[f"{original_prefix}.mlp.c_fc.weight"]})
|
||||
visual_encoder_layer.update({f"{diffuser_prefix}.mlp.fc1.bias": model[f"{original_prefix}.mlp.c_fc.bias"]})
|
||||
visual_encoder_layer.update({f"{diffuser_prefix}.mlp.fc2.weight": model[f"{original_prefix}.mlp.c_proj.weight"]})
|
||||
visual_encoder_layer.update({f"{diffuser_prefix}.mlp.fc2.bias": model[f"{original_prefix}.mlp.c_proj.bias"]})
|
||||
|
||||
return visual_encoder_layer
|
||||
|
||||
|
||||
def visual_encoder_from_original_checkpoint(model, diffuser_prefix, original_prefix):
|
||||
visual_encoder = {}
|
||||
|
||||
visual_encoder.update(
|
||||
{
|
||||
f"{diffuser_prefix}.embeddings.class_embedding": model[f"{original_prefix}.class_embedding"]
|
||||
.unsqueeze(0)
|
||||
.unsqueeze(0)
|
||||
}
|
||||
)
|
||||
visual_encoder.update(
|
||||
{
|
||||
f"{diffuser_prefix}.embeddings.position_embedding": model[
|
||||
f"{original_prefix}.positional_embedding"
|
||||
].unsqueeze(0)
|
||||
}
|
||||
)
|
||||
visual_encoder.update(
|
||||
{f"{diffuser_prefix}.embeddings.patch_embedding.weight": model[f"{original_prefix}.conv1.weight"]}
|
||||
)
|
||||
visual_encoder.update({f"{diffuser_prefix}.pre_layernorm.weight": model[f"{original_prefix}.ln_pre.weight"]})
|
||||
visual_encoder.update({f"{diffuser_prefix}.pre_layernorm.bias": model[f"{original_prefix}.ln_pre.bias"]})
|
||||
|
||||
for i in range(blip2config.vision_config.num_hidden_layers):
|
||||
visual_encoder.update(
|
||||
visual_encoder_layer_from_original_checkpoint(
|
||||
model, f"{diffuser_prefix}.encoder.layers.{i}", f"{original_prefix}.transformer.resblocks.{i}"
|
||||
)
|
||||
)
|
||||
|
||||
visual_encoder.update({f"{diffuser_prefix}.post_layernorm.weight": model["blip.ln_vision.weight"]})
|
||||
visual_encoder.update({f"{diffuser_prefix}.post_layernorm.bias": model["blip.ln_vision.bias"]})
|
||||
|
||||
return visual_encoder
|
||||
|
||||
|
||||
def qformer_original_checkpoint_to_diffusers_checkpoint(model):
|
||||
qformer_checkpoint = {}
|
||||
qformer_checkpoint.update(embeddings_from_original_checkpoint(model, "embeddings", "blip.Qformer.bert.embeddings"))
|
||||
qformer_checkpoint.update({"query_tokens": model["blip.query_tokens"]})
|
||||
qformer_checkpoint.update(proj_layer_from_original_checkpoint(model, "proj_layer", "proj_layer"))
|
||||
qformer_checkpoint.update(
|
||||
encoder_from_original_checkpoint(model, "encoder.layer", "blip.Qformer.bert.encoder.layer")
|
||||
)
|
||||
qformer_checkpoint.update(visual_encoder_from_original_checkpoint(model, "visual_encoder", "blip.visual_encoder"))
|
||||
return qformer_checkpoint
|
||||
|
||||
|
||||
def get_qformer(model):
|
||||
print("loading qformer")
|
||||
|
||||
qformer = qformer_model_from_original_config()
|
||||
qformer_diffusers_checkpoint = qformer_original_checkpoint_to_diffusers_checkpoint(model)
|
||||
|
||||
load_checkpoint_to_model(qformer_diffusers_checkpoint, qformer)
|
||||
|
||||
print("done loading qformer")
|
||||
return qformer
|
||||
|
||||
|
||||
def load_checkpoint_to_model(checkpoint, model):
|
||||
with tempfile.NamedTemporaryFile(delete=False) as file:
|
||||
torch.save(checkpoint, file.name)
|
||||
del checkpoint
|
||||
model.load_state_dict(torch.load(file.name), strict=False)
|
||||
|
||||
os.remove(file.name)
|
||||
|
||||
|
||||
def save_blip_diffusion_model(model, args):
|
||||
qformer = get_qformer(model)
|
||||
qformer.eval()
|
||||
|
||||
text_encoder = ContextCLIPTextModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="text_encoder")
|
||||
vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae")
|
||||
|
||||
unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")
|
||||
vae.eval()
|
||||
text_encoder.eval()
|
||||
scheduler = PNDMScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
set_alpha_to_one=False,
|
||||
skip_prk_steps=True,
|
||||
)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="tokenizer")
|
||||
image_processor = BlipImageProcessor()
|
||||
blip_diffusion = BlipDiffusionPipeline(
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
vae=vae,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
qformer=qformer,
|
||||
image_processor=image_processor,
|
||||
)
|
||||
blip_diffusion.save_pretrained(args.checkpoint_path)
|
||||
|
||||
|
||||
def main(args):
|
||||
model, _, _ = load_model_and_preprocess("blip_diffusion", "base", device="cpu", is_eval=True)
|
||||
save_blip_diffusion_model(model.state_dict(), args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.")
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
@@ -35,12 +35,6 @@ if __name__ == "__main__":
|
||||
type=str,
|
||||
help="The YAML config file corresponding to the original architecture.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config_files",
|
||||
default=None,
|
||||
type=str,
|
||||
help="The YAML config file corresponding to the architecture.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_in_channels",
|
||||
default=None,
|
||||
|
||||
@@ -128,7 +128,6 @@ _deps = [
|
||||
"torchvision",
|
||||
"transformers>=4.25.1",
|
||||
"urllib3<=2.0.0",
|
||||
"peft>=0.5.0"
|
||||
]
|
||||
|
||||
# this is a lookup table with items like:
|
||||
@@ -257,7 +256,7 @@ setup(
|
||||
package_dir={"": "src"},
|
||||
packages=find_packages("src"),
|
||||
include_package_data=True,
|
||||
python_requires=">=3.8.0",
|
||||
python_requires=">=3.7.0",
|
||||
install_requires=list(install_requires),
|
||||
extras_require=extras,
|
||||
entry_points={"console_scripts": ["diffusers-cli=diffusers.commands.diffusers_cli:main"]},
|
||||
@@ -269,6 +268,7 @@ setup(
|
||||
"License :: OSI Approved :: Apache Software License",
|
||||
"Operating System :: OS Independent",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.7",
|
||||
"Programming Language :: Python :: 3.8",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||
|
||||
@@ -197,8 +197,6 @@ else:
|
||||
"AudioLDM2ProjectionModel",
|
||||
"AudioLDM2UNet2DConditionModel",
|
||||
"AudioLDMPipeline",
|
||||
"BlipDiffusionControlNetPipeline",
|
||||
"BlipDiffusionPipeline",
|
||||
"CLIPImageProjection",
|
||||
"CycleDiffusionPipeline",
|
||||
"IFImg2ImgPipeline",
|
||||
@@ -460,8 +458,6 @@ if TYPE_CHECKING:
|
||||
AutoPipelineForImage2Image,
|
||||
AutoPipelineForInpainting,
|
||||
AutoPipelineForText2Image,
|
||||
BlipDiffusionControlNetPipeline,
|
||||
BlipDiffusionPipeline,
|
||||
CLIPImageProjection,
|
||||
ConsistencyModelPipeline,
|
||||
DanceDiffusionPipeline,
|
||||
|
||||
@@ -41,5 +41,4 @@ deps = {
|
||||
"torchvision": "torchvision",
|
||||
"transformers": "transformers>=4.25.1",
|
||||
"urllib3": "urllib3<=2.0.0",
|
||||
"peft": "peft>=0.5.0",
|
||||
}
|
||||
|
||||
+117
-271
@@ -11,7 +11,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import importlib
|
||||
import os
|
||||
import re
|
||||
from collections import defaultdict
|
||||
@@ -24,7 +23,6 @@ import requests
|
||||
import safetensors
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download, model_info
|
||||
from packaging import version
|
||||
from torch import nn
|
||||
|
||||
from .models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
|
||||
@@ -32,20 +30,11 @@ from .utils import (
|
||||
DIFFUSERS_CACHE,
|
||||
HF_HUB_OFFLINE,
|
||||
_get_model_file,
|
||||
convert_state_dict_to_diffusers,
|
||||
convert_state_dict_to_peft,
|
||||
deprecate,
|
||||
get_adapter_name,
|
||||
get_rank_and_alpha_pattern,
|
||||
is_accelerate_available,
|
||||
is_omegaconf_available,
|
||||
is_peft_available,
|
||||
is_transformers_available,
|
||||
logging,
|
||||
recurse_remove_peft_layers,
|
||||
scale_lora_layers,
|
||||
set_adapter_layers,
|
||||
set_weights_and_activate_adapters,
|
||||
)
|
||||
from .utils.import_utils import BACKENDS_MAPPING
|
||||
|
||||
@@ -72,21 +61,6 @@ CUSTOM_DIFFUSION_WEIGHT_NAME = "pytorch_custom_diffusion_weights.bin"
|
||||
CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE = "pytorch_custom_diffusion_weights.safetensors"
|
||||
|
||||
|
||||
# Below should be `True` if the current version of `peft` and `transformers` are compatible with
|
||||
# PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are
|
||||
# available.
|
||||
# For PEFT it is has to be greater than 0.6.0 and for transformers it has to be greater than 4.33.1.
|
||||
_required_peft_version = is_peft_available() and version.parse(
|
||||
version.parse(importlib.metadata.version("peft")).base_version
|
||||
) > version.parse("0.5")
|
||||
_required_transformers_version = version.parse(
|
||||
version.parse(importlib.metadata.version("transformers")).base_version
|
||||
) > version.parse("4.33")
|
||||
|
||||
USE_PEFT_BACKEND = _required_peft_version and _required_transformers_version
|
||||
LORA_DEPRECATION_MESSAGE = "You are using an old version of LoRA backend. This will be deprecated in the next releases in favor of PEFT make sure to install the latest PEFT and transformers packages in the future."
|
||||
|
||||
|
||||
class PatchedLoraProjection(nn.Module):
|
||||
def __init__(self, regular_linear_layer, lora_scale=1, network_alpha=None, rank=4, dtype=None):
|
||||
super().__init__()
|
||||
@@ -1103,11 +1077,8 @@ class LoraLoaderMixin:
|
||||
text_encoder_name = TEXT_ENCODER_NAME
|
||||
unet_name = UNET_NAME
|
||||
num_fused_loras = 0
|
||||
use_peft_backend = USE_PEFT_BACKEND
|
||||
|
||||
def load_lora_weights(
|
||||
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
|
||||
):
|
||||
def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
|
||||
"""
|
||||
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
|
||||
`self.text_encoder`.
|
||||
@@ -1151,7 +1122,6 @@ class LoraLoaderMixin:
|
||||
lora_scale=self.lora_scale,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
_pipeline=self,
|
||||
adapter_name=adapter_name,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -1508,7 +1478,6 @@ class LoraLoaderMixin:
|
||||
lora_scale=1.0,
|
||||
low_cpu_mem_usage=None,
|
||||
_pipeline=None,
|
||||
adapter_name=None,
|
||||
):
|
||||
"""
|
||||
This will load the LoRA layers specified in `state_dict` into `text_encoder`
|
||||
@@ -1531,9 +1500,6 @@ class LoraLoaderMixin:
|
||||
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
||||
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
|
||||
argument to `True` will raise an error.
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded
|
||||
"""
|
||||
low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT
|
||||
|
||||
@@ -1554,35 +1520,55 @@ class LoraLoaderMixin:
|
||||
if len(text_encoder_lora_state_dict) > 0:
|
||||
logger.info(f"Loading {prefix}.")
|
||||
rank = {}
|
||||
text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
|
||||
|
||||
if cls.use_peft_backend:
|
||||
# convert state dict
|
||||
text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
|
||||
|
||||
if any("to_out_lora" in k for k in text_encoder_lora_state_dict.keys()):
|
||||
# Convert from the old naming convention to the new naming convention.
|
||||
#
|
||||
# Previously, the old LoRA layers were stored on the state dict at the
|
||||
# same level as the attention block i.e.
|
||||
# `text_model.encoder.layers.11.self_attn.to_out_lora.up.weight`.
|
||||
#
|
||||
# This is no actual module at that point, they were monkey patched on to the
|
||||
# existing module. We want to be able to load them via their actual state dict.
|
||||
# They're in `PatchedLoraProjection.lora_linear_layer` now.
|
||||
for name, _ in text_encoder_attn_modules(text_encoder):
|
||||
rank_key = f"{name}.out_proj.lora_B.weight"
|
||||
rank.update({rank_key: text_encoder_lora_state_dict[rank_key].shape[1]})
|
||||
text_encoder_lora_state_dict[
|
||||
f"{name}.q_proj.lora_linear_layer.up.weight"
|
||||
] = text_encoder_lora_state_dict.pop(f"{name}.to_q_lora.up.weight")
|
||||
text_encoder_lora_state_dict[
|
||||
f"{name}.k_proj.lora_linear_layer.up.weight"
|
||||
] = text_encoder_lora_state_dict.pop(f"{name}.to_k_lora.up.weight")
|
||||
text_encoder_lora_state_dict[
|
||||
f"{name}.v_proj.lora_linear_layer.up.weight"
|
||||
] = text_encoder_lora_state_dict.pop(f"{name}.to_v_lora.up.weight")
|
||||
text_encoder_lora_state_dict[
|
||||
f"{name}.out_proj.lora_linear_layer.up.weight"
|
||||
] = text_encoder_lora_state_dict.pop(f"{name}.to_out_lora.up.weight")
|
||||
|
||||
patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys())
|
||||
if patch_mlp:
|
||||
for name, _ in text_encoder_mlp_modules(text_encoder):
|
||||
rank_key_fc1 = f"{name}.fc1.lora_B.weight"
|
||||
rank_key_fc2 = f"{name}.fc2.lora_B.weight"
|
||||
rank.update({rank_key_fc1: text_encoder_lora_state_dict[rank_key_fc1].shape[1]})
|
||||
rank.update({rank_key_fc2: text_encoder_lora_state_dict[rank_key_fc2].shape[1]})
|
||||
else:
|
||||
for name, _ in text_encoder_attn_modules(text_encoder):
|
||||
rank_key = f"{name}.out_proj.lora_linear_layer.up.weight"
|
||||
rank.update({rank_key: text_encoder_lora_state_dict[rank_key].shape[1]})
|
||||
text_encoder_lora_state_dict[
|
||||
f"{name}.q_proj.lora_linear_layer.down.weight"
|
||||
] = text_encoder_lora_state_dict.pop(f"{name}.to_q_lora.down.weight")
|
||||
text_encoder_lora_state_dict[
|
||||
f"{name}.k_proj.lora_linear_layer.down.weight"
|
||||
] = text_encoder_lora_state_dict.pop(f"{name}.to_k_lora.down.weight")
|
||||
text_encoder_lora_state_dict[
|
||||
f"{name}.v_proj.lora_linear_layer.down.weight"
|
||||
] = text_encoder_lora_state_dict.pop(f"{name}.to_v_lora.down.weight")
|
||||
text_encoder_lora_state_dict[
|
||||
f"{name}.out_proj.lora_linear_layer.down.weight"
|
||||
] = text_encoder_lora_state_dict.pop(f"{name}.to_out_lora.down.weight")
|
||||
|
||||
patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys())
|
||||
if patch_mlp:
|
||||
for name, _ in text_encoder_mlp_modules(text_encoder):
|
||||
rank_key_fc1 = f"{name}.fc1.lora_linear_layer.up.weight"
|
||||
rank_key_fc2 = f"{name}.fc2.lora_linear_layer.up.weight"
|
||||
rank.update({rank_key_fc1: text_encoder_lora_state_dict[rank_key_fc1].shape[1]})
|
||||
rank.update({rank_key_fc2: text_encoder_lora_state_dict[rank_key_fc2].shape[1]})
|
||||
for name, _ in text_encoder_attn_modules(text_encoder):
|
||||
rank_key = f"{name}.out_proj.lora_linear_layer.up.weight"
|
||||
rank.update({rank_key: text_encoder_lora_state_dict[rank_key].shape[1]})
|
||||
|
||||
patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys())
|
||||
if patch_mlp:
|
||||
for name, _ in text_encoder_mlp_modules(text_encoder):
|
||||
rank_key_fc1 = f"{name}.fc1.lora_linear_layer.up.weight"
|
||||
rank_key_fc2 = f"{name}.fc2.lora_linear_layer.up.weight"
|
||||
rank.update({rank_key_fc1: text_encoder_lora_state_dict[rank_key_fc1].shape[1]})
|
||||
rank.update({rank_key_fc2: text_encoder_lora_state_dict[rank_key_fc2].shape[1]})
|
||||
|
||||
if network_alphas is not None:
|
||||
alpha_keys = [
|
||||
@@ -1592,90 +1578,56 @@ class LoraLoaderMixin:
|
||||
k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
|
||||
}
|
||||
|
||||
if cls.use_peft_backend:
|
||||
from peft import LoraConfig
|
||||
cls._modify_text_encoder(
|
||||
text_encoder,
|
||||
lora_scale,
|
||||
network_alphas,
|
||||
rank=rank,
|
||||
patch_mlp=patch_mlp,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
r, lora_alpha, rank_pattern, alpha_pattern, target_modules = get_rank_and_alpha_pattern(
|
||||
rank, network_alphas, text_encoder_lora_state_dict
|
||||
is_pipeline_offloaded = _pipeline is not None and any(
|
||||
isinstance(c, torch.nn.Module) and hasattr(c, "_hf_hook") for c in _pipeline.components.values()
|
||||
)
|
||||
if is_pipeline_offloaded and low_cpu_mem_usage:
|
||||
low_cpu_mem_usage = True
|
||||
logger.info(
|
||||
f"Pipeline {_pipeline.__class__} is offloaded. Therefore low cpu mem usage loading is forced."
|
||||
)
|
||||
|
||||
lora_config = LoraConfig(
|
||||
r=r,
|
||||
target_modules=target_modules,
|
||||
lora_alpha=lora_alpha,
|
||||
rank_pattern=rank_pattern,
|
||||
alpha_pattern=alpha_pattern,
|
||||
if low_cpu_mem_usage:
|
||||
device = next(iter(text_encoder_lora_state_dict.values())).device
|
||||
dtype = next(iter(text_encoder_lora_state_dict.values())).dtype
|
||||
unexpected_keys = load_model_dict_into_meta(
|
||||
text_encoder, text_encoder_lora_state_dict, device=device, dtype=dtype
|
||||
)
|
||||
|
||||
# adapter_name
|
||||
if adapter_name is None:
|
||||
adapter_name = get_adapter_name(text_encoder)
|
||||
|
||||
# inject LoRA layers and load the state dict
|
||||
text_encoder.load_adapter(
|
||||
adapter_name=adapter_name,
|
||||
adapter_state_dict=text_encoder_lora_state_dict,
|
||||
peft_config=lora_config,
|
||||
)
|
||||
# scale LoRA layers with `lora_scale`
|
||||
scale_lora_layers(text_encoder, lora_weightage=lora_scale)
|
||||
|
||||
is_model_cpu_offload = False
|
||||
is_sequential_cpu_offload = False
|
||||
else:
|
||||
cls._modify_text_encoder(
|
||||
text_encoder,
|
||||
lora_scale,
|
||||
network_alphas,
|
||||
rank=rank,
|
||||
patch_mlp=patch_mlp,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
load_state_dict_results = text_encoder.load_state_dict(text_encoder_lora_state_dict, strict=False)
|
||||
unexpected_keys = load_state_dict_results.unexpected_keys
|
||||
|
||||
if len(unexpected_keys) != 0:
|
||||
raise ValueError(
|
||||
f"failed to load text encoder state dict, unexpected keys: {load_state_dict_results.unexpected_keys}"
|
||||
)
|
||||
|
||||
is_pipeline_offloaded = _pipeline is not None and any(
|
||||
isinstance(c, torch.nn.Module) and hasattr(c, "_hf_hook")
|
||||
for c in _pipeline.components.values()
|
||||
)
|
||||
if is_pipeline_offloaded and low_cpu_mem_usage:
|
||||
low_cpu_mem_usage = True
|
||||
logger.info(
|
||||
f"Pipeline {_pipeline.__class__} is offloaded. Therefore low cpu mem usage loading is forced."
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage:
|
||||
device = next(iter(text_encoder_lora_state_dict.values())).device
|
||||
dtype = next(iter(text_encoder_lora_state_dict.values())).dtype
|
||||
unexpected_keys = load_model_dict_into_meta(
|
||||
text_encoder, text_encoder_lora_state_dict, device=device, dtype=dtype
|
||||
)
|
||||
else:
|
||||
load_state_dict_results = text_encoder.load_state_dict(
|
||||
text_encoder_lora_state_dict, strict=False
|
||||
)
|
||||
unexpected_keys = load_state_dict_results.unexpected_keys
|
||||
|
||||
if len(unexpected_keys) != 0:
|
||||
raise ValueError(
|
||||
f"failed to load text encoder state dict, unexpected keys: {load_state_dict_results.unexpected_keys}"
|
||||
)
|
||||
|
||||
# <Unsafe code
|
||||
# We can be sure that the following works as all we do is change the dtype and device of the text encoder
|
||||
# Now we remove any existing hooks to
|
||||
is_model_cpu_offload = False
|
||||
is_sequential_cpu_offload = False
|
||||
if _pipeline is not None:
|
||||
for _, component in _pipeline.components.items():
|
||||
if isinstance(component, torch.nn.Module):
|
||||
if hasattr(component, "_hf_hook"):
|
||||
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
|
||||
is_sequential_cpu_offload = isinstance(
|
||||
getattr(component, "_hf_hook"), AlignDevicesHook
|
||||
)
|
||||
logger.info(
|
||||
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
|
||||
)
|
||||
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
|
||||
# <Unsafe code
|
||||
# We can be sure that the following works as all we do is change the dtype and device of the text encoder
|
||||
# Now we remove any existing hooks to
|
||||
is_model_cpu_offload = False
|
||||
is_sequential_cpu_offload = False
|
||||
if _pipeline is not None:
|
||||
for _, component in _pipeline.components.items():
|
||||
if isinstance(component, torch.nn.Module):
|
||||
if hasattr(component, "_hf_hook"):
|
||||
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
|
||||
is_sequential_cpu_offload = isinstance(
|
||||
getattr(component, "_hf_hook"), AlignDevicesHook
|
||||
)
|
||||
logger.info(
|
||||
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
|
||||
)
|
||||
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
|
||||
|
||||
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
|
||||
|
||||
@@ -1693,20 +1645,10 @@ class LoraLoaderMixin:
|
||||
return self._lora_scale if hasattr(self, "_lora_scale") else 1.0
|
||||
|
||||
def _remove_text_encoder_monkey_patch(self):
|
||||
if self.use_peft_backend:
|
||||
remove_method = recurse_remove_peft_layers
|
||||
else:
|
||||
remove_method = self._remove_text_encoder_monkey_patch_classmethod
|
||||
|
||||
if hasattr(self, "text_encoder"):
|
||||
remove_method(self.text_encoder)
|
||||
if hasattr(self, "text_encoder_2"):
|
||||
remove_method(self.text_encoder_2)
|
||||
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder)
|
||||
|
||||
@classmethod
|
||||
def _remove_text_encoder_monkey_patch_classmethod(cls, text_encoder):
|
||||
deprecate("_remove_text_encoder_monkey_patch_classmethod", "0.23", LORA_DEPRECATION_MESSAGE)
|
||||
|
||||
for _, attn_module in text_encoder_attn_modules(text_encoder):
|
||||
if isinstance(attn_module.q_proj, PatchedLoraProjection):
|
||||
attn_module.q_proj.lora_linear_layer = None
|
||||
@@ -1733,7 +1675,6 @@ class LoraLoaderMixin:
|
||||
r"""
|
||||
Monkey-patches the forward passes of attention modules of the text encoder.
|
||||
"""
|
||||
deprecate("_modify_text_encoder", "0.23", LORA_DEPRECATION_MESSAGE)
|
||||
|
||||
def create_patched_linear_lora(model, network_alpha, rank, dtype, lora_parameters):
|
||||
linear_layer = model.regular_linear_layer if isinstance(model, PatchedLoraProjection) else model
|
||||
@@ -1937,7 +1878,7 @@ class LoraLoaderMixin:
|
||||
diffusers_name = diffusers_name.replace("emb.layers", "time_emb_proj")
|
||||
|
||||
# SDXL specificity.
|
||||
if "emb" in diffusers_name and "time" not in diffusers_name:
|
||||
if "emb" in diffusers_name:
|
||||
pattern = r"\.\d+(?=\D*$)"
|
||||
diffusers_name = re.sub(pattern, "", diffusers_name, count=1)
|
||||
if ".in." in diffusers_name:
|
||||
@@ -1949,13 +1890,6 @@ class LoraLoaderMixin:
|
||||
if "skip" in diffusers_name:
|
||||
diffusers_name = diffusers_name.replace("skip.connection", "conv_shortcut")
|
||||
|
||||
# LyCORIS specificity.
|
||||
if "time" in diffusers_name:
|
||||
diffusers_name = diffusers_name.replace("time.emb.proj", "time_emb_proj")
|
||||
if "conv.shortcut" in diffusers_name:
|
||||
diffusers_name = diffusers_name.replace("conv.shortcut", "conv_shortcut")
|
||||
|
||||
# General coverage.
|
||||
if "transformer_blocks" in diffusers_name:
|
||||
if "attn1" in diffusers_name or "attn2" in diffusers_name:
|
||||
diffusers_name = diffusers_name.replace("attn1", "attn1.processor")
|
||||
@@ -2108,38 +2042,24 @@ class LoraLoaderMixin:
|
||||
if fuse_unet:
|
||||
self.unet.fuse_lora(lora_scale)
|
||||
|
||||
if self.use_peft_backend:
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
def fuse_text_encoder_lora(text_encoder):
|
||||
for _, attn_module in text_encoder_attn_modules(text_encoder):
|
||||
if isinstance(attn_module.q_proj, PatchedLoraProjection):
|
||||
attn_module.q_proj._fuse_lora(lora_scale)
|
||||
attn_module.k_proj._fuse_lora(lora_scale)
|
||||
attn_module.v_proj._fuse_lora(lora_scale)
|
||||
attn_module.out_proj._fuse_lora(lora_scale)
|
||||
|
||||
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0):
|
||||
for module in text_encoder.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
if lora_scale != 1.0:
|
||||
module.scale_layer(lora_scale)
|
||||
|
||||
module.merge()
|
||||
|
||||
else:
|
||||
deprecate("fuse_text_encoder_lora", "0.23", LORA_DEPRECATION_MESSAGE)
|
||||
|
||||
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0):
|
||||
for _, attn_module in text_encoder_attn_modules(text_encoder):
|
||||
if isinstance(attn_module.q_proj, PatchedLoraProjection):
|
||||
attn_module.q_proj._fuse_lora(lora_scale)
|
||||
attn_module.k_proj._fuse_lora(lora_scale)
|
||||
attn_module.v_proj._fuse_lora(lora_scale)
|
||||
attn_module.out_proj._fuse_lora(lora_scale)
|
||||
|
||||
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
|
||||
if isinstance(mlp_module.fc1, PatchedLoraProjection):
|
||||
mlp_module.fc1._fuse_lora(lora_scale)
|
||||
mlp_module.fc2._fuse_lora(lora_scale)
|
||||
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
|
||||
if isinstance(mlp_module.fc1, PatchedLoraProjection):
|
||||
mlp_module.fc1._fuse_lora(lora_scale)
|
||||
mlp_module.fc2._fuse_lora(lora_scale)
|
||||
|
||||
if fuse_text_encoder:
|
||||
if hasattr(self, "text_encoder"):
|
||||
fuse_text_encoder_lora(self.text_encoder, lora_scale)
|
||||
fuse_text_encoder_lora(self.text_encoder)
|
||||
if hasattr(self, "text_encoder_2"):
|
||||
fuse_text_encoder_lora(self.text_encoder_2, lora_scale)
|
||||
fuse_text_encoder_lora(self.text_encoder_2)
|
||||
|
||||
def unfuse_lora(self, unfuse_unet: bool = True, unfuse_text_encoder: bool = True):
|
||||
r"""
|
||||
@@ -2161,29 +2081,18 @@ class LoraLoaderMixin:
|
||||
if unfuse_unet:
|
||||
self.unet.unfuse_lora()
|
||||
|
||||
if self.use_peft_backend:
|
||||
from peft.tuners.tuner_utils import BaseTunerLayer
|
||||
def unfuse_text_encoder_lora(text_encoder):
|
||||
for _, attn_module in text_encoder_attn_modules(text_encoder):
|
||||
if isinstance(attn_module.q_proj, PatchedLoraProjection):
|
||||
attn_module.q_proj._unfuse_lora()
|
||||
attn_module.k_proj._unfuse_lora()
|
||||
attn_module.v_proj._unfuse_lora()
|
||||
attn_module.out_proj._unfuse_lora()
|
||||
|
||||
def unfuse_text_encoder_lora(text_encoder):
|
||||
for module in text_encoder.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
module.unmerge()
|
||||
|
||||
else:
|
||||
deprecate("unfuse_text_encoder_lora", "0.23", LORA_DEPRECATION_MESSAGE)
|
||||
|
||||
def unfuse_text_encoder_lora(text_encoder):
|
||||
for _, attn_module in text_encoder_attn_modules(text_encoder):
|
||||
if isinstance(attn_module.q_proj, PatchedLoraProjection):
|
||||
attn_module.q_proj._unfuse_lora()
|
||||
attn_module.k_proj._unfuse_lora()
|
||||
attn_module.v_proj._unfuse_lora()
|
||||
attn_module.out_proj._unfuse_lora()
|
||||
|
||||
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
|
||||
if isinstance(mlp_module.fc1, PatchedLoraProjection):
|
||||
mlp_module.fc1._unfuse_lora()
|
||||
mlp_module.fc2._unfuse_lora()
|
||||
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
|
||||
if isinstance(mlp_module.fc1, PatchedLoraProjection):
|
||||
mlp_module.fc1._unfuse_lora()
|
||||
mlp_module.fc2._unfuse_lora()
|
||||
|
||||
if unfuse_text_encoder:
|
||||
if hasattr(self, "text_encoder"):
|
||||
@@ -2193,65 +2102,6 @@ class LoraLoaderMixin:
|
||||
|
||||
self.num_fused_loras -= 1
|
||||
|
||||
def set_adapter(
|
||||
self,
|
||||
adapter_names: Union[List[str], str],
|
||||
unet_weights: List[float] = None,
|
||||
te_weights: List[float] = None,
|
||||
te2_weights: List[float] = None,
|
||||
):
|
||||
if not self.use_peft_backend:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
def process_weights(adapter_names, weights):
|
||||
if weights is None:
|
||||
weights = [1.0] * len(adapter_names)
|
||||
elif isinstance(weights, float):
|
||||
weights = [weights]
|
||||
|
||||
if len(adapter_names) != len(weights):
|
||||
raise ValueError(
|
||||
f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(weights)}"
|
||||
)
|
||||
return weights
|
||||
|
||||
adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
|
||||
|
||||
# To Do
|
||||
# Handle the UNET
|
||||
|
||||
# Handle the Text Encoder
|
||||
te_weights = process_weights(adapter_names, te_weights)
|
||||
if hasattr(self, "text_encoder"):
|
||||
set_weights_and_activate_adapters(self.text_encoder, adapter_names, te_weights)
|
||||
te2_weights = process_weights(adapter_names, te2_weights)
|
||||
if hasattr(self, "text_encoder_2"):
|
||||
set_weights_and_activate_adapters(self.text_encoder_2, adapter_names, te2_weights)
|
||||
|
||||
def disable_lora(self):
|
||||
if not self.use_peft_backend:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
# To Do
|
||||
# Disbale unet adapters
|
||||
|
||||
# Disbale text encoder adapters
|
||||
if hasattr(self, "text_encoder"):
|
||||
set_adapter_layers(self.text_encoder, enabled=False)
|
||||
if hasattr(self, "text_encoder_2"):
|
||||
set_adapter_layers(self.text_encoder_2, enabled=False)
|
||||
|
||||
def enable_lora(self):
|
||||
if not self.use_peft_backend:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
# To Do
|
||||
# Enable unet adapters
|
||||
|
||||
# Enable text encoder adapters
|
||||
if hasattr(self, "text_encoder"):
|
||||
set_adapter_layers(self.text_encoder, enabled=True)
|
||||
if hasattr(self, "text_encoder_2"):
|
||||
set_adapter_layers(self.text_encoder_2, enabled=True)
|
||||
|
||||
|
||||
class FromSingleFileMixin:
|
||||
"""
|
||||
@@ -2953,9 +2803,5 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
|
||||
)
|
||||
|
||||
def _remove_text_encoder_monkey_patch(self):
|
||||
if self.use_peft_backend:
|
||||
recurse_remove_peft_layers(self.text_encoder)
|
||||
recurse_remove_peft_layers(self.text_encoder_2)
|
||||
else:
|
||||
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder)
|
||||
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2)
|
||||
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder)
|
||||
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2)
|
||||
|
||||
@@ -19,7 +19,6 @@ import torch
|
||||
from torch import nn
|
||||
|
||||
from .activations import get_activation
|
||||
from .lora import LoRACompatibleLinear
|
||||
|
||||
|
||||
def get_timestep_embedding(
|
||||
@@ -167,7 +166,7 @@ class TimestepEmbedding(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.linear_1 = LoRACompatibleLinear(in_channels, time_embed_dim)
|
||||
self.linear_1 = nn.Linear(in_channels, time_embed_dim)
|
||||
|
||||
if cond_proj_dim is not None:
|
||||
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
|
||||
@@ -180,7 +179,7 @@ class TimestepEmbedding(nn.Module):
|
||||
time_embed_dim_out = out_dim
|
||||
else:
|
||||
time_embed_dim_out = time_embed_dim
|
||||
self.linear_2 = LoRACompatibleLinear(time_embed_dim, time_embed_dim_out)
|
||||
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
|
||||
|
||||
if post_act_fn is None:
|
||||
self.post_act = None
|
||||
|
||||
@@ -19,27 +19,24 @@ import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ..loaders import PatchedLoraProjection, text_encoder_attn_modules, text_encoder_mlp_modules
|
||||
from ..utils import logging, scale_lora_layers
|
||||
from ..utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def adjust_lora_scale_text_encoder(text_encoder, lora_scale: float = 1.0, use_peft_backend: bool = False):
|
||||
if use_peft_backend:
|
||||
scale_lora_layers(text_encoder, lora_weightage=lora_scale)
|
||||
else:
|
||||
for _, attn_module in text_encoder_attn_modules(text_encoder):
|
||||
if isinstance(attn_module.q_proj, PatchedLoraProjection):
|
||||
attn_module.q_proj.lora_scale = lora_scale
|
||||
attn_module.k_proj.lora_scale = lora_scale
|
||||
attn_module.v_proj.lora_scale = lora_scale
|
||||
attn_module.out_proj.lora_scale = lora_scale
|
||||
def adjust_lora_scale_text_encoder(text_encoder, lora_scale: float = 1.0):
|
||||
for _, attn_module in text_encoder_attn_modules(text_encoder):
|
||||
if isinstance(attn_module.q_proj, PatchedLoraProjection):
|
||||
attn_module.q_proj.lora_scale = lora_scale
|
||||
attn_module.k_proj.lora_scale = lora_scale
|
||||
attn_module.v_proj.lora_scale = lora_scale
|
||||
attn_module.out_proj.lora_scale = lora_scale
|
||||
|
||||
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
|
||||
if isinstance(mlp_module.fc1, PatchedLoraProjection):
|
||||
mlp_module.fc1.lora_scale = lora_scale
|
||||
mlp_module.fc2.lora_scale = lora_scale
|
||||
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
|
||||
if isinstance(mlp_module.fc1, PatchedLoraProjection):
|
||||
mlp_module.fc1.lora_scale = lora_scale
|
||||
mlp_module.fc2.lora_scale = lora_scale
|
||||
|
||||
|
||||
class LoRALinearLayer(nn.Module):
|
||||
|
||||
+456
-460
@@ -1,460 +1,456 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ..utils import (
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_flax_available,
|
||||
is_k_diffusion_available,
|
||||
is_librosa_available,
|
||||
is_note_seq_available,
|
||||
is_onnx_available,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
# These modules contain pipelines from multiple libraries/frameworks
|
||||
_dummy_objects = {}
|
||||
_import_structure = {"stable_diffusion": [], "latent_diffusion": [], "controlnet": []}
|
||||
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils import dummy_pt_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_pt_objects))
|
||||
else:
|
||||
_import_structure["auto_pipeline"] = [
|
||||
"AutoPipelineForImage2Image",
|
||||
"AutoPipelineForInpainting",
|
||||
"AutoPipelineForText2Image",
|
||||
]
|
||||
_import_structure["consistency_models"] = ["ConsistencyModelPipeline"]
|
||||
_import_structure["dance_diffusion"] = ["DanceDiffusionPipeline"]
|
||||
_import_structure["ddim"] = ["DDIMPipeline"]
|
||||
_import_structure["ddpm"] = ["DDPMPipeline"]
|
||||
_import_structure["dit"] = ["DiTPipeline"]
|
||||
_import_structure["latent_diffusion"].extend(["LDMSuperResolutionPipeline"])
|
||||
_import_structure["latent_diffusion_uncond"] = ["LDMPipeline"]
|
||||
_import_structure["pipeline_utils"] = ["AudioPipelineOutput", "DiffusionPipeline", "ImagePipelineOutput"]
|
||||
_import_structure["pndm"] = ["PNDMPipeline"]
|
||||
_import_structure["repaint"] = ["RePaintPipeline"]
|
||||
_import_structure["score_sde_ve"] = ["ScoreSdeVePipeline"]
|
||||
_import_structure["stochastic_karras_ve"] = ["KarrasVePipeline"]
|
||||
try:
|
||||
if not (is_torch_available() and is_librosa_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils import dummy_torch_and_librosa_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_librosa_objects))
|
||||
else:
|
||||
_import_structure["audio_diffusion"] = ["AudioDiffusionPipeline", "Mel"]
|
||||
try:
|
||||
if not (is_torch_available() and is_transformers_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils import dummy_torch_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["alt_diffusion"] = ["AltDiffusionImg2ImgPipeline", "AltDiffusionPipeline"]
|
||||
_import_structure["audioldm"] = ["AudioLDMPipeline"]
|
||||
_import_structure["audioldm2"] = [
|
||||
"AudioLDM2Pipeline",
|
||||
"AudioLDM2ProjectionModel",
|
||||
"AudioLDM2UNet2DConditionModel",
|
||||
]
|
||||
_import_structure["blip_diffusion"] = ["BlipDiffusionPipeline"]
|
||||
_import_structure["controlnet"].extend(
|
||||
[
|
||||
"BlipDiffusionControlNetPipeline",
|
||||
"StableDiffusionControlNetImg2ImgPipeline",
|
||||
"StableDiffusionControlNetInpaintPipeline",
|
||||
"StableDiffusionControlNetPipeline",
|
||||
"StableDiffusionXLControlNetImg2ImgPipeline",
|
||||
"StableDiffusionXLControlNetInpaintPipeline",
|
||||
"StableDiffusionXLControlNetPipeline",
|
||||
]
|
||||
)
|
||||
_import_structure["deepfloyd_if"] = [
|
||||
"IFImg2ImgPipeline",
|
||||
"IFImg2ImgSuperResolutionPipeline",
|
||||
"IFInpaintingPipeline",
|
||||
"IFInpaintingSuperResolutionPipeline",
|
||||
"IFPipeline",
|
||||
"IFSuperResolutionPipeline",
|
||||
]
|
||||
_import_structure["kandinsky"] = [
|
||||
"KandinskyCombinedPipeline",
|
||||
"KandinskyImg2ImgCombinedPipeline",
|
||||
"KandinskyImg2ImgPipeline",
|
||||
"KandinskyInpaintCombinedPipeline",
|
||||
"KandinskyInpaintPipeline",
|
||||
"KandinskyPipeline",
|
||||
"KandinskyPriorPipeline",
|
||||
]
|
||||
_import_structure["kandinsky2_2"] = [
|
||||
"KandinskyV22CombinedPipeline",
|
||||
"KandinskyV22ControlnetImg2ImgPipeline",
|
||||
"KandinskyV22ControlnetPipeline",
|
||||
"KandinskyV22Img2ImgCombinedPipeline",
|
||||
"KandinskyV22Img2ImgPipeline",
|
||||
"KandinskyV22InpaintCombinedPipeline",
|
||||
"KandinskyV22InpaintPipeline",
|
||||
"KandinskyV22Pipeline",
|
||||
"KandinskyV22PriorEmb2EmbPipeline",
|
||||
"KandinskyV22PriorPipeline",
|
||||
]
|
||||
_import_structure["latent_diffusion"].extend(["LDMTextToImagePipeline"])
|
||||
_import_structure["musicldm"] = ["MusicLDMPipeline"]
|
||||
_import_structure["paint_by_example"] = ["PaintByExamplePipeline"]
|
||||
_import_structure["semantic_stable_diffusion"] = ["SemanticStableDiffusionPipeline"]
|
||||
_import_structure["shap_e"] = ["ShapEImg2ImgPipeline", "ShapEPipeline"]
|
||||
_import_structure["stable_diffusion"].extend(
|
||||
[
|
||||
"CLIPImageProjection",
|
||||
"CycleDiffusionPipeline",
|
||||
"StableDiffusionAttendAndExcitePipeline",
|
||||
"StableDiffusionDepth2ImgPipeline",
|
||||
"StableDiffusionDiffEditPipeline",
|
||||
"StableDiffusionGLIGENPipeline",
|
||||
"StableDiffusionGLIGENPipeline",
|
||||
"StableDiffusionGLIGENTextImagePipeline",
|
||||
"StableDiffusionImageVariationPipeline",
|
||||
"StableDiffusionImg2ImgPipeline",
|
||||
"StableDiffusionInpaintPipeline",
|
||||
"StableDiffusionInpaintPipelineLegacy",
|
||||
"StableDiffusionInstructPix2PixPipeline",
|
||||
"StableDiffusionLatentUpscalePipeline",
|
||||
"StableDiffusionLDM3DPipeline",
|
||||
"StableDiffusionModelEditingPipeline",
|
||||
"StableDiffusionPanoramaPipeline",
|
||||
"StableDiffusionParadigmsPipeline",
|
||||
"StableDiffusionPipeline",
|
||||
"StableDiffusionPix2PixZeroPipeline",
|
||||
"StableDiffusionSAGPipeline",
|
||||
"StableDiffusionUpscalePipeline",
|
||||
"StableUnCLIPImg2ImgPipeline",
|
||||
"StableUnCLIPPipeline",
|
||||
]
|
||||
)
|
||||
_import_structure["stable_diffusion_safe"] = ["StableDiffusionPipelineSafe"]
|
||||
_import_structure["stable_diffusion_xl"] = [
|
||||
"StableDiffusionXLImg2ImgPipeline",
|
||||
"StableDiffusionXLInpaintPipeline",
|
||||
"StableDiffusionXLInstructPix2PixPipeline",
|
||||
"StableDiffusionXLPipeline",
|
||||
]
|
||||
_import_structure["t2i_adapter"] = ["StableDiffusionAdapterPipeline", "StableDiffusionXLAdapterPipeline"]
|
||||
_import_structure["text_to_video_synthesis"] = [
|
||||
"TextToVideoSDPipeline",
|
||||
"TextToVideoZeroPipeline",
|
||||
"VideoToVideoSDPipeline",
|
||||
]
|
||||
_import_structure["unclip"] = ["UnCLIPImageVariationPipeline", "UnCLIPPipeline"]
|
||||
_import_structure["unidiffuser"] = [
|
||||
"ImageTextPipelineOutput",
|
||||
"UniDiffuserModel",
|
||||
"UniDiffuserPipeline",
|
||||
"UniDiffuserTextDecoder",
|
||||
]
|
||||
_import_structure["versatile_diffusion"] = [
|
||||
"VersatileDiffusionDualGuidedPipeline",
|
||||
"VersatileDiffusionImageVariationPipeline",
|
||||
"VersatileDiffusionPipeline",
|
||||
"VersatileDiffusionTextToImagePipeline",
|
||||
]
|
||||
_import_structure["vq_diffusion"] = ["VQDiffusionPipeline"]
|
||||
_import_structure["wuerstchen"] = [
|
||||
"WuerstchenCombinedPipeline",
|
||||
"WuerstchenDecoderPipeline",
|
||||
"WuerstchenPriorPipeline",
|
||||
]
|
||||
try:
|
||||
if not is_onnx_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils import dummy_onnx_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_onnx_objects))
|
||||
else:
|
||||
_import_structure["onnx_utils"] = ["OnnxRuntimeModel"]
|
||||
try:
|
||||
if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils import dummy_torch_and_transformers_and_onnx_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_onnx_objects))
|
||||
else:
|
||||
_import_structure["stable_diffusion"].extend(
|
||||
[
|
||||
"OnnxStableDiffusionImg2ImgPipeline",
|
||||
"OnnxStableDiffusionInpaintPipeline",
|
||||
"OnnxStableDiffusionInpaintPipelineLegacy",
|
||||
"OnnxStableDiffusionPipeline",
|
||||
"OnnxStableDiffusionUpscalePipeline",
|
||||
"StableDiffusionOnnxPipeline",
|
||||
]
|
||||
)
|
||||
try:
|
||||
if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils import dummy_torch_and_transformers_and_k_diffusion_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_k_diffusion_objects))
|
||||
else:
|
||||
_import_structure["stable_diffusion"].extend(["StableDiffusionKDiffusionPipeline"])
|
||||
try:
|
||||
if not is_flax_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils import dummy_flax_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_flax_objects))
|
||||
else:
|
||||
_import_structure["pipeline_flax_utils"] = ["FlaxDiffusionPipeline"]
|
||||
try:
|
||||
if not (is_flax_available() and is_transformers_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils import dummy_flax_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_flax_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["controlnet"].extend(["FlaxStableDiffusionControlNetPipeline"])
|
||||
_import_structure["stable_diffusion"].extend(
|
||||
[
|
||||
"FlaxStableDiffusionImg2ImgPipeline",
|
||||
"FlaxStableDiffusionInpaintPipeline",
|
||||
"FlaxStableDiffusionPipeline",
|
||||
]
|
||||
)
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available() and is_note_seq_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils import dummy_transformers_and_torch_and_note_seq_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_transformers_and_torch_and_note_seq_objects))
|
||||
else:
|
||||
_import_structure["spectrogram_diffusion"] = ["MidiProcessor", "SpectrogramDiffusionPipeline"]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils.dummy_pt_objects import * # noqa F403
|
||||
|
||||
else:
|
||||
from .auto_pipeline import AutoPipelineForImage2Image, AutoPipelineForInpainting, AutoPipelineForText2Image
|
||||
from .consistency_models import ConsistencyModelPipeline
|
||||
from .dance_diffusion import DanceDiffusionPipeline
|
||||
from .ddim import DDIMPipeline
|
||||
from .ddpm import DDPMPipeline
|
||||
from .dit import DiTPipeline
|
||||
from .latent_diffusion import LDMSuperResolutionPipeline
|
||||
from .latent_diffusion_uncond import LDMPipeline
|
||||
from .pipeline_utils import AudioPipelineOutput, DiffusionPipeline, ImagePipelineOutput
|
||||
from .pndm import PNDMPipeline
|
||||
from .repaint import RePaintPipeline
|
||||
from .score_sde_ve import ScoreSdeVePipeline
|
||||
from .stochastic_karras_ve import KarrasVePipeline
|
||||
|
||||
try:
|
||||
if not (is_torch_available() and is_librosa_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils.dummy_torch_and_librosa_objects import *
|
||||
else:
|
||||
from .audio_diffusion import AudioDiffusionPipeline, Mel
|
||||
|
||||
try:
|
||||
if not (is_torch_available() and is_transformers_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .alt_diffusion import AltDiffusionImg2ImgPipeline, AltDiffusionPipeline
|
||||
from .audioldm import AudioLDMPipeline
|
||||
from .audioldm2 import AudioLDM2Pipeline, AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel
|
||||
from .blip_diffusion import BlipDiffusionPipeline
|
||||
from .controlnet import (
|
||||
BlipDiffusionControlNetPipeline,
|
||||
StableDiffusionControlNetImg2ImgPipeline,
|
||||
StableDiffusionControlNetInpaintPipeline,
|
||||
StableDiffusionControlNetPipeline,
|
||||
StableDiffusionXLControlNetImg2ImgPipeline,
|
||||
StableDiffusionXLControlNetInpaintPipeline,
|
||||
StableDiffusionXLControlNetPipeline,
|
||||
)
|
||||
from .deepfloyd_if import (
|
||||
IFImg2ImgPipeline,
|
||||
IFImg2ImgSuperResolutionPipeline,
|
||||
IFInpaintingPipeline,
|
||||
IFInpaintingSuperResolutionPipeline,
|
||||
IFPipeline,
|
||||
IFSuperResolutionPipeline,
|
||||
)
|
||||
from .kandinsky import (
|
||||
KandinskyCombinedPipeline,
|
||||
KandinskyImg2ImgCombinedPipeline,
|
||||
KandinskyImg2ImgPipeline,
|
||||
KandinskyInpaintCombinedPipeline,
|
||||
KandinskyInpaintPipeline,
|
||||
KandinskyPipeline,
|
||||
KandinskyPriorPipeline,
|
||||
)
|
||||
from .kandinsky2_2 import (
|
||||
KandinskyV22CombinedPipeline,
|
||||
KandinskyV22ControlnetImg2ImgPipeline,
|
||||
KandinskyV22ControlnetPipeline,
|
||||
KandinskyV22Img2ImgCombinedPipeline,
|
||||
KandinskyV22Img2ImgPipeline,
|
||||
KandinskyV22InpaintCombinedPipeline,
|
||||
KandinskyV22InpaintPipeline,
|
||||
KandinskyV22Pipeline,
|
||||
KandinskyV22PriorEmb2EmbPipeline,
|
||||
KandinskyV22PriorPipeline,
|
||||
)
|
||||
from .latent_diffusion import LDMTextToImagePipeline
|
||||
from .musicldm import MusicLDMPipeline
|
||||
from .paint_by_example import PaintByExamplePipeline
|
||||
from .semantic_stable_diffusion import SemanticStableDiffusionPipeline
|
||||
from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline
|
||||
from .stable_diffusion import (
|
||||
CLIPImageProjection,
|
||||
CycleDiffusionPipeline,
|
||||
StableDiffusionAttendAndExcitePipeline,
|
||||
StableDiffusionDepth2ImgPipeline,
|
||||
StableDiffusionDiffEditPipeline,
|
||||
StableDiffusionGLIGENPipeline,
|
||||
StableDiffusionGLIGENTextImagePipeline,
|
||||
StableDiffusionImageVariationPipeline,
|
||||
StableDiffusionImg2ImgPipeline,
|
||||
StableDiffusionInpaintPipeline,
|
||||
StableDiffusionInpaintPipelineLegacy,
|
||||
StableDiffusionInstructPix2PixPipeline,
|
||||
StableDiffusionLatentUpscalePipeline,
|
||||
StableDiffusionLDM3DPipeline,
|
||||
StableDiffusionModelEditingPipeline,
|
||||
StableDiffusionPanoramaPipeline,
|
||||
StableDiffusionParadigmsPipeline,
|
||||
StableDiffusionPipeline,
|
||||
StableDiffusionPix2PixZeroPipeline,
|
||||
StableDiffusionSAGPipeline,
|
||||
StableDiffusionUpscalePipeline,
|
||||
StableUnCLIPImg2ImgPipeline,
|
||||
StableUnCLIPPipeline,
|
||||
)
|
||||
from .stable_diffusion_safe import StableDiffusionPipelineSafe
|
||||
from .stable_diffusion_xl import (
|
||||
StableDiffusionXLImg2ImgPipeline,
|
||||
StableDiffusionXLInpaintPipeline,
|
||||
StableDiffusionXLInstructPix2PixPipeline,
|
||||
StableDiffusionXLPipeline,
|
||||
)
|
||||
from .t2i_adapter import StableDiffusionAdapterPipeline, StableDiffusionXLAdapterPipeline
|
||||
from .text_to_video_synthesis import (
|
||||
TextToVideoSDPipeline,
|
||||
TextToVideoZeroPipeline,
|
||||
VideoToVideoSDPipeline,
|
||||
)
|
||||
from .unclip import UnCLIPImageVariationPipeline, UnCLIPPipeline
|
||||
from .unidiffuser import (
|
||||
ImageTextPipelineOutput,
|
||||
UniDiffuserModel,
|
||||
UniDiffuserPipeline,
|
||||
UniDiffuserTextDecoder,
|
||||
)
|
||||
from .versatile_diffusion import (
|
||||
VersatileDiffusionDualGuidedPipeline,
|
||||
VersatileDiffusionImageVariationPipeline,
|
||||
VersatileDiffusionPipeline,
|
||||
VersatileDiffusionTextToImagePipeline,
|
||||
)
|
||||
from .vq_diffusion import VQDiffusionPipeline
|
||||
from .wuerstchen import (
|
||||
WuerstchenCombinedPipeline,
|
||||
WuerstchenDecoderPipeline,
|
||||
WuerstchenPriorPipeline,
|
||||
)
|
||||
|
||||
try:
|
||||
if not is_onnx_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils.dummy_onnx_objects import * # noqa F403
|
||||
|
||||
else:
|
||||
from .onnx_utils import OnnxRuntimeModel
|
||||
|
||||
try:
|
||||
if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils.dummy_torch_and_transformers_and_onnx_objects import *
|
||||
else:
|
||||
from .stable_diffusion import (
|
||||
OnnxStableDiffusionImg2ImgPipeline,
|
||||
OnnxStableDiffusionInpaintPipeline,
|
||||
OnnxStableDiffusionInpaintPipelineLegacy,
|
||||
OnnxStableDiffusionPipeline,
|
||||
OnnxStableDiffusionUpscalePipeline,
|
||||
StableDiffusionOnnxPipeline,
|
||||
)
|
||||
|
||||
try:
|
||||
if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils.dummy_torch_and_transformers_and_k_diffusion_objects import *
|
||||
else:
|
||||
from .stable_diffusion import StableDiffusionKDiffusionPipeline
|
||||
|
||||
try:
|
||||
if not is_flax_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils.dummy_flax_objects import * # noqa F403
|
||||
else:
|
||||
from .pipeline_flax_utils import FlaxDiffusionPipeline
|
||||
|
||||
try:
|
||||
if not (is_flax_available() and is_transformers_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils.dummy_flax_and_transformers_objects import *
|
||||
else:
|
||||
from .controlnet import FlaxStableDiffusionControlNetPipeline
|
||||
from .stable_diffusion import (
|
||||
FlaxStableDiffusionImg2ImgPipeline,
|
||||
FlaxStableDiffusionInpaintPipeline,
|
||||
FlaxStableDiffusionPipeline,
|
||||
)
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available() and is_note_seq_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils.dummy_transformers_and_torch_and_note_seq_objects import * # noqa F403
|
||||
|
||||
else:
|
||||
from .spectrogram_diffusion import MidiProcessor, SpectrogramDiffusionPipeline
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ..utils import (
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_flax_available,
|
||||
is_k_diffusion_available,
|
||||
is_librosa_available,
|
||||
is_note_seq_available,
|
||||
is_onnx_available,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
# These modules contain pipelines from multiple libraries/frameworks
|
||||
_dummy_objects = {}
|
||||
_import_structure = {"stable_diffusion": [], "latent_diffusion": [], "controlnet": []}
|
||||
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils import dummy_pt_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_pt_objects))
|
||||
else:
|
||||
_import_structure["auto_pipeline"] = [
|
||||
"AutoPipelineForImage2Image",
|
||||
"AutoPipelineForInpainting",
|
||||
"AutoPipelineForText2Image",
|
||||
]
|
||||
_import_structure["consistency_models"] = ["ConsistencyModelPipeline"]
|
||||
_import_structure["dance_diffusion"] = ["DanceDiffusionPipeline"]
|
||||
_import_structure["ddim"] = ["DDIMPipeline"]
|
||||
_import_structure["ddpm"] = ["DDPMPipeline"]
|
||||
_import_structure["dit"] = ["DiTPipeline"]
|
||||
_import_structure["latent_diffusion"].extend(["LDMSuperResolutionPipeline"])
|
||||
_import_structure["latent_diffusion_uncond"] = ["LDMPipeline"]
|
||||
_import_structure["pipeline_utils"] = ["AudioPipelineOutput", "DiffusionPipeline", "ImagePipelineOutput"]
|
||||
_import_structure["pndm"] = ["PNDMPipeline"]
|
||||
_import_structure["repaint"] = ["RePaintPipeline"]
|
||||
_import_structure["score_sde_ve"] = ["ScoreSdeVePipeline"]
|
||||
_import_structure["stochastic_karras_ve"] = ["KarrasVePipeline"]
|
||||
try:
|
||||
if not (is_torch_available() and is_librosa_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils import dummy_torch_and_librosa_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_librosa_objects))
|
||||
else:
|
||||
_import_structure["audio_diffusion"] = ["AudioDiffusionPipeline", "Mel"]
|
||||
try:
|
||||
if not (is_torch_available() and is_transformers_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils import dummy_torch_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["alt_diffusion"] = ["AltDiffusionImg2ImgPipeline", "AltDiffusionPipeline"]
|
||||
_import_structure["audioldm"] = ["AudioLDMPipeline"]
|
||||
_import_structure["audioldm2"] = [
|
||||
"AudioLDM2Pipeline",
|
||||
"AudioLDM2ProjectionModel",
|
||||
"AudioLDM2UNet2DConditionModel",
|
||||
]
|
||||
_import_structure["controlnet"].extend(
|
||||
[
|
||||
"StableDiffusionControlNetImg2ImgPipeline",
|
||||
"StableDiffusionControlNetInpaintPipeline",
|
||||
"StableDiffusionControlNetPipeline",
|
||||
"StableDiffusionXLControlNetImg2ImgPipeline",
|
||||
"StableDiffusionXLControlNetInpaintPipeline",
|
||||
"StableDiffusionXLControlNetPipeline",
|
||||
]
|
||||
)
|
||||
_import_structure["deepfloyd_if"] = [
|
||||
"IFImg2ImgPipeline",
|
||||
"IFImg2ImgSuperResolutionPipeline",
|
||||
"IFInpaintingPipeline",
|
||||
"IFInpaintingSuperResolutionPipeline",
|
||||
"IFPipeline",
|
||||
"IFSuperResolutionPipeline",
|
||||
]
|
||||
_import_structure["kandinsky"] = [
|
||||
"KandinskyCombinedPipeline",
|
||||
"KandinskyImg2ImgCombinedPipeline",
|
||||
"KandinskyImg2ImgPipeline",
|
||||
"KandinskyInpaintCombinedPipeline",
|
||||
"KandinskyInpaintPipeline",
|
||||
"KandinskyPipeline",
|
||||
"KandinskyPriorPipeline",
|
||||
]
|
||||
_import_structure["kandinsky2_2"] = [
|
||||
"KandinskyV22CombinedPipeline",
|
||||
"KandinskyV22ControlnetImg2ImgPipeline",
|
||||
"KandinskyV22ControlnetPipeline",
|
||||
"KandinskyV22Img2ImgCombinedPipeline",
|
||||
"KandinskyV22Img2ImgPipeline",
|
||||
"KandinskyV22InpaintCombinedPipeline",
|
||||
"KandinskyV22InpaintPipeline",
|
||||
"KandinskyV22Pipeline",
|
||||
"KandinskyV22PriorEmb2EmbPipeline",
|
||||
"KandinskyV22PriorPipeline",
|
||||
]
|
||||
_import_structure["latent_diffusion"].extend(["LDMTextToImagePipeline"])
|
||||
_import_structure["musicldm"] = ["MusicLDMPipeline"]
|
||||
_import_structure["paint_by_example"] = ["PaintByExamplePipeline"]
|
||||
_import_structure["semantic_stable_diffusion"] = ["SemanticStableDiffusionPipeline"]
|
||||
_import_structure["shap_e"] = ["ShapEImg2ImgPipeline", "ShapEPipeline"]
|
||||
_import_structure["stable_diffusion"].extend(
|
||||
[
|
||||
"CLIPImageProjection",
|
||||
"CycleDiffusionPipeline",
|
||||
"StableDiffusionAttendAndExcitePipeline",
|
||||
"StableDiffusionDepth2ImgPipeline",
|
||||
"StableDiffusionDiffEditPipeline",
|
||||
"StableDiffusionGLIGENPipeline",
|
||||
"StableDiffusionGLIGENPipeline",
|
||||
"StableDiffusionGLIGENTextImagePipeline",
|
||||
"StableDiffusionImageVariationPipeline",
|
||||
"StableDiffusionImg2ImgPipeline",
|
||||
"StableDiffusionInpaintPipeline",
|
||||
"StableDiffusionInpaintPipelineLegacy",
|
||||
"StableDiffusionInstructPix2PixPipeline",
|
||||
"StableDiffusionLatentUpscalePipeline",
|
||||
"StableDiffusionLDM3DPipeline",
|
||||
"StableDiffusionModelEditingPipeline",
|
||||
"StableDiffusionPanoramaPipeline",
|
||||
"StableDiffusionParadigmsPipeline",
|
||||
"StableDiffusionPipeline",
|
||||
"StableDiffusionPix2PixZeroPipeline",
|
||||
"StableDiffusionSAGPipeline",
|
||||
"StableDiffusionUpscalePipeline",
|
||||
"StableUnCLIPImg2ImgPipeline",
|
||||
"StableUnCLIPPipeline",
|
||||
]
|
||||
)
|
||||
_import_structure["stable_diffusion_safe"] = ["StableDiffusionPipelineSafe"]
|
||||
_import_structure["stable_diffusion_xl"] = [
|
||||
"StableDiffusionXLImg2ImgPipeline",
|
||||
"StableDiffusionXLInpaintPipeline",
|
||||
"StableDiffusionXLInstructPix2PixPipeline",
|
||||
"StableDiffusionXLPipeline",
|
||||
]
|
||||
_import_structure["t2i_adapter"] = ["StableDiffusionAdapterPipeline", "StableDiffusionXLAdapterPipeline"]
|
||||
_import_structure["text_to_video_synthesis"] = [
|
||||
"TextToVideoSDPipeline",
|
||||
"TextToVideoZeroPipeline",
|
||||
"VideoToVideoSDPipeline",
|
||||
]
|
||||
_import_structure["unclip"] = ["UnCLIPImageVariationPipeline", "UnCLIPPipeline"]
|
||||
_import_structure["unidiffuser"] = [
|
||||
"ImageTextPipelineOutput",
|
||||
"UniDiffuserModel",
|
||||
"UniDiffuserPipeline",
|
||||
"UniDiffuserTextDecoder",
|
||||
]
|
||||
_import_structure["versatile_diffusion"] = [
|
||||
"VersatileDiffusionDualGuidedPipeline",
|
||||
"VersatileDiffusionImageVariationPipeline",
|
||||
"VersatileDiffusionPipeline",
|
||||
"VersatileDiffusionTextToImagePipeline",
|
||||
]
|
||||
_import_structure["vq_diffusion"] = ["VQDiffusionPipeline"]
|
||||
_import_structure["wuerstchen"] = [
|
||||
"WuerstchenCombinedPipeline",
|
||||
"WuerstchenDecoderPipeline",
|
||||
"WuerstchenPriorPipeline",
|
||||
]
|
||||
try:
|
||||
if not is_onnx_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils import dummy_onnx_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_onnx_objects))
|
||||
else:
|
||||
_import_structure["onnx_utils"] = ["OnnxRuntimeModel"]
|
||||
try:
|
||||
if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils import dummy_torch_and_transformers_and_onnx_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_onnx_objects))
|
||||
else:
|
||||
_import_structure["stable_diffusion"].extend(
|
||||
[
|
||||
"OnnxStableDiffusionImg2ImgPipeline",
|
||||
"OnnxStableDiffusionInpaintPipeline",
|
||||
"OnnxStableDiffusionInpaintPipelineLegacy",
|
||||
"OnnxStableDiffusionPipeline",
|
||||
"OnnxStableDiffusionUpscalePipeline",
|
||||
"StableDiffusionOnnxPipeline",
|
||||
]
|
||||
)
|
||||
try:
|
||||
if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils import dummy_torch_and_transformers_and_k_diffusion_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_k_diffusion_objects))
|
||||
else:
|
||||
_import_structure["stable_diffusion"].extend(["StableDiffusionKDiffusionPipeline"])
|
||||
try:
|
||||
if not is_flax_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils import dummy_flax_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_flax_objects))
|
||||
else:
|
||||
_import_structure["pipeline_flax_utils"] = ["FlaxDiffusionPipeline"]
|
||||
try:
|
||||
if not (is_flax_available() and is_transformers_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils import dummy_flax_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_flax_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["controlnet"].extend(["FlaxStableDiffusionControlNetPipeline"])
|
||||
_import_structure["stable_diffusion"].extend(
|
||||
[
|
||||
"FlaxStableDiffusionImg2ImgPipeline",
|
||||
"FlaxStableDiffusionInpaintPipeline",
|
||||
"FlaxStableDiffusionPipeline",
|
||||
]
|
||||
)
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available() and is_note_seq_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils import dummy_transformers_and_torch_and_note_seq_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_transformers_and_torch_and_note_seq_objects))
|
||||
else:
|
||||
_import_structure["spectrogram_diffusion"] = ["MidiProcessor", "SpectrogramDiffusionPipeline"]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils.dummy_pt_objects import * # noqa F403
|
||||
|
||||
else:
|
||||
from .auto_pipeline import AutoPipelineForImage2Image, AutoPipelineForInpainting, AutoPipelineForText2Image
|
||||
from .consistency_models import ConsistencyModelPipeline
|
||||
from .dance_diffusion import DanceDiffusionPipeline
|
||||
from .ddim import DDIMPipeline
|
||||
from .ddpm import DDPMPipeline
|
||||
from .dit import DiTPipeline
|
||||
from .latent_diffusion import LDMSuperResolutionPipeline
|
||||
from .latent_diffusion_uncond import LDMPipeline
|
||||
from .pipeline_utils import AudioPipelineOutput, DiffusionPipeline, ImagePipelineOutput
|
||||
from .pndm import PNDMPipeline
|
||||
from .repaint import RePaintPipeline
|
||||
from .score_sde_ve import ScoreSdeVePipeline
|
||||
from .stochastic_karras_ve import KarrasVePipeline
|
||||
|
||||
try:
|
||||
if not (is_torch_available() and is_librosa_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils.dummy_torch_and_librosa_objects import *
|
||||
else:
|
||||
from .audio_diffusion import AudioDiffusionPipeline, Mel
|
||||
|
||||
try:
|
||||
if not (is_torch_available() and is_transformers_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .alt_diffusion import AltDiffusionImg2ImgPipeline, AltDiffusionPipeline
|
||||
from .audioldm import AudioLDMPipeline
|
||||
from .audioldm2 import AudioLDM2Pipeline, AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel
|
||||
from .controlnet import (
|
||||
StableDiffusionControlNetImg2ImgPipeline,
|
||||
StableDiffusionControlNetInpaintPipeline,
|
||||
StableDiffusionControlNetPipeline,
|
||||
StableDiffusionXLControlNetImg2ImgPipeline,
|
||||
StableDiffusionXLControlNetInpaintPipeline,
|
||||
StableDiffusionXLControlNetPipeline,
|
||||
)
|
||||
from .deepfloyd_if import (
|
||||
IFImg2ImgPipeline,
|
||||
IFImg2ImgSuperResolutionPipeline,
|
||||
IFInpaintingPipeline,
|
||||
IFInpaintingSuperResolutionPipeline,
|
||||
IFPipeline,
|
||||
IFSuperResolutionPipeline,
|
||||
)
|
||||
from .kandinsky import (
|
||||
KandinskyCombinedPipeline,
|
||||
KandinskyImg2ImgCombinedPipeline,
|
||||
KandinskyImg2ImgPipeline,
|
||||
KandinskyInpaintCombinedPipeline,
|
||||
KandinskyInpaintPipeline,
|
||||
KandinskyPipeline,
|
||||
KandinskyPriorPipeline,
|
||||
)
|
||||
from .kandinsky2_2 import (
|
||||
KandinskyV22CombinedPipeline,
|
||||
KandinskyV22ControlnetImg2ImgPipeline,
|
||||
KandinskyV22ControlnetPipeline,
|
||||
KandinskyV22Img2ImgCombinedPipeline,
|
||||
KandinskyV22Img2ImgPipeline,
|
||||
KandinskyV22InpaintCombinedPipeline,
|
||||
KandinskyV22InpaintPipeline,
|
||||
KandinskyV22Pipeline,
|
||||
KandinskyV22PriorEmb2EmbPipeline,
|
||||
KandinskyV22PriorPipeline,
|
||||
)
|
||||
from .latent_diffusion import LDMTextToImagePipeline
|
||||
from .musicldm import MusicLDMPipeline
|
||||
from .paint_by_example import PaintByExamplePipeline
|
||||
from .semantic_stable_diffusion import SemanticStableDiffusionPipeline
|
||||
from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline
|
||||
from .stable_diffusion import (
|
||||
CLIPImageProjection,
|
||||
CycleDiffusionPipeline,
|
||||
StableDiffusionAttendAndExcitePipeline,
|
||||
StableDiffusionDepth2ImgPipeline,
|
||||
StableDiffusionDiffEditPipeline,
|
||||
StableDiffusionGLIGENPipeline,
|
||||
StableDiffusionGLIGENTextImagePipeline,
|
||||
StableDiffusionImageVariationPipeline,
|
||||
StableDiffusionImg2ImgPipeline,
|
||||
StableDiffusionInpaintPipeline,
|
||||
StableDiffusionInpaintPipelineLegacy,
|
||||
StableDiffusionInstructPix2PixPipeline,
|
||||
StableDiffusionLatentUpscalePipeline,
|
||||
StableDiffusionLDM3DPipeline,
|
||||
StableDiffusionModelEditingPipeline,
|
||||
StableDiffusionPanoramaPipeline,
|
||||
StableDiffusionParadigmsPipeline,
|
||||
StableDiffusionPipeline,
|
||||
StableDiffusionPix2PixZeroPipeline,
|
||||
StableDiffusionSAGPipeline,
|
||||
StableDiffusionUpscalePipeline,
|
||||
StableUnCLIPImg2ImgPipeline,
|
||||
StableUnCLIPPipeline,
|
||||
)
|
||||
from .stable_diffusion_safe import StableDiffusionPipelineSafe
|
||||
from .stable_diffusion_xl import (
|
||||
StableDiffusionXLImg2ImgPipeline,
|
||||
StableDiffusionXLInpaintPipeline,
|
||||
StableDiffusionXLInstructPix2PixPipeline,
|
||||
StableDiffusionXLPipeline,
|
||||
)
|
||||
from .t2i_adapter import StableDiffusionAdapterPipeline, StableDiffusionXLAdapterPipeline
|
||||
from .text_to_video_synthesis import (
|
||||
TextToVideoSDPipeline,
|
||||
TextToVideoZeroPipeline,
|
||||
VideoToVideoSDPipeline,
|
||||
)
|
||||
from .unclip import UnCLIPImageVariationPipeline, UnCLIPPipeline
|
||||
from .unidiffuser import (
|
||||
ImageTextPipelineOutput,
|
||||
UniDiffuserModel,
|
||||
UniDiffuserPipeline,
|
||||
UniDiffuserTextDecoder,
|
||||
)
|
||||
from .versatile_diffusion import (
|
||||
VersatileDiffusionDualGuidedPipeline,
|
||||
VersatileDiffusionImageVariationPipeline,
|
||||
VersatileDiffusionPipeline,
|
||||
VersatileDiffusionTextToImagePipeline,
|
||||
)
|
||||
from .vq_diffusion import VQDiffusionPipeline
|
||||
from .wuerstchen import (
|
||||
WuerstchenCombinedPipeline,
|
||||
WuerstchenDecoderPipeline,
|
||||
WuerstchenPriorPipeline,
|
||||
)
|
||||
|
||||
try:
|
||||
if not is_onnx_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils.dummy_onnx_objects import * # noqa F403
|
||||
|
||||
else:
|
||||
from .onnx_utils import OnnxRuntimeModel
|
||||
|
||||
try:
|
||||
if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils.dummy_torch_and_transformers_and_onnx_objects import *
|
||||
else:
|
||||
from .stable_diffusion import (
|
||||
OnnxStableDiffusionImg2ImgPipeline,
|
||||
OnnxStableDiffusionInpaintPipeline,
|
||||
OnnxStableDiffusionInpaintPipelineLegacy,
|
||||
OnnxStableDiffusionPipeline,
|
||||
OnnxStableDiffusionUpscalePipeline,
|
||||
StableDiffusionOnnxPipeline,
|
||||
)
|
||||
|
||||
try:
|
||||
if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils.dummy_torch_and_transformers_and_k_diffusion_objects import *
|
||||
else:
|
||||
from .stable_diffusion import StableDiffusionKDiffusionPipeline
|
||||
|
||||
try:
|
||||
if not is_flax_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils.dummy_flax_objects import * # noqa F403
|
||||
else:
|
||||
from .pipeline_flax_utils import FlaxDiffusionPipeline
|
||||
|
||||
try:
|
||||
if not (is_flax_available() and is_transformers_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils.dummy_flax_and_transformers_objects import *
|
||||
else:
|
||||
from .controlnet import FlaxStableDiffusionControlNetPipeline
|
||||
from .stable_diffusion import (
|
||||
FlaxStableDiffusionImg2ImgPipeline,
|
||||
FlaxStableDiffusionInpaintPipeline,
|
||||
FlaxStableDiffusionPipeline,
|
||||
)
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available() and is_note_seq_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils.dummy_transformers_and_torch_and_note_seq_objects import * # noqa F403
|
||||
|
||||
else:
|
||||
from .spectrogram_diffusion import MidiProcessor, SpectrogramDiffusionPipeline
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
|
||||
@@ -303,7 +303,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
|
||||
@@ -301,7 +301,7 @@ class AltDiffusionImg2ImgPipeline(
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
|
||||
@@ -1,20 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
from PIL import Image
|
||||
|
||||
from ...utils import OptionalDependencyNotAvailable, is_torch_available, is_transformers_available
|
||||
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import ShapEPipeline
|
||||
else:
|
||||
from .blip_image_processing import BlipImageProcessor
|
||||
from .modeling_blip2 import Blip2QFormerModel
|
||||
from .modeling_ctx_clip import ContextCLIPTextModel
|
||||
from .pipeline_blip_diffusion import BlipDiffusionPipeline
|
||||
@@ -1,318 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Image processor class for BLIP."""
|
||||
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
||||
from transformers.image_transforms import convert_to_rgb, resize, to_channel_dimension_format
|
||||
from transformers.image_utils import (
|
||||
OPENAI_CLIP_MEAN,
|
||||
OPENAI_CLIP_STD,
|
||||
ChannelDimension,
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
infer_channel_dimension_format,
|
||||
is_scaled_image,
|
||||
make_list_of_images,
|
||||
to_numpy_array,
|
||||
valid_images,
|
||||
)
|
||||
from transformers.utils import TensorType, is_vision_available, logging
|
||||
|
||||
from diffusers.utils import numpy_to_pil
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
import PIL
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# We needed some extra functions on top of the ones in transformers.image_processing_utils.BaseImageProcessor, namely center crop
|
||||
# Copy-pasted from transformers.models.blip.image_processing_blip.BlipImageProcessor
|
||||
class BlipImageProcessor(BaseImageProcessor):
|
||||
r"""
|
||||
Constructs a BLIP image processor.
|
||||
|
||||
Args:
|
||||
do_resize (`bool`, *optional*, defaults to `True`):
|
||||
Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the
|
||||
`do_resize` parameter in the `preprocess` method.
|
||||
size (`dict`, *optional*, defaults to `{"height": 384, "width": 384}`):
|
||||
Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess`
|
||||
method.
|
||||
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
|
||||
Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`. Can be
|
||||
overridden by the `resample` parameter in the `preprocess` method.
|
||||
do_rescale (`bool`, *optional*, defaults to `True`):
|
||||
Wwhether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the
|
||||
`do_rescale` parameter in the `preprocess` method.
|
||||
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
|
||||
Scale factor to use if rescaling the image. Only has an effect if `do_rescale` is set to `True`. Can be
|
||||
overridden by the `rescale_factor` parameter in the `preprocess` method.
|
||||
do_normalize (`bool`, *optional*, defaults to `True`):
|
||||
Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
|
||||
method. Can be overridden by the `do_normalize` parameter in the `preprocess` method.
|
||||
image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
|
||||
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
|
||||
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can be
|
||||
overridden by the `image_mean` parameter in the `preprocess` method.
|
||||
image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
|
||||
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
|
||||
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
|
||||
Can be overridden by the `image_std` parameter in the `preprocess` method.
|
||||
do_convert_rgb (`bool`, *optional*, defaults to `True`):
|
||||
Whether to convert the image to RGB.
|
||||
"""
|
||||
|
||||
model_input_names = ["pixel_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
do_resize: bool = True,
|
||||
size: Dict[str, int] = None,
|
||||
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
||||
do_rescale: bool = True,
|
||||
rescale_factor: Union[int, float] = 1 / 255,
|
||||
do_normalize: bool = True,
|
||||
image_mean: Optional[Union[float, List[float]]] = None,
|
||||
image_std: Optional[Union[float, List[float]]] = None,
|
||||
do_convert_rgb: bool = True,
|
||||
do_center_crop: bool = True,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
size = size if size is not None else {"height": 224, "width": 224}
|
||||
size = get_size_dict(size, default_to_square=True)
|
||||
|
||||
self.do_resize = do_resize
|
||||
self.size = size
|
||||
self.resample = resample
|
||||
self.do_rescale = do_rescale
|
||||
self.rescale_factor = rescale_factor
|
||||
self.do_normalize = do_normalize
|
||||
self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
|
||||
self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
|
||||
self.do_convert_rgb = do_convert_rgb
|
||||
self.do_center_crop = do_center_crop
|
||||
|
||||
# Copy-pasted from transformers.models.vit.image_processing_vit.ViTImageProcessor.resize with PILImageResampling.BILINEAR->PILImageResampling.BICUBIC
|
||||
def resize(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
size: Dict[str, int],
|
||||
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Resize an image to `(size["height"], size["width"])`.
|
||||
|
||||
Args:
|
||||
image (`np.ndarray`):
|
||||
Image to resize.
|
||||
size (`Dict[str, int]`):
|
||||
Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
|
||||
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
|
||||
`PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BICUBIC`.
|
||||
data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the output image. If unset, the channel dimension format of the input
|
||||
image is used. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
|
||||
Returns:
|
||||
`np.ndarray`: The resized image.
|
||||
"""
|
||||
size = get_size_dict(size)
|
||||
if "height" not in size or "width" not in size:
|
||||
raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}")
|
||||
output_size = (size["height"], size["width"])
|
||||
return resize(
|
||||
image,
|
||||
size=output_size,
|
||||
resample=resample,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
images: ImageInput,
|
||||
do_resize: Optional[bool] = None,
|
||||
size: Optional[Dict[str, int]] = None,
|
||||
resample: PILImageResampling = None,
|
||||
do_rescale: Optional[bool] = None,
|
||||
do_center_crop: Optional[bool] = None,
|
||||
rescale_factor: Optional[float] = None,
|
||||
do_normalize: Optional[bool] = None,
|
||||
image_mean: Optional[Union[float, List[float]]] = None,
|
||||
image_std: Optional[Union[float, List[float]]] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
do_convert_rgb: bool = None,
|
||||
data_format: ChannelDimension = ChannelDimension.FIRST,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> PIL.Image.Image:
|
||||
"""
|
||||
Preprocess an image or batch of images.
|
||||
|
||||
Args:
|
||||
images (`ImageInput`):
|
||||
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
|
||||
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
|
||||
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
||||
Whether to resize the image.
|
||||
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
|
||||
Controls the size of the image after `resize`. The shortest edge of the image is resized to
|
||||
`size["shortest_edge"]` whilst preserving the aspect ratio. If the longest edge of this resized image
|
||||
is > `int(size["shortest_edge"] * (1333 / 800))`, then the image is resized again to make the longest
|
||||
edge equal to `int(size["shortest_edge"] * (1333 / 800))`.
|
||||
resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
|
||||
Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`.
|
||||
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
||||
Whether to rescale the image values between [0 - 1].
|
||||
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
||||
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
|
||||
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
||||
Whether to normalize the image.
|
||||
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
||||
Image mean to normalize the image by if `do_normalize` is set to `True`.
|
||||
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
||||
Image standard deviation to normalize the image by if `do_normalize` is set to `True`.
|
||||
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
|
||||
Whether to convert the image to RGB.
|
||||
return_tensors (`str` or `TensorType`, *optional*):
|
||||
The type of tensors to return. Can be one of:
|
||||
- Unset: Return a list of `np.ndarray`.
|
||||
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
||||
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
||||
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
||||
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
||||
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
||||
The channel dimension format for the output image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- Unset: Use the channel dimension format of the input image.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
"""
|
||||
do_resize = do_resize if do_resize is not None else self.do_resize
|
||||
resample = resample if resample is not None else self.resample
|
||||
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
||||
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
||||
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
||||
image_mean = image_mean if image_mean is not None else self.image_mean
|
||||
image_std = image_std if image_std is not None else self.image_std
|
||||
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
|
||||
do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
|
||||
|
||||
size = size if size is not None else self.size
|
||||
size = get_size_dict(size, default_to_square=False)
|
||||
images = make_list_of_images(images)
|
||||
|
||||
if not valid_images(images):
|
||||
raise ValueError(
|
||||
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
||||
"torch.Tensor, tf.Tensor or jax.ndarray."
|
||||
)
|
||||
|
||||
if do_resize and size is None or resample is None:
|
||||
raise ValueError("Size and resample must be specified if do_resize is True.")
|
||||
|
||||
if do_rescale and rescale_factor is None:
|
||||
raise ValueError("Rescale factor must be specified if do_rescale is True.")
|
||||
|
||||
if do_normalize and (image_mean is None or image_std is None):
|
||||
raise ValueError("Image mean and std must be specified if do_normalize is True.")
|
||||
|
||||
# PIL RGBA images are converted to RGB
|
||||
if do_convert_rgb:
|
||||
images = [convert_to_rgb(image) for image in images]
|
||||
|
||||
# All transformations expect numpy arrays.
|
||||
images = [to_numpy_array(image) for image in images]
|
||||
|
||||
if is_scaled_image(images[0]) and do_rescale:
|
||||
logger.warning_once(
|
||||
"It looks like you are trying to rescale already rescaled images. If the input"
|
||||
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
|
||||
)
|
||||
if input_data_format is None:
|
||||
# We assume that all images have the same channel dimension format.
|
||||
input_data_format = infer_channel_dimension_format(images[0])
|
||||
|
||||
if do_resize:
|
||||
images = [
|
||||
self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
|
||||
if do_rescale:
|
||||
images = [
|
||||
self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
if do_normalize:
|
||||
images = [
|
||||
self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
if do_center_crop:
|
||||
images = [self.center_crop(image, size, input_data_format=input_data_format) for image in images]
|
||||
|
||||
images = [
|
||||
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
|
||||
]
|
||||
|
||||
encoded_outputs = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors)
|
||||
return encoded_outputs
|
||||
|
||||
# Follows diffusers.VaeImageProcessor.postprocess
|
||||
def postprocess(self, sample: torch.FloatTensor, output_type: str = "pil"):
|
||||
if output_type not in ["pt", "np", "pil"]:
|
||||
raise ValueError(
|
||||
f"output_type={output_type} is not supported. Make sure to choose one of ['pt', 'np', or 'pil']"
|
||||
)
|
||||
|
||||
# Equivalent to diffusers.VaeImageProcessor.denormalize
|
||||
sample = (sample / 2 + 0.5).clamp(0, 1)
|
||||
if output_type == "pt":
|
||||
return sample
|
||||
|
||||
# Equivalent to diffusers.VaeImageProcessor.pt_to_numpy
|
||||
sample = sample.cpu().permute(0, 2, 3, 1).numpy()
|
||||
if output_type == "np":
|
||||
return sample
|
||||
# Output_type must be 'pil'
|
||||
sample = numpy_to_pil(sample)
|
||||
return sample
|
||||
@@ -1,642 +0,0 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from transformers import BertTokenizer
|
||||
from transformers.activations import QuickGELUActivation as QuickGELU
|
||||
from transformers.modeling_outputs import (
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
BaseModelOutputWithPooling,
|
||||
BaseModelOutputWithPoolingAndCrossAttentions,
|
||||
)
|
||||
from transformers.models.blip_2.configuration_blip_2 import Blip2Config, Blip2VisionConfig
|
||||
from transformers.models.blip_2.modeling_blip_2 import (
|
||||
Blip2Encoder,
|
||||
Blip2PreTrainedModel,
|
||||
Blip2QFormerAttention,
|
||||
Blip2QFormerIntermediate,
|
||||
Blip2QFormerOutput,
|
||||
)
|
||||
from transformers.pytorch_utils import apply_chunking_to_forward
|
||||
from transformers.utils import (
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# There is an implementation of Blip2 in `transformers` : https://github.com/huggingface/transformers/blob/main/src/transformers/models/blip_2/modeling_blip_2.py.
|
||||
# But it doesn't support getting multimodal embeddings. So, this module can be
|
||||
# replaced with a future `transformers` version supports that.
|
||||
class Blip2TextEmbeddings(nn.Module):
|
||||
"""Construct the embeddings from word and position embeddings."""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
||||
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
||||
|
||||
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
||||
# any TensorFlow checkpoint file
|
||||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
||||
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
||||
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
||||
|
||||
self.config = config
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
position_ids=None,
|
||||
query_embeds=None,
|
||||
past_key_values_length=0,
|
||||
):
|
||||
if input_ids is not None:
|
||||
seq_length = input_ids.size()[1]
|
||||
else:
|
||||
seq_length = 0
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length].clone()
|
||||
|
||||
if input_ids is not None:
|
||||
embeddings = self.word_embeddings(input_ids)
|
||||
if self.position_embedding_type == "absolute":
|
||||
position_embeddings = self.position_embeddings(position_ids)
|
||||
embeddings = embeddings + position_embeddings
|
||||
|
||||
if query_embeds is not None:
|
||||
batch_size = embeddings.shape[0]
|
||||
# repeat the query embeddings for batch size
|
||||
query_embeds = query_embeds.repeat(batch_size, 1, 1)
|
||||
embeddings = torch.cat((query_embeds, embeddings), dim=1)
|
||||
else:
|
||||
embeddings = query_embeds
|
||||
embeddings = embeddings.to(query_embeds.dtype)
|
||||
embeddings = self.LayerNorm(embeddings)
|
||||
embeddings = self.dropout(embeddings)
|
||||
return embeddings
|
||||
|
||||
|
||||
# Copy-pasted from transformers.models.blip.modeling_blip.BlipVisionEmbeddings with Blip->Blip2
|
||||
class Blip2VisionEmbeddings(nn.Module):
|
||||
def __init__(self, config: Blip2VisionConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
self.image_size = config.image_size
|
||||
self.patch_size = config.patch_size
|
||||
|
||||
self.class_embedding = nn.Parameter(torch.randn(1, 1, self.embed_dim))
|
||||
|
||||
self.patch_embedding = nn.Conv2d(
|
||||
in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size, bias=False
|
||||
)
|
||||
|
||||
self.num_patches = (self.image_size // self.patch_size) ** 2
|
||||
self.num_positions = self.num_patches + 1
|
||||
|
||||
self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))
|
||||
|
||||
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
||||
batch_size = pixel_values.shape[0]
|
||||
target_dtype = self.patch_embedding.weight.dtype
|
||||
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
|
||||
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
||||
|
||||
class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
|
||||
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
||||
embeddings = embeddings + self.position_embedding[:, : embeddings.size(1), :].to(target_dtype)
|
||||
return embeddings
|
||||
|
||||
|
||||
# The Qformer encoder, which takes the visual embeddings, and the text input, to get multimodal embeddings
|
||||
class Blip2QFormerEncoder(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer = nn.ModuleList(
|
||||
[Blip2QFormerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||
)
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
past_key_values=None,
|
||||
use_cache=None,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
return_dict=True,
|
||||
query_length=0,
|
||||
):
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
all_cross_attentions = () if output_attentions else None
|
||||
|
||||
next_decoder_cache = () if use_cache else None
|
||||
|
||||
for i in range(self.config.num_hidden_layers):
|
||||
layer_module = self.layer[i]
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||
|
||||
if getattr(self.config, "gradient_checkpointing", False) and self.training:
|
||||
if use_cache:
|
||||
logger.warning(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs, past_key_value, output_attentions, query_length)
|
||||
|
||||
return custom_forward
|
||||
|
||||
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(layer_module),
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
)
|
||||
else:
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
query_length,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
if use_cache:
|
||||
next_decoder_cache += (layer_outputs[-1],)
|
||||
if output_attentions:
|
||||
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
||||
if layer_module.has_cross_attention:
|
||||
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [
|
||||
hidden_states,
|
||||
next_decoder_cache,
|
||||
all_hidden_states,
|
||||
all_self_attentions,
|
||||
all_cross_attentions,
|
||||
]
|
||||
if v is not None
|
||||
)
|
||||
return BaseModelOutputWithPastAndCrossAttentions(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_decoder_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attentions,
|
||||
cross_attentions=all_cross_attentions,
|
||||
)
|
||||
|
||||
|
||||
# The layers making up the Qformer encoder
|
||||
class Blip2QFormerLayer(nn.Module):
|
||||
def __init__(self, config, layer_idx):
|
||||
super().__init__()
|
||||
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||
self.seq_len_dim = 1
|
||||
self.attention = Blip2QFormerAttention(config)
|
||||
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
if layer_idx % config.cross_attention_frequency == 0:
|
||||
self.crossattention = Blip2QFormerAttention(config, is_cross_attention=True)
|
||||
self.has_cross_attention = True
|
||||
else:
|
||||
self.has_cross_attention = False
|
||||
|
||||
self.intermediate = Blip2QFormerIntermediate(config)
|
||||
self.intermediate_query = Blip2QFormerIntermediate(config)
|
||||
self.output_query = Blip2QFormerOutput(config)
|
||||
self.output = Blip2QFormerOutput(config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
past_key_value=None,
|
||||
output_attentions=False,
|
||||
query_length=0,
|
||||
):
|
||||
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
||||
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
||||
self_attention_outputs = self.attention(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
output_attentions=output_attentions,
|
||||
past_key_value=self_attn_past_key_value,
|
||||
)
|
||||
attention_output = self_attention_outputs[0]
|
||||
outputs = self_attention_outputs[1:-1]
|
||||
|
||||
present_key_value = self_attention_outputs[-1]
|
||||
|
||||
if query_length > 0:
|
||||
query_attention_output = attention_output[:, :query_length, :]
|
||||
|
||||
if self.has_cross_attention:
|
||||
if encoder_hidden_states is None:
|
||||
raise ValueError("encoder_hidden_states must be given for cross-attention layers")
|
||||
cross_attention_outputs = self.crossattention(
|
||||
query_attention_output,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
query_attention_output = cross_attention_outputs[0]
|
||||
# add cross attentions if we output attention weights
|
||||
outputs = outputs + cross_attention_outputs[1:-1]
|
||||
|
||||
layer_output = apply_chunking_to_forward(
|
||||
self.feed_forward_chunk_query,
|
||||
self.chunk_size_feed_forward,
|
||||
self.seq_len_dim,
|
||||
query_attention_output,
|
||||
)
|
||||
|
||||
if attention_output.shape[1] > query_length:
|
||||
layer_output_text = apply_chunking_to_forward(
|
||||
self.feed_forward_chunk,
|
||||
self.chunk_size_feed_forward,
|
||||
self.seq_len_dim,
|
||||
attention_output[:, query_length:, :],
|
||||
)
|
||||
layer_output = torch.cat([layer_output, layer_output_text], dim=1)
|
||||
else:
|
||||
layer_output = apply_chunking_to_forward(
|
||||
self.feed_forward_chunk,
|
||||
self.chunk_size_feed_forward,
|
||||
self.seq_len_dim,
|
||||
attention_output,
|
||||
)
|
||||
outputs = (layer_output,) + outputs
|
||||
|
||||
outputs = outputs + (present_key_value,)
|
||||
|
||||
return outputs
|
||||
|
||||
def feed_forward_chunk(self, attention_output):
|
||||
intermediate_output = self.intermediate(attention_output)
|
||||
layer_output = self.output(intermediate_output, attention_output)
|
||||
return layer_output
|
||||
|
||||
def feed_forward_chunk_query(self, attention_output):
|
||||
intermediate_output = self.intermediate_query(attention_output)
|
||||
layer_output = self.output_query(intermediate_output, attention_output)
|
||||
return layer_output
|
||||
|
||||
|
||||
# ProjLayer used to project the multimodal Blip2 embeddings to be used in the text encoder
|
||||
class ProjLayer(nn.Module):
|
||||
def __init__(self, in_dim, out_dim, hidden_dim, drop_p=0.1, eps=1e-12):
|
||||
super().__init__()
|
||||
|
||||
# Dense1 -> Act -> Dense2 -> Drop -> Res -> Norm
|
||||
self.dense1 = nn.Linear(in_dim, hidden_dim)
|
||||
self.act_fn = QuickGELU()
|
||||
self.dense2 = nn.Linear(hidden_dim, out_dim)
|
||||
self.dropout = nn.Dropout(drop_p)
|
||||
|
||||
self.LayerNorm = nn.LayerNorm(out_dim, eps=eps)
|
||||
|
||||
def forward(self, x):
|
||||
x_in = x
|
||||
|
||||
x = self.LayerNorm(x)
|
||||
x = self.dropout(self.dense2(self.act_fn(self.dense1(x)))) + x_in
|
||||
|
||||
return x
|
||||
|
||||
|
||||
# Copy-pasted from transformers.models.blip.modeling_blip.BlipVisionModel with Blip->Blip2, BLIP->BLIP_2
|
||||
class Blip2VisionModel(Blip2PreTrainedModel):
|
||||
main_input_name = "pixel_values"
|
||||
config_class = Blip2VisionConfig
|
||||
|
||||
def __init__(self, config: Blip2VisionConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
embed_dim = config.hidden_size
|
||||
self.embeddings = Blip2VisionEmbeddings(config)
|
||||
self.pre_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
||||
self.encoder = Blip2Encoder(config)
|
||||
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
||||
|
||||
self.post_init()
|
||||
|
||||
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=Blip2VisionConfig)
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: Optional[torch.FloatTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
||||
r"""
|
||||
Returns:
|
||||
|
||||
"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if pixel_values is None:
|
||||
raise ValueError("You have to specify pixel_values")
|
||||
|
||||
hidden_states = self.embeddings(pixel_values)
|
||||
hidden_states = self.pre_layernorm(hidden_states)
|
||||
encoder_outputs = self.encoder(
|
||||
inputs_embeds=hidden_states,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
last_hidden_state = encoder_outputs[0]
|
||||
last_hidden_state = self.post_layernorm(last_hidden_state)
|
||||
|
||||
pooled_output = last_hidden_state[:, 0, :]
|
||||
pooled_output = self.post_layernorm(pooled_output)
|
||||
|
||||
if not return_dict:
|
||||
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
||||
|
||||
return BaseModelOutputWithPooling(
|
||||
last_hidden_state=last_hidden_state,
|
||||
pooler_output=pooled_output,
|
||||
hidden_states=encoder_outputs.hidden_states,
|
||||
attentions=encoder_outputs.attentions,
|
||||
)
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embeddings
|
||||
|
||||
|
||||
# Qformer model, used to get multimodal embeddings from the text and image inputs
|
||||
class Blip2QFormerModel(Blip2PreTrainedModel):
|
||||
"""
|
||||
Querying Transformer (Q-Former), used in BLIP-2.
|
||||
"""
|
||||
|
||||
def __init__(self, config: Blip2Config):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.embeddings = Blip2TextEmbeddings(config.qformer_config)
|
||||
self.visual_encoder = Blip2VisionModel(config.vision_config)
|
||||
self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size))
|
||||
if not hasattr(config, "tokenizer") or config.tokenizer is None:
|
||||
self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", truncation_side="right")
|
||||
else:
|
||||
self.tokenizer = BertTokenizer.from_pretrained(config.tokenizer, truncation_side="right")
|
||||
self.tokenizer.add_special_tokens({"bos_token": "[DEC]"})
|
||||
self.proj_layer = ProjLayer(
|
||||
in_dim=config.qformer_config.hidden_size,
|
||||
out_dim=config.qformer_config.hidden_size,
|
||||
hidden_dim=config.qformer_config.hidden_size * 4,
|
||||
drop_p=0.1,
|
||||
eps=1e-12,
|
||||
)
|
||||
|
||||
self.encoder = Blip2QFormerEncoder(config.qformer_config)
|
||||
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embeddings.word_embeddings
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.embeddings.word_embeddings = value
|
||||
|
||||
def _prune_heads(self, heads_to_prune):
|
||||
"""
|
||||
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
||||
class PreTrainedModel
|
||||
"""
|
||||
for layer, heads in heads_to_prune.items():
|
||||
self.encoder.layer[layer].attention.prune_heads(heads)
|
||||
|
||||
def get_extended_attention_mask(
|
||||
self,
|
||||
attention_mask: torch.Tensor,
|
||||
input_shape: Tuple[int],
|
||||
device: torch.device,
|
||||
has_query: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
|
||||
|
||||
Arguments:
|
||||
attention_mask (`torch.Tensor`):
|
||||
Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
|
||||
input_shape (`Tuple[int]`):
|
||||
The shape of the input to the model.
|
||||
device (`torch.device`):
|
||||
The device of the input to the model.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`.
|
||||
"""
|
||||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||
if attention_mask.dim() == 3:
|
||||
extended_attention_mask = attention_mask[:, None, :, :]
|
||||
elif attention_mask.dim() == 2:
|
||||
# Provided a padding mask of dimensions [batch_size, seq_length]
|
||||
# - the model is an encoder, so make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||
extended_attention_mask = attention_mask[:, None, None, :]
|
||||
else:
|
||||
raise ValueError(
|
||||
"Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
|
||||
input_shape, attention_mask.shape
|
||||
)
|
||||
)
|
||||
|
||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||
# masked positions, this operation will create a tensor which is 0.0 for
|
||||
# positions we want to attend and -10000.0 for masked positions.
|
||||
# Since we are adding it to the raw scores before the softmax, this is
|
||||
# effectively the same as removing these entirely.
|
||||
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
||||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
||||
return extended_attention_mask
|
||||
|
||||
def forward(
|
||||
self,
|
||||
text_input=None,
|
||||
image_input=None,
|
||||
head_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
past_key_values=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
r"""
|
||||
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
||||
the model is configured as a decoder.
|
||||
encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, `optional`):
|
||||
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
||||
the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of:
|
||||
shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and
|
||||
value hidden states of the attention blocks. Can be used to speed up decoding. If `past_key_values` are
|
||||
used, the user can optionally input only the last `decoder_input_ids` (those that don't have their past key
|
||||
value states given to this model) of shape `(batch_size, 1)` instead of all `decoder_input_ids` of shape
|
||||
`(batch_size, sequence_length)`.
|
||||
use_cache (`bool`, `optional`):
|
||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
||||
`past_key_values`).
|
||||
"""
|
||||
|
||||
text = self.tokenizer(text_input, return_tensors="pt", padding=True)
|
||||
text = text.to(self.device)
|
||||
input_ids = text.input_ids
|
||||
batch_size = input_ids.shape[0]
|
||||
query_atts = torch.ones((batch_size, self.query_tokens.size()[1]), dtype=torch.long).to(self.device)
|
||||
attention_mask = torch.cat([query_atts, text.attention_mask], dim=1)
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# past_key_values_length
|
||||
past_key_values_length = (
|
||||
past_key_values[0][0].shape[2] - self.config.query_length if past_key_values is not None else 0
|
||||
)
|
||||
|
||||
query_length = self.query_tokens.shape[1]
|
||||
|
||||
embedding_output = self.embeddings(
|
||||
input_ids=input_ids,
|
||||
query_embeds=self.query_tokens,
|
||||
past_key_values_length=past_key_values_length,
|
||||
)
|
||||
|
||||
# embedding_output = self.layernorm(query_embeds)
|
||||
# embedding_output = self.dropout(embedding_output)
|
||||
|
||||
input_shape = embedding_output.size()[:-1]
|
||||
batch_size, seq_length = input_shape
|
||||
device = embedding_output.device
|
||||
|
||||
image_embeds_frozen = self.visual_encoder(image_input).last_hidden_state
|
||||
# image_embeds_frozen = torch.ones_like(image_embeds_frozen)
|
||||
encoder_hidden_states = image_embeds_frozen
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
|
||||
|
||||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device)
|
||||
|
||||
# If a 2D or 3D attention mask is provided for the cross-attention
|
||||
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||
if encoder_hidden_states is not None:
|
||||
if isinstance(encoder_hidden_states, list):
|
||||
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
|
||||
else:
|
||||
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
||||
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
||||
|
||||
if isinstance(encoder_attention_mask, list):
|
||||
encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
|
||||
elif encoder_attention_mask is None:
|
||||
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
||||
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
||||
else:
|
||||
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
||||
else:
|
||||
encoder_extended_attention_mask = None
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
# attention_probs has shape bsz x n_heads x N x N
|
||||
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||
head_mask = self.get_head_mask(head_mask, self.config.qformer_config.num_hidden_layers)
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
embedding_output,
|
||||
attention_mask=extended_attention_mask,
|
||||
head_mask=head_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_extended_attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
query_length=query_length,
|
||||
)
|
||||
sequence_output = encoder_outputs[0]
|
||||
pooled_output = sequence_output[:, 0, :]
|
||||
|
||||
if not return_dict:
|
||||
return self.proj_layer(sequence_output[:, :query_length, :])
|
||||
|
||||
return BaseModelOutputWithPoolingAndCrossAttentions(
|
||||
last_hidden_state=sequence_output,
|
||||
pooler_output=pooled_output,
|
||||
past_key_values=encoder_outputs.past_key_values,
|
||||
hidden_states=encoder_outputs.hidden_states,
|
||||
attentions=encoder_outputs.attentions,
|
||||
cross_attentions=encoder_outputs.cross_attentions,
|
||||
)
|
||||
@@ -1,212 +0,0 @@
|
||||
# Copyright 2023 Salesforce.com, inc.
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import CLIPPreTrainedModel
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPooling
|
||||
from transformers.models.clip.configuration_clip import CLIPTextConfig
|
||||
from transformers.models.clip.modeling_clip import (
|
||||
CLIPEncoder,
|
||||
_expand_mask,
|
||||
)
|
||||
|
||||
|
||||
# This is a modified version of the CLIPTextModel from transformers.models.clip.modeling_clip
|
||||
# Which allows for an extra input of "context embeddings", which are the query embeddings used in Qformer
|
||||
# They pass through the clip model, along with the text embeddings, and interact with them using self attention
|
||||
class ContextCLIPTextModel(CLIPPreTrainedModel):
|
||||
config_class = CLIPTextConfig
|
||||
|
||||
_no_split_modules = ["CLIPEncoderLayer"]
|
||||
|
||||
def __init__(self, config: CLIPTextConfig):
|
||||
super().__init__(config)
|
||||
self.text_model = ContextCLIPTextTransformer(config)
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
ctx_embeddings: torch.Tensor = None,
|
||||
ctx_begin_pos: list = None,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
||||
return self.text_model(
|
||||
ctx_embeddings=ctx_embeddings,
|
||||
ctx_begin_pos=ctx_begin_pos,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
|
||||
class ContextCLIPTextTransformer(nn.Module):
|
||||
def __init__(self, config: CLIPTextConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
embed_dim = config.hidden_size
|
||||
self.embeddings = ContextCLIPTextEmbeddings(config)
|
||||
self.encoder = CLIPEncoder(config)
|
||||
self.final_layer_norm = nn.LayerNorm(embed_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
ctx_embeddings: torch.Tensor,
|
||||
ctx_begin_pos: list,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
||||
r"""
|
||||
Returns:
|
||||
|
||||
"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if input_ids is None:
|
||||
raise ValueError("You have to specify either input_ids")
|
||||
|
||||
input_shape = input_ids.size()
|
||||
input_ids = input_ids.view(-1, input_shape[-1])
|
||||
|
||||
hidden_states = self.embeddings(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
ctx_embeddings=ctx_embeddings,
|
||||
ctx_begin_pos=ctx_begin_pos,
|
||||
)
|
||||
|
||||
bsz, seq_len = input_shape
|
||||
if ctx_embeddings is not None:
|
||||
seq_len += ctx_embeddings.size(1)
|
||||
# CLIP's text model uses causal mask, prepare it here.
|
||||
# https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
|
||||
causal_attention_mask = self._build_causal_attention_mask(bsz, seq_len, hidden_states.dtype).to(
|
||||
hidden_states.device
|
||||
)
|
||||
# expand attention_mask
|
||||
if attention_mask is not None:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
attention_mask = _expand_mask(attention_mask, hidden_states.dtype)
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
inputs_embeds=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
causal_attention_mask=causal_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
last_hidden_state = encoder_outputs[0]
|
||||
last_hidden_state = self.final_layer_norm(last_hidden_state)
|
||||
|
||||
# text_embeds.shape = [batch_size, sequence_length, transformer.width]
|
||||
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
||||
# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
|
||||
pooled_output = last_hidden_state[
|
||||
torch.arange(last_hidden_state.shape[0], device=input_ids.device),
|
||||
input_ids.to(torch.int).argmax(dim=-1),
|
||||
]
|
||||
|
||||
if not return_dict:
|
||||
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
||||
|
||||
return BaseModelOutputWithPooling(
|
||||
last_hidden_state=last_hidden_state,
|
||||
pooler_output=pooled_output,
|
||||
hidden_states=encoder_outputs.hidden_states,
|
||||
attentions=encoder_outputs.attentions,
|
||||
)
|
||||
|
||||
def _build_causal_attention_mask(self, bsz, seq_len, dtype):
|
||||
# lazily create causal attention mask, with full attention between the vision tokens
|
||||
# pytorch uses additive attention mask; fill with -inf
|
||||
mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype)
|
||||
mask.fill_(torch.tensor(torch.finfo(dtype).min))
|
||||
mask.triu_(1) # zero out the lower diagonal
|
||||
mask = mask.unsqueeze(1) # expand mask
|
||||
return mask
|
||||
|
||||
|
||||
class ContextCLIPTextEmbeddings(nn.Module):
|
||||
def __init__(self, config: CLIPTextConfig):
|
||||
super().__init__()
|
||||
embed_dim = config.hidden_size
|
||||
|
||||
self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
|
||||
self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
|
||||
|
||||
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
||||
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
ctx_embeddings: torch.Tensor,
|
||||
ctx_begin_pos: list,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if ctx_embeddings is None:
|
||||
ctx_len = 0
|
||||
else:
|
||||
ctx_len = ctx_embeddings.shape[1]
|
||||
|
||||
seq_length = (input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]) + ctx_len
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = self.position_ids[:, :seq_length]
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.token_embedding(input_ids)
|
||||
|
||||
# for each input embeddings, add the ctx embeddings at the correct position
|
||||
input_embeds_ctx = []
|
||||
bsz = inputs_embeds.shape[0]
|
||||
|
||||
if ctx_embeddings is not None:
|
||||
for i in range(bsz):
|
||||
cbp = ctx_begin_pos[i]
|
||||
|
||||
prefix = inputs_embeds[i, :cbp]
|
||||
# remove the special token embedding
|
||||
suffix = inputs_embeds[i, cbp:]
|
||||
|
||||
input_embeds_ctx.append(torch.cat([prefix, ctx_embeddings[i], suffix], dim=0))
|
||||
|
||||
inputs_embeds = torch.stack(input_embeds_ctx, dim=0)
|
||||
|
||||
position_embeddings = self.position_embedding(position_ids)
|
||||
embeddings = inputs_embeds + position_embeddings
|
||||
|
||||
return embeddings
|
||||
@@ -1,339 +0,0 @@
|
||||
# Copyright 2023 Salesforce.com, inc.
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import PIL
|
||||
import torch
|
||||
from transformers import CLIPTokenizer
|
||||
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...schedulers import PNDMScheduler
|
||||
from ...utils import (
|
||||
logging,
|
||||
replace_example_docstring,
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
from .blip_image_processing import BlipImageProcessor
|
||||
from .modeling_blip2 import Blip2QFormerModel
|
||||
from .modeling_ctx_clip import ContextCLIPTextModel
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> from diffusers.pipelines import BlipDiffusionPipeline
|
||||
>>> from diffusers.utils import load_image
|
||||
>>> import torch
|
||||
|
||||
>>> blip_diffusion_pipe = BlipDiffusionPipeline.from_pretrained(
|
||||
... "Salesforce/blipdiffusion", torch_dtype=torch.float16
|
||||
... ).to("cuda")
|
||||
|
||||
|
||||
>>> cond_subject = "dog"
|
||||
>>> tgt_subject = "dog"
|
||||
>>> text_prompt_input = "swimming underwater"
|
||||
|
||||
>>> cond_image = load_image(
|
||||
... "https://huggingface.co/datasets/ayushtues/blipdiffusion_images/resolve/main/dog.jpg"
|
||||
... )
|
||||
>>> guidance_scale = 7.5
|
||||
>>> num_inference_steps = 25
|
||||
>>> negative_prompt = "over-exposure, under-exposure, saturated, duplicate, out of frame, lowres, cropped, worst quality, low quality, jpeg artifacts, morbid, mutilated, out of frame, ugly, bad anatomy, bad proportions, deformed, blurry, duplicate"
|
||||
|
||||
|
||||
>>> output = blip_diffusion_pipe(
|
||||
... text_prompt_input,
|
||||
... cond_image,
|
||||
... cond_subject,
|
||||
... tgt_subject,
|
||||
... guidance_scale=guidance_scale,
|
||||
... num_inference_steps=num_inference_steps,
|
||||
... neg_prompt=negative_prompt,
|
||||
... height=512,
|
||||
... width=512,
|
||||
... ).images
|
||||
>>> output[0].save("image.png")
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
class BlipDiffusionPipeline(DiffusionPipeline):
|
||||
"""
|
||||
Pipeline for Zero-Shot Subject Driven Generation using Blip Diffusion.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
Args:
|
||||
tokenizer ([`CLIPTokenizer`]):
|
||||
Tokenizer for the text encoder
|
||||
text_encoder ([`ContextCLIPTextModel`]):
|
||||
Text encoder to encode the text prompt
|
||||
vae ([`AutoencoderKL`]):
|
||||
VAE model to map the latents to the image
|
||||
unet ([`UNet2DConditionModel`]):
|
||||
Conditional U-Net architecture to denoise the image embedding.
|
||||
scheduler ([`PNDMScheduler`]):
|
||||
A scheduler to be used in combination with `unet` to generate image latents.
|
||||
qformer ([`Blip2QFormerModel`]):
|
||||
QFormer model to get multi-modal embeddings from the text and image.
|
||||
image_processor ([`BlipImageProcessor`]):
|
||||
Image Processor to preprocess and postprocess the image.
|
||||
ctx_begin_pos (int, `optional`, defaults to 2):
|
||||
Position of the context token in the text encoder.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: CLIPTokenizer,
|
||||
text_encoder: ContextCLIPTextModel,
|
||||
vae: AutoencoderKL,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: PNDMScheduler,
|
||||
qformer: Blip2QFormerModel,
|
||||
image_processor: BlipImageProcessor,
|
||||
ctx_begin_pos: int = 2,
|
||||
mean: List[float] = None,
|
||||
std: List[float] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
vae=vae,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
qformer=qformer,
|
||||
image_processor=image_processor,
|
||||
)
|
||||
self.register_to_config(ctx_begin_pos=ctx_begin_pos, mean=mean, std=std)
|
||||
|
||||
def get_query_embeddings(self, input_image, src_subject):
|
||||
return self.qformer(image_input=input_image, text_input=src_subject, return_dict=False)
|
||||
|
||||
# from the original Blip Diffusion code, speciefies the target subject and augments the prompt by repeating it
|
||||
def _build_prompt(self, prompts, tgt_subjects, prompt_strength=1.0, prompt_reps=20):
|
||||
rv = []
|
||||
for prompt, tgt_subject in zip(prompts, tgt_subjects):
|
||||
prompt = f"a {tgt_subject} {prompt.strip()}"
|
||||
# a trick to amplify the prompt
|
||||
rv.append(", ".join([prompt] * int(prompt_strength * prompt_reps)))
|
||||
|
||||
return rv
|
||||
|
||||
# Copied from diffusers.pipelines.consistency_models.pipeline_consistency_models.ConsistencyModelPipeline.prepare_latents
|
||||
def prepare_latents(self, batch_size, num_channels, height, width, dtype, device, generator, latents=None):
|
||||
shape = (batch_size, num_channels, height, width)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
latents = latents.to(device=device, dtype=dtype)
|
||||
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
def encode_prompt(self, query_embeds, prompt):
|
||||
# embeddings for prompt, with query_embeds as context
|
||||
max_len = self.text_encoder.text_model.config.max_position_embeddings
|
||||
max_len -= self.qformer.config.num_query_tokens
|
||||
|
||||
tokenized_prompt = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=max_len,
|
||||
return_tensors="pt",
|
||||
).to(self.device)
|
||||
|
||||
batch_size = query_embeds.shape[0]
|
||||
ctx_begin_pos = [self.config.ctx_begin_pos] * batch_size
|
||||
|
||||
text_embeddings = self.text_encoder(
|
||||
input_ids=tokenized_prompt.input_ids,
|
||||
ctx_embeddings=query_embeds,
|
||||
ctx_begin_pos=ctx_begin_pos,
|
||||
)[0]
|
||||
|
||||
return text_embeddings
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: List[str],
|
||||
reference_image: PIL.Image.Image,
|
||||
source_subject_category: List[str],
|
||||
target_subject_category: List[str],
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
guidance_scale: float = 7.5,
|
||||
height: int = 512,
|
||||
width: int = 512,
|
||||
num_inference_steps: int = 50,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
neg_prompt: Optional[str] = "",
|
||||
prompt_strength: float = 1.0,
|
||||
prompt_reps: int = 20,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
):
|
||||
"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`List[str]`):
|
||||
The prompt or prompts to guide the image generation.
|
||||
reference_image (`PIL.Image.Image`):
|
||||
The reference image to condition the generation on.
|
||||
source_subject_category (`List[str]`):
|
||||
The source subject category.
|
||||
target_subject_category (`List[str]`):
|
||||
The target subject category.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by random sampling.
|
||||
guidance_scale (`float`, *optional*, defaults to 7.5):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
height (`int`, *optional*, defaults to 512):
|
||||
The height of the generated image.
|
||||
width (`int`, *optional*, defaults to 512):
|
||||
The width of the generated image.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
neg_prompt (`str`, *optional*, defaults to ""):
|
||||
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
||||
if `guidance_scale` is less than `1`).
|
||||
prompt_strength (`float`, *optional*, defaults to 1.0):
|
||||
The strength of the prompt. Specifies the number of times the prompt is repeated along with prompt_reps
|
||||
to amplify the prompt.
|
||||
prompt_reps (`int`, *optional*, defaults to 20):
|
||||
The number of times the prompt is repeated along with prompt_strength to amplify the prompt.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
|
||||
(`np.array`) or `"pt"` (`torch.Tensor`).
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.ImagePipelineOutput`] or `tuple`
|
||||
"""
|
||||
|
||||
reference_image = self.image_processor.preprocess(
|
||||
reference_image, image_mean=self.config.mean, image_std=self.config.std, return_tensors="pt"
|
||||
)["pixel_values"]
|
||||
reference_image = reference_image.to(self.device)
|
||||
|
||||
if isinstance(prompt, str):
|
||||
prompt = [prompt]
|
||||
if isinstance(source_subject_category, str):
|
||||
source_subject_category = [source_subject_category]
|
||||
if isinstance(target_subject_category, str):
|
||||
target_subject_category = [target_subject_category]
|
||||
|
||||
batch_size = len(prompt)
|
||||
|
||||
prompt = self._build_prompt(
|
||||
prompts=prompt,
|
||||
tgt_subjects=target_subject_category,
|
||||
prompt_strength=prompt_strength,
|
||||
prompt_reps=prompt_reps,
|
||||
)
|
||||
query_embeds = self.get_query_embeddings(reference_image, source_subject_category)
|
||||
text_embeddings = self.encode_prompt(query_embeds, prompt)
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
if do_classifier_free_guidance:
|
||||
max_length = self.text_encoder.text_model.config.max_position_embeddings
|
||||
|
||||
uncond_input = self.tokenizer(
|
||||
[neg_prompt] * batch_size,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
return_tensors="pt",
|
||||
)
|
||||
uncond_embeddings = self.text_encoder(
|
||||
input_ids=uncond_input.input_ids.to(self.device),
|
||||
ctx_embeddings=None,
|
||||
)[0]
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
||||
|
||||
scale_down_factor = 2 ** (len(self.unet.config.block_out_channels) - 1)
|
||||
latents = self.prepare_latents(
|
||||
batch_size=batch_size,
|
||||
num_channels=self.unet.config.in_channels,
|
||||
height=height // scale_down_factor,
|
||||
width=width // scale_down_factor,
|
||||
generator=generator,
|
||||
latents=latents,
|
||||
dtype=self.unet.dtype,
|
||||
device=self.device,
|
||||
)
|
||||
# set timesteps
|
||||
extra_set_kwargs = {}
|
||||
self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
|
||||
|
||||
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
|
||||
noise_pred = self.unet(
|
||||
latent_model_input,
|
||||
timestep=t,
|
||||
encoder_hidden_states=text_embeddings,
|
||||
down_block_additional_residuals=None,
|
||||
mid_block_additional_residual=None,
|
||||
)["sample"]
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
latents = self.scheduler.step(
|
||||
noise_pred,
|
||||
t,
|
||||
latents,
|
||||
)["prev_sample"]
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return ImagePipelineOutput(images=image)
|
||||
@@ -1,79 +1,77 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_flax_available,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_import_structure = {}
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["multicontrolnet"] = ["MultiControlNetModel"]
|
||||
_import_structure["pipeline_controlnet"] = ["StableDiffusionControlNetPipeline"]
|
||||
_import_structure["pipeline_controlnet_blip_diffusion"] = ["BlipDiffusionControlNetPipeline"]
|
||||
_import_structure["pipeline_controlnet_img2img"] = ["StableDiffusionControlNetImg2ImgPipeline"]
|
||||
_import_structure["pipeline_controlnet_inpaint"] = ["StableDiffusionControlNetInpaintPipeline"]
|
||||
_import_structure["pipeline_controlnet_inpaint_sd_xl"] = ["StableDiffusionXLControlNetInpaintPipeline"]
|
||||
_import_structure["pipeline_controlnet_sd_xl"] = ["StableDiffusionXLControlNetPipeline"]
|
||||
_import_structure["pipeline_controlnet_sd_xl_img2img"] = ["StableDiffusionXLControlNetImg2ImgPipeline"]
|
||||
try:
|
||||
if not (is_transformers_available() and is_flax_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_flax_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_flax_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_flax_controlnet"] = ["FlaxStableDiffusionControlNetPipeline"]
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .multicontrolnet import MultiControlNetModel
|
||||
from .pipeline_controlnet import StableDiffusionControlNetPipeline
|
||||
from .pipeline_controlnet_blip_diffusion import BlipDiffusionControlNetPipeline
|
||||
from .pipeline_controlnet_img2img import StableDiffusionControlNetImg2ImgPipeline
|
||||
from .pipeline_controlnet_inpaint import StableDiffusionControlNetInpaintPipeline
|
||||
from .pipeline_controlnet_inpaint_sd_xl import StableDiffusionXLControlNetInpaintPipeline
|
||||
from .pipeline_controlnet_sd_xl import StableDiffusionXLControlNetPipeline
|
||||
from .pipeline_controlnet_sd_xl_img2img import StableDiffusionXLControlNetImg2ImgPipeline
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_flax_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_flax_and_transformers_objects import * # noqa F403
|
||||
else:
|
||||
from .pipeline_flax_controlnet import FlaxStableDiffusionControlNetPipeline
|
||||
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_flax_available,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_import_structure = {}
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["multicontrolnet"] = ["MultiControlNetModel"]
|
||||
_import_structure["pipeline_controlnet"] = ["StableDiffusionControlNetPipeline"]
|
||||
_import_structure["pipeline_controlnet_img2img"] = ["StableDiffusionControlNetImg2ImgPipeline"]
|
||||
_import_structure["pipeline_controlnet_inpaint"] = ["StableDiffusionControlNetInpaintPipeline"]
|
||||
_import_structure["pipeline_controlnet_inpaint_sd_xl"] = ["StableDiffusionXLControlNetInpaintPipeline"]
|
||||
_import_structure["pipeline_controlnet_sd_xl"] = ["StableDiffusionXLControlNetPipeline"]
|
||||
_import_structure["pipeline_controlnet_sd_xl_img2img"] = ["StableDiffusionXLControlNetImg2ImgPipeline"]
|
||||
try:
|
||||
if not (is_transformers_available() and is_flax_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_flax_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_flax_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_flax_controlnet"] = ["FlaxStableDiffusionControlNetPipeline"]
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .multicontrolnet import MultiControlNetModel
|
||||
from .pipeline_controlnet import StableDiffusionControlNetPipeline
|
||||
from .pipeline_controlnet_img2img import StableDiffusionControlNetImg2ImgPipeline
|
||||
from .pipeline_controlnet_inpaint import StableDiffusionControlNetInpaintPipeline
|
||||
from .pipeline_controlnet_inpaint_sd_xl import StableDiffusionXLControlNetInpaintPipeline
|
||||
from .pipeline_controlnet_sd_xl import StableDiffusionXLControlNetPipeline
|
||||
from .pipeline_controlnet_sd_xl_img2img import StableDiffusionXLControlNetImg2ImgPipeline
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_flax_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_flax_and_transformers_objects import * # noqa F403
|
||||
else:
|
||||
from .pipeline_flax_controlnet import FlaxStableDiffusionControlNetPipeline
|
||||
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
|
||||
@@ -291,7 +291,7 @@ class StableDiffusionControlNetPipeline(
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
|
||||
@@ -1,405 +0,0 @@
|
||||
# Copyright 2023 Salesforce.com, inc.
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import PIL
|
||||
import torch
|
||||
from transformers import CLIPTokenizer
|
||||
|
||||
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
|
||||
from ...schedulers import PNDMScheduler
|
||||
from ...utils import (
|
||||
logging,
|
||||
replace_example_docstring,
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..blip_diffusion.blip_image_processing import BlipImageProcessor
|
||||
from ..blip_diffusion.modeling_blip2 import Blip2QFormerModel
|
||||
from ..blip_diffusion.modeling_ctx_clip import ContextCLIPTextModel
|
||||
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> from diffusers.pipelines import BlipDiffusionControlNetPipeline
|
||||
>>> from diffusers.utils import load_image
|
||||
>>> from controlnet_aux import CannyDetector
|
||||
>>> import torch
|
||||
|
||||
>>> blip_diffusion_pipe = BlipDiffusionControlNetPipeline.from_pretrained(
|
||||
... "Salesforce/blipdiffusion-controlnet", torch_dtype=torch.float16
|
||||
... ).to("cuda")
|
||||
|
||||
>>> style_subject = "flower"
|
||||
>>> tgt_subject = "teapot"
|
||||
>>> text_prompt = "on a marble table"
|
||||
|
||||
>>> cldm_cond_image = load_image(
|
||||
... "https://huggingface.co/datasets/ayushtues/blipdiffusion_images/resolve/main/kettle.jpg"
|
||||
... ).resize(512, 512)
|
||||
>>> canny = CannyDetector()
|
||||
>>> cldm_cond_image = canny(cldm_cond_image, 30, 70, output_type="pil")
|
||||
>>> style_image = load_image(
|
||||
... "https://huggingface.co/datasets/ayushtues/blipdiffusion_images/resolve/main/flower.jpg"
|
||||
... )
|
||||
>>> guidance_scale = 7.5
|
||||
>>> num_inference_steps = 50
|
||||
>>> negative_prompt = "over-exposure, under-exposure, saturated, duplicate, out of frame, lowres, cropped, worst quality, low quality, jpeg artifacts, morbid, mutilated, out of frame, ugly, bad anatomy, bad proportions, deformed, blurry, duplicate"
|
||||
|
||||
|
||||
>>> output = blip_diffusion_pipe(
|
||||
... text_prompt,
|
||||
... style_image,
|
||||
... cldm_cond_image,
|
||||
... style_subject,
|
||||
... tgt_subject,
|
||||
... guidance_scale=guidance_scale,
|
||||
... num_inference_steps=num_inference_steps,
|
||||
... neg_prompt=negative_prompt,
|
||||
... height=512,
|
||||
... width=512,
|
||||
... ).images
|
||||
>>> output[0].save("image.png")
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
class BlipDiffusionControlNetPipeline(DiffusionPipeline):
|
||||
"""
|
||||
Pipeline for Canny Edge based Controlled subject-driven generation using Blip Diffusion.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
Args:
|
||||
tokenizer ([`CLIPTokenizer`]):
|
||||
Tokenizer for the text encoder
|
||||
text_encoder ([`ContextCLIPTextModel`]):
|
||||
Text encoder to encode the text prompt
|
||||
vae ([`AutoencoderKL`]):
|
||||
VAE model to map the latents to the image
|
||||
unet ([`UNet2DConditionModel`]):
|
||||
Conditional U-Net architecture to denoise the image embedding.
|
||||
scheduler ([`PNDMScheduler`]):
|
||||
A scheduler to be used in combination with `unet` to generate image latents.
|
||||
qformer ([`Blip2QFormerModel`]):
|
||||
QFormer model to get multi-modal embeddings from the text and image.
|
||||
controlnet ([`ControlNetModel`]):
|
||||
ControlNet model to get the conditioning image embedding.
|
||||
image_processor ([`BlipImageProcessor`]):
|
||||
Image Processor to preprocess and postprocess the image.
|
||||
ctx_begin_pos (int, `optional`, defaults to 2):
|
||||
Position of the context token in the text encoder.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: CLIPTokenizer,
|
||||
text_encoder: ContextCLIPTextModel,
|
||||
vae: AutoencoderKL,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: PNDMScheduler,
|
||||
qformer: Blip2QFormerModel,
|
||||
controlnet: ControlNetModel,
|
||||
image_processor: BlipImageProcessor,
|
||||
ctx_begin_pos: int = 2,
|
||||
mean: List[float] = None,
|
||||
std: List[float] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
vae=vae,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
qformer=qformer,
|
||||
controlnet=controlnet,
|
||||
image_processor=image_processor,
|
||||
)
|
||||
self.register_to_config(ctx_begin_pos=ctx_begin_pos, mean=mean, std=std)
|
||||
|
||||
def get_query_embeddings(self, input_image, src_subject):
|
||||
return self.qformer(image_input=input_image, text_input=src_subject, return_dict=False)
|
||||
|
||||
# from the original Blip Diffusion code, speciefies the target subject and augments the prompt by repeating it
|
||||
def _build_prompt(self, prompts, tgt_subjects, prompt_strength=1.0, prompt_reps=20):
|
||||
rv = []
|
||||
for prompt, tgt_subject in zip(prompts, tgt_subjects):
|
||||
prompt = f"a {tgt_subject} {prompt.strip()}"
|
||||
# a trick to amplify the prompt
|
||||
rv.append(", ".join([prompt] * int(prompt_strength * prompt_reps)))
|
||||
|
||||
return rv
|
||||
|
||||
# Copied from diffusers.pipelines.consistency_models.pipeline_consistency_models.ConsistencyModelPipeline.prepare_latents
|
||||
def prepare_latents(self, batch_size, num_channels, height, width, dtype, device, generator, latents=None):
|
||||
shape = (batch_size, num_channels, height, width)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
latents = latents.to(device=device, dtype=dtype)
|
||||
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
def encode_prompt(self, query_embeds, prompt):
|
||||
# embeddings for prompt, with query_embeds as context
|
||||
max_len = self.text_encoder.text_model.config.max_position_embeddings
|
||||
max_len -= self.qformer.config.num_query_tokens
|
||||
|
||||
tokenized_prompt = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=max_len,
|
||||
return_tensors="pt",
|
||||
).to(self.device)
|
||||
|
||||
batch_size = query_embeds.shape[0]
|
||||
ctx_begin_pos = [self.config.ctx_begin_pos] * batch_size
|
||||
|
||||
text_embeddings = self.text_encoder(
|
||||
input_ids=tokenized_prompt.input_ids,
|
||||
ctx_embeddings=query_embeds,
|
||||
ctx_begin_pos=ctx_begin_pos,
|
||||
)[0]
|
||||
|
||||
return text_embeddings
|
||||
|
||||
# Adapted from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image
|
||||
def prepare_control_image(
|
||||
self,
|
||||
image,
|
||||
width,
|
||||
height,
|
||||
batch_size,
|
||||
num_images_per_prompt,
|
||||
device,
|
||||
dtype,
|
||||
do_classifier_free_guidance=False,
|
||||
):
|
||||
image = self.image_processor.preprocess(
|
||||
image,
|
||||
size={"width": width, "height": height},
|
||||
do_rescale=True,
|
||||
do_center_crop=False,
|
||||
do_normalize=False,
|
||||
return_tensors="pt",
|
||||
)["pixel_values"].to(self.device)
|
||||
image_batch_size = image.shape[0]
|
||||
|
||||
if image_batch_size == 1:
|
||||
repeat_by = batch_size
|
||||
else:
|
||||
# image batch size is the same as prompt batch size
|
||||
repeat_by = num_images_per_prompt
|
||||
|
||||
image = image.repeat_interleave(repeat_by, dim=0)
|
||||
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
image = torch.cat([image] * 2)
|
||||
|
||||
return image
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: List[str],
|
||||
reference_image: PIL.Image.Image,
|
||||
condtioning_image: PIL.Image.Image,
|
||||
source_subject_category: List[str],
|
||||
target_subject_category: List[str],
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
guidance_scale: float = 7.5,
|
||||
height: int = 512,
|
||||
width: int = 512,
|
||||
num_inference_steps: int = 50,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
neg_prompt: Optional[str] = "",
|
||||
prompt_strength: float = 1.0,
|
||||
prompt_reps: int = 20,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
):
|
||||
"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`List[str]`):
|
||||
The prompt or prompts to guide the image generation.
|
||||
reference_image (`PIL.Image.Image`):
|
||||
The reference image to condition the generation on.
|
||||
condtioning_image (`PIL.Image.Image`):
|
||||
The conditioning canny edge image to condition the generation on.
|
||||
source_subject_category (`List[str]`):
|
||||
The source subject category.
|
||||
target_subject_category (`List[str]`):
|
||||
The target subject category.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by random sampling.
|
||||
guidance_scale (`float`, *optional*, defaults to 7.5):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
height (`int`, *optional*, defaults to 512):
|
||||
The height of the generated image.
|
||||
width (`int`, *optional*, defaults to 512):
|
||||
The width of the generated image.
|
||||
seed (`int`, *optional*, defaults to 42):
|
||||
The seed to use for random generation.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
neg_prompt (`str`, *optional*, defaults to ""):
|
||||
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
||||
if `guidance_scale` is less than `1`).
|
||||
prompt_strength (`float`, *optional*, defaults to 1.0):
|
||||
The strength of the prompt. Specifies the number of times the prompt is repeated along with prompt_reps
|
||||
to amplify the prompt.
|
||||
prompt_reps (`int`, *optional*, defaults to 20):
|
||||
The number of times the prompt is repeated along with prompt_strength to amplify the prompt.
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.ImagePipelineOutput`] or `tuple`
|
||||
"""
|
||||
|
||||
reference_image = self.image_processor.preprocess(
|
||||
reference_image, image_mean=self.config.mean, image_std=self.config.std, return_tensors="pt"
|
||||
)["pixel_values"]
|
||||
reference_image = reference_image.to(self.device)
|
||||
|
||||
if isinstance(prompt, str):
|
||||
prompt = [prompt]
|
||||
if isinstance(source_subject_category, str):
|
||||
source_subject_category = [source_subject_category]
|
||||
if isinstance(target_subject_category, str):
|
||||
target_subject_category = [target_subject_category]
|
||||
|
||||
batch_size = len(prompt)
|
||||
|
||||
prompt = self._build_prompt(
|
||||
prompts=prompt,
|
||||
tgt_subjects=target_subject_category,
|
||||
prompt_strength=prompt_strength,
|
||||
prompt_reps=prompt_reps,
|
||||
)
|
||||
query_embeds = self.get_query_embeddings(reference_image, source_subject_category)
|
||||
text_embeddings = self.encode_prompt(query_embeds, prompt)
|
||||
# 3. unconditional embedding
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
if do_classifier_free_guidance:
|
||||
max_length = self.text_encoder.text_model.config.max_position_embeddings
|
||||
|
||||
uncond_input = self.tokenizer(
|
||||
[neg_prompt] * batch_size,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
return_tensors="pt",
|
||||
)
|
||||
uncond_embeddings = self.text_encoder(
|
||||
input_ids=uncond_input.input_ids.to(self.device),
|
||||
ctx_embeddings=None,
|
||||
)[0]
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
||||
scale_down_factor = 2 ** (len(self.unet.config.block_out_channels) - 1)
|
||||
latents = self.prepare_latents(
|
||||
batch_size=batch_size,
|
||||
num_channels=self.unet.config.in_channels,
|
||||
height=height // scale_down_factor,
|
||||
width=width // scale_down_factor,
|
||||
generator=generator,
|
||||
latents=latents,
|
||||
dtype=self.unet.dtype,
|
||||
device=self.device,
|
||||
)
|
||||
# set timesteps
|
||||
extra_set_kwargs = {}
|
||||
self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
|
||||
|
||||
cond_image = self.prepare_control_image(
|
||||
image=condtioning_image,
|
||||
width=width,
|
||||
height=height,
|
||||
batch_size=batch_size,
|
||||
num_images_per_prompt=1,
|
||||
device=self.device,
|
||||
dtype=self.controlnet.dtype,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
)
|
||||
|
||||
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
down_block_res_samples, mid_block_res_sample = self.controlnet(
|
||||
latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=text_embeddings,
|
||||
controlnet_cond=cond_image,
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
noise_pred = self.unet(
|
||||
latent_model_input,
|
||||
timestep=t,
|
||||
encoder_hidden_states=text_embeddings,
|
||||
down_block_additional_residuals=down_block_res_samples,
|
||||
mid_block_additional_residual=mid_block_res_sample,
|
||||
)["sample"]
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
latents = self.scheduler.step(
|
||||
noise_pred,
|
||||
t,
|
||||
latents,
|
||||
)["prev_sample"]
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return ImagePipelineOutput(images=image)
|
||||
@@ -315,7 +315,7 @@ class StableDiffusionControlNetImg2ImgPipeline(
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
|
||||
@@ -442,7 +442,7 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
|
||||
@@ -315,7 +315,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
@@ -288,7 +288,7 @@ class StableDiffusionXLControlNetPipeline(
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
@@ -326,7 +326,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
@@ -308,7 +308,7 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
|
||||
@@ -301,7 +301,7 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
|
||||
+1
-1
@@ -332,7 +332,7 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversion
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
|
||||
@@ -213,7 +213,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
|
||||
@@ -481,7 +481,7 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
|
||||
@@ -278,7 +278,7 @@ class StableDiffusionGLIGENPipeline(DiffusionPipeline):
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
|
||||
+1
-1
@@ -309,7 +309,7 @@ class StableDiffusionGLIGENTextImagePipeline(DiffusionPipeline):
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
|
||||
@@ -302,7 +302,7 @@ class StableDiffusionImg2ImgPipeline(
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
|
||||
@@ -375,7 +375,7 @@ class StableDiffusionInpaintPipeline(
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
|
||||
+1
-1
@@ -297,7 +297,7 @@ class StableDiffusionInpaintPipelineLegacy(
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
|
||||
@@ -211,7 +211,7 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline, TextualInversionLoade
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
@@ -609,9 +609,6 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline, TextualInversionLoade
|
||||
noise_sampler = BrownianTreeNoiseSampler(latents, min_sigma, max_sigma, noise_sampler_seed)
|
||||
sampler_kwargs["noise_sampler"] = noise_sampler
|
||||
|
||||
if "generator" in inspect.signature(self.sampler).parameters:
|
||||
sampler_kwargs["generator"] = generator
|
||||
|
||||
latents = self.sampler(model_fn, latents, sigmas, **sampler_kwargs)
|
||||
|
||||
if not output_type == "latent":
|
||||
|
||||
@@ -272,7 +272,7 @@ class StableDiffusionLDM3DPipeline(
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
|
||||
@@ -244,7 +244,7 @@ class StableDiffusionModelEditingPipeline(DiffusionPipeline, TextualInversionLoa
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
|
||||
@@ -221,7 +221,7 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
|
||||
@@ -256,7 +256,7 @@ class StableDiffusionParadigmsPipeline(
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
|
||||
@@ -446,7 +446,7 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
|
||||
@@ -244,7 +244,7 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin)
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
|
||||
@@ -240,7 +240,7 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
|
||||
@@ -346,7 +346,7 @@ class StableUnCLIPPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
|
||||
@@ -296,7 +296,7 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
|
||||
@@ -264,7 +264,7 @@ class StableDiffusionXLPipeline(
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
@@ -271,7 +271,7 @@ class StableDiffusionXLImg2ImgPipeline(
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
@@ -420,7 +420,7 @@ class StableDiffusionXLInpaintPipeline(
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
+1
-1
@@ -272,7 +272,7 @@ class StableDiffusionXLInstructPix2PixPipeline(
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
|
||||
@@ -296,7 +296,7 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline):
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
|
||||
@@ -288,7 +288,7 @@ class StableDiffusionXLAdapterPipeline(
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
@@ -787,16 +787,8 @@ class StableDiffusionXLAdapterPipeline(
|
||||
height, width = self._default_height_width(height, width, image)
|
||||
device = self._execution_device
|
||||
|
||||
if isinstance(self.adapter, MultiAdapter):
|
||||
adapter_input = []
|
||||
adapter_input = _preprocess_adapter_image(image, height, width).to(device)
|
||||
|
||||
for one_image in image:
|
||||
one_image = _preprocess_adapter_image(one_image, height, width)
|
||||
one_image = one_image.to(device=device, dtype=self.adapter.dtype)
|
||||
adapter_input.append(one_image)
|
||||
else:
|
||||
adapter_input = _preprocess_adapter_image(image, height, width)
|
||||
adapter_input = adapter_input.to(device=device, dtype=self.adapter.dtype)
|
||||
original_size = original_size or (height, width)
|
||||
target_size = target_size or (height, width)
|
||||
|
||||
@@ -873,14 +865,10 @@ class StableDiffusionXLAdapterPipeline(
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 7. Prepare added time ids & embeddings & adapter features
|
||||
if isinstance(self.adapter, MultiAdapter):
|
||||
adapter_state = self.adapter(adapter_input, adapter_conditioning_scale)
|
||||
for k, v in enumerate(adapter_state):
|
||||
adapter_state[k] = v
|
||||
else:
|
||||
adapter_state = self.adapter(adapter_input)
|
||||
for k, v in enumerate(adapter_state):
|
||||
adapter_state[k] = v * adapter_conditioning_scale
|
||||
adapter_input = adapter_input.type(latents.dtype)
|
||||
adapter_state = self.adapter(adapter_input)
|
||||
for k, v in enumerate(adapter_state):
|
||||
adapter_state[k] = v * adapter_conditioning_scale
|
||||
if num_images_per_prompt > 1:
|
||||
for k, v in enumerate(adapter_state):
|
||||
adapter_state[k] = v.repeat(num_images_per_prompt, 1, 1, 1)
|
||||
|
||||
@@ -228,7 +228,7 @@ class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lora
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
|
||||
+1
-1
@@ -290,7 +290,7 @@ class VideoToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
|
||||
@@ -22,7 +22,6 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import deprecate
|
||||
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
|
||||
|
||||
|
||||
@@ -187,14 +186,6 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.timesteps = torch.from_numpy(timesteps)
|
||||
self.model_outputs = [None] * solver_order
|
||||
self.lower_order_nums = 0
|
||||
self._step_index = None
|
||||
|
||||
@property
|
||||
def step_index(self):
|
||||
"""
|
||||
The index counter for current timestep. It will increae 1 after each scheduler step.
|
||||
"""
|
||||
return self._step_index
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
||||
"""
|
||||
@@ -234,16 +225,17 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
||||
if self.config.use_karras_sigmas:
|
||||
log_sigmas = np.log(sigmas)
|
||||
sigmas = np.flip(sigmas).copy()
|
||||
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
||||
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
|
||||
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
|
||||
else:
|
||||
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
|
||||
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
|
||||
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
|
||||
timesteps = np.flip(timesteps).copy().astype(np.int64)
|
||||
|
||||
self.sigmas = torch.from_numpy(sigmas)
|
||||
|
||||
# when num_inference_steps == num_train_timesteps, we can end up with
|
||||
# duplicates in timesteps.
|
||||
_, unique_indices = np.unique(timesteps, return_index=True)
|
||||
timesteps = timesteps[np.sort(unique_indices)]
|
||||
|
||||
self.sigmas = torch.from_numpy(sigmas).to(device=device)
|
||||
self.timesteps = torch.from_numpy(timesteps).to(device)
|
||||
|
||||
self.num_inference_steps = len(timesteps)
|
||||
@@ -253,9 +245,6 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
] * self.config.solver_order
|
||||
self.lower_order_nums = 0
|
||||
|
||||
# add an index counter for schedulers that allow duplicated timesteps
|
||||
self._step_index = None
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
||||
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||
"""
|
||||
@@ -291,57 +280,8 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
return sample
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
|
||||
def _sigma_to_t(self, sigma, log_sigmas):
|
||||
# get log sigma
|
||||
log_sigma = np.log(sigma)
|
||||
|
||||
# get distribution
|
||||
dists = log_sigma - log_sigmas[:, np.newaxis]
|
||||
|
||||
# get sigmas range
|
||||
low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
|
||||
high_idx = low_idx + 1
|
||||
|
||||
low = log_sigmas[low_idx]
|
||||
high = log_sigmas[high_idx]
|
||||
|
||||
# interpolate sigmas
|
||||
w = (low - log_sigma) / (low - high)
|
||||
w = np.clip(w, 0, 1)
|
||||
|
||||
# transform interpolation to time range
|
||||
t = (1 - w) * low_idx + w * high_idx
|
||||
t = t.reshape(sigma.shape)
|
||||
return t
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
|
||||
def _sigma_to_alpha_sigma_t(self, sigma):
|
||||
alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
|
||||
sigma_t = sigma * alpha_t
|
||||
|
||||
return alpha_t, sigma_t
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
|
||||
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
|
||||
"""Constructs the noise schedule of Karras et al. (2022)."""
|
||||
|
||||
sigma_min: float = in_sigmas[-1].item()
|
||||
sigma_max: float = in_sigmas[0].item()
|
||||
|
||||
rho = 7.0 # 7.0 is the value used in the paper
|
||||
ramp = np.linspace(0, 1, num_inference_steps)
|
||||
min_inv_rho = sigma_min ** (1 / rho)
|
||||
max_inv_rho = sigma_max ** (1 / rho)
|
||||
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
||||
return sigmas
|
||||
|
||||
def convert_model_output(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
*args,
|
||||
sample: torch.FloatTensor = None,
|
||||
**kwargs,
|
||||
self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
Convert the model output to the corresponding type the DEIS algorithm needs.
|
||||
@@ -358,26 +298,13 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
`torch.FloatTensor`:
|
||||
The converted model output.
|
||||
"""
|
||||
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
|
||||
if sample is None:
|
||||
if len(args) > 1:
|
||||
sample = args[1]
|
||||
else:
|
||||
raise ValueError("missing `sample` as a required keyward argument")
|
||||
if timestep is not None:
|
||||
deprecate(
|
||||
"timesteps",
|
||||
"1.0.0",
|
||||
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
sigma = self.sigmas[self.step_index]
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
||||
if self.config.prediction_type == "epsilon":
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
x0_pred = (sample - sigma_t * model_output) / alpha_t
|
||||
elif self.config.prediction_type == "sample":
|
||||
x0_pred = model_output
|
||||
elif self.config.prediction_type == "v_prediction":
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
x0_pred = alpha_t * sample - sigma_t * model_output
|
||||
else:
|
||||
raise ValueError(
|
||||
@@ -389,6 +316,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
x0_pred = self._threshold_sample(x0_pred)
|
||||
|
||||
if self.config.algorithm_type == "deis":
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
return (sample - alpha_t * x0_pred) / sigma_t
|
||||
else:
|
||||
raise NotImplementedError("only support log-rho multistep deis now")
|
||||
@@ -396,9 +324,9 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
def deis_first_order_update(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
*args,
|
||||
sample: torch.FloatTensor = None,
|
||||
**kwargs,
|
||||
timestep: int,
|
||||
prev_timestep: int,
|
||||
sample: torch.FloatTensor,
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
One step for the first-order DEIS (equivalent to DDIM).
|
||||
@@ -417,33 +345,9 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
`torch.FloatTensor`:
|
||||
The sample tensor at the previous timestep.
|
||||
"""
|
||||
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
|
||||
prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
|
||||
if sample is None:
|
||||
if len(args) > 2:
|
||||
sample = args[2]
|
||||
else:
|
||||
raise ValueError(" missing `sample` as a required keyward argument")
|
||||
if timestep is not None:
|
||||
deprecate(
|
||||
"timesteps",
|
||||
"1.0.0",
|
||||
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
if prev_timestep is not None:
|
||||
deprecate(
|
||||
"prev_timestep",
|
||||
"1.0.0",
|
||||
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
||||
alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
|
||||
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
|
||||
lambda_s = torch.log(alpha_s) - torch.log(sigma_s)
|
||||
|
||||
lambda_t, lambda_s = self.lambda_t[prev_timestep], self.lambda_t[timestep]
|
||||
alpha_t, alpha_s = self.alpha_t[prev_timestep], self.alpha_t[timestep]
|
||||
sigma_t, _ = self.sigma_t[prev_timestep], self.sigma_t[timestep]
|
||||
h = lambda_t - lambda_s
|
||||
if self.config.algorithm_type == "deis":
|
||||
x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output
|
||||
@@ -454,9 +358,9 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
def multistep_deis_second_order_update(
|
||||
self,
|
||||
model_output_list: List[torch.FloatTensor],
|
||||
*args,
|
||||
sample: torch.FloatTensor = None,
|
||||
**kwargs,
|
||||
timestep_list: List[int],
|
||||
prev_timestep: int,
|
||||
sample: torch.FloatTensor,
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
One step for the second-order multistep DEIS.
|
||||
@@ -464,6 +368,10 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
Args:
|
||||
model_output_list (`List[torch.FloatTensor]`):
|
||||
The direct outputs from learned diffusion model at current and latter timesteps.
|
||||
timestep (`int`):
|
||||
The current and latter discrete timestep in the diffusion chain.
|
||||
prev_timestep (`int`):
|
||||
The previous discrete timestep in the diffusion chain.
|
||||
sample (`torch.FloatTensor`):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
|
||||
@@ -471,38 +379,10 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
`torch.FloatTensor`:
|
||||
The sample tensor at the previous timestep.
|
||||
"""
|
||||
timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
|
||||
prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
|
||||
if sample is None:
|
||||
if len(args) > 2:
|
||||
sample = args[2]
|
||||
else:
|
||||
raise ValueError(" missing `sample` as a required keyward argument")
|
||||
if timestep_list is not None:
|
||||
deprecate(
|
||||
"timestep_list",
|
||||
"1.0.0",
|
||||
"Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
if prev_timestep is not None:
|
||||
deprecate(
|
||||
"prev_timestep",
|
||||
"1.0.0",
|
||||
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
sigma_t, sigma_s0, sigma_s1 = (
|
||||
self.sigmas[self.step_index + 1],
|
||||
self.sigmas[self.step_index],
|
||||
self.sigmas[self.step_index - 1],
|
||||
)
|
||||
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
||||
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
|
||||
alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
|
||||
|
||||
t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2]
|
||||
m0, m1 = model_output_list[-1], model_output_list[-2]
|
||||
alpha_t, alpha_s0, alpha_s1 = self.alpha_t[t], self.alpha_t[s0], self.alpha_t[s1]
|
||||
sigma_t, sigma_s0, sigma_s1 = self.sigma_t[t], self.sigma_t[s0], self.sigma_t[s1]
|
||||
|
||||
rho_t, rho_s0, rho_s1 = sigma_t / alpha_t, sigma_s0 / alpha_s0, sigma_s1 / alpha_s1
|
||||
|
||||
@@ -523,9 +403,9 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
def multistep_deis_third_order_update(
|
||||
self,
|
||||
model_output_list: List[torch.FloatTensor],
|
||||
*args,
|
||||
sample: torch.FloatTensor = None,
|
||||
**kwargs,
|
||||
timestep_list: List[int],
|
||||
prev_timestep: int,
|
||||
sample: torch.FloatTensor,
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
One step for the third-order multistep DEIS.
|
||||
@@ -533,6 +413,10 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
Args:
|
||||
model_output_list (`List[torch.FloatTensor]`):
|
||||
The direct outputs from learned diffusion model at current and latter timesteps.
|
||||
timestep (`int`):
|
||||
The current and latter discrete timestep in the diffusion chain.
|
||||
prev_timestep (`int`):
|
||||
The previous discrete timestep in the diffusion chain.
|
||||
sample (`torch.FloatTensor`):
|
||||
A current instance of a sample created by diffusion process.
|
||||
|
||||
@@ -540,47 +424,15 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
`torch.FloatTensor`:
|
||||
The sample tensor at the previous timestep.
|
||||
"""
|
||||
|
||||
timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
|
||||
prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
|
||||
if sample is None:
|
||||
if len(args) > 2:
|
||||
sample = args[2]
|
||||
else:
|
||||
raise ValueError(" missing`sample` as a required keyward argument")
|
||||
if timestep_list is not None:
|
||||
deprecate(
|
||||
"timestep_list",
|
||||
"1.0.0",
|
||||
"Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
if prev_timestep is not None:
|
||||
deprecate(
|
||||
"prev_timestep",
|
||||
"1.0.0",
|
||||
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
sigma_t, sigma_s0, sigma_s1, sigma_s2 = (
|
||||
self.sigmas[self.step_index + 1],
|
||||
self.sigmas[self.step_index],
|
||||
self.sigmas[self.step_index - 1],
|
||||
self.sigmas[self.step_index - 2],
|
||||
)
|
||||
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
||||
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
|
||||
alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
|
||||
alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2)
|
||||
|
||||
t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3]
|
||||
m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3]
|
||||
|
||||
alpha_t, alpha_s0, alpha_s1, alpha_s2 = self.alpha_t[t], self.alpha_t[s0], self.alpha_t[s1], self.alpha_t[s2]
|
||||
sigma_t, sigma_s0, sigma_s1, simga_s2 = self.sigma_t[t], self.sigma_t[s0], self.sigma_t[s1], self.sigma_t[s2]
|
||||
rho_t, rho_s0, rho_s1, rho_s2 = (
|
||||
sigma_t / alpha_t,
|
||||
sigma_s0 / alpha_s0,
|
||||
sigma_s1 / alpha_s1,
|
||||
sigma_s2 / alpha_s2,
|
||||
simga_s2 / alpha_s2,
|
||||
)
|
||||
|
||||
if self.config.algorithm_type == "deis":
|
||||
@@ -608,25 +460,6 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
else:
|
||||
raise NotImplementedError("only support log-rho multistep deis now")
|
||||
|
||||
def _init_step_index(self, timestep):
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
|
||||
index_candidates = (self.timesteps == timestep).nonzero()
|
||||
|
||||
if len(index_candidates) == 0:
|
||||
step_index = len(self.timesteps) - 1
|
||||
# The sigma index that is taken for the **very** first `step`
|
||||
# is always the second index (or the last index if there is only 1)
|
||||
# This way we can ensure we don't accidentally skip a sigma in
|
||||
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
||||
elif len(index_candidates) > 1:
|
||||
step_index = index_candidates[1].item()
|
||||
else:
|
||||
step_index = index_candidates[0].item()
|
||||
|
||||
self._step_index = step_index
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
@@ -659,34 +492,42 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
||||
)
|
||||
|
||||
if self.step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
step_index = (self.timesteps == timestep).nonzero()
|
||||
if len(step_index) == 0:
|
||||
step_index = len(self.timesteps) - 1
|
||||
else:
|
||||
step_index = step_index.item()
|
||||
prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1]
|
||||
lower_order_final = (
|
||||
(self.step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15
|
||||
(step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15
|
||||
)
|
||||
lower_order_second = (
|
||||
(self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15
|
||||
(step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15
|
||||
)
|
||||
|
||||
model_output = self.convert_model_output(model_output, sample=sample)
|
||||
model_output = self.convert_model_output(model_output, timestep, sample)
|
||||
for i in range(self.config.solver_order - 1):
|
||||
self.model_outputs[i] = self.model_outputs[i + 1]
|
||||
self.model_outputs[-1] = model_output
|
||||
|
||||
if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final:
|
||||
prev_sample = self.deis_first_order_update(model_output, sample=sample)
|
||||
prev_sample = self.deis_first_order_update(model_output, timestep, prev_timestep, sample)
|
||||
elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second:
|
||||
prev_sample = self.multistep_deis_second_order_update(self.model_outputs, sample=sample)
|
||||
timestep_list = [self.timesteps[step_index - 1], timestep]
|
||||
prev_sample = self.multistep_deis_second_order_update(
|
||||
self.model_outputs, timestep_list, prev_timestep, sample
|
||||
)
|
||||
else:
|
||||
prev_sample = self.multistep_deis_third_order_update(self.model_outputs, sample=sample)
|
||||
timestep_list = [self.timesteps[step_index - 2], self.timesteps[step_index - 1], timestep]
|
||||
prev_sample = self.multistep_deis_third_order_update(
|
||||
self.model_outputs, timestep_list, prev_timestep, sample
|
||||
)
|
||||
|
||||
if self.lower_order_nums < self.config.solver_order:
|
||||
self.lower_order_nums += 1
|
||||
|
||||
# upon completion increase step index by one
|
||||
self._step_index += 1
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
|
||||
@@ -707,30 +548,28 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
return sample
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
|
||||
def add_noise(
|
||||
self,
|
||||
original_samples: torch.FloatTensor,
|
||||
noise: torch.FloatTensor,
|
||||
timesteps: torch.FloatTensor,
|
||||
timesteps: torch.IntTensor,
|
||||
) -> torch.FloatTensor:
|
||||
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
||||
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
||||
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
||||
# mps does not support float64
|
||||
schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
|
||||
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
|
||||
else:
|
||||
schedule_timesteps = self.timesteps.to(original_samples.device)
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
|
||||
alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
|
||||
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
||||
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
|
||||
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
||||
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
|
||||
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
|
||||
|
||||
sigma = sigmas[step_indices].flatten()
|
||||
while len(sigma.shape) < len(original_samples.shape):
|
||||
sigma = sigma.unsqueeze(-1)
|
||||
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
|
||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
||||
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
|
||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
|
||||
|
||||
noisy_samples = original_samples + noise * sigma
|
||||
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
||||
return noisy_samples
|
||||
|
||||
def __len__(self):
|
||||
|
||||
@@ -21,7 +21,6 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import deprecate
|
||||
from ..utils.torch_utils import randn_tensor
|
||||
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
|
||||
|
||||
@@ -204,14 +203,6 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.timesteps = torch.from_numpy(timesteps)
|
||||
self.model_outputs = [None] * solver_order
|
||||
self.lower_order_nums = 0
|
||||
self._step_index = None
|
||||
|
||||
@property
|
||||
def step_index(self):
|
||||
"""
|
||||
The index counter for current timestep. It will increae 1 after each scheduler step.
|
||||
"""
|
||||
return self._step_index
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None):
|
||||
"""
|
||||
@@ -251,19 +242,19 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
)
|
||||
|
||||
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
||||
log_sigmas = np.log(sigmas)
|
||||
|
||||
if self.config.use_karras_sigmas:
|
||||
sigmas = np.flip(sigmas).copy()
|
||||
log_sigmas = np.log(sigmas)
|
||||
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
||||
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
|
||||
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
|
||||
else:
|
||||
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
|
||||
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
|
||||
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
|
||||
timesteps = np.flip(timesteps).copy().astype(np.int64)
|
||||
|
||||
self.sigmas = torch.from_numpy(sigmas)
|
||||
|
||||
# when num_inference_steps == num_train_timesteps, we can end up with
|
||||
# duplicates in timesteps.
|
||||
_, unique_indices = np.unique(timesteps, return_index=True)
|
||||
timesteps = timesteps[np.sort(unique_indices)]
|
||||
|
||||
self.sigmas = torch.from_numpy(sigmas).to(device=device)
|
||||
self.timesteps = torch.from_numpy(timesteps).to(device)
|
||||
|
||||
self.num_inference_steps = len(timesteps)
|
||||
@@ -273,9 +264,6 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
] * self.config.solver_order
|
||||
self.lower_order_nums = 0
|
||||
|
||||
# add an index counter for schedulers that allow duplicated timesteps
|
||||
self._step_index = None
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
||||
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||
"""
|
||||
@@ -335,12 +323,6 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
t = t.reshape(sigma.shape)
|
||||
return t
|
||||
|
||||
def _sigma_to_alpha_sigma_t(self, sigma):
|
||||
alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
|
||||
sigma_t = sigma * alpha_t
|
||||
|
||||
return alpha_t, sigma_t
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
|
||||
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
|
||||
"""Constructs the noise schedule of Karras et al. (2022)."""
|
||||
@@ -356,11 +338,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
return sigmas
|
||||
|
||||
def convert_model_output(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
*args,
|
||||
sample: torch.FloatTensor = None,
|
||||
**kwargs,
|
||||
self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is
|
||||
@@ -377,6 +355,8 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
Args:
|
||||
model_output (`torch.FloatTensor`):
|
||||
The direct output from the learned diffusion model.
|
||||
timestep (`int`):
|
||||
The current discrete timestep in the diffusion chain.
|
||||
sample (`torch.FloatTensor`):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
|
||||
@@ -384,18 +364,6 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
`torch.FloatTensor`:
|
||||
The converted model output.
|
||||
"""
|
||||
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
|
||||
if sample is None:
|
||||
if len(args) > 1:
|
||||
sample = args[1]
|
||||
else:
|
||||
raise ValueError("missing `sample` as a required keyward argument")
|
||||
if timestep is not None:
|
||||
deprecate(
|
||||
"timesteps",
|
||||
"1.0.0",
|
||||
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
# DPM-Solver++ needs to solve an integral of the data prediction model.
|
||||
if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]:
|
||||
@@ -403,14 +371,12 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
# DPM-Solver and DPM-Solver++ only need the "mean" output.
|
||||
if self.config.variance_type in ["learned", "learned_range"]:
|
||||
model_output = model_output[:, :3]
|
||||
sigma = self.sigmas[self.step_index]
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
x0_pred = (sample - sigma_t * model_output) / alpha_t
|
||||
elif self.config.prediction_type == "sample":
|
||||
x0_pred = model_output
|
||||
elif self.config.prediction_type == "v_prediction":
|
||||
sigma = self.sigmas[self.step_index]
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
x0_pred = alpha_t * sample - sigma_t * model_output
|
||||
else:
|
||||
raise ValueError(
|
||||
@@ -432,12 +398,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
else:
|
||||
epsilon = model_output
|
||||
elif self.config.prediction_type == "sample":
|
||||
sigma = self.sigmas[self.step_index]
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
epsilon = (sample - alpha_t * model_output) / sigma_t
|
||||
elif self.config.prediction_type == "v_prediction":
|
||||
sigma = self.sigmas[self.step_index]
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
epsilon = alpha_t * model_output + sigma_t * sample
|
||||
else:
|
||||
raise ValueError(
|
||||
@@ -446,8 +410,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
)
|
||||
|
||||
if self.config.thresholding:
|
||||
sigma = self.sigmas[self.step_index]
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
x0_pred = (sample - sigma_t * epsilon) / alpha_t
|
||||
x0_pred = self._threshold_sample(x0_pred)
|
||||
epsilon = (sample - alpha_t * x0_pred) / sigma_t
|
||||
@@ -457,10 +420,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
def dpm_solver_first_order_update(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
*args,
|
||||
sample: torch.FloatTensor = None,
|
||||
timestep: int,
|
||||
prev_timestep: int,
|
||||
sample: torch.FloatTensor,
|
||||
noise: Optional[torch.FloatTensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
One step for the first-order DPMSolver (equivalent to DDIM).
|
||||
@@ -468,6 +431,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
Args:
|
||||
model_output (`torch.FloatTensor`):
|
||||
The direct output from the learned diffusion model.
|
||||
timestep (`int`):
|
||||
The current discrete timestep in the diffusion chain.
|
||||
prev_timestep (`int`):
|
||||
The previous discrete timestep in the diffusion chain.
|
||||
sample (`torch.FloatTensor`):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
|
||||
@@ -475,33 +442,9 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
`torch.FloatTensor`:
|
||||
The sample tensor at the previous timestep.
|
||||
"""
|
||||
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
|
||||
prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
|
||||
if sample is None:
|
||||
if len(args) > 2:
|
||||
sample = args[2]
|
||||
else:
|
||||
raise ValueError(" missing `sample` as a required keyward argument")
|
||||
if timestep is not None:
|
||||
deprecate(
|
||||
"timesteps",
|
||||
"1.0.0",
|
||||
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
if prev_timestep is not None:
|
||||
deprecate(
|
||||
"prev_timestep",
|
||||
"1.0.0",
|
||||
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
||||
alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
|
||||
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
|
||||
lambda_s = torch.log(alpha_s) - torch.log(sigma_s)
|
||||
|
||||
lambda_t, lambda_s = self.lambda_t[prev_timestep], self.lambda_t[timestep]
|
||||
alpha_t, alpha_s = self.alpha_t[prev_timestep], self.alpha_t[timestep]
|
||||
sigma_t, sigma_s = self.sigma_t[prev_timestep], self.sigma_t[timestep]
|
||||
h = lambda_t - lambda_s
|
||||
if self.config.algorithm_type == "dpmsolver++":
|
||||
x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output
|
||||
@@ -526,10 +469,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
def multistep_dpm_solver_second_order_update(
|
||||
self,
|
||||
model_output_list: List[torch.FloatTensor],
|
||||
*args,
|
||||
sample: torch.FloatTensor = None,
|
||||
timestep_list: List[int],
|
||||
prev_timestep: int,
|
||||
sample: torch.FloatTensor,
|
||||
noise: Optional[torch.FloatTensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
One step for the second-order multistep DPMSolver.
|
||||
@@ -537,6 +480,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
Args:
|
||||
model_output_list (`List[torch.FloatTensor]`):
|
||||
The direct outputs from learned diffusion model at current and latter timesteps.
|
||||
timestep (`int`):
|
||||
The current and latter discrete timestep in the diffusion chain.
|
||||
prev_timestep (`int`):
|
||||
The previous discrete timestep in the diffusion chain.
|
||||
sample (`torch.FloatTensor`):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
|
||||
@@ -544,43 +491,11 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
`torch.FloatTensor`:
|
||||
The sample tensor at the previous timestep.
|
||||
"""
|
||||
timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
|
||||
prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
|
||||
if sample is None:
|
||||
if len(args) > 2:
|
||||
sample = args[2]
|
||||
else:
|
||||
raise ValueError(" missing `sample` as a required keyward argument")
|
||||
if timestep_list is not None:
|
||||
deprecate(
|
||||
"timestep_list",
|
||||
"1.0.0",
|
||||
"Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
if prev_timestep is not None:
|
||||
deprecate(
|
||||
"prev_timestep",
|
||||
"1.0.0",
|
||||
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
sigma_t, sigma_s0, sigma_s1 = (
|
||||
self.sigmas[self.step_index + 1],
|
||||
self.sigmas[self.step_index],
|
||||
self.sigmas[self.step_index - 1],
|
||||
)
|
||||
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
||||
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
|
||||
alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
|
||||
|
||||
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
|
||||
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
|
||||
lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
|
||||
|
||||
t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2]
|
||||
m0, m1 = model_output_list[-1], model_output_list[-2]
|
||||
|
||||
lambda_t, lambda_s0, lambda_s1 = self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1]
|
||||
alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
|
||||
sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
|
||||
h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1
|
||||
r0 = h_0 / h
|
||||
D0, D1 = m0, (1.0 / r0) * (m0 - m1)
|
||||
@@ -649,9 +564,9 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
def multistep_dpm_solver_third_order_update(
|
||||
self,
|
||||
model_output_list: List[torch.FloatTensor],
|
||||
*args,
|
||||
sample: torch.FloatTensor = None,
|
||||
**kwargs,
|
||||
timestep_list: List[int],
|
||||
prev_timestep: int,
|
||||
sample: torch.FloatTensor,
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
One step for the third-order multistep DPMSolver.
|
||||
@@ -659,6 +574,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
Args:
|
||||
model_output_list (`List[torch.FloatTensor]`):
|
||||
The direct outputs from learned diffusion model at current and latter timesteps.
|
||||
timestep (`int`):
|
||||
The current and latter discrete timestep in the diffusion chain.
|
||||
prev_timestep (`int`):
|
||||
The previous discrete timestep in the diffusion chain.
|
||||
sample (`torch.FloatTensor`):
|
||||
A current instance of a sample created by diffusion process.
|
||||
|
||||
@@ -666,47 +585,16 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
`torch.FloatTensor`:
|
||||
The sample tensor at the previous timestep.
|
||||
"""
|
||||
|
||||
timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
|
||||
prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
|
||||
if sample is None:
|
||||
if len(args) > 2:
|
||||
sample = args[2]
|
||||
else:
|
||||
raise ValueError(" missing`sample` as a required keyward argument")
|
||||
if timestep_list is not None:
|
||||
deprecate(
|
||||
"timestep_list",
|
||||
"1.0.0",
|
||||
"Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
if prev_timestep is not None:
|
||||
deprecate(
|
||||
"prev_timestep",
|
||||
"1.0.0",
|
||||
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
sigma_t, sigma_s0, sigma_s1, sigma_s2 = (
|
||||
self.sigmas[self.step_index + 1],
|
||||
self.sigmas[self.step_index],
|
||||
self.sigmas[self.step_index - 1],
|
||||
self.sigmas[self.step_index - 2],
|
||||
)
|
||||
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
||||
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
|
||||
alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
|
||||
alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2)
|
||||
|
||||
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
|
||||
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
|
||||
lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
|
||||
lambda_s2 = torch.log(alpha_s2) - torch.log(sigma_s2)
|
||||
|
||||
t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3]
|
||||
m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3]
|
||||
|
||||
lambda_t, lambda_s0, lambda_s1, lambda_s2 = (
|
||||
self.lambda_t[t],
|
||||
self.lambda_t[s0],
|
||||
self.lambda_t[s1],
|
||||
self.lambda_t[s2],
|
||||
)
|
||||
alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
|
||||
sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
|
||||
h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2
|
||||
r0, r1 = h_0 / h, h_1 / h
|
||||
D0 = m0
|
||||
@@ -731,25 +619,6 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
)
|
||||
return x_t
|
||||
|
||||
def _init_step_index(self, timestep):
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
|
||||
index_candidates = (self.timesteps == timestep).nonzero()
|
||||
|
||||
if len(index_candidates) == 0:
|
||||
step_index = len(self.timesteps) - 1
|
||||
# The sigma index that is taken for the **very** first `step`
|
||||
# is always the second index (or the last index if there is only 1)
|
||||
# This way we can ensure we don't accidentally skip a sigma in
|
||||
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
||||
elif len(index_candidates) > 1:
|
||||
step_index = index_candidates[1].item()
|
||||
else:
|
||||
step_index = index_candidates[0].item()
|
||||
|
||||
self._step_index = step_index
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
@@ -785,17 +654,22 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
||||
)
|
||||
|
||||
if self.step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
step_index = (self.timesteps == timestep).nonzero()
|
||||
if len(step_index) == 0:
|
||||
step_index = len(self.timesteps) - 1
|
||||
else:
|
||||
step_index = step_index.item()
|
||||
prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1]
|
||||
lower_order_final = (
|
||||
(self.step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15
|
||||
(step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15
|
||||
)
|
||||
lower_order_second = (
|
||||
(self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15
|
||||
(step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15
|
||||
)
|
||||
|
||||
model_output = self.convert_model_output(model_output, sample=sample)
|
||||
model_output = self.convert_model_output(model_output, timestep, sample)
|
||||
for i in range(self.config.solver_order - 1):
|
||||
self.model_outputs[i] = self.model_outputs[i + 1]
|
||||
self.model_outputs[-1] = model_output
|
||||
@@ -808,18 +682,23 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
noise = None
|
||||
|
||||
if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final:
|
||||
prev_sample = self.dpm_solver_first_order_update(model_output, sample=sample, noise=noise)
|
||||
prev_sample = self.dpm_solver_first_order_update(
|
||||
model_output, timestep, prev_timestep, sample, noise=noise
|
||||
)
|
||||
elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second:
|
||||
prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, sample=sample, noise=noise)
|
||||
timestep_list = [self.timesteps[step_index - 1], timestep]
|
||||
prev_sample = self.multistep_dpm_solver_second_order_update(
|
||||
self.model_outputs, timestep_list, prev_timestep, sample, noise=noise
|
||||
)
|
||||
else:
|
||||
prev_sample = self.multistep_dpm_solver_third_order_update(self.model_outputs, sample=sample)
|
||||
timestep_list = [self.timesteps[step_index - 2], self.timesteps[step_index - 1], timestep]
|
||||
prev_sample = self.multistep_dpm_solver_third_order_update(
|
||||
self.model_outputs, timestep_list, prev_timestep, sample
|
||||
)
|
||||
|
||||
if self.lower_order_nums < self.config.solver_order:
|
||||
self.lower_order_nums += 1
|
||||
|
||||
# upon completion increase step index by one
|
||||
self._step_index += 1
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
|
||||
@@ -840,30 +719,28 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
return sample
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
|
||||
def add_noise(
|
||||
self,
|
||||
original_samples: torch.FloatTensor,
|
||||
noise: torch.FloatTensor,
|
||||
timesteps: torch.FloatTensor,
|
||||
timesteps: torch.IntTensor,
|
||||
) -> torch.FloatTensor:
|
||||
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
||||
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
||||
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
||||
# mps does not support float64
|
||||
schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
|
||||
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
|
||||
else:
|
||||
schedule_timesteps = self.timesteps.to(original_samples.device)
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
|
||||
alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
|
||||
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
||||
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
|
||||
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
||||
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
|
||||
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
|
||||
|
||||
sigma = sigmas[step_indices].flatten()
|
||||
while len(sigma.shape) < len(original_samples.shape):
|
||||
sigma = sigma.unsqueeze(-1)
|
||||
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
|
||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
||||
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
|
||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
|
||||
|
||||
noisy_samples = original_samples + noise * sigma
|
||||
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
||||
return noisy_samples
|
||||
|
||||
def __len__(self):
|
||||
|
||||
@@ -21,7 +21,6 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import deprecate
|
||||
from ..utils.torch_utils import randn_tensor
|
||||
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
|
||||
|
||||
@@ -204,16 +203,8 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.timesteps = torch.from_numpy(timesteps)
|
||||
self.model_outputs = [None] * solver_order
|
||||
self.lower_order_nums = 0
|
||||
self._step_index = None
|
||||
self.use_karras_sigmas = use_karras_sigmas
|
||||
|
||||
@property
|
||||
def step_index(self):
|
||||
"""
|
||||
The index counter for current timestep. It will increae 1 after each scheduler step.
|
||||
"""
|
||||
return self._step_index
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None):
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||
@@ -253,19 +244,11 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
)
|
||||
|
||||
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
||||
log_sigmas = np.log(sigmas)
|
||||
|
||||
if self.config.use_karras_sigmas:
|
||||
log_sigmas = np.log(sigmas)
|
||||
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
||||
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
|
||||
timesteps = timesteps.copy().astype(np.int64)
|
||||
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
|
||||
else:
|
||||
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
|
||||
sigma_max = (
|
||||
(1 - self.alphas_cumprod[self.noisiest_timestep]) / self.alphas_cumprod[self.noisiest_timestep]
|
||||
) ** 0.5
|
||||
sigmas = np.concatenate([sigmas, [sigma_max]]).astype(np.float32)
|
||||
|
||||
self.sigmas = torch.from_numpy(sigmas)
|
||||
|
||||
@@ -283,9 +266,6 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
] * self.config.solver_order
|
||||
self.lower_order_nums = 0
|
||||
|
||||
# add an index counter for schedulers that allow duplicated timesteps
|
||||
self._step_index = None
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
||||
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||
"""
|
||||
@@ -345,13 +325,6 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
t = t.reshape(sigma.shape)
|
||||
return t
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
|
||||
def _sigma_to_alpha_sigma_t(self, sigma):
|
||||
alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
|
||||
sigma_t = sigma * alpha_t
|
||||
|
||||
return alpha_t, sigma_t
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
|
||||
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
|
||||
"""Constructs the noise schedule of Karras et al. (2022)."""
|
||||
@@ -368,11 +341,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.convert_model_output
|
||||
def convert_model_output(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
*args,
|
||||
sample: torch.FloatTensor = None,
|
||||
**kwargs,
|
||||
self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is
|
||||
@@ -389,6 +358,8 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
Args:
|
||||
model_output (`torch.FloatTensor`):
|
||||
The direct output from the learned diffusion model.
|
||||
timestep (`int`):
|
||||
The current discrete timestep in the diffusion chain.
|
||||
sample (`torch.FloatTensor`):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
|
||||
@@ -396,18 +367,6 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
`torch.FloatTensor`:
|
||||
The converted model output.
|
||||
"""
|
||||
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
|
||||
if sample is None:
|
||||
if len(args) > 1:
|
||||
sample = args[1]
|
||||
else:
|
||||
raise ValueError("missing `sample` as a required keyward argument")
|
||||
if timestep is not None:
|
||||
deprecate(
|
||||
"timesteps",
|
||||
"1.0.0",
|
||||
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
# DPM-Solver++ needs to solve an integral of the data prediction model.
|
||||
if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]:
|
||||
@@ -415,14 +374,12 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
# DPM-Solver and DPM-Solver++ only need the "mean" output.
|
||||
if self.config.variance_type in ["learned", "learned_range"]:
|
||||
model_output = model_output[:, :3]
|
||||
sigma = self.sigmas[self.step_index]
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
x0_pred = (sample - sigma_t * model_output) / alpha_t
|
||||
elif self.config.prediction_type == "sample":
|
||||
x0_pred = model_output
|
||||
elif self.config.prediction_type == "v_prediction":
|
||||
sigma = self.sigmas[self.step_index]
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
x0_pred = alpha_t * sample - sigma_t * model_output
|
||||
else:
|
||||
raise ValueError(
|
||||
@@ -444,12 +401,10 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
else:
|
||||
epsilon = model_output
|
||||
elif self.config.prediction_type == "sample":
|
||||
sigma = self.sigmas[self.step_index]
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
epsilon = (sample - alpha_t * model_output) / sigma_t
|
||||
elif self.config.prediction_type == "v_prediction":
|
||||
sigma = self.sigmas[self.step_index]
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
epsilon = alpha_t * model_output + sigma_t * sample
|
||||
else:
|
||||
raise ValueError(
|
||||
@@ -458,22 +413,20 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
)
|
||||
|
||||
if self.config.thresholding:
|
||||
sigma = self.sigmas[self.step_index]
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
x0_pred = (sample - sigma_t * epsilon) / alpha_t
|
||||
x0_pred = self._threshold_sample(x0_pred)
|
||||
epsilon = (sample - alpha_t * x0_pred) / sigma_t
|
||||
|
||||
return epsilon
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.dpm_solver_first_order_update
|
||||
def dpm_solver_first_order_update(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
*args,
|
||||
sample: torch.FloatTensor = None,
|
||||
timestep: int,
|
||||
prev_timestep: int,
|
||||
sample: torch.FloatTensor,
|
||||
noise: Optional[torch.FloatTensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
One step for the first-order DPMSolver (equivalent to DDIM).
|
||||
@@ -481,6 +434,10 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
Args:
|
||||
model_output (`torch.FloatTensor`):
|
||||
The direct output from the learned diffusion model.
|
||||
timestep (`int`):
|
||||
The current discrete timestep in the diffusion chain.
|
||||
prev_timestep (`int`):
|
||||
The previous discrete timestep in the diffusion chain.
|
||||
sample (`torch.FloatTensor`):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
|
||||
@@ -488,62 +445,27 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
`torch.FloatTensor`:
|
||||
The sample tensor at the previous timestep.
|
||||
"""
|
||||
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
|
||||
prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
|
||||
if sample is None:
|
||||
if len(args) > 2:
|
||||
sample = args[2]
|
||||
else:
|
||||
raise ValueError(" missing `sample` as a required keyward argument")
|
||||
if timestep is not None:
|
||||
deprecate(
|
||||
"timesteps",
|
||||
"1.0.0",
|
||||
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
if prev_timestep is not None:
|
||||
deprecate(
|
||||
"prev_timestep",
|
||||
"1.0.0",
|
||||
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
||||
alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
|
||||
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
|
||||
lambda_s = torch.log(alpha_s) - torch.log(sigma_s)
|
||||
|
||||
lambda_t, lambda_s = self.lambda_t[prev_timestep], self.lambda_t[timestep]
|
||||
alpha_t, alpha_s = self.alpha_t[prev_timestep], self.alpha_t[timestep]
|
||||
sigma_t, sigma_s = self.sigma_t[prev_timestep], self.sigma_t[timestep]
|
||||
h = lambda_t - lambda_s
|
||||
if self.config.algorithm_type == "dpmsolver++":
|
||||
x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output
|
||||
elif self.config.algorithm_type == "dpmsolver":
|
||||
x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output
|
||||
elif self.config.algorithm_type == "sde-dpmsolver++":
|
||||
assert noise is not None
|
||||
x_t = (
|
||||
(sigma_t / sigma_s * torch.exp(-h)) * sample
|
||||
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output
|
||||
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
|
||||
)
|
||||
elif self.config.algorithm_type == "sde-dpmsolver":
|
||||
assert noise is not None
|
||||
x_t = (
|
||||
(alpha_t / alpha_s) * sample
|
||||
- 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * model_output
|
||||
+ sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
|
||||
elif "sde" in self.config.algorithm_type:
|
||||
raise NotImplementedError(
|
||||
f"Inversion step is not yet implemented for algorithm type {self.config.algorithm_type}."
|
||||
)
|
||||
return x_t
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.multistep_dpm_solver_second_order_update
|
||||
def multistep_dpm_solver_second_order_update(
|
||||
self,
|
||||
model_output_list: List[torch.FloatTensor],
|
||||
*args,
|
||||
sample: torch.FloatTensor = None,
|
||||
timestep_list: List[int],
|
||||
prev_timestep: int,
|
||||
sample: torch.FloatTensor,
|
||||
noise: Optional[torch.FloatTensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
One step for the second-order multistep DPMSolver.
|
||||
@@ -551,6 +473,10 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
Args:
|
||||
model_output_list (`List[torch.FloatTensor]`):
|
||||
The direct outputs from learned diffusion model at current and latter timesteps.
|
||||
timestep (`int`):
|
||||
The current and latter discrete timestep in the diffusion chain.
|
||||
prev_timestep (`int`):
|
||||
The previous discrete timestep in the diffusion chain.
|
||||
sample (`torch.FloatTensor`):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
|
||||
@@ -558,43 +484,11 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
`torch.FloatTensor`:
|
||||
The sample tensor at the previous timestep.
|
||||
"""
|
||||
timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
|
||||
prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
|
||||
if sample is None:
|
||||
if len(args) > 2:
|
||||
sample = args[2]
|
||||
else:
|
||||
raise ValueError(" missing `sample` as a required keyward argument")
|
||||
if timestep_list is not None:
|
||||
deprecate(
|
||||
"timestep_list",
|
||||
"1.0.0",
|
||||
"Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
if prev_timestep is not None:
|
||||
deprecate(
|
||||
"prev_timestep",
|
||||
"1.0.0",
|
||||
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
sigma_t, sigma_s0, sigma_s1 = (
|
||||
self.sigmas[self.step_index + 1],
|
||||
self.sigmas[self.step_index],
|
||||
self.sigmas[self.step_index - 1],
|
||||
)
|
||||
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
||||
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
|
||||
alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
|
||||
|
||||
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
|
||||
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
|
||||
lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
|
||||
|
||||
t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2]
|
||||
m0, m1 = model_output_list[-1], model_output_list[-2]
|
||||
|
||||
lambda_t, lambda_s0, lambda_s1 = self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1]
|
||||
alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
|
||||
sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
|
||||
h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1
|
||||
r0 = h_0 / h
|
||||
D0, D1 = m0, (1.0 / r0) * (m0 - m1)
|
||||
@@ -626,47 +520,19 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
- (sigma_t * (torch.exp(h) - 1.0)) * D0
|
||||
- (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
|
||||
)
|
||||
elif self.config.algorithm_type == "sde-dpmsolver++":
|
||||
assert noise is not None
|
||||
if self.config.solver_type == "midpoint":
|
||||
x_t = (
|
||||
(sigma_t / sigma_s0 * torch.exp(-h)) * sample
|
||||
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
|
||||
+ 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1
|
||||
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
|
||||
)
|
||||
elif self.config.solver_type == "heun":
|
||||
x_t = (
|
||||
(sigma_t / sigma_s0 * torch.exp(-h)) * sample
|
||||
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
|
||||
+ (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1
|
||||
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
|
||||
)
|
||||
elif self.config.algorithm_type == "sde-dpmsolver":
|
||||
assert noise is not None
|
||||
if self.config.solver_type == "midpoint":
|
||||
x_t = (
|
||||
(alpha_t / alpha_s0) * sample
|
||||
- 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0
|
||||
- (sigma_t * (torch.exp(h) - 1.0)) * D1
|
||||
+ sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
|
||||
)
|
||||
elif self.config.solver_type == "heun":
|
||||
x_t = (
|
||||
(alpha_t / alpha_s0) * sample
|
||||
- 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0
|
||||
- 2.0 * (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
|
||||
+ sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
|
||||
)
|
||||
elif "sde" in self.config.algorithm_type:
|
||||
raise NotImplementedError(
|
||||
f"Inversion step is not yet implemented for algorithm type {self.config.algorithm_type}."
|
||||
)
|
||||
return x_t
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.multistep_dpm_solver_third_order_update
|
||||
def multistep_dpm_solver_third_order_update(
|
||||
self,
|
||||
model_output_list: List[torch.FloatTensor],
|
||||
*args,
|
||||
sample: torch.FloatTensor = None,
|
||||
**kwargs,
|
||||
timestep_list: List[int],
|
||||
prev_timestep: int,
|
||||
sample: torch.FloatTensor,
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
One step for the third-order multistep DPMSolver.
|
||||
@@ -674,6 +540,10 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
Args:
|
||||
model_output_list (`List[torch.FloatTensor]`):
|
||||
The direct outputs from learned diffusion model at current and latter timesteps.
|
||||
timestep (`int`):
|
||||
The current and latter discrete timestep in the diffusion chain.
|
||||
prev_timestep (`int`):
|
||||
The previous discrete timestep in the diffusion chain.
|
||||
sample (`torch.FloatTensor`):
|
||||
A current instance of a sample created by diffusion process.
|
||||
|
||||
@@ -681,47 +551,16 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
`torch.FloatTensor`:
|
||||
The sample tensor at the previous timestep.
|
||||
"""
|
||||
|
||||
timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
|
||||
prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
|
||||
if sample is None:
|
||||
if len(args) > 2:
|
||||
sample = args[2]
|
||||
else:
|
||||
raise ValueError(" missing`sample` as a required keyward argument")
|
||||
if timestep_list is not None:
|
||||
deprecate(
|
||||
"timestep_list",
|
||||
"1.0.0",
|
||||
"Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
if prev_timestep is not None:
|
||||
deprecate(
|
||||
"prev_timestep",
|
||||
"1.0.0",
|
||||
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
sigma_t, sigma_s0, sigma_s1, sigma_s2 = (
|
||||
self.sigmas[self.step_index + 1],
|
||||
self.sigmas[self.step_index],
|
||||
self.sigmas[self.step_index - 1],
|
||||
self.sigmas[self.step_index - 2],
|
||||
)
|
||||
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
||||
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
|
||||
alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
|
||||
alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2)
|
||||
|
||||
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
|
||||
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
|
||||
lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
|
||||
lambda_s2 = torch.log(alpha_s2) - torch.log(sigma_s2)
|
||||
|
||||
t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3]
|
||||
m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3]
|
||||
|
||||
lambda_t, lambda_s0, lambda_s1, lambda_s2 = (
|
||||
self.lambda_t[t],
|
||||
self.lambda_t[s0],
|
||||
self.lambda_t[s1],
|
||||
self.lambda_t[s2],
|
||||
)
|
||||
alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
|
||||
sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
|
||||
h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2
|
||||
r0, r1 = h_0 / h, h_1 / h
|
||||
D0 = m0
|
||||
@@ -746,27 +585,6 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
)
|
||||
return x_t
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
|
||||
def _init_step_index(self, timestep):
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
|
||||
index_candidates = (self.timesteps == timestep).nonzero()
|
||||
|
||||
if len(index_candidates) == 0:
|
||||
step_index = len(self.timesteps) - 1
|
||||
# The sigma index that is taken for the **very** first `step`
|
||||
# is always the second index (or the last index if there is only 1)
|
||||
# This way we can ensure we don't accidentally skip a sigma in
|
||||
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
||||
elif len(index_candidates) > 1:
|
||||
step_index = index_candidates[1].item()
|
||||
else:
|
||||
step_index = index_candidates[0].item()
|
||||
|
||||
self._step_index = step_index
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.step
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
@@ -786,8 +604,6 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
The current discrete timestep in the diffusion chain.
|
||||
sample (`torch.FloatTensor`):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A random number generator.
|
||||
return_dict (`bool`):
|
||||
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
|
||||
|
||||
@@ -802,17 +618,24 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
||||
)
|
||||
|
||||
if self.step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
step_index = (self.timesteps == timestep).nonzero()
|
||||
if len(step_index) == 0:
|
||||
step_index = len(self.timesteps) - 1
|
||||
else:
|
||||
step_index = step_index.item()
|
||||
prev_timestep = (
|
||||
self.noisiest_timestep if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1]
|
||||
)
|
||||
lower_order_final = (
|
||||
(self.step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15
|
||||
(step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15
|
||||
)
|
||||
lower_order_second = (
|
||||
(self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15
|
||||
(step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15
|
||||
)
|
||||
|
||||
model_output = self.convert_model_output(model_output, sample=sample)
|
||||
model_output = self.convert_model_output(model_output, timestep, sample)
|
||||
for i in range(self.config.solver_order - 1):
|
||||
self.model_outputs[i] = self.model_outputs[i + 1]
|
||||
self.model_outputs[-1] = model_output
|
||||
@@ -825,18 +648,23 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
noise = None
|
||||
|
||||
if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final:
|
||||
prev_sample = self.dpm_solver_first_order_update(model_output, sample=sample, noise=noise)
|
||||
prev_sample = self.dpm_solver_first_order_update(
|
||||
model_output, timestep, prev_timestep, sample, noise=noise
|
||||
)
|
||||
elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second:
|
||||
prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, sample=sample, noise=noise)
|
||||
timestep_list = [self.timesteps[step_index - 1], timestep]
|
||||
prev_sample = self.multistep_dpm_solver_second_order_update(
|
||||
self.model_outputs, timestep_list, prev_timestep, sample, noise=noise
|
||||
)
|
||||
else:
|
||||
prev_sample = self.multistep_dpm_solver_third_order_update(self.model_outputs, sample=sample)
|
||||
timestep_list = [self.timesteps[step_index - 2], self.timesteps[step_index - 1], timestep]
|
||||
prev_sample = self.multistep_dpm_solver_third_order_update(
|
||||
self.model_outputs, timestep_list, prev_timestep, sample
|
||||
)
|
||||
|
||||
if self.lower_order_nums < self.config.solver_order:
|
||||
self.lower_order_nums += 1
|
||||
|
||||
# upon completion increase step index by one
|
||||
self._step_index += 1
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
|
||||
@@ -858,30 +686,28 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
return sample
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
|
||||
def add_noise(
|
||||
self,
|
||||
original_samples: torch.FloatTensor,
|
||||
noise: torch.FloatTensor,
|
||||
timesteps: torch.FloatTensor,
|
||||
timesteps: torch.IntTensor,
|
||||
) -> torch.FloatTensor:
|
||||
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
||||
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
||||
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
||||
# mps does not support float64
|
||||
schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
|
||||
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
|
||||
else:
|
||||
schedule_timesteps = self.timesteps.to(original_samples.device)
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
|
||||
alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
|
||||
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
||||
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
|
||||
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
||||
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
|
||||
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
|
||||
|
||||
sigma = sigmas[step_indices].flatten()
|
||||
while len(sigma.shape) < len(original_samples.shape):
|
||||
sigma = sigma.unsqueeze(-1)
|
||||
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
|
||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
||||
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
|
||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
|
||||
|
||||
noisy_samples = original_samples + noise * sigma
|
||||
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
||||
return noisy_samples
|
||||
|
||||
def __len__(self):
|
||||
|
||||
@@ -21,7 +21,7 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import deprecate, logging
|
||||
from ..utils import logging
|
||||
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
|
||||
|
||||
|
||||
@@ -197,7 +197,6 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.model_outputs = [None] * solver_order
|
||||
self.sample = None
|
||||
self.order_list = self.get_order_list(num_train_timesteps)
|
||||
self._step_index = None
|
||||
|
||||
def get_order_list(self, num_inference_steps: int) -> List[int]:
|
||||
"""
|
||||
@@ -233,13 +232,6 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
orders = [1] * steps
|
||||
return orders
|
||||
|
||||
@property
|
||||
def step_index(self):
|
||||
"""
|
||||
The index counter for current timestep. It will increae 1 after each scheduler step.
|
||||
"""
|
||||
return self._step_index
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||
@@ -264,16 +256,11 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
||||
if self.config.use_karras_sigmas:
|
||||
log_sigmas = np.log(sigmas)
|
||||
sigmas = np.flip(sigmas).copy()
|
||||
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
||||
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
|
||||
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
|
||||
else:
|
||||
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
|
||||
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
|
||||
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
|
||||
timesteps = np.flip(timesteps).copy().astype(np.int64)
|
||||
|
||||
self.sigmas = torch.from_numpy(sigmas).to(device=device)
|
||||
self.sigmas = torch.from_numpy(sigmas)
|
||||
|
||||
self.timesteps = torch.from_numpy(timesteps).to(device)
|
||||
self.model_outputs = [None] * self.config.solver_order
|
||||
@@ -287,9 +274,6 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
self.order_list = self.get_order_list(num_inference_steps)
|
||||
|
||||
# add an index counter for schedulers that allow duplicated timesteps
|
||||
self._step_index = None
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
||||
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||
"""
|
||||
@@ -349,13 +333,6 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
t = t.reshape(sigma.shape)
|
||||
return t
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
|
||||
def _sigma_to_alpha_sigma_t(self, sigma):
|
||||
alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
|
||||
sigma_t = sigma * alpha_t
|
||||
|
||||
return alpha_t, sigma_t
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
|
||||
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
|
||||
"""Constructs the noise schedule of Karras et al. (2022)."""
|
||||
@@ -371,11 +348,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
return sigmas
|
||||
|
||||
def convert_model_output(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
*args,
|
||||
sample: torch.FloatTensor = None,
|
||||
**kwargs,
|
||||
self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is
|
||||
@@ -392,6 +365,8 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
Args:
|
||||
model_output (`torch.FloatTensor`):
|
||||
The direct output from the learned diffusion model.
|
||||
timestep (`int`):
|
||||
The current discrete timestep in the diffusion chain.
|
||||
sample (`torch.FloatTensor`):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
|
||||
@@ -399,32 +374,18 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
`torch.FloatTensor`:
|
||||
The converted model output.
|
||||
"""
|
||||
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
|
||||
if sample is None:
|
||||
if len(args) > 1:
|
||||
sample = args[1]
|
||||
else:
|
||||
raise ValueError("missing `sample` as a required keyward argument")
|
||||
if timestep is not None:
|
||||
deprecate(
|
||||
"timesteps",
|
||||
"1.0.0",
|
||||
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
# DPM-Solver++ needs to solve an integral of the data prediction model.
|
||||
if self.config.algorithm_type == "dpmsolver++":
|
||||
if self.config.prediction_type == "epsilon":
|
||||
# DPM-Solver and DPM-Solver++ only need the "mean" output.
|
||||
if self.config.variance_type in ["learned_range"]:
|
||||
model_output = model_output[:, :3]
|
||||
sigma = self.sigmas[self.step_index]
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
x0_pred = (sample - sigma_t * model_output) / alpha_t
|
||||
elif self.config.prediction_type == "sample":
|
||||
x0_pred = model_output
|
||||
elif self.config.prediction_type == "v_prediction":
|
||||
sigma = self.sigmas[self.step_index]
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
x0_pred = alpha_t * sample - sigma_t * model_output
|
||||
else:
|
||||
raise ValueError(
|
||||
@@ -444,13 +405,11 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
model_output = model_output[:, :3]
|
||||
return model_output
|
||||
elif self.config.prediction_type == "sample":
|
||||
sigma = self.sigmas[self.step_index]
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
epsilon = (sample - alpha_t * model_output) / sigma_t
|
||||
return epsilon
|
||||
elif self.config.prediction_type == "v_prediction":
|
||||
sigma = self.sigmas[self.step_index]
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
epsilon = alpha_t * model_output + sigma_t * sample
|
||||
return epsilon
|
||||
else:
|
||||
@@ -462,9 +421,9 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
def dpm_solver_first_order_update(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
*args,
|
||||
sample: torch.FloatTensor = None,
|
||||
**kwargs,
|
||||
timestep: int,
|
||||
prev_timestep: int,
|
||||
sample: torch.FloatTensor,
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
One step for the first-order DPMSolver (equivalent to DDIM).
|
||||
@@ -483,31 +442,9 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
`torch.FloatTensor`:
|
||||
The sample tensor at the previous timestep.
|
||||
"""
|
||||
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
|
||||
prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
|
||||
if sample is None:
|
||||
if len(args) > 2:
|
||||
sample = args[2]
|
||||
else:
|
||||
raise ValueError(" missing `sample` as a required keyward argument")
|
||||
if timestep is not None:
|
||||
deprecate(
|
||||
"timesteps",
|
||||
"1.0.0",
|
||||
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
if prev_timestep is not None:
|
||||
deprecate(
|
||||
"prev_timestep",
|
||||
"1.0.0",
|
||||
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
||||
alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
|
||||
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
|
||||
lambda_s = torch.log(alpha_s) - torch.log(sigma_s)
|
||||
lambda_t, lambda_s = self.lambda_t[prev_timestep], self.lambda_t[timestep]
|
||||
alpha_t, alpha_s = self.alpha_t[prev_timestep], self.alpha_t[timestep]
|
||||
sigma_t, sigma_s = self.sigma_t[prev_timestep], self.sigma_t[timestep]
|
||||
h = lambda_t - lambda_s
|
||||
if self.config.algorithm_type == "dpmsolver++":
|
||||
x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output
|
||||
@@ -518,9 +455,9 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
def singlestep_dpm_solver_second_order_update(
|
||||
self,
|
||||
model_output_list: List[torch.FloatTensor],
|
||||
*args,
|
||||
sample: torch.FloatTensor = None,
|
||||
**kwargs,
|
||||
timestep_list: List[int],
|
||||
prev_timestep: int,
|
||||
sample: torch.FloatTensor,
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
One step for the second-order singlestep DPMSolver that computes the solution at time `prev_timestep` from the
|
||||
@@ -540,42 +477,11 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
`torch.FloatTensor`:
|
||||
The sample tensor at the previous timestep.
|
||||
"""
|
||||
timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
|
||||
prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
|
||||
if sample is None:
|
||||
if len(args) > 2:
|
||||
sample = args[2]
|
||||
else:
|
||||
raise ValueError(" missing `sample` as a required keyward argument")
|
||||
if timestep_list is not None:
|
||||
deprecate(
|
||||
"timestep_list",
|
||||
"1.0.0",
|
||||
"Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
if prev_timestep is not None:
|
||||
deprecate(
|
||||
"prev_timestep",
|
||||
"1.0.0",
|
||||
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
sigma_t, sigma_s0, sigma_s1 = (
|
||||
self.sigmas[self.step_index + 1],
|
||||
self.sigmas[self.step_index],
|
||||
self.sigmas[self.step_index - 1],
|
||||
)
|
||||
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
||||
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
|
||||
alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
|
||||
|
||||
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
|
||||
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
|
||||
lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
|
||||
|
||||
t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2]
|
||||
m0, m1 = model_output_list[-1], model_output_list[-2]
|
||||
|
||||
lambda_t, lambda_s0, lambda_s1 = self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1]
|
||||
alpha_t, alpha_s1 = self.alpha_t[t], self.alpha_t[s1]
|
||||
sigma_t, sigma_s1 = self.sigma_t[t], self.sigma_t[s1]
|
||||
h, h_0 = lambda_t - lambda_s1, lambda_s0 - lambda_s1
|
||||
r0 = h_0 / h
|
||||
D0, D1 = m1, (1.0 / r0) * (m0 - m1)
|
||||
@@ -612,9 +518,9 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
def singlestep_dpm_solver_third_order_update(
|
||||
self,
|
||||
model_output_list: List[torch.FloatTensor],
|
||||
*args,
|
||||
sample: torch.FloatTensor = None,
|
||||
**kwargs,
|
||||
timestep_list: List[int],
|
||||
prev_timestep: int,
|
||||
sample: torch.FloatTensor,
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
One step for the third-order singlestep DPMSolver that computes the solution at time `prev_timestep` from the
|
||||
@@ -634,47 +540,16 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
`torch.FloatTensor`:
|
||||
The sample tensor at the previous timestep.
|
||||
"""
|
||||
|
||||
timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
|
||||
prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
|
||||
if sample is None:
|
||||
if len(args) > 2:
|
||||
sample = args[2]
|
||||
else:
|
||||
raise ValueError(" missing`sample` as a required keyward argument")
|
||||
if timestep_list is not None:
|
||||
deprecate(
|
||||
"timestep_list",
|
||||
"1.0.0",
|
||||
"Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
if prev_timestep is not None:
|
||||
deprecate(
|
||||
"prev_timestep",
|
||||
"1.0.0",
|
||||
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
sigma_t, sigma_s0, sigma_s1, sigma_s2 = (
|
||||
self.sigmas[self.step_index + 1],
|
||||
self.sigmas[self.step_index],
|
||||
self.sigmas[self.step_index - 1],
|
||||
self.sigmas[self.step_index - 2],
|
||||
)
|
||||
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
||||
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
|
||||
alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
|
||||
alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2)
|
||||
|
||||
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
|
||||
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
|
||||
lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
|
||||
lambda_s2 = torch.log(alpha_s2) - torch.log(sigma_s2)
|
||||
|
||||
t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3]
|
||||
m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3]
|
||||
|
||||
lambda_t, lambda_s0, lambda_s1, lambda_s2 = (
|
||||
self.lambda_t[t],
|
||||
self.lambda_t[s0],
|
||||
self.lambda_t[s1],
|
||||
self.lambda_t[s2],
|
||||
)
|
||||
alpha_t, alpha_s2 = self.alpha_t[t], self.alpha_t[s2]
|
||||
sigma_t, sigma_s2 = self.sigma_t[t], self.sigma_t[s2]
|
||||
h, h_0, h_1 = lambda_t - lambda_s2, lambda_s0 - lambda_s2, lambda_s1 - lambda_s2
|
||||
r0, r1 = h_0 / h, h_1 / h
|
||||
D0 = m2
|
||||
@@ -716,10 +591,10 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
def singlestep_dpm_solver_update(
|
||||
self,
|
||||
model_output_list: List[torch.FloatTensor],
|
||||
*args,
|
||||
sample: torch.FloatTensor = None,
|
||||
order: int = None,
|
||||
**kwargs,
|
||||
timestep_list: List[int],
|
||||
prev_timestep: int,
|
||||
sample: torch.FloatTensor,
|
||||
order: int,
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
One step for the singlestep DPMSolver.
|
||||
@@ -740,60 +615,19 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
`torch.FloatTensor`:
|
||||
The sample tensor at the previous timestep.
|
||||
"""
|
||||
timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
|
||||
prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
|
||||
if sample is None:
|
||||
if len(args) > 2:
|
||||
sample = args[2]
|
||||
else:
|
||||
raise ValueError(" missing`sample` as a required keyward argument")
|
||||
if order is None:
|
||||
if len(args) > 3:
|
||||
order = args[3]
|
||||
else:
|
||||
raise ValueError(" missing `order` as a required keyward argument")
|
||||
if timestep_list is not None:
|
||||
deprecate(
|
||||
"timestep_list",
|
||||
"1.0.0",
|
||||
"Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
if prev_timestep is not None:
|
||||
deprecate(
|
||||
"prev_timestep",
|
||||
"1.0.0",
|
||||
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
if order == 1:
|
||||
return self.dpm_solver_first_order_update(model_output_list[-1], sample=sample)
|
||||
return self.dpm_solver_first_order_update(model_output_list[-1], timestep_list[-1], prev_timestep, sample)
|
||||
elif order == 2:
|
||||
return self.singlestep_dpm_solver_second_order_update(model_output_list, sample=sample)
|
||||
return self.singlestep_dpm_solver_second_order_update(
|
||||
model_output_list, timestep_list, prev_timestep, sample
|
||||
)
|
||||
elif order == 3:
|
||||
return self.singlestep_dpm_solver_third_order_update(model_output_list, sample=sample)
|
||||
return self.singlestep_dpm_solver_third_order_update(
|
||||
model_output_list, timestep_list, prev_timestep, sample
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Order must be 1, 2, 3, got {order}")
|
||||
|
||||
def _init_step_index(self, timestep):
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
|
||||
index_candidates = (self.timesteps == timestep).nonzero()
|
||||
|
||||
if len(index_candidates) == 0:
|
||||
step_index = len(self.timesteps) - 1
|
||||
# The sigma index that is taken for the **very** first `step`
|
||||
# is always the second index (or the last index if there is only 1)
|
||||
# This way we can ensure we don't accidentally skip a sigma in
|
||||
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
||||
elif len(index_candidates) > 1:
|
||||
step_index = index_candidates[1].item()
|
||||
else:
|
||||
step_index = index_candidates[0].item()
|
||||
|
||||
self._step_index = step_index
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
@@ -826,15 +660,21 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
||||
)
|
||||
|
||||
if self.step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
step_index = (self.timesteps == timestep).nonzero()
|
||||
if len(step_index) == 0:
|
||||
step_index = len(self.timesteps) - 1
|
||||
else:
|
||||
step_index = step_index.item()
|
||||
prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1]
|
||||
|
||||
model_output = self.convert_model_output(model_output, sample=sample)
|
||||
model_output = self.convert_model_output(model_output, timestep, sample)
|
||||
for i in range(self.config.solver_order - 1):
|
||||
self.model_outputs[i] = self.model_outputs[i + 1]
|
||||
self.model_outputs[-1] = model_output
|
||||
|
||||
order = self.order_list[self.step_index]
|
||||
order = self.order_list[step_index]
|
||||
|
||||
# For img2img denoising might start with order>1 which is not possible
|
||||
# In this case make sure that the first two steps are both order=1
|
||||
@@ -845,10 +685,10 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
if order == 1:
|
||||
self.sample = sample
|
||||
|
||||
prev_sample = self.singlestep_dpm_solver_update(self.model_outputs, sample=self.sample, order=order)
|
||||
|
||||
# upon completion increase step index by one
|
||||
self._step_index += 1
|
||||
timestep_list = [self.timesteps[step_index - i] for i in range(order - 1, 0, -1)] + [timestep]
|
||||
prev_sample = self.singlestep_dpm_solver_update(
|
||||
self.model_outputs, timestep_list, prev_timestep, self.sample, order
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
@@ -870,30 +710,28 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
return sample
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
|
||||
def add_noise(
|
||||
self,
|
||||
original_samples: torch.FloatTensor,
|
||||
noise: torch.FloatTensor,
|
||||
timesteps: torch.FloatTensor,
|
||||
timesteps: torch.IntTensor,
|
||||
) -> torch.FloatTensor:
|
||||
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
||||
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
||||
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
||||
# mps does not support float64
|
||||
schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
|
||||
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
|
||||
else:
|
||||
schedule_timesteps = self.timesteps.to(original_samples.device)
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
|
||||
alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
|
||||
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
||||
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
|
||||
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
||||
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
|
||||
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
|
||||
|
||||
sigma = sigmas[step_indices].flatten()
|
||||
while len(sigma.shape) < len(original_samples.shape):
|
||||
sigma = sigma.unsqueeze(-1)
|
||||
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
|
||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
||||
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
|
||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
|
||||
|
||||
noisy_samples = original_samples + noise * sigma
|
||||
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
||||
return noisy_samples
|
||||
|
||||
def __len__(self):
|
||||
|
||||
@@ -89,9 +89,6 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
`linear` or `scaled_linear`.
|
||||
trained_betas (`np.ndarray`, *optional*):
|
||||
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
|
||||
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
|
||||
the sigmas are determined according to a sequence of noise levels {σi}.
|
||||
prediction_type (`str`, defaults to `epsilon`, *optional*):
|
||||
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
|
||||
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
|
||||
@@ -116,7 +113,6 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
beta_end: float = 0.012,
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||
use_karras_sigmas: Optional[bool] = False,
|
||||
prediction_type: str = "epsilon",
|
||||
timestep_spacing: str = "linspace",
|
||||
steps_offset: int = 0,
|
||||
@@ -247,15 +243,9 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
)
|
||||
|
||||
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
||||
log_sigmas = np.log(sigmas)
|
||||
self.log_sigmas = torch.from_numpy(np.log(sigmas)).to(device)
|
||||
|
||||
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
|
||||
|
||||
if self.config.use_karras_sigmas:
|
||||
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
||||
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
|
||||
|
||||
self.log_sigmas = torch.from_numpy(log_sigmas).to(device)
|
||||
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
|
||||
sigmas = torch.from_numpy(sigmas).to(device=device)
|
||||
|
||||
@@ -279,13 +269,7 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.sigmas_down = torch.cat([sigmas_down[:1], sigmas_down[1:].repeat_interleave(2), sigmas_down[-1:]])
|
||||
|
||||
timesteps = torch.from_numpy(timesteps).to(device)
|
||||
sigmas_interpol = sigmas_interpol.cpu()
|
||||
log_sigmas = self.log_sigmas.cpu()
|
||||
timesteps_interpol = np.array(
|
||||
[self._sigma_to_t(sigma_interpol, log_sigmas) for sigma_interpol in sigmas_interpol]
|
||||
)
|
||||
|
||||
timesteps_interpol = torch.from_numpy(timesteps_interpol).to(device, dtype=timesteps.dtype)
|
||||
timesteps_interpol = self.sigma_to_t(sigmas_interpol).to(device, dtype=timesteps.dtype)
|
||||
interleaved_timesteps = torch.stack((timesteps_interpol[:-2, None], timesteps[1:, None]), dim=-1).flatten()
|
||||
|
||||
self.timesteps = torch.cat([timesteps[:1], interleaved_timesteps])
|
||||
@@ -298,44 +282,29 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
self._step_index = None
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
|
||||
def _sigma_to_t(self, sigma, log_sigmas):
|
||||
def sigma_to_t(self, sigma):
|
||||
# get log sigma
|
||||
log_sigma = np.log(sigma)
|
||||
log_sigma = sigma.log()
|
||||
|
||||
# get distribution
|
||||
dists = log_sigma - log_sigmas[:, np.newaxis]
|
||||
dists = log_sigma - self.log_sigmas[:, None]
|
||||
|
||||
# get sigmas range
|
||||
low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
|
||||
low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2)
|
||||
high_idx = low_idx + 1
|
||||
|
||||
low = log_sigmas[low_idx]
|
||||
high = log_sigmas[high_idx]
|
||||
low = self.log_sigmas[low_idx]
|
||||
high = self.log_sigmas[high_idx]
|
||||
|
||||
# interpolate sigmas
|
||||
w = (low - log_sigma) / (low - high)
|
||||
w = np.clip(w, 0, 1)
|
||||
w = w.clamp(0, 1)
|
||||
|
||||
# transform interpolation to time range
|
||||
t = (1 - w) * low_idx + w * high_idx
|
||||
t = t.reshape(sigma.shape)
|
||||
t = t.view(sigma.shape)
|
||||
return t
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
|
||||
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
|
||||
"""Constructs the noise schedule of Karras et al. (2022)."""
|
||||
|
||||
sigma_min: float = in_sigmas[-1].item()
|
||||
sigma_max: float = in_sigmas[0].item()
|
||||
|
||||
rho = 7.0 # 7.0 is the value used in the paper
|
||||
ramp = np.linspace(0, 1, num_inference_steps)
|
||||
min_inv_rho = sigma_min ** (1 / rho)
|
||||
max_inv_rho = sigma_max ** (1 / rho)
|
||||
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
||||
return sigmas
|
||||
|
||||
@property
|
||||
def state_in_first_order(self):
|
||||
return self.sample is None
|
||||
|
||||
@@ -88,9 +88,6 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
`linear` or `scaled_linear`.
|
||||
trained_betas (`np.ndarray`, *optional*):
|
||||
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
|
||||
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
|
||||
the sigmas are determined according to a sequence of noise levels {σi}.
|
||||
prediction_type (`str`, defaults to `epsilon`, *optional*):
|
||||
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
|
||||
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
|
||||
@@ -115,7 +112,6 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
beta_end: float = 0.012,
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||
use_karras_sigmas: Optional[bool] = False,
|
||||
prediction_type: str = "epsilon",
|
||||
timestep_spacing: str = "linspace",
|
||||
steps_offset: int = 0,
|
||||
@@ -247,14 +243,9 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
)
|
||||
|
||||
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
||||
log_sigmas = np.log(sigmas)
|
||||
self.log_sigmas = torch.from_numpy(np.log(sigmas)).to(device)
|
||||
|
||||
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
|
||||
|
||||
if self.config.use_karras_sigmas:
|
||||
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
||||
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
|
||||
|
||||
self.log_sigmas = torch.from_numpy(log_sigmas).to(device=device)
|
||||
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
|
||||
sigmas = torch.from_numpy(sigmas).to(device=device)
|
||||
|
||||
@@ -269,12 +260,7 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
timesteps = torch.from_numpy(timesteps).to(device)
|
||||
|
||||
# interpolate timesteps
|
||||
sigmas_interpol = sigmas_interpol.cpu()
|
||||
log_sigmas = self.log_sigmas.cpu()
|
||||
timesteps_interpol = np.array(
|
||||
[self._sigma_to_t(sigma_interpol, log_sigmas) for sigma_interpol in sigmas_interpol]
|
||||
)
|
||||
timesteps_interpol = torch.from_numpy(timesteps_interpol).to(device, dtype=timesteps.dtype)
|
||||
timesteps_interpol = self.sigma_to_t(sigmas_interpol).to(device, dtype=timesteps.dtype)
|
||||
interleaved_timesteps = torch.stack((timesteps_interpol[1:-1, None], timesteps[1:, None]), dim=-1).flatten()
|
||||
|
||||
self.timesteps = torch.cat([timesteps[:1], interleaved_timesteps])
|
||||
@@ -287,6 +273,29 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
self._step_index = None
|
||||
|
||||
def sigma_to_t(self, sigma):
|
||||
# get log sigma
|
||||
log_sigma = sigma.log()
|
||||
|
||||
# get distribution
|
||||
dists = log_sigma - self.log_sigmas[:, None]
|
||||
|
||||
# get sigmas range
|
||||
low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2)
|
||||
high_idx = low_idx + 1
|
||||
|
||||
low = self.log_sigmas[low_idx]
|
||||
high = self.log_sigmas[high_idx]
|
||||
|
||||
# interpolate sigmas
|
||||
w = (low - log_sigma) / (low - high)
|
||||
w = w.clamp(0, 1)
|
||||
|
||||
# transform interpolation to time range
|
||||
t = (1 - w) * low_idx + w * high_idx
|
||||
t = t.view(sigma.shape)
|
||||
return t
|
||||
|
||||
@property
|
||||
def state_in_first_order(self):
|
||||
return self.sample is None
|
||||
@@ -309,44 +318,6 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
self._step_index = step_index.item()
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
|
||||
def _sigma_to_t(self, sigma, log_sigmas):
|
||||
# get log sigma
|
||||
log_sigma = np.log(sigma)
|
||||
|
||||
# get distribution
|
||||
dists = log_sigma - log_sigmas[:, np.newaxis]
|
||||
|
||||
# get sigmas range
|
||||
low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
|
||||
high_idx = low_idx + 1
|
||||
|
||||
low = log_sigmas[low_idx]
|
||||
high = log_sigmas[high_idx]
|
||||
|
||||
# interpolate sigmas
|
||||
w = (low - log_sigma) / (low - high)
|
||||
w = np.clip(w, 0, 1)
|
||||
|
||||
# transform interpolation to time range
|
||||
t = (1 - w) * low_idx + w * high_idx
|
||||
t = t.reshape(sigma.shape)
|
||||
return t
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
|
||||
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
|
||||
"""Constructs the noise schedule of Karras et al. (2022)."""
|
||||
|
||||
sigma_min: float = in_sigmas[-1].item()
|
||||
sigma_max: float = in_sigmas[0].item()
|
||||
|
||||
rho = 7.0 # 7.0 is the value used in the paper
|
||||
ramp = np.linspace(0, 1, num_inference_steps)
|
||||
min_inv_rho = sigma_min ** (1 / rho)
|
||||
max_inv_rho = sigma_max ** (1 / rho)
|
||||
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
||||
return sigmas
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: Union[torch.FloatTensor, np.ndarray],
|
||||
|
||||
@@ -22,16 +22,10 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import deprecate
|
||||
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
|
||||
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
|
||||
def betas_for_alpha_bar(
|
||||
num_diffusion_timesteps,
|
||||
max_beta=0.999,
|
||||
alpha_transform_type="cosine",
|
||||
):
|
||||
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
|
||||
"""
|
||||
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
|
||||
(1-beta) over time from t = [0,1].
|
||||
@@ -44,30 +38,19 @@ def betas_for_alpha_bar(
|
||||
num_diffusion_timesteps (`int`): the number of betas to produce.
|
||||
max_beta (`float`): the maximum beta to use; use values lower than 1 to
|
||||
prevent singularities.
|
||||
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
|
||||
Choose from `cosine` or `exp`
|
||||
|
||||
Returns:
|
||||
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
|
||||
"""
|
||||
if alpha_transform_type == "cosine":
|
||||
|
||||
def alpha_bar_fn(t):
|
||||
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
|
||||
elif alpha_transform_type == "exp":
|
||||
|
||||
def alpha_bar_fn(t):
|
||||
return math.exp(t * -12.0)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
|
||||
def alpha_bar(time_step):
|
||||
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
|
||||
betas = []
|
||||
for i in range(num_diffusion_timesteps):
|
||||
t1 = i / num_diffusion_timesteps
|
||||
t2 = (i + 1) / num_diffusion_timesteps
|
||||
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
|
||||
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
||||
return torch.tensor(betas, dtype=torch.float32)
|
||||
|
||||
|
||||
@@ -198,14 +181,6 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.disable_corrector = disable_corrector
|
||||
self.solver_p = solver_p
|
||||
self.last_sample = None
|
||||
self._step_index = None
|
||||
|
||||
@property
|
||||
def step_index(self):
|
||||
"""
|
||||
The index counter for current timestep. It will increae 1 after each scheduler step.
|
||||
"""
|
||||
return self._step_index
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
||||
"""
|
||||
@@ -245,16 +220,17 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
||||
if self.config.use_karras_sigmas:
|
||||
log_sigmas = np.log(sigmas)
|
||||
sigmas = np.flip(sigmas).copy()
|
||||
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
||||
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
|
||||
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
|
||||
else:
|
||||
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
|
||||
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
|
||||
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
|
||||
timesteps = np.flip(timesteps).copy().astype(np.int64)
|
||||
|
||||
self.sigmas = torch.from_numpy(sigmas)
|
||||
|
||||
# when num_inference_steps == num_train_timesteps, we can end up with
|
||||
# duplicates in timesteps.
|
||||
_, unique_indices = np.unique(timesteps, return_index=True)
|
||||
timesteps = timesteps[np.sort(unique_indices)]
|
||||
|
||||
self.sigmas = torch.from_numpy(sigmas).to(device=device)
|
||||
self.timesteps = torch.from_numpy(timesteps).to(device)
|
||||
|
||||
self.num_inference_steps = len(timesteps)
|
||||
@@ -267,9 +243,6 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
if self.solver_p:
|
||||
self.solver_p.set_timesteps(self.num_inference_steps, device=device)
|
||||
|
||||
# add an index counter for schedulers that allow duplicated timesteps
|
||||
self._step_index = None
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
||||
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||
"""
|
||||
@@ -329,13 +302,6 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
t = t.reshape(sigma.shape)
|
||||
return t
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
|
||||
def _sigma_to_alpha_sigma_t(self, sigma):
|
||||
alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
|
||||
sigma_t = sigma * alpha_t
|
||||
|
||||
return alpha_t, sigma_t
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
|
||||
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
|
||||
"""Constructs the noise schedule of Karras et al. (2022)."""
|
||||
@@ -351,11 +317,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
return sigmas
|
||||
|
||||
def convert_model_output(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
*args,
|
||||
sample: torch.FloatTensor = None,
|
||||
**kwargs,
|
||||
self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
|
||||
) -> torch.FloatTensor:
|
||||
r"""
|
||||
Convert the model output to the corresponding type the UniPC algorithm needs.
|
||||
@@ -372,28 +334,14 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
`torch.FloatTensor`:
|
||||
The converted model output.
|
||||
"""
|
||||
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
|
||||
if sample is None:
|
||||
if len(args) > 1:
|
||||
sample = args[1]
|
||||
else:
|
||||
raise ValueError("missing `sample` as a required keyward argument")
|
||||
if timestep is not None:
|
||||
deprecate(
|
||||
"timesteps",
|
||||
"1.0.0",
|
||||
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
sigma = self.sigmas[self.step_index]
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
||||
|
||||
if self.predict_x0:
|
||||
if self.config.prediction_type == "epsilon":
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
x0_pred = (sample - sigma_t * model_output) / alpha_t
|
||||
elif self.config.prediction_type == "sample":
|
||||
x0_pred = model_output
|
||||
elif self.config.prediction_type == "v_prediction":
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
x0_pred = alpha_t * sample - sigma_t * model_output
|
||||
else:
|
||||
raise ValueError(
|
||||
@@ -409,9 +357,11 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
if self.config.prediction_type == "epsilon":
|
||||
return model_output
|
||||
elif self.config.prediction_type == "sample":
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
epsilon = (sample - alpha_t * model_output) / sigma_t
|
||||
return epsilon
|
||||
elif self.config.prediction_type == "v_prediction":
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
epsilon = alpha_t * model_output + sigma_t * sample
|
||||
return epsilon
|
||||
else:
|
||||
@@ -423,10 +373,9 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
def multistep_uni_p_bh_update(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
*args,
|
||||
sample: torch.FloatTensor = None,
|
||||
order: int = None,
|
||||
**kwargs,
|
||||
prev_timestep: int,
|
||||
sample: torch.FloatTensor,
|
||||
order: int,
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
One step for the UniP (B(h) version). Alternatively, `self.solver_p` is used if is specified.
|
||||
@@ -445,26 +394,10 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
`torch.FloatTensor`:
|
||||
The sample tensor at the previous timestep.
|
||||
"""
|
||||
prev_timestep = args[0] if len(args) > 0 else kwargs.pop("prev_timestep", None)
|
||||
if sample is None:
|
||||
if len(args) > 1:
|
||||
sample = args[1]
|
||||
else:
|
||||
raise ValueError(" missing `sample` as a required keyward argument")
|
||||
if order is None:
|
||||
if len(args) > 2:
|
||||
order = args[2]
|
||||
else:
|
||||
raise ValueError(" missing `order` as a required keyward argument")
|
||||
if prev_timestep is not None:
|
||||
deprecate(
|
||||
"prev_timestep",
|
||||
"1.0.0",
|
||||
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
timestep_list = self.timestep_list
|
||||
model_output_list = self.model_outputs
|
||||
|
||||
s0 = self.timestep_list[-1]
|
||||
s0, t = self.timestep_list[-1], prev_timestep
|
||||
m0 = model_output_list[-1]
|
||||
x = sample
|
||||
|
||||
@@ -472,12 +405,9 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
x_t = self.solver_p.step(model_output, s0, x).prev_sample
|
||||
return x_t
|
||||
|
||||
sigma_t, sigma_s0 = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
||||
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
|
||||
|
||||
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
|
||||
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
|
||||
lambda_t, lambda_s0 = self.lambda_t[t], self.lambda_t[s0]
|
||||
alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
|
||||
sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
|
||||
|
||||
h = lambda_t - lambda_s0
|
||||
device = sample.device
|
||||
@@ -485,10 +415,9 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
rks = []
|
||||
D1s = []
|
||||
for i in range(1, order):
|
||||
si = self.step_index - i
|
||||
si = timestep_list[-(i + 1)]
|
||||
mi = model_output_list[-(i + 1)]
|
||||
alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
|
||||
lambda_si = torch.log(alpha_si) - torch.log(sigma_si)
|
||||
lambda_si = self.lambda_t[si]
|
||||
rk = (lambda_si - lambda_s0) / h
|
||||
rks.append(rk)
|
||||
D1s.append((mi - m0) / rk)
|
||||
@@ -552,11 +481,10 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
def multistep_uni_c_bh_update(
|
||||
self,
|
||||
this_model_output: torch.FloatTensor,
|
||||
*args,
|
||||
last_sample: torch.FloatTensor = None,
|
||||
this_sample: torch.FloatTensor = None,
|
||||
order: int = None,
|
||||
**kwargs,
|
||||
this_timestep: int,
|
||||
last_sample: torch.FloatTensor,
|
||||
this_sample: torch.FloatTensor,
|
||||
order: int,
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
One step for the UniC (B(h) version).
|
||||
@@ -577,42 +505,18 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
`torch.FloatTensor`:
|
||||
The corrected sample tensor at the current timestep.
|
||||
"""
|
||||
this_timestep = args[0] if len(args) > 0 else kwargs.pop("this_timestep", None)
|
||||
if last_sample is None:
|
||||
if len(args) > 1:
|
||||
last_sample = args[1]
|
||||
else:
|
||||
raise ValueError(" missing`last_sample` as a required keyward argument")
|
||||
if this_sample is None:
|
||||
if len(args) > 2:
|
||||
this_sample = args[2]
|
||||
else:
|
||||
raise ValueError(" missing`this_sample` as a required keyward argument")
|
||||
if order is None:
|
||||
if len(args) > 3:
|
||||
order = args[3]
|
||||
else:
|
||||
raise ValueError(" missing`order` as a required keyward argument")
|
||||
if this_timestep is not None:
|
||||
deprecate(
|
||||
"this_timestep",
|
||||
"1.0.0",
|
||||
"Passing `this_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
timestep_list = self.timestep_list
|
||||
model_output_list = self.model_outputs
|
||||
|
||||
s0, t = timestep_list[-1], this_timestep
|
||||
m0 = model_output_list[-1]
|
||||
x = last_sample
|
||||
x_t = this_sample
|
||||
model_t = this_model_output
|
||||
|
||||
sigma_t, sigma_s0 = self.sigmas[self.step_index], self.sigmas[self.step_index - 1]
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
||||
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
|
||||
|
||||
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
|
||||
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
|
||||
lambda_t, lambda_s0 = self.lambda_t[t], self.lambda_t[s0]
|
||||
alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
|
||||
sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
|
||||
|
||||
h = lambda_t - lambda_s0
|
||||
device = this_sample.device
|
||||
@@ -620,10 +524,9 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
rks = []
|
||||
D1s = []
|
||||
for i in range(1, order):
|
||||
si = self.step_index - (i + 1)
|
||||
si = timestep_list[-(i + 1)]
|
||||
mi = model_output_list[-(i + 1)]
|
||||
alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
|
||||
lambda_si = torch.log(alpha_si) - torch.log(sigma_si)
|
||||
lambda_si = self.lambda_t[si]
|
||||
rk = (lambda_si - lambda_s0) / h
|
||||
rks.append(rk)
|
||||
D1s.append((mi - m0) / rk)
|
||||
@@ -686,25 +589,6 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
x_t = x_t.to(x.dtype)
|
||||
return x_t
|
||||
|
||||
def _init_step_index(self, timestep):
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
|
||||
index_candidates = (self.timesteps == timestep).nonzero()
|
||||
|
||||
if len(index_candidates) == 0:
|
||||
step_index = len(self.timesteps) - 1
|
||||
# The sigma index that is taken for the **very** first `step`
|
||||
# is always the second index (or the last index if there is only 1)
|
||||
# This way we can ensure we don't accidentally skip a sigma in
|
||||
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
||||
elif len(index_candidates) > 1:
|
||||
step_index = index_candidates[1].item()
|
||||
else:
|
||||
step_index = index_candidates[0].item()
|
||||
|
||||
self._step_index = step_index
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
@@ -732,27 +616,37 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
tuple is returned where the first element is the sample tensor.
|
||||
|
||||
"""
|
||||
|
||||
if self.num_inference_steps is None:
|
||||
raise ValueError(
|
||||
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
||||
)
|
||||
|
||||
if self.step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
step_index = (self.timesteps == timestep).nonzero()
|
||||
if len(step_index) == 0:
|
||||
step_index = len(self.timesteps) - 1
|
||||
else:
|
||||
step_index = step_index.item()
|
||||
|
||||
use_corrector = (
|
||||
self.step_index > 0 and self.step_index - 1 not in self.disable_corrector and self.last_sample is not None
|
||||
step_index > 0 and step_index - 1 not in self.disable_corrector and self.last_sample is not None
|
||||
)
|
||||
|
||||
model_output_convert = self.convert_model_output(model_output, sample=sample)
|
||||
model_output_convert = self.convert_model_output(model_output, timestep, sample)
|
||||
if use_corrector:
|
||||
sample = self.multistep_uni_c_bh_update(
|
||||
this_model_output=model_output_convert,
|
||||
this_timestep=timestep,
|
||||
last_sample=self.last_sample,
|
||||
this_sample=sample,
|
||||
order=self.this_order,
|
||||
)
|
||||
|
||||
# now prepare to run the predictor
|
||||
prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1]
|
||||
|
||||
for i in range(self.config.solver_order - 1):
|
||||
self.model_outputs[i] = self.model_outputs[i + 1]
|
||||
self.timestep_list[i] = self.timestep_list[i + 1]
|
||||
@@ -761,7 +655,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.timestep_list[-1] = timestep
|
||||
|
||||
if self.config.lower_order_final:
|
||||
this_order = min(self.config.solver_order, len(self.timesteps) - self.step_index)
|
||||
this_order = min(self.config.solver_order, len(self.timesteps) - step_index)
|
||||
else:
|
||||
this_order = self.config.solver_order
|
||||
|
||||
@@ -771,6 +665,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.last_sample = sample
|
||||
prev_sample = self.multistep_uni_p_bh_update(
|
||||
model_output=model_output, # pass the original non-converted model output, in case solver-p is used
|
||||
prev_timestep=prev_timestep,
|
||||
sample=sample,
|
||||
order=self.this_order,
|
||||
)
|
||||
@@ -778,9 +673,6 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
if self.lower_order_nums < self.config.solver_order:
|
||||
self.lower_order_nums += 1
|
||||
|
||||
# upon completion increase step index by one
|
||||
self._step_index += 1
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
|
||||
@@ -801,30 +693,28 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
return sample
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
|
||||
def add_noise(
|
||||
self,
|
||||
original_samples: torch.FloatTensor,
|
||||
noise: torch.FloatTensor,
|
||||
timesteps: torch.FloatTensor,
|
||||
timesteps: torch.IntTensor,
|
||||
) -> torch.FloatTensor:
|
||||
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
||||
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
||||
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
||||
# mps does not support float64
|
||||
schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
|
||||
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
|
||||
else:
|
||||
schedule_timesteps = self.timesteps.to(original_samples.device)
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
|
||||
alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
|
||||
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
||||
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
|
||||
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
||||
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
|
||||
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
|
||||
|
||||
sigma = sigmas[step_indices].flatten()
|
||||
while len(sigma.shape) < len(original_samples.shape):
|
||||
sigma = sigma.unsqueeze(-1)
|
||||
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
|
||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
||||
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
|
||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
|
||||
|
||||
noisy_samples = original_samples + noise * sigma
|
||||
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
||||
return noisy_samples
|
||||
|
||||
def __len__(self):
|
||||
|
||||
@@ -67,7 +67,6 @@ from .import_utils import (
|
||||
is_note_seq_available,
|
||||
is_omegaconf_available,
|
||||
is_onnx_available,
|
||||
is_peft_available,
|
||||
is_scipy_available,
|
||||
is_tensorboard_available,
|
||||
is_torch_available,
|
||||
@@ -83,16 +82,7 @@ from .import_utils import (
|
||||
from .loading_utils import load_image
|
||||
from .logging import get_logger
|
||||
from .outputs import BaseOutput
|
||||
from .peft_utils import (
|
||||
get_adapter_name,
|
||||
get_rank_and_alpha_pattern,
|
||||
recurse_remove_peft_layers,
|
||||
scale_lora_layers,
|
||||
set_adapter_layers,
|
||||
set_weights_and_activate_adapters,
|
||||
)
|
||||
from .pil_utils import PIL_INTERPOLATION, make_image_grid, numpy_to_pil, pt_to_pil
|
||||
from .state_dict_utils import convert_state_dict_to_diffusers, convert_state_dict_to_peft
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -315,36 +315,6 @@ class AutoPipelineForText2Image(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class BlipDiffusionControlNetPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class BlipDiffusionPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class CLIPImageProjection(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
@@ -267,14 +267,6 @@ except importlib_metadata.PackageNotFoundError:
|
||||
_invisible_watermark_available = False
|
||||
|
||||
|
||||
_peft_available = importlib.util.find_spec("peft") is not None
|
||||
try:
|
||||
_peft_version = importlib_metadata.version("peft")
|
||||
logger.debug(f"Successfully imported accelerate version {_peft_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_peft_available = False
|
||||
|
||||
|
||||
def is_torch_available():
|
||||
return _torch_available
|
||||
|
||||
@@ -359,10 +351,6 @@ def is_invisible_watermark_available():
|
||||
return _invisible_watermark_available
|
||||
|
||||
|
||||
def is_peft_available():
|
||||
return _peft_available
|
||||
|
||||
|
||||
# docstyle-ignore
|
||||
FLAX_IMPORT_ERROR = """
|
||||
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
|
||||
|
||||
@@ -1,138 +0,0 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
PEFT utilities: Utilities related to peft library
|
||||
"""
|
||||
from .import_utils import is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
def recurse_remove_peft_layers(model):
|
||||
r"""
|
||||
Recursively replace all instances of `LoraLayer` with corresponding new layers in `model`.
|
||||
"""
|
||||
from peft.tuners.lora import LoraLayer
|
||||
|
||||
for name, module in model.named_children():
|
||||
if len(list(module.children())) > 0:
|
||||
## compound module, go inside it
|
||||
recurse_remove_peft_layers(module)
|
||||
|
||||
module_replaced = False
|
||||
|
||||
if isinstance(module, LoraLayer) and isinstance(module, torch.nn.Linear):
|
||||
new_module = torch.nn.Linear(module.in_features, module.out_features, bias=module.bias is not None).to(
|
||||
module.weight.device
|
||||
)
|
||||
new_module.weight = module.weight
|
||||
if module.bias is not None:
|
||||
new_module.bias = module.bias
|
||||
|
||||
module_replaced = True
|
||||
elif isinstance(module, LoraLayer) and isinstance(module, torch.nn.Conv2d):
|
||||
new_module = torch.nn.Conv2d(
|
||||
module.in_channels,
|
||||
module.out_channels,
|
||||
module.kernel_size,
|
||||
module.stride,
|
||||
module.padding,
|
||||
module.dilation,
|
||||
module.groups,
|
||||
module.bias,
|
||||
).to(module.weight.device)
|
||||
|
||||
new_module.weight = module.weight
|
||||
if module.bias is not None:
|
||||
new_module.bias = module.bias
|
||||
|
||||
module_replaced = True
|
||||
|
||||
if module_replaced:
|
||||
setattr(model, name, new_module)
|
||||
del module
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def scale_lora_layers(model, lora_weightage):
|
||||
from peft.tuners.tuner_utils import BaseTunerLayer
|
||||
|
||||
for module in model.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
module.scale_layer(lora_weightage)
|
||||
|
||||
|
||||
def get_rank_and_alpha_pattern(rank_dict, network_alpha_dict, peft_state_dict):
|
||||
rank_pattern = None
|
||||
alpha_pattern = None
|
||||
r = lora_alpha = list(rank_dict.values())[0]
|
||||
if len(set(rank_dict.values())) > 1:
|
||||
# get the rank occuring the most number of times
|
||||
r = max(set(rank_dict.values()), key=list(rank_dict.values()).count)
|
||||
|
||||
# for modules with rank different from the most occuring rank, add it to the `rank_pattern`
|
||||
rank_pattern = dict(filter(lambda x: x[1] != r, rank_dict.items()))
|
||||
rank_pattern = {k.split(".lora_B.")[0]: v for k, v in rank_pattern.items()}
|
||||
|
||||
if network_alpha_dict is not None and len(set(network_alpha_dict.values())) > 1:
|
||||
# get the alpha occuring the most number of times
|
||||
lora_alpha = max(set(network_alpha_dict.values()), key=list(network_alpha_dict.values()).count)
|
||||
|
||||
# for modules with alpha different from the most occuring alpha, add it to the `alpha_pattern`
|
||||
alpha_pattern = dict(filter(lambda x: x[1] != lora_alpha, network_alpha_dict.items()))
|
||||
alpha_pattern = {".".join(k.split(".down.")[0].split(".")[:-1]): v for k, v in alpha_pattern.items()}
|
||||
|
||||
# layer names without the Diffusers specific
|
||||
target_modules = {name.split(".lora")[0] for name in peft_state_dict.keys()}
|
||||
|
||||
return r, lora_alpha, rank_pattern, alpha_pattern, target_modules
|
||||
|
||||
|
||||
def get_adapter_name(model):
|
||||
from peft.tuners.tuner_utils import BaseTunerLayer
|
||||
|
||||
for module in model.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
return f"default_{len(module.r)}"
|
||||
return "default_0"
|
||||
|
||||
|
||||
def set_adapter_layers(model, enabled=True):
|
||||
from peft.tuners.tuner_utils import BaseTunerLayer
|
||||
|
||||
for module in model.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
module.disable_adapters = False if enabled else True
|
||||
|
||||
|
||||
def set_weights_and_activate_adapters(model, adapter_names, weights):
|
||||
from peft.tuners.tuner_utils import BaseTunerLayer
|
||||
|
||||
# iterate over each adapter, make it active and set the corresponding scaling weight
|
||||
for adapter_name, weight in zip(adapter_names, weights):
|
||||
for module in model.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
module.active_adapter = adapter_name
|
||||
module.scale_layer(weight)
|
||||
|
||||
# set multiple active adapters
|
||||
for module in model.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
module.active_adapter = adapter_names
|
||||
@@ -1,180 +0,0 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
State dict utilities: utility methods for converting state dicts easily
|
||||
"""
|
||||
import enum
|
||||
|
||||
|
||||
class StateDictType(enum.Enum):
|
||||
"""
|
||||
The mode to use when converting state dicts.
|
||||
"""
|
||||
|
||||
DIFFUSERS_OLD = "diffusers_old"
|
||||
# KOHYA_SS = "kohya_ss" # TODO: implement this
|
||||
PEFT = "peft"
|
||||
DIFFUSERS = "diffusers"
|
||||
|
||||
|
||||
DIFFUSERS_TO_PEFT = {
|
||||
".q_proj.lora_linear_layer.up": ".q_proj.lora_B",
|
||||
".q_proj.lora_linear_layer.down": ".q_proj.lora_A",
|
||||
".k_proj.lora_linear_layer.up": ".k_proj.lora_B",
|
||||
".k_proj.lora_linear_layer.down": ".k_proj.lora_A",
|
||||
".v_proj.lora_linear_layer.up": ".v_proj.lora_B",
|
||||
".v_proj.lora_linear_layer.down": ".v_proj.lora_A",
|
||||
".out_proj.lora_linear_layer.up": ".out_proj.lora_B",
|
||||
".out_proj.lora_linear_layer.down": ".out_proj.lora_A",
|
||||
}
|
||||
|
||||
DIFFUSERS_OLD_TO_PEFT = {
|
||||
".to_q_lora.up": ".q_proj.lora_B",
|
||||
".to_q_lora.down": ".q_proj.lora_A",
|
||||
".to_k_lora.up": ".k_proj.lora_B",
|
||||
".to_k_lora.down": ".k_proj.lora_A",
|
||||
".to_v_lora.up": ".v_proj.lora_B",
|
||||
".to_v_lora.down": ".v_proj.lora_A",
|
||||
".to_out_lora.up": ".out_proj.lora_B",
|
||||
".to_out_lora.down": ".out_proj.lora_A",
|
||||
}
|
||||
|
||||
PEFT_TO_DIFFUSERS = {
|
||||
".q_proj.lora_B": ".q_proj.lora_linear_layer.up",
|
||||
".q_proj.lora_A": ".q_proj.lora_linear_layer.down",
|
||||
".k_proj.lora_B": ".k_proj.lora_linear_layer.up",
|
||||
".k_proj.lora_A": ".k_proj.lora_linear_layer.down",
|
||||
".v_proj.lora_B": ".v_proj.lora_linear_layer.up",
|
||||
".v_proj.lora_A": ".v_proj.lora_linear_layer.down",
|
||||
".out_proj.lora_B": ".out_proj.lora_linear_layer.up",
|
||||
".out_proj.lora_A": ".out_proj.lora_linear_layer.down",
|
||||
}
|
||||
|
||||
DIFFUSERS_OLD_TO_DIFFUSERS = {
|
||||
".to_q_lora.up": ".q_proj.lora_linear_layer.up",
|
||||
".to_q_lora.down": ".q_proj.lora_linear_layer.down",
|
||||
".to_k_lora.up": ".k_proj.lora_linear_layer.up",
|
||||
".to_k_lora.down": ".k_proj.lora_linear_layer.down",
|
||||
".to_v_lora.up": ".v_proj.lora_linear_layer.up",
|
||||
".to_v_lora.down": ".v_proj.lora_linear_layer.down",
|
||||
".to_out_lora.up": ".out_proj.lora_linear_layer.up",
|
||||
".to_out_lora.down": ".out_proj.lora_linear_layer.down",
|
||||
}
|
||||
|
||||
PEFT_STATE_DICT_MAPPINGS = {
|
||||
StateDictType.DIFFUSERS_OLD: DIFFUSERS_OLD_TO_PEFT,
|
||||
StateDictType.DIFFUSERS: DIFFUSERS_TO_PEFT,
|
||||
}
|
||||
|
||||
DIFFUSERS_STATE_DICT_MAPPINGS = {
|
||||
StateDictType.DIFFUSERS_OLD: DIFFUSERS_OLD_TO_DIFFUSERS,
|
||||
StateDictType.PEFT: PEFT_TO_DIFFUSERS,
|
||||
}
|
||||
|
||||
|
||||
def convert_state_dict(state_dict, mapping):
|
||||
r"""
|
||||
Simply iterates over the state dict and replaces the patterns in `mapping` with the corresponding values.
|
||||
|
||||
Args:
|
||||
state_dict (`dict[str, torch.Tensor]`):
|
||||
The state dict to convert.
|
||||
mapping (`dict[str, str]`):
|
||||
The mapping to use for conversion, the mapping should be a dictionary with the following structure:
|
||||
- key: the pattern to replace
|
||||
- value: the pattern to replace with
|
||||
|
||||
Returns:
|
||||
converted_state_dict (`dict`)
|
||||
The converted state dict.
|
||||
"""
|
||||
converted_state_dict = {}
|
||||
for k, v in state_dict.items():
|
||||
if any(pattern in k for pattern in mapping.keys()):
|
||||
for old, new in mapping.items():
|
||||
k = k.replace(old, new)
|
||||
|
||||
converted_state_dict[k] = v
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def convert_state_dict_to_peft(state_dict, original_type=None, **kwargs):
|
||||
r"""
|
||||
Converts a state dict to the PEFT format The state dict can be from previous diffusers format (`OLD_DIFFUSERS`), or
|
||||
new diffusers format (`DIFFUSERS`). The method only supports the conversion from diffusers old/new to PEFT for now.
|
||||
|
||||
Args:
|
||||
state_dict (`dict[str, torch.Tensor]`):
|
||||
The state dict to convert.
|
||||
original_type (`StateDictType`, *optional*):
|
||||
The original type of the state dict, if not provided, the method will try to infer it automatically.
|
||||
"""
|
||||
if original_type is None:
|
||||
# Old diffusers to PEFT
|
||||
if any("to_out_lora" in k for k in state_dict.keys()):
|
||||
original_type = StateDictType.DIFFUSERS_OLD
|
||||
elif any("lora_linear_layer" in k for k in state_dict.keys()):
|
||||
original_type = StateDictType.DIFFUSERS
|
||||
else:
|
||||
raise ValueError("Could not automatically infer state dict type")
|
||||
|
||||
if original_type not in PEFT_STATE_DICT_MAPPINGS.keys():
|
||||
raise ValueError(f"Original type {original_type} is not supported")
|
||||
|
||||
mapping = PEFT_STATE_DICT_MAPPINGS[original_type]
|
||||
return convert_state_dict(state_dict, mapping)
|
||||
|
||||
|
||||
def convert_state_dict_to_diffusers(state_dict, original_type=None, **kwargs):
|
||||
r"""
|
||||
Converts a state dict to new diffusers format. The state dict can be from previous diffusers format
|
||||
(`OLD_DIFFUSERS`), or PEFT format (`PEFT`) or new diffusers format (`DIFFUSERS`). In the last case the method will
|
||||
return the state dict as is.
|
||||
|
||||
The method only supports the conversion from diffusers old, PEFT to diffusers new for now.
|
||||
|
||||
Args:
|
||||
state_dict (`dict[str, torch.Tensor]`):
|
||||
The state dict to convert.
|
||||
original_type (`StateDictType`, *optional*):
|
||||
The original type of the state dict, if not provided, the method will try to infer it automatically.
|
||||
kwargs (`dict`, *args*):
|
||||
Additional arguments to pass to the method.
|
||||
|
||||
- **adapter_name**: For example, in case of PEFT, some keys will be pre-pended
|
||||
with the adapter name, therefore needs a special handling. By default PEFT also takes care of that in
|
||||
`get_peft_model_state_dict` method:
|
||||
https://github.com/huggingface/peft/blob/ba0477f2985b1ba311b83459d29895c809404e99/src/peft/utils/save_and_load.py#L92
|
||||
but we add it here in case we don't want to rely on that method.
|
||||
"""
|
||||
peft_adapter_name = kwargs.pop("adapter_name", "")
|
||||
peft_adapter_name = "." + peft_adapter_name
|
||||
|
||||
if original_type is None:
|
||||
# Old diffusers to PEFT
|
||||
if any("to_out_lora" in k for k in state_dict.keys()):
|
||||
original_type = StateDictType.DIFFUSERS_OLD
|
||||
elif any(f".lora_A{peft_adapter_name}.weight" in k for k in state_dict.keys()):
|
||||
original_type = StateDictType.PEFT
|
||||
elif any("lora_linear_layer" in k for k in state_dict.keys()):
|
||||
# nothing to do
|
||||
return state_dict
|
||||
else:
|
||||
raise ValueError("Could not automatically infer state dict type")
|
||||
|
||||
if original_type not in DIFFUSERS_STATE_DICT_MAPPINGS.keys():
|
||||
raise ValueError(f"Original type {original_type} is not supported")
|
||||
|
||||
mapping = DIFFUSERS_STATE_DICT_MAPPINGS[original_type]
|
||||
return convert_state_dict(state_dict, mapping)
|
||||
@@ -1876,25 +1876,6 @@ class LoraIntegrationTests(unittest.TestCase):
|
||||
|
||||
self.assertTrue(np.allclose(images, expected, atol=1e-3))
|
||||
|
||||
def test_lycoris(self):
|
||||
generator = torch.Generator().manual_seed(0)
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
"hf-internal-testing/Amixx", safety_checker=None, use_safetensors=True, variant="fp16"
|
||||
).to(torch_device)
|
||||
lora_model_id = "hf-internal-testing/edgLycorisMugler-light"
|
||||
lora_filename = "edgLycorisMugler-light.safetensors"
|
||||
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
|
||||
|
||||
images = pipe(
|
||||
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
|
||||
).images
|
||||
|
||||
images = images[0, -3:, -3:, -1].flatten()
|
||||
expected = np.array([0.6463, 0.658, 0.599, 0.6542, 0.6512, 0.6213, 0.658, 0.6485, 0.6017])
|
||||
|
||||
self.assertTrue(np.allclose(images, expected, atol=1e-3))
|
||||
|
||||
def test_a1111_with_model_cpu_offload(self):
|
||||
generator = torch.Generator().manual_seed(0)
|
||||
|
||||
@@ -1,147 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDIMScheduler,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.loaders import AttnProcsLayers
|
||||
from diffusers.models.attention_processor import (
|
||||
LoRAAttnProcessor,
|
||||
LoRAAttnProcessor2_0,
|
||||
)
|
||||
from diffusers.utils.testing_utils import floats_tensor
|
||||
|
||||
|
||||
def create_unet_lora_layers(unet: nn.Module):
|
||||
lora_attn_procs = {}
|
||||
for name in unet.attn_processors.keys():
|
||||
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
|
||||
if name.startswith("mid_block"):
|
||||
hidden_size = unet.config.block_out_channels[-1]
|
||||
elif name.startswith("up_blocks"):
|
||||
block_id = int(name[len("up_blocks.")])
|
||||
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
|
||||
elif name.startswith("down_blocks"):
|
||||
block_id = int(name[len("down_blocks.")])
|
||||
hidden_size = unet.config.block_out_channels[block_id]
|
||||
lora_attn_processor_class = (
|
||||
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
|
||||
)
|
||||
lora_attn_procs[name] = lora_attn_processor_class(
|
||||
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
|
||||
)
|
||||
unet_lora_layers = AttnProcsLayers(lora_attn_procs)
|
||||
return lora_attn_procs, unet_lora_layers
|
||||
|
||||
|
||||
class LoraLoaderMixinTests(unittest.TestCase):
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
unet = UNet2DConditionModel(
|
||||
block_out_channels=(32, 64),
|
||||
layers_per_block=2,
|
||||
sample_size=32,
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
|
||||
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
|
||||
cross_attention_dim=32,
|
||||
)
|
||||
scheduler = DDIMScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
clip_sample=False,
|
||||
set_alpha_to_one=False,
|
||||
steps_offset=1,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKL(
|
||||
block_out_channels=[32, 64],
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||
latent_channels=4,
|
||||
)
|
||||
text_encoder_config = CLIPTextConfig(
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
hidden_size=32,
|
||||
intermediate_size=37,
|
||||
layer_norm_eps=1e-05,
|
||||
num_attention_heads=4,
|
||||
num_hidden_layers=5,
|
||||
pad_token_id=1,
|
||||
vocab_size=1000,
|
||||
)
|
||||
text_encoder = CLIPTextModel(text_encoder_config)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
unet_lora_attn_procs, unet_lora_layers = create_unet_lora_layers(unet)
|
||||
|
||||
pipeline_components = {
|
||||
"unet": unet,
|
||||
"scheduler": scheduler,
|
||||
"vae": vae,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"safety_checker": None,
|
||||
"feature_extractor": None,
|
||||
}
|
||||
lora_components = {
|
||||
"unet_lora_layers": unet_lora_layers,
|
||||
"unet_lora_attn_procs": unet_lora_attn_procs,
|
||||
}
|
||||
return pipeline_components, lora_components
|
||||
|
||||
def get_dummy_inputs(self, with_generator=True):
|
||||
batch_size = 1
|
||||
sequence_length = 10
|
||||
num_channels = 4
|
||||
sizes = (32, 32)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
noise = floats_tensor((batch_size, num_channels) + sizes)
|
||||
input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator)
|
||||
|
||||
pipeline_inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 6.0,
|
||||
"output_type": "np",
|
||||
}
|
||||
if with_generator:
|
||||
pipeline_inputs.update({"generator": generator})
|
||||
|
||||
return noise, input_ids, pipeline_inputs
|
||||
|
||||
# copied from: https://colab.research.google.com/gist/sayakpaul/df2ef6e1ae6d8c10a49d859883b10860/scratchpad.ipynb
|
||||
def get_dummy_tokens(self):
|
||||
max_seq_length = 77
|
||||
|
||||
inputs = torch.randint(2, 56, size=(1, max_seq_length), generator=torch.manual_seed(0))
|
||||
|
||||
prepared_inputs = {}
|
||||
prepared_inputs["input_ids"] = inputs
|
||||
return prepared_inputs
|
||||
@@ -359,7 +359,7 @@ class AudioLDMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
self._test_attention_slicing_forward_pass(test_mean_pixel_difference=False)
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
self._test_inference_batch_single_identical()
|
||||
self._test_inference_batch_single_identical(test_mean_pixel_difference=False)
|
||||
|
||||
@unittest.skipIf(
|
||||
torch_device != "cuda" or not is_xformers_available(),
|
||||
|
||||
@@ -459,7 +459,7 @@ class AudioLDM2PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
# increase tolerance from 1e-4 -> 2e-4 to account for large composite model
|
||||
self._test_inference_batch_single_identical(expected_max_diff=2e-4)
|
||||
self._test_inference_batch_single_identical(test_mean_pixel_difference=False, expected_max_diff=2e-4)
|
||||
|
||||
def test_save_load_local(self):
|
||||
# increase tolerance from 1e-4 -> 2e-4 to account for large composite model
|
||||
|
||||
@@ -1,196 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import CLIPTokenizer
|
||||
from transformers.models.blip_2.configuration_blip_2 import Blip2Config
|
||||
from transformers.models.clip.configuration_clip import CLIPTextConfig
|
||||
|
||||
from diffusers import AutoencoderKL, BlipDiffusionPipeline, PNDMScheduler, UNet2DConditionModel
|
||||
from diffusers.utils.testing_utils import enable_full_determinism
|
||||
from src.diffusers.pipelines.blip_diffusion.blip_image_processing import BlipImageProcessor
|
||||
from src.diffusers.pipelines.blip_diffusion.modeling_blip2 import Blip2QFormerModel
|
||||
from src.diffusers.pipelines.blip_diffusion.modeling_ctx_clip import ContextCLIPTextModel
|
||||
|
||||
from ..test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class BlipDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = BlipDiffusionPipeline
|
||||
params = [
|
||||
"prompt",
|
||||
"reference_image",
|
||||
"source_subject_category",
|
||||
"target_subject_category",
|
||||
]
|
||||
batch_params = [
|
||||
"prompt",
|
||||
"reference_image",
|
||||
"source_subject_category",
|
||||
"target_subject_category",
|
||||
]
|
||||
required_optional_params = [
|
||||
"generator",
|
||||
"height",
|
||||
"width",
|
||||
"latents",
|
||||
"guidance_scale",
|
||||
"num_inference_steps",
|
||||
"neg_prompt",
|
||||
"guidance_scale",
|
||||
"prompt_strength",
|
||||
"prompt_reps",
|
||||
]
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
text_encoder_config = CLIPTextConfig(
|
||||
vocab_size=1000,
|
||||
hidden_size=16,
|
||||
intermediate_size=16,
|
||||
projection_dim=16,
|
||||
num_hidden_layers=1,
|
||||
num_attention_heads=1,
|
||||
max_position_embeddings=77,
|
||||
)
|
||||
text_encoder = ContextCLIPTextModel(text_encoder_config)
|
||||
|
||||
vae = AutoencoderKL(
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
down_block_types=("DownEncoderBlock2D",),
|
||||
up_block_types=("UpDecoderBlock2D",),
|
||||
block_out_channels=(32,),
|
||||
layers_per_block=1,
|
||||
act_fn="silu",
|
||||
latent_channels=4,
|
||||
norm_num_groups=16,
|
||||
sample_size=16,
|
||||
)
|
||||
|
||||
blip_vision_config = {
|
||||
"hidden_size": 16,
|
||||
"intermediate_size": 16,
|
||||
"num_hidden_layers": 1,
|
||||
"num_attention_heads": 1,
|
||||
"image_size": 224,
|
||||
"patch_size": 14,
|
||||
"hidden_act": "quick_gelu",
|
||||
}
|
||||
|
||||
blip_qformer_config = {
|
||||
"vocab_size": 1000,
|
||||
"hidden_size": 16,
|
||||
"num_hidden_layers": 1,
|
||||
"num_attention_heads": 1,
|
||||
"intermediate_size": 16,
|
||||
"max_position_embeddings": 512,
|
||||
"cross_attention_frequency": 1,
|
||||
"encoder_hidden_size": 16,
|
||||
}
|
||||
qformer_config = Blip2Config(
|
||||
vision_config=blip_vision_config,
|
||||
qformer_config=blip_qformer_config,
|
||||
num_query_tokens=16,
|
||||
tokenizer="hf-internal-testing/tiny-random-bert",
|
||||
)
|
||||
qformer = Blip2QFormerModel(qformer_config)
|
||||
|
||||
unet = UNet2DConditionModel(
|
||||
block_out_channels=(16, 32),
|
||||
norm_num_groups=16,
|
||||
layers_per_block=1,
|
||||
sample_size=16,
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
|
||||
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
|
||||
cross_attention_dim=16,
|
||||
)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
scheduler = PNDMScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
set_alpha_to_one=False,
|
||||
skip_prk_steps=True,
|
||||
)
|
||||
|
||||
vae.eval()
|
||||
qformer.eval()
|
||||
text_encoder.eval()
|
||||
|
||||
image_processor = BlipImageProcessor()
|
||||
|
||||
components = {
|
||||
"text_encoder": text_encoder,
|
||||
"vae": vae,
|
||||
"qformer": qformer,
|
||||
"unet": unet,
|
||||
"tokenizer": tokenizer,
|
||||
"scheduler": scheduler,
|
||||
"image_processor": image_processor,
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
np.random.seed(seed)
|
||||
reference_image = np.random.rand(32, 32, 3) * 255
|
||||
reference_image = Image.fromarray(reference_image.astype("uint8")).convert("RGBA")
|
||||
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
inputs = {
|
||||
"prompt": "swimming underwater",
|
||||
"generator": generator,
|
||||
"reference_image": reference_image,
|
||||
"source_subject_category": "dog",
|
||||
"target_subject_category": "dog",
|
||||
"height": 32,
|
||||
"width": 32,
|
||||
"guidance_scale": 7.5,
|
||||
"num_inference_steps": 2,
|
||||
"output_type": "np",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_blipdiffusion(self):
|
||||
device = "cpu"
|
||||
components = self.get_dummy_components()
|
||||
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(device)
|
||||
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
image = pipe(**self.get_dummy_inputs(device))[0]
|
||||
image_slice = image[0, -3:, -3:, 0]
|
||||
|
||||
assert image.shape == (1, 16, 16, 4)
|
||||
|
||||
expected_slice = np.array([0.7096, 0.5900, 0.6703, 0.4032, 0.7766, 0.3629, 0.5447, 0.4149, 0.8172])
|
||||
|
||||
assert (
|
||||
np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
), f" expected_slice {image_slice.flatten()}, but got {image_slice.flatten()}"
|
||||
@@ -1,216 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import CLIPTokenizer
|
||||
from transformers.models.blip_2.configuration_blip_2 import Blip2Config
|
||||
from transformers.models.clip.configuration_clip import CLIPTextConfig
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
BlipDiffusionControlNetPipeline,
|
||||
ControlNetModel,
|
||||
PNDMScheduler,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.utils.testing_utils import enable_full_determinism
|
||||
from src.diffusers.pipelines.blip_diffusion.blip_image_processing import BlipImageProcessor
|
||||
from src.diffusers.pipelines.blip_diffusion.modeling_blip2 import Blip2QFormerModel
|
||||
from src.diffusers.pipelines.blip_diffusion.modeling_ctx_clip import ContextCLIPTextModel
|
||||
|
||||
from ..test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class BlipDiffusionControlNetPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = BlipDiffusionControlNetPipeline
|
||||
params = [
|
||||
"prompt",
|
||||
"reference_image",
|
||||
"source_subject_category",
|
||||
"target_subject_category",
|
||||
"condtioning_image",
|
||||
]
|
||||
batch_params = [
|
||||
"prompt",
|
||||
"reference_image",
|
||||
"source_subject_category",
|
||||
"target_subject_category",
|
||||
"condtioning_image",
|
||||
]
|
||||
required_optional_params = [
|
||||
"generator",
|
||||
"height",
|
||||
"width",
|
||||
"latents",
|
||||
"guidance_scale",
|
||||
"num_inference_steps",
|
||||
"neg_prompt",
|
||||
"guidance_scale",
|
||||
"prompt_strength",
|
||||
"prompt_reps",
|
||||
]
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
text_encoder_config = CLIPTextConfig(
|
||||
vocab_size=1000,
|
||||
hidden_size=16,
|
||||
intermediate_size=16,
|
||||
projection_dim=16,
|
||||
num_hidden_layers=1,
|
||||
num_attention_heads=1,
|
||||
max_position_embeddings=77,
|
||||
)
|
||||
text_encoder = ContextCLIPTextModel(text_encoder_config)
|
||||
|
||||
vae = AutoencoderKL(
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
down_block_types=("DownEncoderBlock2D",),
|
||||
up_block_types=("UpDecoderBlock2D",),
|
||||
block_out_channels=(32,),
|
||||
layers_per_block=1,
|
||||
act_fn="silu",
|
||||
latent_channels=4,
|
||||
norm_num_groups=16,
|
||||
sample_size=16,
|
||||
)
|
||||
|
||||
blip_vision_config = {
|
||||
"hidden_size": 16,
|
||||
"intermediate_size": 16,
|
||||
"num_hidden_layers": 1,
|
||||
"num_attention_heads": 1,
|
||||
"image_size": 224,
|
||||
"patch_size": 14,
|
||||
"hidden_act": "quick_gelu",
|
||||
}
|
||||
|
||||
blip_qformer_config = {
|
||||
"vocab_size": 1000,
|
||||
"hidden_size": 16,
|
||||
"num_hidden_layers": 1,
|
||||
"num_attention_heads": 1,
|
||||
"intermediate_size": 16,
|
||||
"max_position_embeddings": 512,
|
||||
"cross_attention_frequency": 1,
|
||||
"encoder_hidden_size": 16,
|
||||
}
|
||||
qformer_config = Blip2Config(
|
||||
vision_config=blip_vision_config,
|
||||
qformer_config=blip_qformer_config,
|
||||
num_query_tokens=16,
|
||||
tokenizer="hf-internal-testing/tiny-random-bert",
|
||||
)
|
||||
qformer = Blip2QFormerModel(qformer_config)
|
||||
|
||||
unet = UNet2DConditionModel(
|
||||
block_out_channels=(4, 16),
|
||||
layers_per_block=1,
|
||||
norm_num_groups=4,
|
||||
sample_size=16,
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
|
||||
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
|
||||
cross_attention_dim=16,
|
||||
)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
scheduler = PNDMScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
set_alpha_to_one=False,
|
||||
skip_prk_steps=True,
|
||||
)
|
||||
controlnet = ControlNetModel(
|
||||
block_out_channels=(4, 16),
|
||||
layers_per_block=1,
|
||||
in_channels=4,
|
||||
norm_num_groups=4,
|
||||
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
|
||||
cross_attention_dim=16,
|
||||
conditioning_embedding_out_channels=(8, 16),
|
||||
)
|
||||
|
||||
vae.eval()
|
||||
qformer.eval()
|
||||
text_encoder.eval()
|
||||
|
||||
image_processor = BlipImageProcessor()
|
||||
|
||||
components = {
|
||||
"text_encoder": text_encoder,
|
||||
"vae": vae,
|
||||
"qformer": qformer,
|
||||
"unet": unet,
|
||||
"tokenizer": tokenizer,
|
||||
"scheduler": scheduler,
|
||||
"controlnet": controlnet,
|
||||
"image_processor": image_processor,
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
np.random.seed(seed)
|
||||
reference_image = np.random.rand(32, 32, 3) * 255
|
||||
reference_image = Image.fromarray(reference_image.astype("uint8")).convert("RGBA")
|
||||
cond_image = np.random.rand(32, 32, 3) * 255
|
||||
cond_image = Image.fromarray(cond_image.astype("uint8")).convert("RGBA")
|
||||
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
inputs = {
|
||||
"prompt": "swimming underwater",
|
||||
"generator": generator,
|
||||
"reference_image": reference_image,
|
||||
"condtioning_image": cond_image,
|
||||
"source_subject_category": "dog",
|
||||
"target_subject_category": "dog",
|
||||
"height": 32,
|
||||
"width": 32,
|
||||
"guidance_scale": 7.5,
|
||||
"num_inference_steps": 2,
|
||||
"output_type": "np",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_blipdiffusion_controlnet(self):
|
||||
device = "cpu"
|
||||
components = self.get_dummy_components()
|
||||
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(device)
|
||||
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
image = pipe(**self.get_dummy_inputs(device))[0]
|
||||
image_slice = image[0, -3:, -3:, 0]
|
||||
|
||||
assert image.shape == (1, 16, 16, 4)
|
||||
expected_slice = np.array([0.7953, 0.7136, 0.6597, 0.4779, 0.7389, 0.4111, 0.5826, 0.4150, 0.8422])
|
||||
|
||||
assert (
|
||||
np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
|
||||
@@ -96,7 +96,7 @@ class DiTPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
self.assertLessEqual(max_diff, 1e-3)
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
self._test_inference_batch_single_identical(expected_max_diff=1e-3)
|
||||
self._test_inference_batch_single_identical(relax_max_difference=True, expected_max_diff=1e-3)
|
||||
|
||||
@unittest.skipIf(
|
||||
torch_device != "cuda" or not is_xformers_available(),
|
||||
|
||||
@@ -224,7 +224,15 @@ class KandinskyPriorPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
|
||||
@skip_mps
|
||||
def test_inference_batch_single_identical(self):
|
||||
self._test_inference_batch_single_identical(expected_max_diff=1e-2)
|
||||
test_max_difference = torch_device == "cpu"
|
||||
relax_max_difference = True
|
||||
test_mean_pixel_difference = False
|
||||
|
||||
self._test_inference_batch_single_identical(
|
||||
test_max_difference=test_max_difference,
|
||||
relax_max_difference=relax_max_difference,
|
||||
test_mean_pixel_difference=test_mean_pixel_difference,
|
||||
)
|
||||
|
||||
@skip_mps
|
||||
def test_attention_slicing_forward_pass(self):
|
||||
|
||||
@@ -224,7 +224,15 @@ class KandinskyV22PriorPipelineFastTests(PipelineTesterMixin, unittest.TestCase)
|
||||
|
||||
@skip_mps
|
||||
def test_inference_batch_single_identical(self):
|
||||
self._test_inference_batch_single_identical(expected_max_diff=1e-3)
|
||||
test_max_difference = torch_device == "cpu"
|
||||
relax_max_difference = True
|
||||
test_mean_pixel_difference = False
|
||||
|
||||
self._test_inference_batch_single_identical(
|
||||
test_max_difference=test_max_difference,
|
||||
relax_max_difference=relax_max_difference,
|
||||
test_mean_pixel_difference=test_mean_pixel_difference,
|
||||
)
|
||||
|
||||
@skip_mps
|
||||
def test_attention_slicing_forward_pass(self):
|
||||
|
||||
@@ -234,7 +234,15 @@ class KandinskyV22PriorEmb2EmbPipelineFastTests(PipelineTesterMixin, unittest.Te
|
||||
|
||||
@skip_mps
|
||||
def test_inference_batch_single_identical(self):
|
||||
self._test_inference_batch_single_identical(expected_max_diff=1e-2)
|
||||
test_max_difference = torch_device == "cpu"
|
||||
relax_max_difference = True
|
||||
test_mean_pixel_difference = False
|
||||
|
||||
self._test_inference_batch_single_identical(
|
||||
test_max_difference=test_max_difference,
|
||||
relax_max_difference=relax_max_difference,
|
||||
test_mean_pixel_difference=test_mean_pixel_difference,
|
||||
)
|
||||
|
||||
@skip_mps
|
||||
def test_attention_slicing_forward_pass(self):
|
||||
|
||||
@@ -373,7 +373,7 @@ class MusicLDMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
self._test_attention_slicing_forward_pass(test_mean_pixel_difference=False)
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
self._test_inference_batch_single_identical()
|
||||
self._test_inference_batch_single_identical(test_mean_pixel_difference=False)
|
||||
|
||||
@unittest.skipIf(
|
||||
torch_device != "cuda" or not is_xformers_available(),
|
||||
|
||||
@@ -44,11 +44,11 @@ class ShapEPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
|
||||
@property
|
||||
def text_embedder_hidden_size(self):
|
||||
return 16
|
||||
return 32
|
||||
|
||||
@property
|
||||
def time_input_dim(self):
|
||||
return 16
|
||||
return 32
|
||||
|
||||
@property
|
||||
def time_embed_dim(self):
|
||||
@@ -201,7 +201,14 @@ class ShapEPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
self._test_inference_batch_consistent(batch_sizes=[1, 2])
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
self._test_inference_batch_single_identical(batch_size=2, expected_max_diff=6e-3)
|
||||
test_max_difference = torch_device == "cpu"
|
||||
relax_max_difference = True
|
||||
|
||||
self._test_inference_batch_single_identical(
|
||||
batch_size=2,
|
||||
test_max_difference=test_max_difference,
|
||||
relax_max_difference=relax_max_difference,
|
||||
)
|
||||
|
||||
def test_num_images_per_prompt(self):
|
||||
components = self.get_dummy_components()
|
||||
|
||||
@@ -52,11 +52,11 @@ class ShapEImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
|
||||
@property
|
||||
def text_embedder_hidden_size(self):
|
||||
return 16
|
||||
return 32
|
||||
|
||||
@property
|
||||
def time_input_dim(self):
|
||||
return 16
|
||||
return 32
|
||||
|
||||
@property
|
||||
def time_embed_dim(self):
|
||||
@@ -71,10 +71,10 @@ class ShapEImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
torch.manual_seed(0)
|
||||
config = CLIPVisionConfig(
|
||||
hidden_size=self.text_embedder_hidden_size,
|
||||
image_size=32,
|
||||
image_size=64,
|
||||
projection_dim=self.text_embedder_hidden_size,
|
||||
intermediate_size=24,
|
||||
num_attention_heads=2,
|
||||
intermediate_size=37,
|
||||
num_attention_heads=4,
|
||||
num_channels=3,
|
||||
num_hidden_layers=5,
|
||||
patch_size=1,
|
||||
@@ -170,7 +170,7 @@ class ShapEImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
input_image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
|
||||
input_image = floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(device)
|
||||
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
@@ -219,12 +219,15 @@ class ShapEImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
|
||||
def test_inference_batch_consistent(self):
|
||||
# NOTE: Larger batch sizes cause this test to timeout, only test on smaller batches
|
||||
self._test_inference_batch_consistent(batch_sizes=[2])
|
||||
self._test_inference_batch_consistent(batch_sizes=[1, 2])
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
test_max_difference = torch_device == "cpu"
|
||||
relax_max_difference = True
|
||||
self._test_inference_batch_single_identical(
|
||||
batch_size=2,
|
||||
expected_max_diff=5e-3,
|
||||
test_max_difference=test_max_difference,
|
||||
relax_max_difference=relax_max_difference,
|
||||
)
|
||||
|
||||
def test_num_images_per_prompt(self):
|
||||
|
||||
@@ -499,28 +499,6 @@ class StableDiffusionPipelineFastTests(
|
||||
negative_prompt = None
|
||||
num_images_per_prompt = 1
|
||||
logger = logging.get_logger("diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion")
|
||||
logger.setLevel(logging.WARNING)
|
||||
|
||||
prompt = 100 * "@"
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
negative_text_embeddings, text_embeddings = sd_pipe.encode_prompt(
|
||||
prompt, torch_device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
|
||||
)
|
||||
if negative_text_embeddings is not None:
|
||||
text_embeddings = torch.cat([negative_text_embeddings, text_embeddings])
|
||||
|
||||
# 100 - 77 + 1 (BOS token) + 1 (EOS token) = 25
|
||||
assert cap_logger.out.count("@") == 25
|
||||
|
||||
negative_prompt = "Hello"
|
||||
with CaptureLogger(logger) as cap_logger_2:
|
||||
negative_text_embeddings_2, text_embeddings_2 = sd_pipe.encode_prompt(
|
||||
prompt, torch_device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
|
||||
)
|
||||
if negative_text_embeddings_2 is not None:
|
||||
text_embeddings_2 = torch.cat([negative_text_embeddings_2, text_embeddings_2])
|
||||
|
||||
assert cap_logger.out == cap_logger_2.out
|
||||
|
||||
prompt = 25 * "@"
|
||||
with CaptureLogger(logger) as cap_logger_3:
|
||||
@@ -530,8 +508,28 @@ class StableDiffusionPipelineFastTests(
|
||||
if negative_text_embeddings_3 is not None:
|
||||
text_embeddings_3 = torch.cat([negative_text_embeddings_3, text_embeddings_3])
|
||||
|
||||
prompt = 100 * "@"
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
negative_text_embeddings, text_embeddings = sd_pipe.encode_prompt(
|
||||
prompt, torch_device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
|
||||
)
|
||||
if negative_text_embeddings is not None:
|
||||
text_embeddings = torch.cat([negative_text_embeddings, text_embeddings])
|
||||
|
||||
negative_prompt = "Hello"
|
||||
with CaptureLogger(logger) as cap_logger_2:
|
||||
negative_text_embeddings_2, text_embeddings_2 = sd_pipe.encode_prompt(
|
||||
prompt, torch_device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
|
||||
)
|
||||
if negative_text_embeddings_2 is not None:
|
||||
text_embeddings_2 = torch.cat([negative_text_embeddings_2, text_embeddings_2])
|
||||
|
||||
assert text_embeddings_3.shape == text_embeddings_2.shape == text_embeddings.shape
|
||||
assert text_embeddings.shape[1] == 77
|
||||
|
||||
assert cap_logger.out == cap_logger_2.out
|
||||
# 100 - 77 + 1 (BOS token) + 1 (EOS token) = 25
|
||||
assert cap_logger.out.count("@") == 25
|
||||
assert cap_logger_3.out == ""
|
||||
|
||||
def test_stable_diffusion_height_width_opt(self):
|
||||
|
||||
@@ -250,7 +250,6 @@ class StableDiffusion2PipelineFastTests(
|
||||
negative_prompt = None
|
||||
num_images_per_prompt = 1
|
||||
logger = logging.get_logger("diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion")
|
||||
logger.setLevel(logging.WARNING)
|
||||
|
||||
prompt = 25 * "@"
|
||||
with CaptureLogger(logger) as cap_logger_3:
|
||||
|
||||
@@ -20,20 +20,17 @@ import numpy as np
|
||||
import torch
|
||||
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
|
||||
|
||||
import diffusers
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
EulerDiscreteScheduler,
|
||||
MultiAdapter,
|
||||
StableDiffusionXLAdapterPipeline,
|
||||
T2IAdapter,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.utils import logging
|
||||
from diffusers.utils.testing_utils import enable_full_determinism, floats_tensor, torch_device
|
||||
from diffusers.utils.testing_utils import enable_full_determinism, floats_tensor
|
||||
|
||||
from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS
|
||||
from ..test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference
|
||||
from ..test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
@@ -44,7 +41,7 @@ class StableDiffusionXLAdapterPipelineFastTests(PipelineTesterMixin, unittest.Te
|
||||
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS
|
||||
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
|
||||
|
||||
def get_dummy_components(self, adapter_type="full_adapter_xl"):
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
unet = UNet2DConditionModel(
|
||||
block_out_channels=(32, 64),
|
||||
@@ -100,38 +97,13 @@ class StableDiffusionXLAdapterPipelineFastTests(PipelineTesterMixin, unittest.Te
|
||||
|
||||
text_encoder_2 = CLIPTextModelWithProjection(text_encoder_config)
|
||||
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
if adapter_type == "full_adapter_xl":
|
||||
adapter = T2IAdapter(
|
||||
in_channels=3,
|
||||
channels=[32, 64],
|
||||
num_res_blocks=2,
|
||||
downscale_factor=4,
|
||||
adapter_type=adapter_type,
|
||||
)
|
||||
elif adapter_type == "multi_adapter":
|
||||
adapter = MultiAdapter(
|
||||
[
|
||||
T2IAdapter(
|
||||
in_channels=3,
|
||||
channels=[32, 64],
|
||||
num_res_blocks=2,
|
||||
downscale_factor=4,
|
||||
adapter_type="full_adapter_xl",
|
||||
),
|
||||
T2IAdapter(
|
||||
in_channels=3,
|
||||
channels=[32, 64],
|
||||
num_res_blocks=2,
|
||||
downscale_factor=4,
|
||||
adapter_type="full_adapter_xl",
|
||||
),
|
||||
]
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown adapter type: {adapter_type}, must be one of 'full_adapter_xl', or 'multi_adapter''"
|
||||
)
|
||||
|
||||
adapter = T2IAdapter(
|
||||
in_channels=3,
|
||||
channels=[32, 64],
|
||||
num_res_blocks=2,
|
||||
downscale_factor=4,
|
||||
adapter_type="full_adapter_xl",
|
||||
)
|
||||
components = {
|
||||
"adapter": adapter,
|
||||
"unet": unet,
|
||||
@@ -146,12 +118,8 @@ class StableDiffusionXLAdapterPipelineFastTests(PipelineTesterMixin, unittest.Te
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0, num_images=1):
|
||||
if num_images == 1:
|
||||
image = floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(device)
|
||||
else:
|
||||
image = [floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(device) for _ in range(num_images)]
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
image = floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(device)
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
@@ -182,202 +150,3 @@ class StableDiffusionXLAdapterPipelineFastTests(PipelineTesterMixin, unittest.Te
|
||||
[0.5752919, 0.6022097, 0.4728038, 0.49861962, 0.57084894, 0.4644975, 0.5193715, 0.5133664, 0.4729858]
|
||||
)
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-3
|
||||
|
||||
|
||||
class StableDiffusionXLMultiAdapterPipelineFastTests(
|
||||
StableDiffusionXLAdapterPipelineFastTests, PipelineTesterMixin, unittest.TestCase
|
||||
):
|
||||
def get_dummy_components(self):
|
||||
return super().get_dummy_components("multi_adapter")
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
inputs = super().get_dummy_inputs(device, seed, num_images=2)
|
||||
inputs["adapter_conditioning_scale"] = [0.5, 0.5]
|
||||
return inputs
|
||||
|
||||
def test_stable_diffusion_adapter_default_case(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = StableDiffusionXLAdapterPipeline(**components)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = sd_pipe(**inputs).images
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
expected_slice = np.array(
|
||||
[0.5813032, 0.60995954, 0.47563356, 0.5056669, 0.57199144, 0.4631841, 0.5176794, 0.51252556, 0.47183886]
|
||||
)
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-3
|
||||
|
||||
def test_inference_batch_consistent(
|
||||
self, batch_sizes=[2, 4, 13], additional_params_copy_to_batched_inputs=["num_inference_steps"]
|
||||
):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
logger = logging.get_logger(pipe.__module__)
|
||||
logger.setLevel(level=diffusers.logging.FATAL)
|
||||
|
||||
# batchify inputs
|
||||
for batch_size in batch_sizes:
|
||||
batched_inputs = {}
|
||||
for name, value in inputs.items():
|
||||
if name in self.batch_params:
|
||||
# prompt is string
|
||||
if name == "prompt":
|
||||
len_prompt = len(value)
|
||||
# make unequal batch sizes
|
||||
batched_inputs[name] = [value[: len_prompt // i] for i in range(1, batch_size + 1)]
|
||||
|
||||
# make last batch super long
|
||||
batched_inputs[name][-1] = 100 * "very long"
|
||||
elif name == "image":
|
||||
batched_images = []
|
||||
|
||||
for image in value:
|
||||
batched_images.append(batch_size * [image])
|
||||
|
||||
batched_inputs[name] = batched_images
|
||||
else:
|
||||
batched_inputs[name] = batch_size * [value]
|
||||
|
||||
elif name == "batch_size":
|
||||
batched_inputs[name] = batch_size
|
||||
else:
|
||||
batched_inputs[name] = value
|
||||
|
||||
for arg in additional_params_copy_to_batched_inputs:
|
||||
batched_inputs[arg] = inputs[arg]
|
||||
|
||||
batched_inputs["output_type"] = "np"
|
||||
|
||||
output = pipe(**batched_inputs)
|
||||
|
||||
assert len(output[0]) == batch_size
|
||||
|
||||
batched_inputs["output_type"] = "np"
|
||||
|
||||
output = pipe(**batched_inputs)[0]
|
||||
|
||||
assert output.shape[0] == batch_size
|
||||
|
||||
logger.setLevel(level=diffusers.logging.WARNING)
|
||||
|
||||
def test_num_images_per_prompt(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
batch_sizes = [1, 2]
|
||||
num_images_per_prompts = [1, 2]
|
||||
|
||||
for batch_size in batch_sizes:
|
||||
for num_images_per_prompt in num_images_per_prompts:
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
for key in inputs.keys():
|
||||
if key in self.batch_params:
|
||||
if key == "image":
|
||||
batched_images = []
|
||||
|
||||
for image in inputs[key]:
|
||||
batched_images.append(batch_size * [image])
|
||||
|
||||
inputs[key] = batched_images
|
||||
else:
|
||||
inputs[key] = batch_size * [inputs[key]]
|
||||
|
||||
images = pipe(**inputs, num_images_per_prompt=num_images_per_prompt)[0]
|
||||
|
||||
assert images.shape[0] == batch_size * num_images_per_prompt
|
||||
|
||||
def test_inference_batch_single_identical(
|
||||
self,
|
||||
batch_size=3,
|
||||
test_max_difference=None,
|
||||
test_mean_pixel_difference=None,
|
||||
relax_max_difference=False,
|
||||
expected_max_diff=2e-3,
|
||||
additional_params_copy_to_batched_inputs=["num_inference_steps"],
|
||||
):
|
||||
if test_max_difference is None:
|
||||
# TODO(Pedro) - not sure why, but not at all reproducible at the moment it seems
|
||||
# make sure that batched and non-batched is identical
|
||||
test_max_difference = torch_device != "mps"
|
||||
|
||||
if test_mean_pixel_difference is None:
|
||||
# TODO same as above
|
||||
test_mean_pixel_difference = torch_device != "mps"
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
logger = logging.get_logger(pipe.__module__)
|
||||
logger.setLevel(level=diffusers.logging.FATAL)
|
||||
|
||||
# batchify inputs
|
||||
batched_inputs = {}
|
||||
batch_size = batch_size
|
||||
for name, value in inputs.items():
|
||||
if name in self.batch_params:
|
||||
# prompt is string
|
||||
if name == "prompt":
|
||||
len_prompt = len(value)
|
||||
# make unequal batch sizes
|
||||
batched_inputs[name] = [value[: len_prompt // i] for i in range(1, batch_size + 1)]
|
||||
|
||||
# make last batch super long
|
||||
batched_inputs[name][-1] = 100 * "very long"
|
||||
elif name == "image":
|
||||
batched_images = []
|
||||
|
||||
for image in value:
|
||||
batched_images.append(batch_size * [image])
|
||||
|
||||
batched_inputs[name] = batched_images
|
||||
else:
|
||||
batched_inputs[name] = batch_size * [value]
|
||||
elif name == "batch_size":
|
||||
batched_inputs[name] = batch_size
|
||||
elif name == "generator":
|
||||
batched_inputs[name] = [self.get_generator(i) for i in range(batch_size)]
|
||||
else:
|
||||
batched_inputs[name] = value
|
||||
|
||||
for arg in additional_params_copy_to_batched_inputs:
|
||||
batched_inputs[arg] = inputs[arg]
|
||||
|
||||
output_batch = pipe(**batched_inputs)
|
||||
assert output_batch[0].shape[0] == batch_size
|
||||
|
||||
inputs["generator"] = self.get_generator(0)
|
||||
|
||||
output = pipe(**inputs)
|
||||
|
||||
logger.setLevel(level=diffusers.logging.WARNING)
|
||||
if test_max_difference:
|
||||
if relax_max_difference:
|
||||
# Taking the median of the largest <n> differences
|
||||
# is resilient to outliers
|
||||
diff = np.abs(output_batch[0][0] - output[0][0])
|
||||
diff = diff.flatten()
|
||||
diff.sort()
|
||||
max_diff = np.median(diff[-5:])
|
||||
else:
|
||||
max_diff = np.abs(output_batch[0][0] - output[0][0]).max()
|
||||
assert max_diff < expected_max_diff
|
||||
|
||||
if test_mean_pixel_difference:
|
||||
assert_mean_pixel_difference(output_batch[0][0], output[0][0])
|
||||
|
||||
@@ -182,7 +182,9 @@ class StableUnCLIPPipelineFastTests(
|
||||
# Overriding PipelineTesterMixin::test_inference_batch_single_identical
|
||||
# because UnCLIP undeterminism requires a looser check.
|
||||
def test_inference_batch_single_identical(self):
|
||||
self._test_inference_batch_single_identical(expected_max_diff=1e-3)
|
||||
test_max_difference = torch_device in ["cpu", "mps"]
|
||||
|
||||
self._test_inference_batch_single_identical(test_max_difference=test_max_difference)
|
||||
|
||||
|
||||
@slow
|
||||
|
||||
@@ -196,7 +196,9 @@ class StableUnCLIPImg2ImgPipelineFastTests(
|
||||
# Overriding PipelineTesterMixin::test_inference_batch_single_identical
|
||||
# because undeterminism requires a looser check.
|
||||
def test_inference_batch_single_identical(self):
|
||||
self._test_inference_batch_single_identical(expected_max_diff=1e-3)
|
||||
test_max_difference = torch_device in ["cpu", "mps"]
|
||||
|
||||
self._test_inference_batch_single_identical(test_max_difference=test_max_difference)
|
||||
|
||||
@unittest.skipIf(
|
||||
torch_device != "cuda" or not is_xformers_available(),
|
||||
|
||||
@@ -374,11 +374,11 @@ class PipelineTesterMixin:
|
||||
f"Required optional parameters not present: {remaining_required_optional_parameters}",
|
||||
)
|
||||
|
||||
def test_inference_batch_consistent(self, batch_sizes=[2]):
|
||||
def test_inference_batch_consistent(self, batch_sizes=[2, 4, 13]):
|
||||
self._test_inference_batch_consistent(batch_sizes=batch_sizes)
|
||||
|
||||
def _test_inference_batch_consistent(
|
||||
self, batch_sizes=[2], additional_params_copy_to_batched_inputs=["num_inference_steps"]
|
||||
self, batch_sizes=[2, 4, 13], additional_params_copy_to_batched_inputs=["num_inference_steps"]
|
||||
):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
@@ -386,103 +386,137 @@ class PipelineTesterMixin:
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
inputs["generator"] = self.get_generator(0)
|
||||
|
||||
logger = logging.get_logger(pipe.__module__)
|
||||
logger.setLevel(level=diffusers.logging.FATAL)
|
||||
|
||||
# prepare batched inputs
|
||||
batched_inputs = []
|
||||
# batchify inputs
|
||||
for batch_size in batch_sizes:
|
||||
batched_input = {}
|
||||
batched_input.update(inputs)
|
||||
|
||||
for name in self.batch_params:
|
||||
if name not in inputs:
|
||||
continue
|
||||
|
||||
value = inputs[name]
|
||||
if name == "prompt":
|
||||
len_prompt = len(value)
|
||||
# make unequal batch sizes
|
||||
batched_input[name] = [value[: len_prompt // i] for i in range(1, batch_size + 1)]
|
||||
|
||||
# make last batch super long
|
||||
batched_input[name][-1] = 100 * "very long"
|
||||
batched_inputs = {}
|
||||
for name, value in inputs.items():
|
||||
if name in self.batch_params:
|
||||
# prompt is string
|
||||
if name == "prompt":
|
||||
len_prompt = len(value)
|
||||
# make unequal batch sizes
|
||||
batched_inputs[name] = [value[: len_prompt // i] for i in range(1, batch_size + 1)]
|
||||
|
||||
# make last batch super long
|
||||
batched_inputs[name][-1] = 100 * "very long"
|
||||
# or else we have images
|
||||
else:
|
||||
batched_inputs[name] = batch_size * [value]
|
||||
elif name == "batch_size":
|
||||
batched_inputs[name] = batch_size
|
||||
else:
|
||||
batched_input[name] = batch_size * [value]
|
||||
batched_inputs[name] = value
|
||||
|
||||
if "generator" in inputs:
|
||||
batched_input["generator"] = [self.get_generator(i) for i in range(batch_size)]
|
||||
for arg in additional_params_copy_to_batched_inputs:
|
||||
batched_inputs[arg] = inputs[arg]
|
||||
|
||||
if "batch_size" in inputs:
|
||||
batched_input["batch_size"] = batch_size
|
||||
batched_inputs["output_type"] = "np"
|
||||
|
||||
batched_inputs.append(batched_input)
|
||||
if self.pipeline_class.__name__ == "DanceDiffusionPipeline":
|
||||
batched_inputs.pop("output_type")
|
||||
|
||||
output = pipe(**batched_inputs)
|
||||
|
||||
assert len(output[0]) == batch_size
|
||||
|
||||
batched_inputs["output_type"] = "np"
|
||||
|
||||
if self.pipeline_class.__name__ == "DanceDiffusionPipeline":
|
||||
batched_inputs.pop("output_type")
|
||||
|
||||
output = pipe(**batched_inputs)[0]
|
||||
|
||||
assert output.shape[0] == batch_size
|
||||
|
||||
logger.setLevel(level=diffusers.logging.WARNING)
|
||||
for batch_size, batched_input in zip(batch_sizes, batched_inputs):
|
||||
output = pipe(**batched_input)
|
||||
assert len(output[0]) == batch_size
|
||||
|
||||
def test_inference_batch_single_identical(self, batch_size=3, expected_max_diff=1e-4):
|
||||
self._test_inference_batch_single_identical(batch_size=batch_size, expected_max_diff=expected_max_diff)
|
||||
|
||||
def _test_inference_batch_single_identical(
|
||||
self,
|
||||
batch_size=2,
|
||||
batch_size=3,
|
||||
test_max_difference=None,
|
||||
test_mean_pixel_difference=None,
|
||||
relax_max_difference=False,
|
||||
expected_max_diff=1e-4,
|
||||
additional_params_copy_to_batched_inputs=["num_inference_steps"],
|
||||
):
|
||||
if test_max_difference is None:
|
||||
# TODO(Pedro) - not sure why, but not at all reproducible at the moment it seems
|
||||
# make sure that batched and non-batched is identical
|
||||
test_max_difference = torch_device != "mps"
|
||||
|
||||
if test_mean_pixel_difference is None:
|
||||
# TODO same as above
|
||||
test_mean_pixel_difference = torch_device != "mps"
|
||||
|
||||
generator_device = "cpu"
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
for components in pipe.components.values():
|
||||
if hasattr(components, "set_default_attn_processor"):
|
||||
components.set_default_attn_processor()
|
||||
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
# Reset generator in case it is has been used in self.get_dummy_inputs
|
||||
inputs["generator"] = self.get_generator(0)
|
||||
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
|
||||
logger = logging.get_logger(pipe.__module__)
|
||||
logger.setLevel(level=diffusers.logging.FATAL)
|
||||
|
||||
# batchify inputs
|
||||
batched_inputs = {}
|
||||
batched_inputs.update(inputs)
|
||||
|
||||
for name in self.batch_params:
|
||||
if name not in inputs:
|
||||
continue
|
||||
|
||||
value = inputs[name]
|
||||
if name == "prompt":
|
||||
len_prompt = len(value)
|
||||
batched_inputs[name] = [value[: len_prompt // i] for i in range(1, batch_size + 1)]
|
||||
batched_inputs[name][-1] = 100 * "very long"
|
||||
batch_size = batch_size
|
||||
for name, value in inputs.items():
|
||||
if name in self.batch_params:
|
||||
# prompt is string
|
||||
if name == "prompt":
|
||||
len_prompt = len(value)
|
||||
# make unequal batch sizes
|
||||
batched_inputs[name] = [value[: len_prompt // i] for i in range(1, batch_size + 1)]
|
||||
|
||||
# make last batch super long
|
||||
batched_inputs[name][-1] = 100 * "very long"
|
||||
# or else we have images
|
||||
else:
|
||||
batched_inputs[name] = batch_size * [value]
|
||||
elif name == "batch_size":
|
||||
batched_inputs[name] = batch_size
|
||||
elif name == "generator":
|
||||
batched_inputs[name] = [self.get_generator(i) for i in range(batch_size)]
|
||||
else:
|
||||
batched_inputs[name] = batch_size * [value]
|
||||
|
||||
if "generator" in inputs:
|
||||
batched_inputs["generator"] = [self.get_generator(i) for i in range(batch_size)]
|
||||
|
||||
if "batch_size" in inputs:
|
||||
batched_inputs["batch_size"] = batch_size
|
||||
batched_inputs[name] = value
|
||||
|
||||
for arg in additional_params_copy_to_batched_inputs:
|
||||
batched_inputs[arg] = inputs[arg]
|
||||
|
||||
output = pipe(**inputs)
|
||||
output_batch = pipe(**batched_inputs)
|
||||
if self.pipeline_class.__name__ != "DanceDiffusionPipeline":
|
||||
batched_inputs["output_type"] = "np"
|
||||
|
||||
output_batch = pipe(**batched_inputs)
|
||||
assert output_batch[0].shape[0] == batch_size
|
||||
|
||||
max_diff = np.abs(output_batch[0][0] - output[0][0]).max()
|
||||
assert max_diff < expected_max_diff
|
||||
inputs["generator"] = self.get_generator(0)
|
||||
|
||||
output = pipe(**inputs)
|
||||
|
||||
logger.setLevel(level=diffusers.logging.WARNING)
|
||||
if test_max_difference:
|
||||
if relax_max_difference:
|
||||
# Taking the median of the largest <n> differences
|
||||
# is resilient to outliers
|
||||
diff = np.abs(output_batch[0][0] - output[0][0])
|
||||
diff = diff.flatten()
|
||||
diff.sort()
|
||||
max_diff = np.median(diff[-5:])
|
||||
else:
|
||||
max_diff = np.abs(output_batch[0][0] - output[0][0]).max()
|
||||
assert max_diff < expected_max_diff
|
||||
|
||||
if test_mean_pixel_difference:
|
||||
assert_mean_pixel_difference(output_batch[0][0], output[0][0])
|
||||
|
||||
def test_dict_tuple_outputs_equivalent(self, expected_max_difference=1e-4):
|
||||
components = self.get_dummy_components()
|
||||
@@ -494,9 +528,8 @@ class PipelineTesterMixin:
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator_device = "cpu"
|
||||
output = pipe(**self.get_dummy_inputs(generator_device))[0]
|
||||
output_tuple = pipe(**self.get_dummy_inputs(generator_device), return_dict=False)[0]
|
||||
output = pipe(**self.get_dummy_inputs(torch_device))[0]
|
||||
output_tuple = pipe(**self.get_dummy_inputs(torch_device), return_dict=False)[0]
|
||||
|
||||
max_diff = np.abs(to_np(output) - to_np(output_tuple)).max()
|
||||
self.assertLess(max_diff, expected_max_difference)
|
||||
@@ -677,12 +710,11 @@ class PipelineTesterMixin:
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator_device = "cpu"
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
output_without_slicing = pipe(**inputs)[0]
|
||||
|
||||
pipe.enable_attention_slicing(slice_size=1)
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
output_with_slicing = pipe(**inputs)[0]
|
||||
|
||||
if test_max_difference:
|
||||
|
||||
@@ -62,14 +62,14 @@ class TextToVideoSDPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
unet = UNet3DConditionModel(
|
||||
block_out_channels=(32, 32),
|
||||
block_out_channels=(32, 64, 64, 64),
|
||||
layers_per_block=2,
|
||||
sample_size=32,
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
down_block_types=("CrossAttnDownBlock3D", "DownBlock3D"),
|
||||
up_block_types=("UpBlock3D", "CrossAttnUpBlock3D"),
|
||||
cross_attention_dim=4,
|
||||
down_block_types=("CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "DownBlock3D"),
|
||||
up_block_types=("UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D"),
|
||||
cross_attention_dim=32,
|
||||
attention_head_dim=4,
|
||||
)
|
||||
scheduler = DDIMScheduler(
|
||||
@@ -81,27 +81,27 @@ class TextToVideoSDPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKL(
|
||||
block_out_channels=(32,),
|
||||
block_out_channels=[32, 64],
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=["DownEncoderBlock2D"],
|
||||
up_block_types=["UpDecoderBlock2D"],
|
||||
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||
latent_channels=4,
|
||||
sample_size=32,
|
||||
sample_size=128,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
text_encoder_config = CLIPTextConfig(
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
hidden_size=4,
|
||||
intermediate_size=16,
|
||||
hidden_size=32,
|
||||
intermediate_size=37,
|
||||
layer_norm_eps=1e-05,
|
||||
num_attention_heads=2,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
num_hidden_layers=5,
|
||||
pad_token_id=1,
|
||||
vocab_size=1000,
|
||||
hidden_act="gelu",
|
||||
projection_dim=32,
|
||||
projection_dim=512,
|
||||
)
|
||||
text_encoder = CLIPTextModel(text_encoder_config)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
@@ -141,8 +141,8 @@ class TextToVideoSDPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
frames = sd_pipe(**inputs).frames
|
||||
image_slice = frames[0][-3:, -3:, -1]
|
||||
|
||||
assert frames[0].shape == (32, 32, 3)
|
||||
expected_slice = np.array([91.0, 152.0, 66.0, 192.0, 94.0, 126.0, 101.0, 123.0, 152.0])
|
||||
assert frames[0].shape == (64, 64, 3)
|
||||
expected_slice = np.array([158.0, 160.0, 153.0, 125.0, 100.0, 121.0, 111.0, 93.0, 113.0])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user