Compare commits

..

4 Commits

Author SHA1 Message Date
Dhruv Nair 4b682be902 update 2024-03-12 13:08:23 +00:00
Dhruv Nair 3d0bb51d53 update 2024-03-12 13:04:25 +00:00
Dhruv Nair 4b72aae0cd update 2024-03-12 08:52:55 +00:00
Dhruv Nair 33bbe58ea7 update 2024-03-12 05:41:37 +00:00
85 changed files with 674 additions and 1515 deletions
+13 -154
View File
@@ -12,13 +12,13 @@ specific language governing permissions and limitations under the License.
# Stable Cascade
This model is built upon the [Würstchen](https://openreview.net/forum?id=gU58d5QeGv) architecture and its main
difference to other models like Stable Diffusion is that it is working at a much smaller latent space. Why is this
important? The smaller the latent space, the **faster** you can run inference and the **cheaper** the training becomes.
How small is the latent space? Stable Diffusion uses a compression factor of 8, resulting in a 1024x1024 image being
encoded to 128x128. Stable Cascade achieves a compression factor of 42, meaning that it is possible to encode a
1024x1024 image to 24x24, while maintaining crisp reconstructions. The text-conditional model is then trained in the
highly compressed latent space. Previous versions of this architecture, achieved a 16x cost reduction over Stable
This model is built upon the [Würstchen](https://openreview.net/forum?id=gU58d5QeGv) architecture and its main
difference to other models like Stable Diffusion is that it is working at a much smaller latent space. Why is this
important? The smaller the latent space, the **faster** you can run inference and the **cheaper** the training becomes.
How small is the latent space? Stable Diffusion uses a compression factor of 8, resulting in a 1024x1024 image being
encoded to 128x128. Stable Cascade achieves a compression factor of 42, meaning that it is possible to encode a
1024x1024 image to 24x24, while maintaining crisp reconstructions. The text-conditional model is then trained in the
highly compressed latent space. Previous versions of this architecture, achieved a 16x cost reduction over Stable
Diffusion 1.5.
Therefore, this kind of model is well suited for usages where efficiency is important. Furthermore, all known extensions
@@ -30,154 +30,13 @@ The original codebase can be found at [Stability-AI/StableCascade](https://githu
Stable Cascade consists of three models: Stage A, Stage B and Stage C, representing a cascade to generate images,
hence the name "Stable Cascade".
Stage A & B are used to compress images, similar to what the job of the VAE is in Stable Diffusion.
However, with this setup, a much higher compression of images can be achieved. While the Stable Diffusion models use a
spatial compression factor of 8, encoding an image with resolution of 1024 x 1024 to 128 x 128, Stable Cascade achieves
a compression factor of 42. This encodes a 1024 x 1024 image to 24 x 24, while being able to accurately decode the
image. This comes with the great benefit of cheaper training and inference. Furthermore, Stage C is responsible
Stage A & B are used to compress images, similar to what the job of the VAE is in Stable Diffusion.
However, with this setup, a much higher compression of images can be achieved. While the Stable Diffusion models use a
spatial compression factor of 8, encoding an image with resolution of 1024 x 1024 to 128 x 128, Stable Cascade achieves
a compression factor of 42. This encodes a 1024 x 1024 image to 24 x 24, while being able to accurately decode the
image. This comes with the great benefit of cheaper training and inference. Furthermore, Stage C is responsible
for generating the small 24 x 24 latents given a text prompt.
The Stage C model operates on the small 24 x 24 latents and denoises the latents conditioned on text prompts. The model is also the largest component in the Cascade pipeline and is meant to be used with the `StableCascadePriorPipeline`
The Stage B and Stage A models are used with the `StableCascadeDecoderPipeline` and are responsible for generating the final image given the small 24 x 24 latents.
<Tip warning={true}>
There are some restrictions on data types that can be used with the Stable Cascade models. The official checkpoints for the `StableCascadePriorPipeline` do not support the `torch.float16` data type. Please use `torch.bfloat16` instead.
In order to use the `torch.bfloat16` data type with the `StableCascadeDecoderPipeline` you need to have PyTorch 2.2.0 or higher installed. This also means that using the `StableCascadeCombinedPipeline` with `torch.bfloat16` requires PyTorch 2.2.0 or higher, since it calls the `StableCascadeDecoderPipeline` internally.
If it is not possible to install PyTorch 2.2.0 or higher in your environment, the `StableCascadeDecoderPipeline` can be used on its own with the `torch.float16` data type. You can download the full precision or `bf16` variant weights for the pipeline and cast the weights to `torch.float16`.
</Tip>
## Usage example
```python
import torch
from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline
prompt = "an image of a shiba inu, donning a spacesuit and helmet"
negative_prompt = ""
prior = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", variant="bf16", torch_dtype=torch.bfloat16)
decoder = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", variant="bf16", torch_dtype=torch.float16)
prior.enable_model_cpu_offload()
prior_output = prior(
prompt=prompt,
height=1024,
width=1024,
negative_prompt=negative_prompt,
guidance_scale=4.0,
num_images_per_prompt=1,
num_inference_steps=20
)
decoder.enable_model_cpu_offload()
decoder_output = decoder(
image_embeddings=prior_output.image_embeddings.to(torch.float16),
prompt=prompt,
negative_prompt=negative_prompt,
guidance_scale=0.0,
output_type="pil",
num_inference_steps=10
).images[0]
decoder_output.save("cascade.png")
```
## Using the Lite Versions of the Stage B and Stage C models
```python
import torch
from diffusers import (
StableCascadeDecoderPipeline,
StableCascadePriorPipeline,
StableCascadeUNet,
)
prompt = "an image of a shiba inu, donning a spacesuit and helmet"
negative_prompt = ""
prior_unet = StableCascadeUNet.from_pretrained("stabilityai/stable-cascade-prior", subfolder="prior_lite")
decoder_unet = StableCascadeUNet.from_pretrained("stabilityai/stable-cascade", subfolder="decoder_lite")
prior = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", prior=prior_unet)
decoder = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", decoder=decoder_unet)
prior.enable_model_cpu_offload()
prior_output = prior(
prompt=prompt,
height=1024,
width=1024,
negative_prompt=negative_prompt,
guidance_scale=4.0,
num_images_per_prompt=1,
num_inference_steps=20
)
decoder.enable_model_cpu_offload()
decoder_output = decoder(
image_embeddings=prior_output.image_embeddings,
prompt=prompt,
negative_prompt=negative_prompt,
guidance_scale=0.0,
output_type="pil",
num_inference_steps=10
).images[0]
decoder_output.save("cascade.png")
```
## Loading original checkpoints with `from_single_file`
Loading the original format checkpoints is supported via `from_single_file` method in the StableCascadeUNet.
```python
import torch
from diffusers import (
StableCascadeDecoderPipeline,
StableCascadePriorPipeline,
StableCascadeUNet,
)
prompt = "an image of a shiba inu, donning a spacesuit and helmet"
negative_prompt = ""
prior_unet = StableCascadeUNet.from_single_file(
"https://huggingface.co/stabilityai/stable-cascade/resolve/main/stage_c_bf16.safetensors",
torch_dtype=torch.bfloat16
)
decoder_unet = StableCascadeUNet.from_single_file(
"https://huggingface.co/stabilityai/stable-cascade/blob/main/stage_b_bf16.safetensors",
torch_dtype=torch.bfloat16
)
prior = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", prior=prior_unet, torch_dtype=torch.bfloat16)
decoder = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", decoder=decoder_unet, torch_dtype=torch.bfloat16)
prior.enable_model_cpu_offload()
prior_output = prior(
prompt=prompt,
height=1024,
width=1024,
negative_prompt=negative_prompt,
guidance_scale=4.0,
num_images_per_prompt=1,
num_inference_steps=20
)
decoder.enable_model_cpu_offload()
decoder_output = decoder(
image_embeddings=prior_output.image_embeddings,
prompt=prompt,
negative_prompt=negative_prompt,
guidance_scale=0.0,
output_type="pil",
num_inference_steps=10
).images[0]
decoder_output.save("cascade-single-file.png")
```
## Uses
### Direct Use
@@ -194,7 +53,7 @@ Excluded uses are described below.
### Out-of-Scope Use
The model was not trained to be factual or true representations of people or events,
The model was not trained to be factual or true representations of people or events,
and therefore using the model to generate such content is out-of-scope for the abilities of this model.
The model should not be used in any way that violates Stability AI's [Acceptable Use Policy](https://stability.ai/use-policy).
@@ -1215,7 +1215,7 @@ def main(args):
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warning(
logger.warn(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, "
"please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
@@ -1366,14 +1366,14 @@ def main(args):
# Optimizer creation
if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"):
logger.warning(
logger.warn(
f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]."
"Defaulting to adamW"
)
args.optimizer = "adamw"
if args.use_8bit_adam and not args.optimizer.lower() == "adamw":
logger.warning(
logger.warn(
f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was "
f"set to {args.optimizer.lower()}"
)
@@ -1407,11 +1407,11 @@ def main(args):
optimizer_class = prodigyopt.Prodigy
if args.learning_rate <= 0.1:
logger.warning(
logger.warn(
"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
)
if args.train_text_encoder and args.text_encoder_lr:
logger.warning(
logger.warn(
f"Learning rates were provided both for the unet and the text encoder- e.g. text_encoder_lr:"
f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. "
f"When using prodigy only learning_rate is used as the initial learning rate."
@@ -1317,7 +1317,7 @@ def main(args):
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warning(
logger.warn(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, "
"please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
@@ -1522,14 +1522,14 @@ def main(args):
# Optimizer creation
if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"):
logger.warning(
logger.warn(
f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]."
"Defaulting to adamW"
)
args.optimizer = "adamw"
if args.use_8bit_adam and not args.optimizer.lower() == "adamw":
logger.warning(
logger.warn(
f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was "
f"set to {args.optimizer.lower()}"
)
@@ -1563,11 +1563,11 @@ def main(args):
optimizer_class = prodigyopt.Prodigy
if args.learning_rate <= 0.1:
logger.warning(
logger.warn(
"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
)
if args.train_text_encoder and args.text_encoder_lr:
logger.warning(
logger.warn(
f"Learning rates were provided both for the unet and the text encoder- e.g. text_encoder_lr:"
f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. "
f"When using prodigy only learning_rate is used as the initial learning rate."
@@ -452,7 +452,7 @@ class StableDiffusionXLInstantIDPipeline(StableDiffusionXLControlNetPipeline):
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warning(
logger.warn(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
self.enable_xformers_memory_efficient_attention()
@@ -308,7 +308,7 @@ def log_validation(vae, unet, args, accelerator, weight_dtype, step):
tracker.log({"validation": formatted_images})
else:
logger.warning(f"image logging not implemented for {tracker.name}")
logger.warn(f"image logging not implemented for {tracker.name}")
del pipeline
gc.collect()
@@ -1068,7 +1068,7 @@ def main(args):
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warning(
logger.warn(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
unet.enable_xformers_memory_efficient_attention()
@@ -180,7 +180,7 @@ def log_validation(vae, args, accelerator, weight_dtype, step, unet=None, is_fin
logger_name = "test" if is_final_validation else "validation"
tracker.log({logger_name: formatted_images})
else:
logger.warning(f"image logging not implemented for {tracker.name}")
logger.warn(f"image logging not implemented for {tracker.name}")
del pipeline
gc.collect()
@@ -928,7 +928,7 @@ def main(args):
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warning(
logger.warn(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
unet.enable_xformers_memory_efficient_attention()
@@ -325,7 +325,7 @@ def log_validation(vae, unet, args, accelerator, weight_dtype, step):
tracker.log({"validation": formatted_images})
else:
logger.warning(f"image logging not implemented for {tracker.name}")
logger.warn(f"image logging not implemented for {tracker.name}")
del pipeline
gc.collect()
@@ -1083,7 +1083,7 @@ def main(args):
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warning(
logger.warn(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
unet.enable_xformers_memory_efficient_attention()
@@ -285,7 +285,7 @@ def log_validation(vae, unet, args, accelerator, weight_dtype, step, name="targe
tracker.log({f"validation/{name}": formatted_images})
else:
logger.warning(f"image logging not implemented for {tracker.name}")
logger.warn(f"image logging not implemented for {tracker.name}")
del pipeline
gc.collect()
@@ -1023,7 +1023,7 @@ def main(args):
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warning(
logger.warn(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
unet.enable_xformers_memory_efficient_attention()
@@ -303,7 +303,7 @@ def log_validation(vae, unet, args, accelerator, weight_dtype, step, name="targe
tracker.log({f"validation/{name}": formatted_images})
else:
logger.warning(f"image logging not implemented for {tracker.name}")
logger.warn(f"image logging not implemented for {tracker.name}")
del pipeline
gc.collect()
@@ -1083,7 +1083,7 @@ def main(args):
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warning(
logger.warn(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
unet.enable_xformers_memory_efficient_attention()
+2 -2
View File
@@ -178,7 +178,7 @@ def log_validation(
tracker.log({tracker_key: formatted_images})
else:
logger.warning(f"image logging not implemented for {tracker.name}")
logger.warn(f"image logging not implemented for {tracker.name}")
del pipeline
gc.collect()
@@ -861,7 +861,7 @@ def main(args):
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warning(
logger.warn(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
unet.enable_xformers_memory_efficient_attention()
+1 -1
View File
@@ -128,7 +128,7 @@ def log_validation(pipeline, pipeline_params, controlnet_params, tokenizer, args
wandb.log({"validation": formatted_images})
else:
logger.warning(f"image logging not implemented for {args.report_to}")
logger.warn(f"image logging not implemented for {args.report_to}")
return image_logs
+2 -2
View File
@@ -178,7 +178,7 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step,
tracker.log({tracker_key: formatted_images})
else:
logger.warning(f"image logging not implemented for {tracker.name}")
logger.warn(f"image logging not implemented for {tracker.name}")
del pipeline
gc.collect()
@@ -929,7 +929,7 @@ def main(args):
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warning(
logger.warn(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
unet.enable_xformers_memory_efficient_attention()
@@ -904,7 +904,7 @@ def main(args):
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warning(
logger.warn(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
attention_class = CustomDiffusionXFormersAttnProcessor
+1 -1
View File
@@ -987,7 +987,7 @@ def main(args):
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warning(
logger.warn(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
unet.enable_xformers_memory_efficient_attention()
+1 -1
View File
@@ -895,7 +895,7 @@ def main(args):
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warning(
logger.warn(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
unet.enable_xformers_memory_efficient_attention()
@@ -1141,7 +1141,7 @@ def main(args):
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warning(
logger.warn(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, "
"please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
@@ -1317,14 +1317,14 @@ def main(args):
# Optimizer creation
if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"):
logger.warning(
logger.warn(
f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]."
"Defaulting to adamW"
)
args.optimizer = "adamw"
if args.use_8bit_adam and not args.optimizer.lower() == "adamw":
logger.warning(
logger.warn(
f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was "
f"set to {args.optimizer.lower()}"
)
@@ -1358,11 +1358,11 @@ def main(args):
optimizer_class = prodigyopt.Prodigy
if args.learning_rate <= 0.1:
logger.warning(
logger.warn(
"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
)
if args.train_text_encoder and args.text_encoder_lr:
logger.warning(
logger.warn(
f"Learning rates were provided both for the unet and the text encoder- e.g. text_encoder_lr:"
f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. "
f"When using prodigy only learning_rate is used as the initial learning rate."
@@ -488,7 +488,7 @@ def main():
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warning(
logger.warn(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
unet.enable_xformers_memory_efficient_attention()
@@ -580,7 +580,7 @@ def main():
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warning(
logger.warn(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
unet.enable_xformers_memory_efficient_attention()
@@ -177,7 +177,7 @@ def log_validation(vae, image_encoder, image_processor, unet, args, accelerator,
}
)
else:
logger.warning(f"image logging not implemented for {tracker.name}")
logger.warn(f"image logging not implemented for {tracker.name}")
del pipeline
torch.cuda.empty_cache()
@@ -534,7 +534,7 @@ def main():
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warning(
logger.warn(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
unet.enable_xformers_memory_efficient_attention()
@@ -180,7 +180,7 @@ def log_validation(
}
)
else:
logger.warning(f"image logging not implemented for {tracker.name}")
logger.warn(f"image logging not implemented for {tracker.name}")
del pipeline
torch.cuda.empty_cache()
@@ -219,7 +219,7 @@ def log_validation(unet, scheduler, args, accelerator, weight_dtype, step, name=
if args.num_classes is not None:
class_labels = list(range(args.num_classes))
else:
logger.warning(
logger.warn(
"The model is class-conditional but the number of classes is not set. The generated images will be"
" unconditional rather than class-conditional."
)
@@ -266,7 +266,7 @@ def log_validation(unet, scheduler, args, accelerator, weight_dtype, step, name=
tracker.log({f"validation/{name}": formatted_images})
else:
logger.warning(f"image logging not implemented for {tracker.name}")
logger.warn(f"image logging not implemented for {tracker.name}")
del pipeline
gc.collect()
@@ -863,14 +863,14 @@ def main(args):
elif args.model_config_name_or_path is None:
# TODO: use default architectures from iCT paper
if not args.class_conditional and (args.num_classes is not None or args.class_embed_type is not None):
logger.warning(
logger.warn(
f"`--class_conditional` is set to `False` but `--num_classes` is set to {args.num_classes} and"
f" `--class_embed_type` is set to {args.class_embed_type}. These values will be overridden to `None`."
)
args.num_classes = None
args.class_embed_type = None
elif args.class_conditional and args.num_classes is None and args.class_embed_type is None:
logger.warning(
logger.warn(
"`--class_conditional` is set to `True` but neither `--num_classes` nor `--class_embed_type` is set."
"`class_conditional` will be overridden to `False`."
)
@@ -996,7 +996,7 @@ def main(args):
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warning(
logger.warn(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
unet.enable_xformers_memory_efficient_attention()
@@ -407,7 +407,7 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step)
tracker.log({"validation": formatted_images})
else:
logger.warning(f"image logging not implemented for {tracker.name}")
logger.warn(f"image logging not implemented for {tracker.name}")
del pipeline
gc.collect()
@@ -1057,7 +1057,7 @@ def main(args):
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warning(
logger.warn(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
unet.enable_xformers_memory_efficient_attention()
@@ -574,7 +574,7 @@ def main(args):
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warning(
logger.warn(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
unet.enable_xformers_memory_efficient_attention()
@@ -672,7 +672,7 @@ def main(args):
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warning(
logger.warn(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
unet.enable_xformers_memory_efficient_attention()
@@ -516,7 +516,7 @@ def main():
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warning(
logger.warn(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
unet.enable_xformers_memory_efficient_attention()
@@ -608,7 +608,7 @@ def main():
# Create the pipeline using using the trained modules and save it.
if accelerator.is_main_process:
if args.push_to_hub and args.only_save_embeds:
logger.warning("Enabling full model saving because --push_to_hub=True was specified.")
logger.warn("Enabling full model saving because --push_to_hub=True was specified.")
save_full_model = True
else:
save_full_model = not args.only_save_embeds
@@ -541,7 +541,7 @@ def main():
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warning(
logger.warn(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
unet.enable_xformers_memory_efficient_attention()
@@ -645,7 +645,7 @@ def main():
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warning(
logger.warn(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
unet.enable_xformers_memory_efficient_attention()
@@ -901,7 +901,7 @@ def main():
accelerator.wait_for_everyone()
if accelerator.is_main_process:
if args.push_to_hub and args.only_save_embeds:
logger.warning("Enabling full model saving because --push_to_hub=True was specified.")
logger.warn("Enabling full model saving because --push_to_hub=True was specified.")
save_full_model = True
else:
save_full_model = not args.only_save_embeds
@@ -108,7 +108,7 @@ def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight
}
)
else:
logger.warning(f"image logging not implemented for {tracker.name}")
logger.warn(f"image logging not implemented for {tracker.name}")
del pipeline
torch.cuda.empty_cache()
@@ -523,7 +523,7 @@ def main():
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warning(
logger.warn(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
unet.enable_xformers_memory_efficient_attention()
@@ -687,7 +687,7 @@ def main():
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warning(
logger.warn(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
unet.enable_xformers_memory_efficient_attention()
@@ -916,7 +916,7 @@ def main():
accelerator.wait_for_everyone()
if accelerator.is_main_process:
if args.push_to_hub and not args.save_as_full_pipeline:
logger.warning("Enabling full model saving because --push_to_hub=True was specified.")
logger.warn("Enabling full model saving because --push_to_hub=True was specified.")
save_full_model = True
else:
save_full_model = args.save_as_full_pipeline
@@ -410,7 +410,7 @@ def main(args):
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warning(
logger.warn(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
model.enable_xformers_memory_efficient_attention()
@@ -629,7 +629,7 @@ def main(args):
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warning(
logger.warn(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
unet.enable_xformers_memory_efficient_attention()
@@ -167,7 +167,7 @@ def log_validation(vae, unet, adapter, args, accelerator, weight_dtype, step):
tracker.log({"validation": formatted_images})
else:
logger.warning(f"image logging not implemented for {tracker.name}")
logger.warn(f"image logging not implemented for {tracker.name}")
del pipeline
gc.collect()
@@ -932,7 +932,7 @@ def main(args):
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warning(
logger.warn(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
unet.enable_xformers_memory_efficient_attention()
@@ -183,7 +183,7 @@ def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight
}
)
else:
logger.warning(f"image logging not implemented for {tracker.name}")
logger.warn(f"image logging not implemented for {tracker.name}")
del pipeline
torch.cuda.empty_cache()
@@ -608,7 +608,7 @@ def main():
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warning(
logger.warn(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
unet.enable_xformers_memory_efficient_attention()
@@ -497,7 +497,7 @@ def main():
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warning(
logger.warn(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
unet.enable_xformers_memory_efficient_attention()
@@ -616,7 +616,7 @@ def main(args):
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warning(
logger.warn(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
unet.enable_xformers_memory_efficient_attention()
@@ -712,7 +712,7 @@ def main(args):
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warning(
logger.warn(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
unet.enable_xformers_memory_efficient_attention()
@@ -708,7 +708,7 @@ def main():
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warning(
logger.warn(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
unet.enable_xformers_memory_efficient_attention()
@@ -966,7 +966,7 @@ def main():
accelerator.wait_for_everyone()
if accelerator.is_main_process:
if args.push_to_hub and not args.save_as_full_pipeline:
logger.warning("Enabling full model saving because --push_to_hub=True was specified.")
logger.warn("Enabling full model saving because --push_to_hub=True was specified.")
save_full_model = True
else:
save_full_model = args.save_as_full_pipeline
@@ -711,7 +711,7 @@ def main():
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warning(
logger.warn(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
unet.enable_xformers_memory_efficient_attention()
@@ -1022,7 +1022,7 @@ def main():
)
if args.push_to_hub and not args.save_as_full_pipeline:
logger.warning("Enabling full model saving because --push_to_hub=True was specified.")
logger.warn("Enabling full model saving because --push_to_hub=True was specified.")
save_full_model = True
else:
save_full_model = args.save_as_full_pipeline
@@ -408,7 +408,7 @@ def main(args):
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warning(
logger.warn(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
model.enable_xformers_memory_efficient_attention()
@@ -184,7 +184,7 @@ def log_validation(text_encoder, tokenizer, prior, args, accelerator, weight_dty
}
)
else:
logger.warning(f"image logging not implemented for {tracker.name}")
logger.warn(f"image logging not implemented for {tracker.name}")
del pipeline
torch.cuda.empty_cache()
@@ -182,7 +182,7 @@ def log_validation(text_encoder, tokenizer, prior, args, accelerator, weight_dty
}
)
else:
logger.warning(f"image logging not implemented for {tracker.name}")
logger.warn(f"image logging not implemented for {tracker.name}")
del pipeline
torch.cuda.empty_cache()
+156 -159
View File
@@ -1,7 +1,7 @@
# Run this script to convert the Stable Cascade model weights to a diffusers pipeline.
import argparse
from contextlib import nullcontext
import accelerate
import torch
from safetensors.torch import load_file
from transformers import (
@@ -18,56 +18,23 @@ from diffusers import (
StableCascadeDecoderPipeline,
StableCascadePriorPipeline,
)
from diffusers.loaders.single_file_utils import convert_stable_cascade_unet_single_file_to_diffusers
from diffusers.models import StableCascadeUNet
from diffusers.models.modeling_utils import load_model_dict_into_meta
from diffusers.pipelines.wuerstchen import PaellaVQModel
from diffusers.utils import is_accelerate_available
if is_accelerate_available():
from accelerate import init_empty_weights
parser = argparse.ArgumentParser(description="Convert Stable Cascade model weights to a diffusers pipeline")
parser.add_argument("--model_path", type=str, help="Location of Stable Cascade weights")
parser.add_argument("--model_path", type=str, default="../StableCascade", help="Location of Stable Cascade weights")
parser.add_argument("--stage_c_name", type=str, default="stage_c.safetensors", help="Name of stage c checkpoint file")
parser.add_argument("--stage_b_name", type=str, default="stage_b.safetensors", help="Name of stage b checkpoint file")
parser.add_argument("--skip_stage_c", action="store_true", help="Skip converting stage c")
parser.add_argument("--skip_stage_b", action="store_true", help="Skip converting stage b")
parser.add_argument("--use_safetensors", action="store_true", help="Use SafeTensors for conversion")
parser.add_argument(
"--prior_output_path", default="stable-cascade-prior", type=str, help="Hub organization to save the pipelines to"
)
parser.add_argument(
"--decoder_output_path",
type=str,
default="stable-cascade-decoder",
help="Hub organization to save the pipelines to",
)
parser.add_argument(
"--combined_output_path",
type=str,
default="stable-cascade-combined",
help="Hub organization to save the pipelines to",
)
parser.add_argument("--save_combined", action="store_true")
parser.add_argument("--save_org", type=str, default="diffusers", help="Hub organization to save the pipelines to")
parser.add_argument("--push_to_hub", action="store_true", help="Push to hub")
parser.add_argument("--variant", type=str, help="Set to bf16 to save bfloat16 weights")
args = parser.parse_args()
if args.skip_stage_b and args.skip_stage_c:
raise ValueError("At least one stage should be converted")
if (args.skip_stage_b or args.skip_stage_c) and args.save_combined:
raise ValueError("Cannot skip stages when creating a combined pipeline")
model_path = args.model_path
device = "cpu"
if args.variant == "bf16":
dtype = torch.bfloat16
else:
dtype = torch.float32
# set paths to model weights
prior_checkpoint_path = f"{model_path}/{args.stage_c_name}"
@@ -85,134 +52,164 @@ tokenizer = AutoTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b1
feature_extractor = CLIPImageProcessor()
image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14")
# Prior
if args.use_safetensors:
orig_state_dict = load_file(prior_checkpoint_path, device=device)
else:
orig_state_dict = torch.load(prior_checkpoint_path, map_location=device)
state_dict = {}
for key in orig_state_dict.keys():
if key.endswith("in_proj_weight"):
weights = orig_state_dict[key].chunk(3, 0)
state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0]
state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1]
state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2]
elif key.endswith("in_proj_bias"):
weights = orig_state_dict[key].chunk(3, 0)
state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0]
state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1]
state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2]
elif key.endswith("out_proj.weight"):
weights = orig_state_dict[key]
state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights
elif key.endswith("out_proj.bias"):
weights = orig_state_dict[key]
state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights
else:
state_dict[key] = orig_state_dict[key]
with accelerate.init_empty_weights():
prior_model = StableCascadeUNet(
in_channels=16,
out_channels=16,
timestep_ratio_embedding_dim=64,
patch_size=1,
conditioning_dim=2048,
block_out_channels=[2048, 2048],
num_attention_heads=[32, 32],
down_num_layers_per_block=[8, 24],
up_num_layers_per_block=[24, 8],
down_blocks_repeat_mappers=[1, 1],
up_blocks_repeat_mappers=[1, 1],
block_types_per_layer=[
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
],
clip_text_in_channels=1280,
clip_text_pooled_in_channels=1280,
clip_image_in_channels=768,
clip_seq=4,
kernel_size=3,
dropout=[0.1, 0.1],
self_attn=True,
timestep_conditioning_type=["sca", "crp"],
switch_level=[False],
)
load_model_dict_into_meta(prior_model, state_dict)
# scheduler for prior and decoder
scheduler = DDPMWuerstchenScheduler()
ctx = init_empty_weights if is_accelerate_available() else nullcontext
if not args.skip_stage_c:
# Prior
if args.use_safetensors:
prior_orig_state_dict = load_file(prior_checkpoint_path, device=device)
# Prior pipeline
prior_pipeline = StableCascadePriorPipeline(
prior=prior_model,
tokenizer=tokenizer,
text_encoder=text_encoder,
image_encoder=image_encoder,
scheduler=scheduler,
feature_extractor=feature_extractor,
)
prior_pipeline.save_pretrained(f"{args.save_org}/StableCascade-prior", push_to_hub=args.push_to_hub)
# Decoder
if args.use_safetensors:
orig_state_dict = load_file(decoder_checkpoint_path, device=device)
else:
orig_state_dict = torch.load(decoder_checkpoint_path, map_location=device)
state_dict = {}
for key in orig_state_dict.keys():
if key.endswith("in_proj_weight"):
weights = orig_state_dict[key].chunk(3, 0)
state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0]
state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1]
state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2]
elif key.endswith("in_proj_bias"):
weights = orig_state_dict[key].chunk(3, 0)
state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0]
state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1]
state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2]
elif key.endswith("out_proj.weight"):
weights = orig_state_dict[key]
state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights
elif key.endswith("out_proj.bias"):
weights = orig_state_dict[key]
state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights
# rename clip_mapper to clip_txt_pooled_mapper
elif key.endswith("clip_mapper.weight"):
weights = orig_state_dict[key]
state_dict[key.replace("clip_mapper.weight", "clip_txt_pooled_mapper.weight")] = weights
elif key.endswith("clip_mapper.bias"):
weights = orig_state_dict[key]
state_dict[key.replace("clip_mapper.bias", "clip_txt_pooled_mapper.bias")] = weights
else:
prior_orig_state_dict = torch.load(prior_checkpoint_path, map_location=device)
state_dict[key] = orig_state_dict[key]
prior_state_dict = convert_stable_cascade_unet_single_file_to_diffusers(prior_orig_state_dict)
with ctx():
prior_model = StableCascadeUNet(
in_channels=16,
out_channels=16,
timestep_ratio_embedding_dim=64,
patch_size=1,
conditioning_dim=2048,
block_out_channels=[2048, 2048],
num_attention_heads=[32, 32],
down_num_layers_per_block=[8, 24],
up_num_layers_per_block=[24, 8],
down_blocks_repeat_mappers=[1, 1],
up_blocks_repeat_mappers=[1, 1],
block_types_per_layer=[
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
],
clip_text_in_channels=1280,
clip_text_pooled_in_channels=1280,
clip_image_in_channels=768,
clip_seq=4,
kernel_size=3,
dropout=[0.1, 0.1],
self_attn=True,
timestep_conditioning_type=["sca", "crp"],
switch_level=[False],
)
if is_accelerate_available():
load_model_dict_into_meta(prior_model, prior_state_dict)
else:
prior_model.load_state_dict(prior_state_dict)
# Prior pipeline
prior_pipeline = StableCascadePriorPipeline(
prior=prior_model,
tokenizer=tokenizer,
text_encoder=text_encoder,
image_encoder=image_encoder,
scheduler=scheduler,
feature_extractor=feature_extractor,
)
prior_pipeline.to(dtype).save_pretrained(
args.prior_output_path, push_to_hub=args.push_to_hub, variant=args.variant
with accelerate.init_empty_weights():
decoder = StableCascadeUNet(
in_channels=4,
out_channels=4,
timestep_ratio_embedding_dim=64,
patch_size=2,
conditioning_dim=1280,
block_out_channels=[320, 640, 1280, 1280],
down_num_layers_per_block=[2, 6, 28, 6],
up_num_layers_per_block=[6, 28, 6, 2],
down_blocks_repeat_mappers=[1, 1, 1, 1],
up_blocks_repeat_mappers=[3, 3, 2, 2],
num_attention_heads=[0, 0, 20, 20],
block_types_per_layer=[
["SDCascadeResBlock", "SDCascadeTimestepBlock"],
["SDCascadeResBlock", "SDCascadeTimestepBlock"],
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
],
clip_text_pooled_in_channels=1280,
clip_seq=4,
effnet_in_channels=16,
pixel_mapper_in_channels=3,
kernel_size=3,
dropout=[0, 0, 0.1, 0.1],
self_attn=True,
timestep_conditioning_type=["sca"],
)
load_model_dict_into_meta(decoder, state_dict)
if not args.skip_stage_b:
# VQGAN from Wuerstchen-V2
vqmodel = PaellaVQModel.from_pretrained("warp-ai/wuerstchen", subfolder="vqgan")
# Decoder pipeline
decoder_pipeline = StableCascadeDecoderPipeline(
decoder=decoder, text_encoder=text_encoder, tokenizer=tokenizer, vqgan=vqmodel, scheduler=scheduler
)
decoder_pipeline.save_pretrained(f"{args.save_org}/StableCascade-decoder", push_to_hub=args.push_to_hub)
# Stable Cascade combined pipeline
stable_cascade_pipeline = StableCascadeCombinedPipeline(
# Decoder
if args.use_safetensors:
decoder_orig_state_dict = load_file(decoder_checkpoint_path, device=device)
else:
decoder_orig_state_dict = torch.load(decoder_checkpoint_path, map_location=device)
decoder_state_dict = convert_stable_cascade_unet_single_file_to_diffusers(decoder_orig_state_dict)
with ctx():
decoder = StableCascadeUNet(
in_channels=4,
out_channels=4,
timestep_ratio_embedding_dim=64,
patch_size=2,
conditioning_dim=1280,
block_out_channels=[320, 640, 1280, 1280],
down_num_layers_per_block=[2, 6, 28, 6],
up_num_layers_per_block=[6, 28, 6, 2],
down_blocks_repeat_mappers=[1, 1, 1, 1],
up_blocks_repeat_mappers=[3, 3, 2, 2],
num_attention_heads=[0, 0, 20, 20],
block_types_per_layer=[
["SDCascadeResBlock", "SDCascadeTimestepBlock"],
["SDCascadeResBlock", "SDCascadeTimestepBlock"],
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
],
clip_text_pooled_in_channels=1280,
clip_seq=4,
effnet_in_channels=16,
pixel_mapper_in_channels=3,
kernel_size=3,
dropout=[0, 0, 0.1, 0.1],
self_attn=True,
timestep_conditioning_type=["sca"],
)
if is_accelerate_available():
load_model_dict_into_meta(decoder, decoder_state_dict)
else:
decoder.load_state_dict(decoder_state_dict)
# VQGAN from Wuerstchen-V2
vqmodel = PaellaVQModel.from_pretrained("warp-ai/wuerstchen", subfolder="vqgan")
# Decoder pipeline
decoder_pipeline = StableCascadeDecoderPipeline(
decoder=decoder, text_encoder=text_encoder, tokenizer=tokenizer, vqgan=vqmodel, scheduler=scheduler
)
decoder_pipeline.to(dtype).save_pretrained(
args.decoder_output_path, push_to_hub=args.push_to_hub, variant=args.variant
)
if args.save_combined:
# Stable Cascade combined pipeline
stable_cascade_pipeline = StableCascadeCombinedPipeline(
# Decoder
text_encoder=text_encoder,
tokenizer=tokenizer,
decoder=decoder,
scheduler=scheduler,
vqgan=vqmodel,
# Prior
prior_text_encoder=text_encoder,
prior_tokenizer=tokenizer,
prior_prior=prior_model,
prior_scheduler=scheduler,
prior_image_encoder=image_encoder,
prior_feature_extractor=feature_extractor,
)
stable_cascade_pipeline.to(dtype).save_pretrained(
args.combined_output_path, push_to_hub=args.push_to_hub, variant=args.variant
)
text_encoder=text_encoder,
tokenizer=tokenizer,
decoder=decoder,
scheduler=scheduler,
vqgan=vqmodel,
# Prior
prior_text_encoder=text_encoder,
prior_tokenizer=tokenizer,
prior_prior=prior_model,
prior_scheduler=scheduler,
prior_image_encoder=image_encoder,
prior_feature_extractor=feature_extractor,
)
stable_cascade_pipeline.save_pretrained(f"{args.save_org}/StableCascade", push_to_hub=args.push_to_hub)
-226
View File
@@ -1,226 +0,0 @@
# Run this script to convert the Stable Cascade model weights to a diffusers pipeline.
import argparse
from contextlib import nullcontext
import torch
from safetensors.torch import load_file
from transformers import (
AutoTokenizer,
CLIPConfig,
CLIPImageProcessor,
CLIPTextModelWithProjection,
CLIPVisionModelWithProjection,
)
from diffusers import (
DDPMWuerstchenScheduler,
StableCascadeCombinedPipeline,
StableCascadeDecoderPipeline,
StableCascadePriorPipeline,
)
from diffusers.loaders.single_file_utils import convert_stable_cascade_unet_single_file_to_diffusers
from diffusers.models import StableCascadeUNet
from diffusers.models.modeling_utils import load_model_dict_into_meta
from diffusers.pipelines.wuerstchen import PaellaVQModel
from diffusers.utils import is_accelerate_available
if is_accelerate_available():
from accelerate import init_empty_weights
parser = argparse.ArgumentParser(description="Convert Stable Cascade model weights to a diffusers pipeline")
parser.add_argument("--model_path", type=str, help="Location of Stable Cascade weights")
parser.add_argument(
"--stage_c_name", type=str, default="stage_c_lite.safetensors", help="Name of stage c checkpoint file"
)
parser.add_argument(
"--stage_b_name", type=str, default="stage_b_lite.safetensors", help="Name of stage b checkpoint file"
)
parser.add_argument("--skip_stage_c", action="store_true", help="Skip converting stage c")
parser.add_argument("--skip_stage_b", action="store_true", help="Skip converting stage b")
parser.add_argument("--use_safetensors", action="store_true", help="Use SafeTensors for conversion")
parser.add_argument(
"--prior_output_path",
default="stable-cascade-prior-lite",
type=str,
help="Hub organization to save the pipelines to",
)
parser.add_argument(
"--decoder_output_path",
type=str,
default="stable-cascade-decoder-lite",
help="Hub organization to save the pipelines to",
)
parser.add_argument(
"--combined_output_path",
type=str,
default="stable-cascade-combined-lite",
help="Hub organization to save the pipelines to",
)
parser.add_argument("--save_combined", action="store_true")
parser.add_argument("--push_to_hub", action="store_true", help="Push to hub")
parser.add_argument("--variant", type=str, help="Set to bf16 to save bfloat16 weights")
args = parser.parse_args()
if args.skip_stage_b and args.skip_stage_c:
raise ValueError("At least one stage should be converted")
if (args.skip_stage_b or args.skip_stage_c) and args.save_combined:
raise ValueError("Cannot skip stages when creating a combined pipeline")
model_path = args.model_path
device = "cpu"
if args.variant == "bf16":
dtype = torch.bfloat16
else:
dtype = torch.float32
# set paths to model weights
prior_checkpoint_path = f"{model_path}/{args.stage_c_name}"
decoder_checkpoint_path = f"{model_path}/{args.stage_b_name}"
# Clip Text encoder and tokenizer
config = CLIPConfig.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")
config.text_config.projection_dim = config.projection_dim
text_encoder = CLIPTextModelWithProjection.from_pretrained(
"laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", config=config.text_config
)
tokenizer = AutoTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")
# image processor
feature_extractor = CLIPImageProcessor()
image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14")
# scheduler for prior and decoder
scheduler = DDPMWuerstchenScheduler()
ctx = init_empty_weights if is_accelerate_available() else nullcontext
if not args.skip_stage_c:
# Prior
if args.use_safetensors:
prior_orig_state_dict = load_file(prior_checkpoint_path, device=device)
else:
prior_orig_state_dict = torch.load(prior_checkpoint_path, map_location=device)
prior_state_dict = convert_stable_cascade_unet_single_file_to_diffusers(prior_orig_state_dict)
with ctx():
prior_model = StableCascadeUNet(
in_channels=16,
out_channels=16,
timestep_ratio_embedding_dim=64,
patch_size=1,
conditioning_dim=1536,
block_out_channels=[1536, 1536],
num_attention_heads=[24, 24],
down_num_layers_per_block=[4, 12],
up_num_layers_per_block=[12, 4],
down_blocks_repeat_mappers=[1, 1],
up_blocks_repeat_mappers=[1, 1],
block_types_per_layer=[
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
],
clip_text_in_channels=1280,
clip_text_pooled_in_channels=1280,
clip_image_in_channels=768,
clip_seq=4,
kernel_size=3,
dropout=[0.1, 0.1],
self_attn=True,
timestep_conditioning_type=["sca", "crp"],
switch_level=[False],
)
if is_accelerate_available():
load_model_dict_into_meta(prior_model, prior_state_dict)
else:
prior_model.load_state_dict(prior_state_dict)
# Prior pipeline
prior_pipeline = StableCascadePriorPipeline(
prior=prior_model,
tokenizer=tokenizer,
text_encoder=text_encoder,
image_encoder=image_encoder,
scheduler=scheduler,
feature_extractor=feature_extractor,
)
prior_pipeline.to(dtype).save_pretrained(
args.prior_output_path, push_to_hub=args.push_to_hub, variant=args.variant
)
if not args.skip_stage_b:
# Decoder
if args.use_safetensors:
decoder_orig_state_dict = load_file(decoder_checkpoint_path, device=device)
else:
decoder_orig_state_dict = torch.load(decoder_checkpoint_path, map_location=device)
decoder_state_dict = convert_stable_cascade_unet_single_file_to_diffusers(decoder_orig_state_dict)
with ctx():
decoder = StableCascadeUNet(
in_channels=4,
out_channels=4,
timestep_ratio_embedding_dim=64,
patch_size=2,
conditioning_dim=1280,
block_out_channels=[320, 576, 1152, 1152],
down_num_layers_per_block=[2, 4, 14, 4],
up_num_layers_per_block=[4, 14, 4, 2],
down_blocks_repeat_mappers=[1, 1, 1, 1],
up_blocks_repeat_mappers=[2, 2, 2, 2],
num_attention_heads=[0, 9, 18, 18],
block_types_per_layer=[
["SDCascadeResBlock", "SDCascadeTimestepBlock"],
["SDCascadeResBlock", "SDCascadeTimestepBlock"],
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
],
clip_text_pooled_in_channels=1280,
clip_seq=4,
effnet_in_channels=16,
pixel_mapper_in_channels=3,
kernel_size=3,
dropout=[0, 0, 0.1, 0.1],
self_attn=True,
timestep_conditioning_type=["sca"],
)
if is_accelerate_available():
load_model_dict_into_meta(decoder, decoder_state_dict)
else:
decoder.load_state_dict(decoder_state_dict)
# VQGAN from Wuerstchen-V2
vqmodel = PaellaVQModel.from_pretrained("warp-ai/wuerstchen", subfolder="vqgan")
# Decoder pipeline
decoder_pipeline = StableCascadeDecoderPipeline(
decoder=decoder, text_encoder=text_encoder, tokenizer=tokenizer, vqgan=vqmodel, scheduler=scheduler
)
decoder_pipeline.to(dtype).save_pretrained(
args.decoder_output_path, push_to_hub=args.push_to_hub, variant=args.variant
)
if args.save_combined:
# Stable Cascade combined pipeline
stable_cascade_pipeline = StableCascadeCombinedPipeline(
# Decoder
text_encoder=text_encoder,
tokenizer=tokenizer,
decoder=decoder,
scheduler=scheduler,
vqgan=vqmodel,
# Prior
prior_text_encoder=text_encoder,
prior_tokenizer=tokenizer,
prior_prior=prior_model,
prior_scheduler=scheduler,
prior_image_encoder=image_encoder,
prior_feature_extractor=feature_extractor,
)
stable_cascade_pipeline.to(dtype).save_pretrained(
args.combined_output_path, push_to_hub=args.push_to_hub, variant=args.variant
)
+2 -2
View File
@@ -430,7 +430,7 @@ class LoraLoaderMixin:
# contain the module names of the `unet` as its keys WITHOUT any prefix.
if not USE_PEFT_BACKEND:
warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet.{module_name}': params for module_name, params in old_state_dict.items()}`."
logger.warning(warn_message)
logger.warn(warn_message)
if len(state_dict.keys()) > 0:
if adapter_name in getattr(unet, "peft_config", {}):
@@ -882,7 +882,7 @@ class LoraLoaderMixin:
if fuse_unet or fuse_text_encoder:
self.num_fused_loras += 1
if self.num_fused_loras > 1:
logger.warning(
logger.warn(
"The current API is supported for operating with a single LoRA file. You are trying to load and fuse more than one LoRA which is not well-supported.",
)
+1 -6
View File
@@ -56,8 +56,6 @@ def build_sub_model_components(
if component_name == "unet":
num_in_channels = kwargs.pop("num_in_channels", None)
upcast_attention = kwargs.pop("upcast_attention", None)
unet_components = create_diffusers_unet_model_from_ldm(
pipeline_class_name,
original_config,
@@ -66,7 +64,6 @@ def build_sub_model_components(
image_size=image_size,
torch_dtype=torch_dtype,
model_type=model_type,
upcast_attention=upcast_attention,
)
return unet_components
@@ -303,9 +300,7 @@ class FromSingleFileMixin:
continue
init_kwargs.update(components)
additional_components = set_additional_components(
class_name, original_config, checkpoint=checkpoint, model_type=model_type
)
additional_components = set_additional_components(class_name, original_config, model_type=model_type)
if additional_components:
init_kwargs.update(additional_components)
+12 -116
View File
@@ -81,87 +81,6 @@ SCHEDULER_DEFAULT_CONFIG = {
"timestep_spacing": "leading",
}
STABLE_CASCADE_DEFAULT_CONFIGS = {
"stage_c": {"pretrained_model_name_or_path": "diffusers/stable-cascade-configs", "subfolder": "prior"},
"stage_c_lite": {"pretrained_model_name_or_path": "diffusers/stable-cascade-configs", "subfolder": "prior_lite"},
"stage_b": {"pretrained_model_name_or_path": "diffusers/stable-cascade-configs", "subfolder": "decoder"},
"stage_b_lite": {"pretrained_model_name_or_path": "diffusers/stable-cascade-configs", "subfolder": "decoder_lite"},
}
def convert_stable_cascade_unet_single_file_to_diffusers(original_state_dict):
is_stage_c = "clip_txt_mapper.weight" in original_state_dict
if is_stage_c:
state_dict = {}
for key in original_state_dict.keys():
if key.endswith("in_proj_weight"):
weights = original_state_dict[key].chunk(3, 0)
state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0]
state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1]
state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2]
elif key.endswith("in_proj_bias"):
weights = original_state_dict[key].chunk(3, 0)
state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0]
state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1]
state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2]
elif key.endswith("out_proj.weight"):
weights = original_state_dict[key]
state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights
elif key.endswith("out_proj.bias"):
weights = original_state_dict[key]
state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights
else:
state_dict[key] = original_state_dict[key]
else:
state_dict = {}
for key in original_state_dict.keys():
if key.endswith("in_proj_weight"):
weights = original_state_dict[key].chunk(3, 0)
state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0]
state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1]
state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2]
elif key.endswith("in_proj_bias"):
weights = original_state_dict[key].chunk(3, 0)
state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0]
state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1]
state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2]
elif key.endswith("out_proj.weight"):
weights = original_state_dict[key]
state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights
elif key.endswith("out_proj.bias"):
weights = original_state_dict[key]
state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights
# rename clip_mapper to clip_txt_pooled_mapper
elif key.endswith("clip_mapper.weight"):
weights = original_state_dict[key]
state_dict[key.replace("clip_mapper.weight", "clip_txt_pooled_mapper.weight")] = weights
elif key.endswith("clip_mapper.bias"):
weights = original_state_dict[key]
state_dict[key.replace("clip_mapper.bias", "clip_txt_pooled_mapper.bias")] = weights
else:
state_dict[key] = original_state_dict[key]
return state_dict
def infer_stable_cascade_single_file_config(checkpoint):
is_stage_c = "clip_txt_mapper.weight" in checkpoint
is_stage_b = "down_blocks.1.0.channelwise.0.weight" in checkpoint
if is_stage_c and (checkpoint["clip_txt_mapper.weight"].shape[0] == 1536):
config_type = "stage_c_lite"
elif is_stage_c and (checkpoint["clip_txt_mapper.weight"].shape[0] == 2048):
config_type = "stage_c"
elif is_stage_b and checkpoint["down_blocks.1.0.channelwise.0.weight"].shape[-1] == 576:
config_type = "stage_b_lite"
elif is_stage_b and checkpoint["down_blocks.1.0.channelwise.0.weight"].shape[-1] == 640:
config_type = "stage_b"
return STABLE_CASCADE_DEFAULT_CONFIGS[config_type]
DIFFUSERS_TO_LDM_MAPPING = {
"unet": {
"layers": {
@@ -310,34 +229,10 @@ def fetch_ldm_config_and_checkpoint(
cache_dir=None,
local_files_only=None,
revision=None,
):
checkpoint = load_single_file_model_checkpoint(
pretrained_model_link_or_path,
resume_download=resume_download,
force_download=force_download,
proxies=proxies,
token=token,
cache_dir=cache_dir,
local_files_only=local_files_only,
revision=revision,
)
original_config = fetch_original_config(class_name, checkpoint, original_config_file)
return original_config, checkpoint
def load_single_file_model_checkpoint(
pretrained_model_link_or_path,
resume_download=False,
force_download=False,
proxies=None,
token=None,
cache_dir=None,
local_files_only=None,
revision=None,
):
if os.path.isfile(pretrained_model_link_or_path):
checkpoint = load_state_dict(pretrained_model_link_or_path)
else:
repo_id, weights_name = _extract_repo_id_and_weights_name(pretrained_model_link_or_path)
checkpoint_path = _get_model_file(
@@ -357,7 +252,9 @@ def load_single_file_model_checkpoint(
while "state_dict" in checkpoint:
checkpoint = checkpoint["state_dict"]
return checkpoint
original_config = fetch_original_config(class_name, checkpoint, original_config_file)
return original_config, checkpoint
def infer_original_config_file(class_name, checkpoint):
@@ -410,7 +307,7 @@ def fetch_original_config(pipeline_class_name, checkpoint, original_config_file=
return original_config
def infer_model_type(original_config, checkpoint, model_type=None):
def infer_model_type(original_config, checkpoint=None, model_type=None):
if model_type is not None:
return model_type
@@ -987,7 +884,7 @@ def create_diffusers_controlnet_model_from_ldm(
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
if len(unexpected_keys) > 0:
logger.warning(
logger.warn(
f"Some weights of the model checkpoint were not used when initializing {controlnet.__name__}: \n {[', '.join(unexpected_keys)]}"
)
else:
@@ -1163,7 +1060,7 @@ def create_text_encoder_from_ldm_clip_checkpoint(config_name, checkpoint, local_
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
if len(unexpected_keys) > 0:
logger.warning(
logger.warn(
f"Some weights of the model checkpoint were not used when initializing {text_model.__class__.__name__}: \n {[', '.join(unexpected_keys)]}"
)
else:
@@ -1258,7 +1155,7 @@ def create_text_encoder_from_open_clip_checkpoint(
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
if len(unexpected_keys) > 0:
logger.warning(
logger.warn(
f"Some weights of the model checkpoint were not used when initializing {text_model.__class__.__name__}: \n {[', '.join(unexpected_keys)]}"
)
@@ -1279,7 +1176,7 @@ def create_diffusers_unet_model_from_ldm(
original_config,
checkpoint,
num_in_channels=None,
upcast_attention=None,
upcast_attention=False,
extract_ema=False,
image_size=None,
torch_dtype=None,
@@ -1307,8 +1204,7 @@ def create_diffusers_unet_model_from_ldm(
)
unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
unet_config["in_channels"] = num_in_channels
if upcast_attention is not None:
unet_config["upcast_attention"] = upcast_attention
unet_config["upcast_attention"] = upcast_attention
diffusers_format_unet_checkpoint = convert_ldm_unet_checkpoint(checkpoint, unet_config, extract_ema=extract_ema)
ctx = init_empty_weights if is_accelerate_available() else nullcontext
@@ -1325,7 +1221,7 @@ def create_diffusers_unet_model_from_ldm(
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
if len(unexpected_keys) > 0:
logger.warning(
logger.warn(
f"Some weights of the model checkpoint were not used when initializing {unet.__name__}: \n {[', '.join(unexpected_keys)]}"
)
else:
@@ -1387,7 +1283,7 @@ def create_diffusers_vae_model_from_ldm(
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
if len(unexpected_keys) > 0:
logger.warning(
logger.warn(
f"Some weights of the model checkpoint were not used when initializing {vae.__name__}: \n {[', '.join(unexpected_keys)]}"
)
else:
+2 -107
View File
@@ -42,11 +42,6 @@ from ..utils import (
set_adapter_layers,
set_weights_and_activate_adapters,
)
from .single_file_utils import (
convert_stable_cascade_unet_single_file_to_diffusers,
infer_stable_cascade_single_file_config,
load_single_file_model_checkpoint,
)
from .utils import AttnProcsLayers
@@ -350,7 +345,7 @@ class UNet2DConditionLoadersMixin:
is_model_cpu_offload = False
is_sequential_cpu_offload = False
# For PEFT backend the Unet is already offloaded at this stage as it is handled inside `load_lora_weights_into_unet`
# For PEFT backend the Unet is already offloaded at this stage as it is handled inside `lora_lora_weights_into_unet`
if not USE_PEFT_BACKEND:
if _pipeline is not None:
for _, component in _pipeline.components.items():
@@ -389,7 +384,7 @@ class UNet2DConditionLoadersMixin:
is_text_encoder_present = any(key.startswith(self.text_encoder_name) for key in state_dict.keys())
if is_text_encoder_present:
warn_message = "The state_dict contains LoRA params corresponding to the text encoder which are not being used here. To use both UNet and text encoder related LoRA params, use [`pipe.load_lora_weights()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraLoaderMixin.load_lora_weights)."
logger.warning(warn_message)
logger.warn(warn_message)
unet_keys = [k for k in state_dict.keys() if k.startswith(self.unet_name)]
state_dict = {k.replace(f"{self.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys}
@@ -901,103 +896,3 @@ class UNet2DConditionLoadersMixin:
self.config.encoder_hid_dim_type = "ip_image_proj"
self.to(dtype=self.dtype, device=self.device)
class FromOriginalUNetMixin:
"""
Load pretrained UNet model weights saved in the `.ckpt` or `.safetensors` format into a [`StableCascadeUNet`].
"""
@classmethod
@validate_hf_hub_args
def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
r"""
Instantiate a [`StableCascadeUNet`] from pretrained StableCascadeUNet weights saved in the original `.ckpt` or
`.safetensors` format. The pipeline is set in evaluation mode (`model.eval()`) by default.
Parameters:
pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*):
Can be either:
- A link to the `.ckpt` file (for example
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.ckpt"`) on the Hub.
- A path to a *file* containing all pipeline weights.
config: (`dict`, *optional*):
Dictionary containing the configuration of the model:
torch_dtype (`str` or `torch.dtype`, *optional*):
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
dtype is automatically derived from the model's weights.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
is not used.
resume_download (`bool`, *optional*, defaults to `False`):
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
incompletely downloaded files are deleted.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to True, the model
won't be downloaded from the Hub.
token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
allowed by Git.
kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to overwrite load and saveable variables of the model.
"""
class_name = cls.__name__
if class_name != "StableCascadeUNet":
raise ValueError("FromOriginalUNetMixin is currently only compatible with StableCascadeUNet")
config = kwargs.pop("config", None)
resume_download = kwargs.pop("resume_download", False)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
token = kwargs.pop("token", None)
cache_dir = kwargs.pop("cache_dir", None)
local_files_only = kwargs.pop("local_files_only", None)
revision = kwargs.pop("revision", None)
torch_dtype = kwargs.pop("torch_dtype", None)
checkpoint = load_single_file_model_checkpoint(
pretrained_model_link_or_path,
resume_download=resume_download,
force_download=force_download,
proxies=proxies,
token=token,
cache_dir=cache_dir,
local_files_only=local_files_only,
revision=revision,
)
if config is None:
config = infer_stable_cascade_single_file_config(checkpoint)
model_config = cls.load_config(**config, **kwargs)
else:
model_config = config
ctx = init_empty_weights if is_accelerate_available() else nullcontext
with ctx():
model = cls.from_config(model_config, **kwargs)
diffusers_format_checkpoint = convert_stable_cascade_unet_single_file_to_diffusers(checkpoint)
if is_accelerate_available():
unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
if len(unexpected_keys) > 0:
logger.warn(
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
)
else:
model.load_state_dict(diffusers_format_checkpoint)
if torch_dtype is not None:
model.to(torch_dtype)
return model
+8 -8
View File
@@ -17,7 +17,8 @@ import torch
import torch.nn.functional as F
from torch import nn
from ..utils import deprecate
from ..utils import USE_PEFT_BACKEND
from .lora import LoRACompatibleLinear
ACTIVATION_FUNCTIONS = {
@@ -86,7 +87,9 @@ class GEGLU(nn.Module):
def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias)
linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear
self.proj = linear_cls(dim_in, dim_out * 2, bias=bias)
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
if gate.device.type != "mps":
@@ -94,12 +97,9 @@ class GEGLU(nn.Module):
# mps: gelu is not implemented for float16
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
def forward(self, hidden_states, *args, **kwargs):
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
def forward(self, hidden_states, scale: float = 1.0):
args = () if USE_PEFT_BACKEND else (scale,)
hidden_states, gate = self.proj(hidden_states, *args).chunk(2, dim=-1)
return hidden_states * self.gelu(gate)
+33 -23
View File
@@ -17,18 +17,18 @@ import torch
import torch.nn.functional as F
from torch import nn
from ..utils import deprecate, logging
from ..utils import USE_PEFT_BACKEND
from ..utils.torch_utils import maybe_allow_in_graph
from .activations import GEGLU, GELU, ApproximateGELU
from .attention_processor import Attention
from .embeddings import SinusoidalPositionalEmbedding
from .lora import LoRACompatibleLinear
from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm
logger = logging.get_logger(__name__)
def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int):
def _chunked_feed_forward(
ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int, lora_scale: Optional[float] = None
):
# "feed_forward_chunk_size" can be used to save memory
if hidden_states.shape[chunk_dim] % chunk_size != 0:
raise ValueError(
@@ -36,10 +36,18 @@ def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim:
)
num_chunks = hidden_states.shape[chunk_dim] // chunk_size
ff_output = torch.cat(
[ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
dim=chunk_dim,
)
if lora_scale is None:
ff_output = torch.cat(
[ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
dim=chunk_dim,
)
else:
# TOOD(Patrick): LoRA scale can be removed once PEFT refactor is complete
ff_output = torch.cat(
[ff(hid_slice, scale=lora_scale) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
dim=chunk_dim,
)
return ff_output
@@ -291,10 +299,6 @@ class BasicTransformerBlock(nn.Module):
class_labels: Optional[torch.LongTensor] = None,
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
) -> torch.FloatTensor:
if cross_attention_kwargs is not None:
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
# Notice that normalization is always applied before the real computation in the following blocks.
# 0. Self-Attention
batch_size = hidden_states.shape[0]
@@ -322,7 +326,10 @@ class BasicTransformerBlock(nn.Module):
if self.pos_embed is not None:
norm_hidden_states = self.pos_embed(norm_hidden_states)
# 1. Prepare GLIGEN inputs
# 1. Retrieve lora scale.
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
# 2. Prepare GLIGEN inputs
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
@@ -341,7 +348,7 @@ class BasicTransformerBlock(nn.Module):
if hidden_states.ndim == 4:
hidden_states = hidden_states.squeeze(1)
# 1.2 GLIGEN Control
# 2.5 GLIGEN Control
if gligen_kwargs is not None:
hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
@@ -387,9 +394,11 @@ class BasicTransformerBlock(nn.Module):
if self._chunk_size is not None:
# "feed_forward_chunk_size" can be used to save memory
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
ff_output = _chunked_feed_forward(
self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size, lora_scale=lora_scale
)
else:
ff_output = self.ff(norm_hidden_states)
ff_output = self.ff(norm_hidden_states, scale=lora_scale)
if self.norm_type == "ada_norm_zero":
ff_output = gate_mlp.unsqueeze(1) * ff_output
@@ -634,7 +643,7 @@ class FeedForward(nn.Module):
if inner_dim is None:
inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim
linear_cls = nn.Linear
linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear
if activation_fn == "gelu":
act_fn = GELU(dim, inner_dim, bias=bias)
@@ -656,10 +665,11 @@ class FeedForward(nn.Module):
if final_dropout:
self.net.append(nn.Dropout(dropout))
def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
compatible_cls = (GEGLU,) if USE_PEFT_BACKEND else (GEGLU, LoRACompatibleLinear)
for module in self.net:
hidden_states = module(hidden_states)
if isinstance(module, compatible_cls):
hidden_states = module(hidden_states, scale)
else:
hidden_states = module(hidden_states)
return hidden_states
+56 -73
View File
@@ -20,10 +20,10 @@ import torch.nn.functional as F
from torch import nn
from ..image_processor import IPAdapterMaskProcessor
from ..utils import deprecate, logging
from ..utils import USE_PEFT_BACKEND, deprecate, logging
from ..utils.import_utils import is_xformers_available
from ..utils.torch_utils import maybe_allow_in_graph
from .lora import LoRALinearLayer
from .lora import LoRACompatibleLinear, LoRALinearLayer
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -181,7 +181,10 @@ class Attention(nn.Module):
f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
)
linear_cls = nn.Linear
if USE_PEFT_BACKEND:
linear_cls = nn.Linear
else:
linear_cls = LoRACompatibleLinear
self.linear_cls = linear_cls
self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias)
@@ -738,15 +741,12 @@ class AttnProcessor:
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
*args,
**kwargs,
scale: float = 1.0,
) -> torch.Tensor:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
residual = hidden_states
args = () if USE_PEFT_BACKEND else (scale,)
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
@@ -764,15 +764,15 @@ class AttnProcessor:
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
query = attn.to_q(hidden_states, *args)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states, *args)
value = attn.to_v(encoder_hidden_states, *args)
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
@@ -783,7 +783,7 @@ class AttnProcessor:
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[0](hidden_states, *args)
# dropout
hidden_states = attn.to_out[1](hidden_states)
@@ -914,15 +914,12 @@ class AttnAddedKVProcessor:
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
*args,
**kwargs,
scale: float = 1.0,
) -> torch.Tensor:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
residual = hidden_states
args = () if USE_PEFT_BACKEND else (scale,)
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
batch_size, sequence_length, _ = hidden_states.shape
@@ -935,17 +932,17 @@ class AttnAddedKVProcessor:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
query = attn.to_q(hidden_states, *args)
query = attn.head_to_batch_dim(query)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states, *args)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states, *args)
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
if not attn.only_cross_attention:
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
key = attn.to_k(hidden_states, *args)
value = attn.to_v(hidden_states, *args)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
@@ -959,7 +956,7 @@ class AttnAddedKVProcessor:
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[0](hidden_states, *args)
# dropout
hidden_states = attn.to_out[1](hidden_states)
@@ -987,15 +984,12 @@ class AttnAddedKVProcessor2_0:
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
*args,
**kwargs,
scale: float = 1.0,
) -> torch.Tensor:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
residual = hidden_states
args = () if USE_PEFT_BACKEND else (scale,)
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
batch_size, sequence_length, _ = hidden_states.shape
@@ -1008,7 +1002,7 @@ class AttnAddedKVProcessor2_0:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
query = attn.to_q(hidden_states, *args)
query = attn.head_to_batch_dim(query, out_dim=4)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
@@ -1017,8 +1011,8 @@ class AttnAddedKVProcessor2_0:
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj, out_dim=4)
if not attn.only_cross_attention:
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
key = attn.to_k(hidden_states, *args)
value = attn.to_v(hidden_states, *args)
key = attn.head_to_batch_dim(key, out_dim=4)
value = attn.head_to_batch_dim(value, out_dim=4)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
@@ -1035,7 +1029,7 @@ class AttnAddedKVProcessor2_0:
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, residual.shape[1])
# linear proj
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[0](hidden_states, *args)
# dropout
hidden_states = attn.to_out[1](hidden_states)
@@ -1138,15 +1132,12 @@ class XFormersAttnProcessor:
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
*args,
**kwargs,
scale: float = 1.0,
) -> torch.FloatTensor:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
residual = hidden_states
args = () if USE_PEFT_BACKEND else (scale,)
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
@@ -1174,15 +1165,15 @@ class XFormersAttnProcessor:
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
query = attn.to_q(hidden_states, *args)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states, *args)
value = attn.to_v(encoder_hidden_states, *args)
query = attn.head_to_batch_dim(query).contiguous()
key = attn.head_to_batch_dim(key).contiguous()
@@ -1195,7 +1186,7 @@ class XFormersAttnProcessor:
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[0](hidden_states, *args)
# dropout
hidden_states = attn.to_out[1](hidden_states)
@@ -1226,13 +1217,8 @@ class AttnProcessor2_0:
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
*args,
**kwargs,
scale: float = 1.0,
) -> torch.FloatTensor:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
@@ -1256,15 +1242,16 @@ class AttnProcessor2_0:
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
args = () if USE_PEFT_BACKEND else (scale,)
query = attn.to_q(hidden_states, *args)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states, *args)
value = attn.to_v(encoder_hidden_states, *args)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
@@ -1284,7 +1271,7 @@ class AttnProcessor2_0:
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[0](hidden_states, *args)
# dropout
hidden_states = attn.to_out[1](hidden_states)
@@ -1325,13 +1312,8 @@ class FusedAttnProcessor2_0:
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
*args,
**kwargs,
scale: float = 1.0,
) -> torch.FloatTensor:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
@@ -1355,16 +1337,17 @@ class FusedAttnProcessor2_0:
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
args = () if USE_PEFT_BACKEND else (scale,)
if encoder_hidden_states is None:
qkv = attn.to_qkv(hidden_states)
qkv = attn.to_qkv(hidden_states, *args)
split_size = qkv.shape[-1] // 3
query, key, value = torch.split(qkv, split_size, dim=-1)
else:
if attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
query = attn.to_q(hidden_states)
query = attn.to_q(hidden_states, *args)
kv = attn.to_kv(encoder_hidden_states)
kv = attn.to_kv(encoder_hidden_states, *args)
split_size = kv.shape[-1] // 2
key, value = torch.split(kv, split_size, dim=-1)
@@ -1385,7 +1368,7 @@ class FusedAttnProcessor2_0:
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[0](hidden_states, *args)
# dropout
hidden_states = attn.to_out[1](hidden_states)
@@ -1876,7 +1859,7 @@ class LoRAAttnProcessor(nn.Module):
self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
self_cls_name = self.__class__.__name__
deprecate(
self_cls_name,
@@ -1894,7 +1877,7 @@ class LoRAAttnProcessor(nn.Module):
attn._modules.pop("processor")
attn.processor = AttnProcessor()
return attn.processor(attn, hidden_states, **kwargs)
return attn.processor(attn, hidden_states, *args, **kwargs)
class LoRAAttnProcessor2_0(nn.Module):
@@ -1937,7 +1920,7 @@ class LoRAAttnProcessor2_0(nn.Module):
self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
self_cls_name = self.__class__.__name__
deprecate(
self_cls_name,
@@ -1955,7 +1938,7 @@ class LoRAAttnProcessor2_0(nn.Module):
attn._modules.pop("processor")
attn.processor = AttnProcessor2_0()
return attn.processor(attn, hidden_states, **kwargs)
return attn.processor(attn, hidden_states, *args, **kwargs)
class LoRAXFormersAttnProcessor(nn.Module):
@@ -2016,7 +1999,7 @@ class LoRAXFormersAttnProcessor(nn.Module):
self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
self_cls_name = self.__class__.__name__
deprecate(
self_cls_name,
@@ -2034,7 +2017,7 @@ class LoRAXFormersAttnProcessor(nn.Module):
attn._modules.pop("processor")
attn.processor = XFormersAttnProcessor()
return attn.processor(attn, hidden_states, **kwargs)
return attn.processor(attn, hidden_states, *args, **kwargs)
class LoRAAttnAddedKVProcessor(nn.Module):
@@ -2075,7 +2058,7 @@ class LoRAAttnAddedKVProcessor(nn.Module):
self.to_v_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
self_cls_name = self.__class__.__name__
deprecate(
self_cls_name,
@@ -2093,7 +2076,7 @@ class LoRAAttnAddedKVProcessor(nn.Module):
attn._modules.pop("processor")
attn.processor = AttnAddedKVProcessor()
return attn.processor(attn, hidden_states, **kwargs)
return attn.processor(attn, hidden_states, *args, **kwargs)
class IPAdapterAttnProcessor(nn.Module):
+11 -7
View File
@@ -18,7 +18,8 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from ..utils import deprecate
from ..utils import USE_PEFT_BACKEND
from .lora import LoRACompatibleConv
from .normalization import RMSNorm
from .upsampling import upfirdn2d_native
@@ -102,7 +103,7 @@ class Downsample2D(nn.Module):
self.padding = padding
stride = 2
self.name = name
conv_cls = nn.Conv2d
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
if norm_type == "ln_norm":
self.norm = nn.LayerNorm(channels, eps, elementwise_affine)
@@ -130,10 +131,7 @@ class Downsample2D(nn.Module):
else:
self.conv = conv
def forward(self, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor:
assert hidden_states.shape[1] == self.channels
if self.norm is not None:
@@ -145,7 +143,13 @@ class Downsample2D(nn.Module):
assert hidden_states.shape[1] == self.channels
hidden_states = self.conv(hidden_states)
if not USE_PEFT_BACKEND:
if isinstance(self.conv, LoRACompatibleConv):
hidden_states = self.conv(hidden_states, scale)
else:
hidden_states = self.conv(hidden_states)
else:
hidden_states = self.conv(hidden_states)
return hidden_states
+3 -2
View File
@@ -18,9 +18,10 @@ import numpy as np
import torch
from torch import nn
from ..utils import deprecate
from ..utils import USE_PEFT_BACKEND, deprecate
from .activations import get_activation
from .attention_processor import Attention
from .lora import LoRACompatibleLinear
def get_timestep_embedding(
@@ -199,7 +200,7 @@ class TimestepEmbedding(nn.Module):
sample_proj_bias=True,
):
super().__init__()
linear_cls = nn.Linear
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
self.linear_1 = linear_cls(in_channels, time_embed_dim, sample_proj_bias)
-6
View File
@@ -204,9 +204,6 @@ class LoRALinearLayer(nn.Module):
):
super().__init__()
deprecation_message = "Use of `LoRALinearLayer` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`."
deprecate("LoRALinearLayer", "1.0.0", deprecation_message)
self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype)
self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype)
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
@@ -267,9 +264,6 @@ class LoRAConv2dLayer(nn.Module):
):
super().__init__()
deprecation_message = "Use of `LoRAConv2dLayer` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`."
deprecate("LoRAConv2dLayer", "1.0.0", deprecation_message)
self.down = nn.Conv2d(in_features, rank, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)
# according to the official kohya_ss trainer kernel_size are always fixed for the up layer
# # see: https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L129
+2 -2
View File
@@ -677,7 +677,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
if len(unexpected_keys) > 0:
logger.warning(
logger.warn(
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
)
@@ -705,7 +705,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
# the weights so we don't have to do this again.
if "'Attention' object has no attribute" in str(e):
logger.warning(
logger.warn(
f"Taking `{str(e)}` while using `accelerate.load_checkpoint_and_dispatch` to mean {pretrained_model_name_or_path}"
" was saved with deprecated attention block weight names. We will load it with the deprecated attention block"
" names and convert them on the fly to the new attention block format. Please re-save the model after this conversion,"
+56 -29
View File
@@ -20,7 +20,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from ..utils import deprecate
from ..utils import USE_PEFT_BACKEND
from .activations import get_activation
from .attention_processor import SpatialNorm
from .downsampling import ( # noqa
@@ -30,6 +30,7 @@ from .downsampling import ( # noqa
KDownsample2D,
downsample_2d,
)
from .lora import LoRACompatibleConv, LoRACompatibleLinear
from .normalization import AdaGroupNorm
from .upsampling import ( # noqa
FirUpsample2D,
@@ -101,7 +102,7 @@ class ResnetBlockCondNorm2D(nn.Module):
self.output_scale_factor = output_scale_factor
self.time_embedding_norm = time_embedding_norm
conv_cls = nn.Conv2d
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
if groups_out is None:
groups_out = groups
@@ -148,11 +149,12 @@ class ResnetBlockCondNorm2D(nn.Module):
bias=conv_shortcut_bias,
)
def forward(self, input_tensor: torch.FloatTensor, temb: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
def forward(
self,
input_tensor: torch.FloatTensor,
temb: torch.FloatTensor,
scale: float = 1.0,
) -> torch.FloatTensor:
hidden_states = input_tensor
hidden_states = self.norm1(hidden_states, temb)
@@ -164,24 +166,26 @@ class ResnetBlockCondNorm2D(nn.Module):
if hidden_states.shape[0] >= 64:
input_tensor = input_tensor.contiguous()
hidden_states = hidden_states.contiguous()
input_tensor = self.upsample(input_tensor)
hidden_states = self.upsample(hidden_states)
input_tensor = self.upsample(input_tensor, scale=scale)
hidden_states = self.upsample(hidden_states, scale=scale)
elif self.downsample is not None:
input_tensor = self.downsample(input_tensor)
hidden_states = self.downsample(hidden_states)
input_tensor = self.downsample(input_tensor, scale=scale)
hidden_states = self.downsample(hidden_states, scale=scale)
hidden_states = self.conv1(hidden_states)
hidden_states = self.conv1(hidden_states, scale) if not USE_PEFT_BACKEND else self.conv1(hidden_states)
hidden_states = self.norm2(hidden_states, temb)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states)
hidden_states = self.conv2(hidden_states, scale) if not USE_PEFT_BACKEND else self.conv2(hidden_states)
if self.conv_shortcut is not None:
input_tensor = self.conv_shortcut(input_tensor)
input_tensor = (
self.conv_shortcut(input_tensor, scale) if not USE_PEFT_BACKEND else self.conv_shortcut(input_tensor)
)
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
@@ -263,8 +267,8 @@ class ResnetBlock2D(nn.Module):
self.time_embedding_norm = time_embedding_norm
self.skip_time_act = skip_time_act
linear_cls = nn.Linear
conv_cls = nn.Conv2d
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
if groups_out is None:
groups_out = groups
@@ -322,11 +326,12 @@ class ResnetBlock2D(nn.Module):
bias=conv_shortcut_bias,
)
def forward(self, input_tensor: torch.FloatTensor, temb: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
def forward(
self,
input_tensor: torch.FloatTensor,
temb: torch.FloatTensor,
scale: float = 1.0,
) -> torch.FloatTensor:
hidden_states = input_tensor
hidden_states = self.norm1(hidden_states)
@@ -337,18 +342,38 @@ class ResnetBlock2D(nn.Module):
if hidden_states.shape[0] >= 64:
input_tensor = input_tensor.contiguous()
hidden_states = hidden_states.contiguous()
input_tensor = self.upsample(input_tensor)
hidden_states = self.upsample(hidden_states)
input_tensor = (
self.upsample(input_tensor, scale=scale)
if isinstance(self.upsample, Upsample2D)
else self.upsample(input_tensor)
)
hidden_states = (
self.upsample(hidden_states, scale=scale)
if isinstance(self.upsample, Upsample2D)
else self.upsample(hidden_states)
)
elif self.downsample is not None:
input_tensor = self.downsample(input_tensor)
hidden_states = self.downsample(hidden_states)
input_tensor = (
self.downsample(input_tensor, scale=scale)
if isinstance(self.downsample, Downsample2D)
else self.downsample(input_tensor)
)
hidden_states = (
self.downsample(hidden_states, scale=scale)
if isinstance(self.downsample, Downsample2D)
else self.downsample(hidden_states)
)
hidden_states = self.conv1(hidden_states)
hidden_states = self.conv1(hidden_states, scale) if not USE_PEFT_BACKEND else self.conv1(hidden_states)
if self.time_emb_proj is not None:
if not self.skip_time_act:
temb = self.nonlinearity(temb)
temb = self.time_emb_proj(temb)[:, :, None, None]
temb = (
self.time_emb_proj(temb, scale)[:, :, None, None]
if not USE_PEFT_BACKEND
else self.time_emb_proj(temb)[:, :, None, None]
)
if self.time_embedding_norm == "default":
if temb is not None:
@@ -368,10 +393,12 @@ class ResnetBlock2D(nn.Module):
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states)
hidden_states = self.conv2(hidden_states, scale) if not USE_PEFT_BACKEND else self.conv2(hidden_states)
if self.conv_shortcut is not None:
input_tensor = self.conv_shortcut(input_tensor)
input_tensor = (
self.conv_shortcut(input_tensor, scale) if not USE_PEFT_BACKEND else self.conv_shortcut(input_tensor)
)
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
@@ -19,16 +19,14 @@ import torch.nn.functional as F
from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import BaseOutput, deprecate, is_torch_version, logging
from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_torch_version
from ..attention import BasicTransformerBlock
from ..embeddings import ImagePositionalEmbeddings, PatchEmbed, PixArtAlphaTextProjection
from ..lora import LoRACompatibleConv, LoRACompatibleLinear
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormSingle
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@dataclass
class Transformer2DModelOutput(BaseOutput):
"""
@@ -117,8 +115,8 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
self.attention_head_dim = attention_head_dim
inner_dim = num_attention_heads * attention_head_dim
conv_cls = nn.Conv2d
linear_cls = nn.Linear
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
# 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
# Define whether input is continuous or discrete depending on configuration
@@ -306,9 +304,6 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
if cross_attention_kwargs is not None:
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
@@ -332,6 +327,9 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
# Retrieve lora scale.
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
# 1. Input
if self.is_input_continuous:
batch, _, height, width = hidden_states.shape
@@ -339,13 +337,21 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
hidden_states = self.norm(hidden_states)
if not self.use_linear_projection:
hidden_states = self.proj_in(hidden_states)
hidden_states = (
self.proj_in(hidden_states, scale=lora_scale)
if not USE_PEFT_BACKEND
else self.proj_in(hidden_states)
)
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
else:
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
hidden_states = self.proj_in(hidden_states)
hidden_states = (
self.proj_in(hidden_states, scale=lora_scale)
if not USE_PEFT_BACKEND
else self.proj_in(hidden_states)
)
elif self.is_input_vectorized:
hidden_states = self.latent_image_embedding(hidden_states)
@@ -408,9 +414,17 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
if self.is_input_continuous:
if not self.use_linear_projection:
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
hidden_states = self.proj_out(hidden_states)
hidden_states = (
self.proj_out(hidden_states, scale=lora_scale)
if not USE_PEFT_BACKEND
else self.proj_out(hidden_states)
)
else:
hidden_states = self.proj_out(hidden_states)
hidden_states = (
self.proj_out(hidden_states, scale=lora_scale)
if not USE_PEFT_BACKEND
else self.proj_out(hidden_states)
)
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
output = hidden_states + residual
+98 -163
View File
@@ -18,7 +18,7 @@ import torch
import torch.nn.functional as F
from torch import nn
from ...utils import deprecate, is_torch_version, logging
from ...utils import is_torch_version, logging
from ...utils.torch_utils import apply_freeu
from ..activations import get_activation
from ..attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0
@@ -69,7 +69,7 @@ def get_down_block(
):
# If attn head dim is not defined, we default it to the number of heads
if attention_head_dim is None:
logger.warning(
logger.warn(
f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
)
attention_head_dim = num_attention_heads
@@ -354,7 +354,7 @@ def get_up_block(
) -> nn.Module:
# If attn head dim is not defined, we default it to the number of heads
if attention_head_dim is None:
logger.warning(
logger.warn(
f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
)
attention_head_dim = num_attention_heads
@@ -673,7 +673,7 @@ class UNetMidBlock2D(nn.Module):
attentions = []
if attention_head_dim is None:
logger.warning(
logger.warn(
f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}."
)
attention_head_dim = in_channels
@@ -844,11 +844,8 @@ class UNetMidBlock2DCrossAttn(nn.Module):
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
if cross_attention_kwargs is not None:
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
hidden_states = self.resnets[0](hidden_states, temb)
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
if self.training and self.gradient_checkpointing:
@@ -885,7 +882,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
hidden_states = resnet(hidden_states, temb)
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
return hidden_states
@@ -985,8 +982,7 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
encoder_attention_mask: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
lora_scale = cross_attention_kwargs.get("scale", 1.0)
if attention_mask is None:
# if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask.
@@ -999,7 +995,7 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
# mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask
mask = attention_mask
hidden_states = self.resnets[0](hidden_states, temb)
hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
# attn
hidden_states = attn(
@@ -1010,7 +1006,7 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
)
# resnet
hidden_states = resnet(hidden_states, temb)
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
return hidden_states
@@ -1039,7 +1035,7 @@ class AttnDownBlock2D(nn.Module):
self.downsample_type = downsample_type
if attention_head_dim is None:
logger.warning(
logger.warn(
f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}."
)
attention_head_dim = out_channels
@@ -1115,22 +1111,23 @@ class AttnDownBlock2D(nn.Module):
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
lora_scale = cross_attention_kwargs.get("scale", 1.0)
output_states = ()
for resnet, attn in zip(self.resnets, self.attentions):
hidden_states = resnet(hidden_states, temb)
cross_attention_kwargs.update({"scale": lora_scale})
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
hidden_states = attn(hidden_states, **cross_attention_kwargs)
output_states = output_states + (hidden_states,)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
if self.downsample_type == "resnet":
hidden_states = downsampler(hidden_states, temb=temb)
hidden_states = downsampler(hidden_states, temb=temb, scale=lora_scale)
else:
hidden_states = downsampler(hidden_states)
hidden_states = downsampler(hidden_states, scale=lora_scale)
output_states += (hidden_states,)
@@ -1239,12 +1236,10 @@ class CrossAttnDownBlock2D(nn.Module):
encoder_attention_mask: Optional[torch.FloatTensor] = None,
additional_residuals: Optional[torch.FloatTensor] = None,
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
if cross_attention_kwargs is not None:
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
output_states = ()
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
blocks = list(zip(self.resnets, self.attentions))
for i, (resnet, attn) in enumerate(blocks):
@@ -1275,7 +1270,7 @@ class CrossAttnDownBlock2D(nn.Module):
return_dict=False,
)[0]
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
@@ -1293,7 +1288,7 @@ class CrossAttnDownBlock2D(nn.Module):
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
hidden_states = downsampler(hidden_states, scale=lora_scale)
output_states = output_states + (hidden_states,)
@@ -1353,12 +1348,8 @@ class DownBlock2D(nn.Module):
self.gradient_checkpointing = False
def forward(
self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, *args, **kwargs
self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
output_states = ()
for resnet in self.resnets:
@@ -1379,13 +1370,13 @@ class DownBlock2D(nn.Module):
create_custom_forward(resnet), hidden_states, temb
)
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = resnet(hidden_states, temb, scale=scale)
output_states = output_states + (hidden_states,)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
hidden_states = downsampler(hidden_states, scale=scale)
output_states = output_states + (hidden_states,)
@@ -1456,17 +1447,13 @@ class DownEncoderBlock2D(nn.Module):
else:
self.downsamplers = None
def forward(self, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor:
for resnet in self.resnets:
hidden_states = resnet(hidden_states, temb=None)
hidden_states = resnet(hidden_states, temb=None, scale=scale)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
hidden_states = downsampler(hidden_states, scale)
return hidden_states
@@ -1493,7 +1480,7 @@ class AttnDownEncoderBlock2D(nn.Module):
attentions = []
if attention_head_dim is None:
logger.warning(
logger.warn(
f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}."
)
attention_head_dim = out_channels
@@ -1558,18 +1545,15 @@ class AttnDownEncoderBlock2D(nn.Module):
else:
self.downsamplers = None
def forward(self, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor:
for resnet, attn in zip(self.resnets, self.attentions):
hidden_states = resnet(hidden_states, temb=None)
hidden_states = attn(hidden_states)
hidden_states = resnet(hidden_states, temb=None, scale=scale)
cross_attention_kwargs = {"scale": scale}
hidden_states = attn(hidden_states, **cross_attention_kwargs)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
hidden_states = downsampler(hidden_states, scale)
return hidden_states
@@ -1595,7 +1579,7 @@ class AttnSkipDownBlock2D(nn.Module):
self.resnets = nn.ModuleList([])
if attention_head_dim is None:
logger.warning(
logger.warn(
f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}."
)
attention_head_dim = out_channels
@@ -1660,22 +1644,18 @@ class AttnSkipDownBlock2D(nn.Module):
hidden_states: torch.FloatTensor,
temb: Optional[torch.FloatTensor] = None,
skip_sample: Optional[torch.FloatTensor] = None,
*args,
**kwargs,
scale: float = 1.0,
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...], torch.FloatTensor]:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
output_states = ()
for resnet, attn in zip(self.resnets, self.attentions):
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states)
hidden_states = resnet(hidden_states, temb, scale=scale)
cross_attention_kwargs = {"scale": scale}
hidden_states = attn(hidden_states, **cross_attention_kwargs)
output_states += (hidden_states,)
if self.downsamplers is not None:
hidden_states = self.resnet_down(hidden_states, temb)
hidden_states = self.resnet_down(hidden_states, temb, scale=scale)
for downsampler in self.downsamplers:
skip_sample = downsampler(skip_sample)
@@ -1751,21 +1731,16 @@ class SkipDownBlock2D(nn.Module):
hidden_states: torch.FloatTensor,
temb: Optional[torch.FloatTensor] = None,
skip_sample: Optional[torch.FloatTensor] = None,
*args,
**kwargs,
scale: float = 1.0,
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...], torch.FloatTensor]:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
output_states = ()
for resnet in self.resnets:
hidden_states = resnet(hidden_states, temb)
hidden_states = resnet(hidden_states, temb, scale)
output_states += (hidden_states,)
if self.downsamplers is not None:
hidden_states = self.resnet_down(hidden_states, temb)
hidden_states = self.resnet_down(hidden_states, temb, scale)
for downsampler in self.downsamplers:
skip_sample = downsampler(skip_sample)
@@ -1841,12 +1816,8 @@ class ResnetDownsampleBlock2D(nn.Module):
self.gradient_checkpointing = False
def forward(
self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, *args, **kwargs
self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
output_states = ()
for resnet in self.resnets:
@@ -1867,13 +1838,13 @@ class ResnetDownsampleBlock2D(nn.Module):
create_custom_forward(resnet), hidden_states, temb
)
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = resnet(hidden_states, temb, scale)
output_states = output_states + (hidden_states,)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states, temb)
hidden_states = downsampler(hidden_states, temb, scale)
output_states = output_states + (hidden_states,)
@@ -1984,11 +1955,10 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
output_states = ()
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
lora_scale = cross_attention_kwargs.get("scale", 1.0)
if attention_mask is None:
# if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask.
@@ -2021,7 +1991,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
**cross_attention_kwargs,
)
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
hidden_states = attn(
hidden_states,
@@ -2034,7 +2004,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states, temb)
hidden_states = downsampler(hidden_states, temb, scale=lora_scale)
output_states = output_states + (hidden_states,)
@@ -2088,12 +2058,8 @@ class KDownBlock2D(nn.Module):
self.gradient_checkpointing = False
def forward(
self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, *args, **kwargs
self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
output_states = ()
for resnet in self.resnets:
@@ -2114,7 +2080,7 @@ class KDownBlock2D(nn.Module):
create_custom_forward(resnet), hidden_states, temb
)
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = resnet(hidden_states, temb, scale)
output_states += (hidden_states,)
@@ -2199,11 +2165,8 @@ class KCrossAttnDownBlock2D(nn.Module):
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
output_states = ()
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
for resnet, attn in zip(self.resnets, self.attentions):
if self.training and self.gradient_checkpointing:
@@ -2233,7 +2196,7 @@ class KCrossAttnDownBlock2D(nn.Module):
encoder_attention_mask=encoder_attention_mask,
)
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
@@ -2281,7 +2244,7 @@ class AttnUpBlock2D(nn.Module):
self.upsample_type = upsample_type
if attention_head_dim is None:
logger.warning(
logger.warn(
f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}."
)
attention_head_dim = out_channels
@@ -2353,28 +2316,24 @@ class AttnUpBlock2D(nn.Module):
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb: Optional[torch.FloatTensor] = None,
upsample_size: Optional[int] = None,
*args,
**kwargs,
scale: float = 1.0,
) -> torch.FloatTensor:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
for resnet, attn in zip(self.resnets, self.attentions):
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states)
hidden_states = resnet(hidden_states, temb, scale=scale)
cross_attention_kwargs = {"scale": scale}
hidden_states = attn(hidden_states, **cross_attention_kwargs)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
if self.upsample_type == "resnet":
hidden_states = upsampler(hidden_states, temb=temb)
hidden_states = upsampler(hidden_states, temb=temb, scale=scale)
else:
hidden_states = upsampler(hidden_states)
hidden_states = upsampler(hidden_states, scale=scale)
return hidden_states
@@ -2481,10 +2440,7 @@ class CrossAttnUpBlock2D(nn.Module):
attention_mask: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
if cross_attention_kwargs is not None:
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
is_freeu_enabled = (
getattr(self, "s1", None)
and getattr(self, "s2", None)
@@ -2538,7 +2494,7 @@ class CrossAttnUpBlock2D(nn.Module):
return_dict=False,
)[0]
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
@@ -2550,7 +2506,7 @@ class CrossAttnUpBlock2D(nn.Module):
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size)
hidden_states = upsampler(hidden_states, upsample_size, scale=lora_scale)
return hidden_states
@@ -2611,13 +2567,8 @@ class UpBlock2D(nn.Module):
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb: Optional[torch.FloatTensor] = None,
upsample_size: Optional[int] = None,
*args,
**kwargs,
scale: float = 1.0,
) -> torch.FloatTensor:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
is_freeu_enabled = (
getattr(self, "s1", None)
and getattr(self, "s2", None)
@@ -2661,11 +2612,11 @@ class UpBlock2D(nn.Module):
create_custom_forward(resnet), hidden_states, temb
)
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = resnet(hidden_states, temb, scale=scale)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size)
hidden_states = upsampler(hidden_states, upsample_size, scale=scale)
return hidden_states
@@ -2732,9 +2683,11 @@ class UpDecoderBlock2D(nn.Module):
self.resolution_idx = resolution_idx
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
def forward(
self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0
) -> torch.FloatTensor:
for resnet in self.resnets:
hidden_states = resnet(hidden_states, temb=temb)
hidden_states = resnet(hidden_states, temb=temb, scale=scale)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
@@ -2766,7 +2719,7 @@ class AttnUpDecoderBlock2D(nn.Module):
attentions = []
if attention_head_dim is None:
logger.warning(
logger.warn(
f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `out_channels`: {out_channels}."
)
attention_head_dim = out_channels
@@ -2830,14 +2783,17 @@ class AttnUpDecoderBlock2D(nn.Module):
self.resolution_idx = resolution_idx
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
def forward(
self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0
) -> torch.FloatTensor:
for resnet, attn in zip(self.resnets, self.attentions):
hidden_states = resnet(hidden_states, temb=temb)
hidden_states = attn(hidden_states, temb=temb)
hidden_states = resnet(hidden_states, temb=temb, scale=scale)
cross_attention_kwargs = {"scale": scale}
hidden_states = attn(hidden_states, temb=temb, **cross_attention_kwargs)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states)
hidden_states = upsampler(hidden_states, scale=scale)
return hidden_states
@@ -2885,7 +2841,7 @@ class AttnSkipUpBlock2D(nn.Module):
)
if attention_head_dim is None:
logger.warning(
logger.warn(
f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `out_channels`: {out_channels}."
)
attention_head_dim = out_channels
@@ -2942,22 +2898,18 @@ class AttnSkipUpBlock2D(nn.Module):
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb: Optional[torch.FloatTensor] = None,
skip_sample=None,
*args,
**kwargs,
scale: float = 1.0,
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
for resnet in self.resnets:
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
hidden_states = resnet(hidden_states, temb)
hidden_states = resnet(hidden_states, temb, scale=scale)
hidden_states = self.attentions[0](hidden_states)
cross_attention_kwargs = {"scale": scale}
hidden_states = self.attentions[0](hidden_states, **cross_attention_kwargs)
if skip_sample is not None:
skip_sample = self.upsampler(skip_sample)
@@ -2971,7 +2923,7 @@ class AttnSkipUpBlock2D(nn.Module):
skip_sample = skip_sample + skip_sample_states
hidden_states = self.resnet_up(hidden_states, temb)
hidden_states = self.resnet_up(hidden_states, temb, scale=scale)
return hidden_states, skip_sample
@@ -3054,20 +3006,15 @@ class SkipUpBlock2D(nn.Module):
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb: Optional[torch.FloatTensor] = None,
skip_sample=None,
*args,
**kwargs,
scale: float = 1.0,
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
for resnet in self.resnets:
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
hidden_states = resnet(hidden_states, temb)
hidden_states = resnet(hidden_states, temb, scale=scale)
if skip_sample is not None:
skip_sample = self.upsampler(skip_sample)
@@ -3081,7 +3028,7 @@ class SkipUpBlock2D(nn.Module):
skip_sample = skip_sample + skip_sample_states
hidden_states = self.resnet_up(hidden_states, temb)
hidden_states = self.resnet_up(hidden_states, temb, scale=scale)
return hidden_states, skip_sample
@@ -3161,13 +3108,8 @@ class ResnetUpsampleBlock2D(nn.Module):
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb: Optional[torch.FloatTensor] = None,
upsample_size: Optional[int] = None,
*args,
**kwargs,
scale: float = 1.0,
) -> torch.FloatTensor:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
for resnet in self.resnets:
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
@@ -3191,11 +3133,11 @@ class ResnetUpsampleBlock2D(nn.Module):
create_custom_forward(resnet), hidden_states, temb
)
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = resnet(hidden_states, temb, scale=scale)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, temb)
hidden_states = upsampler(hidden_states, temb, scale=scale)
return hidden_states
@@ -3311,9 +3253,8 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
encoder_attention_mask: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
lora_scale = cross_attention_kwargs.get("scale", 1.0)
if attention_mask is None:
# if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask.
mask = None if encoder_hidden_states is None else encoder_attention_mask
@@ -3351,7 +3292,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
**cross_attention_kwargs,
)
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
hidden_states = attn(
hidden_states,
@@ -3362,7 +3303,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, temb)
hidden_states = upsampler(hidden_states, temb, scale=lora_scale)
return hidden_states
@@ -3423,13 +3364,8 @@ class KUpBlock2D(nn.Module):
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb: Optional[torch.FloatTensor] = None,
upsample_size: Optional[int] = None,
*args,
**kwargs,
scale: float = 1.0,
) -> torch.FloatTensor:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
res_hidden_states_tuple = res_hidden_states_tuple[-1]
if res_hidden_states_tuple is not None:
hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1)
@@ -3452,7 +3388,7 @@ class KUpBlock2D(nn.Module):
create_custom_forward(resnet), hidden_states, temb
)
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = resnet(hidden_states, temb, scale=scale)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
@@ -3562,6 +3498,7 @@ class KCrossAttnUpBlock2D(nn.Module):
if res_hidden_states_tuple is not None:
hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1)
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
for resnet, attn in zip(self.resnets, self.attentions):
if self.training and self.gradient_checkpointing:
@@ -3590,7 +3527,7 @@ class KCrossAttnUpBlock2D(nn.Module):
encoder_attention_mask=encoder_attention_mask,
)
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
@@ -3693,8 +3630,6 @@ class KAttentionBlock(nn.Module):
encoder_attention_mask: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
# 1. Self-Attention
if self.add_self_attention:
@@ -1226,7 +1226,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
**additional_residuals,
)
else:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale)
if is_adapter and len(down_intrablock_additional_residuals) > 0:
sample += down_intrablock_additional_residuals.pop(0)
@@ -1297,6 +1297,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
temb=emb,
res_hidden_states_tuple=res_samples,
upsample_size=upsample_size,
scale=lora_scale,
)
# 6. post-process
+18 -39
View File
@@ -17,7 +17,7 @@ from typing import Any, Dict, Optional, Tuple, Union
import torch
from torch import nn
from ...utils import deprecate, is_torch_version, logging
from ...utils import is_torch_version
from ...utils.torch_utils import apply_freeu
from ..attention import Attention
from ..resnet import (
@@ -35,9 +35,6 @@ from ..transformers.transformer_temporal import (
)
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def get_down_block(
down_block_type: str,
num_layers: int,
@@ -1008,14 +1005,9 @@ class DownBlockMotion(nn.Module):
self,
hidden_states: torch.FloatTensor,
temb: Optional[torch.FloatTensor] = None,
scale: float = 1.0,
num_frames: int = 1,
*args,
**kwargs,
) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
output_states = ()
blocks = zip(self.resnets, self.motion_modules)
@@ -1037,18 +1029,18 @@ class DownBlockMotion(nn.Module):
)
else:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb
create_custom_forward(resnet), hidden_states, temb, scale
)
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = resnet(hidden_states, temb, scale=scale)
hidden_states = motion_module(hidden_states, num_frames=num_frames)[0]
output_states = output_states + (hidden_states,)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
hidden_states = downsampler(hidden_states, scale=scale)
output_states = output_states + (hidden_states,)
@@ -1181,12 +1173,10 @@ class CrossAttnDownBlockMotion(nn.Module):
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
additional_residuals: Optional[torch.FloatTensor] = None,
):
if cross_attention_kwargs is not None:
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
output_states = ()
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
blocks = list(zip(self.resnets, self.attentions, self.motion_modules))
for i, (resnet, attn, motion_module) in enumerate(blocks):
if self.training and self.gradient_checkpointing:
@@ -1216,7 +1206,7 @@ class CrossAttnDownBlockMotion(nn.Module):
return_dict=False,
)[0]
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
@@ -1238,7 +1228,7 @@ class CrossAttnDownBlockMotion(nn.Module):
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
hidden_states = downsampler(hidden_states, scale=lora_scale)
output_states = output_states + (hidden_states,)
@@ -1365,10 +1355,7 @@ class CrossAttnUpBlockMotion(nn.Module):
encoder_attention_mask: Optional[torch.FloatTensor] = None,
num_frames: int = 1,
) -> torch.FloatTensor:
if cross_attention_kwargs is not None:
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
is_freeu_enabled = (
getattr(self, "s1", None)
and getattr(self, "s2", None)
@@ -1423,7 +1410,7 @@ class CrossAttnUpBlockMotion(nn.Module):
return_dict=False,
)[0]
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
@@ -1439,7 +1426,7 @@ class CrossAttnUpBlockMotion(nn.Module):
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size)
hidden_states = upsampler(hidden_states, upsample_size, scale=lora_scale)
return hidden_states
@@ -1520,14 +1507,9 @@ class UpBlockMotion(nn.Module):
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb: Optional[torch.FloatTensor] = None,
upsample_size=None,
scale: float = 1.0,
num_frames: int = 1,
*args,
**kwargs,
) -> torch.FloatTensor:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
is_freeu_enabled = (
getattr(self, "s1", None)
and getattr(self, "s2", None)
@@ -1577,12 +1559,12 @@ class UpBlockMotion(nn.Module):
)
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = resnet(hidden_states, temb, scale=scale)
hidden_states = motion_module(hidden_states, num_frames=num_frames)[0]
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size)
hidden_states = upsampler(hidden_states, upsample_size, scale=scale)
return hidden_states
@@ -1705,11 +1687,8 @@ class UNetMidBlockCrossAttnMotion(nn.Module):
encoder_attention_mask: Optional[torch.FloatTensor] = None,
num_frames: int = 1,
) -> torch.FloatTensor:
if cross_attention_kwargs is not None:
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
hidden_states = self.resnets[0](hidden_states, temb)
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale)
blocks = zip(self.attentions, self.resnets[1:], self.motion_modules)
for attn, resnet, motion_module in blocks:
@@ -1758,7 +1737,7 @@ class UNetMidBlockCrossAttnMotion(nn.Module):
hidden_states,
num_frames=num_frames,
)[0]
hidden_states = resnet(hidden_states, temb)
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
return hidden_states
+1 -1
View File
@@ -89,7 +89,7 @@ class I2VGenXLTransformerTemporalEncoder(nn.Module):
if hidden_states.ndim == 4:
hidden_states = hidden_states.squeeze(1)
ff_output = self.ff(hidden_states)
ff_output = self.ff(hidden_states, scale=1.0)
hidden_states = ff_output + hidden_states
if hidden_states.ndim == 4:
hidden_states = hidden_states.squeeze(1)
@@ -21,7 +21,6 @@ import torch
import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders.unet import FromOriginalUNetMixin
from ...utils import BaseOutput
from ..attention_processor import Attention
from ..modeling_utils import ModelMixin
@@ -135,7 +134,7 @@ class StableCascadeUNetOutput(BaseOutput):
sample: torch.FloatTensor = None
class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalUNetMixin):
class StableCascadeUNet(ModelMixin, ConfigMixin):
_supports_gradient_checkpointing = True
@register_to_config
+15 -9
View File
@@ -18,7 +18,8 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from ..utils import deprecate
from ..utils import USE_PEFT_BACKEND
from .lora import LoRACompatibleConv
from .normalization import RMSNorm
@@ -110,7 +111,7 @@ class Upsample2D(nn.Module):
self.use_conv_transpose = use_conv_transpose
self.name = name
self.interpolate = interpolate
conv_cls = nn.Conv2d
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
if norm_type == "ln_norm":
self.norm = nn.LayerNorm(channels, eps, elementwise_affine)
@@ -140,12 +141,11 @@ class Upsample2D(nn.Module):
self.Conv2d_0 = conv
def forward(
self, hidden_states: torch.FloatTensor, output_size: Optional[int] = None, *args, **kwargs
self,
hidden_states: torch.FloatTensor,
output_size: Optional[int] = None,
scale: float = 1.0,
) -> torch.FloatTensor:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
assert hidden_states.shape[1] == self.channels
if self.norm is not None:
@@ -180,9 +180,15 @@ class Upsample2D(nn.Module):
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
if self.use_conv:
if self.name == "conv":
hidden_states = self.conv(hidden_states)
if isinstance(self.conv, LoRACompatibleConv) and not USE_PEFT_BACKEND:
hidden_states = self.conv(hidden_states, scale)
else:
hidden_states = self.conv(hidden_states)
else:
hidden_states = self.Conv2d_0(hidden_states)
if isinstance(self.Conv2d_0, LoRACompatibleConv) and not USE_PEFT_BACKEND:
hidden_states = self.Conv2d_0(hidden_states, scale)
else:
hidden_states = self.Conv2d_0(hidden_states)
return hidden_states
@@ -156,7 +156,7 @@ class FlaxStableDiffusionControlNetPipeline(FlaxDiffusionPipeline):
self.dtype = dtype
if safety_checker is None:
logger.warning(
logger.warn(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
@@ -416,13 +416,13 @@ class IFPipeline(DiffusionPipeline, LoraLoaderMixin):
def _text_preprocessing(self, text, clean_caption=False):
if clean_caption and not is_bs4_available():
logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
logger.warning("Setting `clean_caption` to False...")
logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
logger.warn("Setting `clean_caption` to False...")
clean_caption = False
if clean_caption and not is_ftfy_available():
logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
logger.warning("Setting `clean_caption` to False...")
logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
logger.warn("Setting `clean_caption` to False...")
clean_caption = False
if not isinstance(text, (tuple, list)):
@@ -460,13 +460,13 @@ class IFImg2ImgPipeline(DiffusionPipeline, LoraLoaderMixin):
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
def _text_preprocessing(self, text, clean_caption=False):
if clean_caption and not is_bs4_available():
logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
logger.warning("Setting `clean_caption` to False...")
logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
logger.warn("Setting `clean_caption` to False...")
clean_caption = False
if clean_caption and not is_ftfy_available():
logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
logger.warning("Setting `clean_caption` to False...")
logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
logger.warn("Setting `clean_caption` to False...")
clean_caption = False
if not isinstance(text, (tuple, list)):
@@ -175,7 +175,7 @@ class IFImg2ImgSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
)
if unet.config.in_channels != 6:
logger.warning(
logger.warn(
"It seems like you have loaded a checkpoint that shall not be used for super resolution from {unet.config._name_or_path} as it accepts {unet.config.in_channels} input channels instead of 6. Please make sure to pass a super resolution checkpoint as the `'unet'`: IFSuperResolutionPipeline.from_pretrained(unet=super_resolution_unet, ...)`."
)
@@ -209,13 +209,13 @@ class IFImg2ImgSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
def _text_preprocessing(self, text, clean_caption=False):
if clean_caption and not is_bs4_available():
logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
logger.warning("Setting `clean_caption` to False...")
logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
logger.warn("Setting `clean_caption` to False...")
clean_caption = False
if clean_caption and not is_ftfy_available():
logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
logger.warning("Setting `clean_caption` to False...")
logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
logger.warn("Setting `clean_caption` to False...")
clean_caption = False
if not isinstance(text, (tuple, list)):
@@ -500,13 +500,13 @@ class IFInpaintingPipeline(DiffusionPipeline, LoraLoaderMixin):
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
def _text_preprocessing(self, text, clean_caption=False):
if clean_caption and not is_bs4_available():
logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
logger.warning("Setting `clean_caption` to False...")
logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
logger.warn("Setting `clean_caption` to False...")
clean_caption = False
if clean_caption and not is_ftfy_available():
logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
logger.warning("Setting `clean_caption` to False...")
logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
logger.warn("Setting `clean_caption` to False...")
clean_caption = False
if not isinstance(text, (tuple, list)):
@@ -177,7 +177,7 @@ class IFInpaintingSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
)
if unet.config.in_channels != 6:
logger.warning(
logger.warn(
"It seems like you have loaded a checkpoint that shall not be used for super resolution from {unet.config._name_or_path} as it accepts {unet.config.in_channels} input channels instead of 6. Please make sure to pass a super resolution checkpoint as the `'unet'`: IFSuperResolutionPipeline.from_pretrained(unet=super_resolution_unet, ...)`."
)
@@ -211,13 +211,13 @@ class IFInpaintingSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
def _text_preprocessing(self, text, clean_caption=False):
if clean_caption and not is_bs4_available():
logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
logger.warning("Setting `clean_caption` to False...")
logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
logger.warn("Setting `clean_caption` to False...")
clean_caption = False
if clean_caption and not is_ftfy_available():
logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
logger.warning("Setting `clean_caption` to False...")
logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
logger.warn("Setting `clean_caption` to False...")
clean_caption = False
if not isinstance(text, (tuple, list)):
@@ -133,7 +133,7 @@ class IFSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
)
if unet.config.in_channels != 6:
logger.warning(
logger.warn(
"It seems like you have loaded a checkpoint that shall not be used for super resolution from {unet.config._name_or_path} as it accepts {unet.config.in_channels} input channels instead of 6. Please make sure to pass a super resolution checkpoint as the `'unet'`: IFSuperResolutionPipeline.from_pretrained(unet=super_resolution_unet, ...)`."
)
@@ -167,13 +167,13 @@ class IFSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
def _text_preprocessing(self, text, clean_caption=False):
if clean_caption and not is_bs4_available():
logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
logger.warning("Setting `clean_caption` to False...")
logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
logger.warn("Setting `clean_caption` to False...")
clean_caption = False
if clean_caption and not is_ftfy_available():
logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
logger.warning("Setting `clean_caption` to False...")
logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
logger.warn("Setting `clean_caption` to False...")
clean_caption = False
if not isinstance(text, (tuple, list)):
@@ -1333,7 +1333,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
**additional_residuals,
)
else:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale)
if is_adapter and len(down_intrablock_additional_residuals) > 0:
sample += down_intrablock_additional_residuals.pop(0)
@@ -1589,7 +1589,7 @@ class DownBlockFlat(nn.Module):
self.gradient_checkpointing = False
def forward(
self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None
self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
output_states = ()
@@ -1611,13 +1611,13 @@ class DownBlockFlat(nn.Module):
create_custom_forward(resnet), hidden_states, temb
)
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = resnet(hidden_states, temb, scale=scale)
output_states = output_states + (hidden_states,)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
hidden_states = downsampler(hidden_states, scale=scale)
output_states = output_states + (hidden_states,)
@@ -1728,6 +1728,8 @@ class CrossAttnDownBlockFlat(nn.Module):
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
output_states = ()
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
blocks = list(zip(self.resnets, self.attentions))
for i, (resnet, attn) in enumerate(blocks):
@@ -1758,7 +1760,7 @@ class CrossAttnDownBlockFlat(nn.Module):
return_dict=False,
)[0]
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
@@ -1776,7 +1778,7 @@ class CrossAttnDownBlockFlat(nn.Module):
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
hidden_states = downsampler(hidden_states, scale=lora_scale)
output_states = output_states + (hidden_states,)
@@ -1840,13 +1842,8 @@ class UpBlockFlat(nn.Module):
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb: Optional[torch.FloatTensor] = None,
upsample_size: Optional[int] = None,
*args,
**kwargs,
scale: float = 1.0,
) -> torch.FloatTensor:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
is_freeu_enabled = (
getattr(self, "s1", None)
and getattr(self, "s2", None)
@@ -1890,11 +1887,11 @@ class UpBlockFlat(nn.Module):
create_custom_forward(resnet), hidden_states, temb
)
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = resnet(hidden_states, temb, scale=scale)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size)
hidden_states = upsampler(hidden_states, upsample_size, scale=scale)
return hidden_states
@@ -2002,10 +1999,7 @@ class CrossAttnUpBlockFlat(nn.Module):
attention_mask: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
if cross_attention_kwargs is not None:
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
is_freeu_enabled = (
getattr(self, "s1", None)
and getattr(self, "s2", None)
@@ -2059,7 +2053,7 @@ class CrossAttnUpBlockFlat(nn.Module):
return_dict=False,
)[0]
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
@@ -2071,7 +2065,7 @@ class CrossAttnUpBlockFlat(nn.Module):
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size)
hidden_states = upsampler(hidden_states, upsample_size, scale=lora_scale)
return hidden_states
@@ -2164,7 +2158,7 @@ class UNetMidBlockFlat(nn.Module):
attentions = []
if attention_head_dim is None:
logger.warning(
logger.warn(
f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}."
)
attention_head_dim = in_channels
@@ -2336,11 +2330,8 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
if cross_attention_kwargs is not None:
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
hidden_states = self.resnets[0](hidden_states, temb)
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
if self.training and self.gradient_checkpointing:
@@ -2377,7 +2368,7 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
hidden_states = resnet(hidden_states, temb)
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
return hidden_states
@@ -2478,8 +2469,7 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
encoder_attention_mask: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
lora_scale = cross_attention_kwargs.get("scale", 1.0)
if attention_mask is None:
# if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask.
@@ -2492,7 +2482,7 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
# mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask
mask = attention_mask
hidden_states = self.resnets[0](hidden_states, temb)
hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
# attn
hidden_states = attn(
@@ -2503,6 +2493,6 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
)
# resnet
hidden_states = resnet(hidden_states, temb)
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
return hidden_states
@@ -481,7 +481,7 @@ class KandinskyInpaintPipeline(DiffusionPipeline):
if not self._warn_has_been_called and version.parse(version.parse(__version__).base_version) < version.parse(
"0.23.0.dev0"
):
logger.warning(
logger.warn(
"Please note that the expected format of `mask_image` has recently been changed. "
"Before diffusers == 0.19.0, Kandinsky Inpainting pipelines repainted black pixels and preserved black pixels. "
"As of diffusers==0.19.0 this behavior has been inverted. Now white pixels are repainted and black pixels are preserved. "
@@ -372,7 +372,7 @@ class KandinskyV22InpaintPipeline(DiffusionPipeline):
if not self._warn_has_been_called and version.parse(version.parse(__version__).base_version) < version.parse(
"0.23.0.dev0"
):
logger.warning(
logger.warn(
"Please note that the expected format of `mask_image` has recently been changed. "
"Before diffusers == 0.19.0, Kandinsky Inpainting pipelines repainted black pixels and preserved black pixels. "
"As of diffusers==0.19.0 this behavior has been inverted. Now white pixels are repainted and black pixels are preserved. "
+4 -6
View File
@@ -256,9 +256,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
break
if save_method_name is None:
logger.warning(
f"self.{pipeline_component_name}={sub_model} of type {type(sub_model)} cannot be saved."
)
logger.warn(f"self.{pipeline_component_name}={sub_model} of type {type(sub_model)} cannot be saved.")
# make sure that unsaveable components are not tried to be loaded afterward
self.register_to_config(**{pipeline_component_name: (None, None)})
continue
@@ -1204,7 +1202,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
try:
info = model_info(pretrained_model_name, token=token, revision=revision)
except (HTTPError, OfflineModeIsEnabled, requests.ConnectionError) as e:
logger.warning(f"Couldn't connect to the Hub: {e}.\nWill try to load from local cache.")
logger.warn(f"Couldn't connect to the Hub: {e}.\nWill try to load from local cache.")
local_files_only = True
model_info_call_error = e # save error to reraise it if model is not cached locally
@@ -1355,7 +1353,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
len(safetensors_variant_filenames) > 0
and safetensors_model_filenames != safetensors_variant_filenames
):
logger.warning(
logger.warn(
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(safetensors_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(safetensors_model_filenames - safetensors_variant_filenames)}\nIf this behavior is not expected, please check your folder structure."
)
else:
@@ -1368,7 +1366,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
bin_variant_filenames = {f for f in variant_filenames if f.endswith(".bin")}
bin_model_filenames = {f for f in model_filenames if f.endswith(".bin")}
if len(bin_variant_filenames) > 0 and bin_model_filenames != bin_variant_filenames:
logger.warning(
logger.warn(
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(bin_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(bin_model_filenames - bin_variant_filenames)}\nIf this behavior is not expected, please check your folder structure."
)
@@ -514,13 +514,13 @@ class PixArtAlphaPipeline(DiffusionPipeline):
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
def _text_preprocessing(self, text, clean_caption=False):
if clean_caption and not is_bs4_available():
logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
logger.warning("Setting `clean_caption` to False...")
logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
logger.warn("Setting `clean_caption` to False...")
clean_caption = False
if clean_caption and not is_ftfy_available():
logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
logger.warning("Setting `clean_caption` to False...")
logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
logger.warn("Setting `clean_caption` to False...")
clean_caption = False
if not isinstance(text, (tuple, list)):
@@ -19,7 +19,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPV
from ...models import StableCascadeUNet
from ...schedulers import DDPMWuerstchenScheduler
from ...utils import is_torch_version, replace_example_docstring
from ...utils import replace_example_docstring
from ..pipeline_utils import DiffusionPipeline
from ..wuerstchen.modeling_paella_vq_model import PaellaVQModel
from .pipeline_stable_cascade import StableCascadeDecoderPipeline
@@ -29,10 +29,11 @@ from .pipeline_stable_cascade_prior import StableCascadePriorPipeline
TEXT2IMAGE_EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from diffusers import StableCascadeCombinedPipeline
>>> pipe = StableCascadeCombinedPipeline.from_pretrained("stabilityai/stable-cascade", variant="bf16", torch_dtype=torch.bfloat16)
>>> pipe.enable_model_cpu_offload()
>>> from diffusions import StableCascadeCombinedPipeline
>>> pipe = StableCascadeCombinedPipeline.from_pretrained("stabilityai/stable-cascade-combined", torch_dtype=torch.bfloat16).to(
... "cuda"
... )
>>> prompt = "an image of a shiba inu, donning a spacesuit and helmet"
>>> images = pipe(prompt=prompt)
```
@@ -259,12 +260,6 @@ class StableCascadeCombinedPipeline(DiffusionPipeline):
[`~pipelines.ImagePipelineOutput`] or `tuple` [`~pipelines.ImagePipelineOutput`] if `return_dict` is True,
otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images.
"""
dtype = self.decoder_pipe.decoder.dtype
if is_torch_version("<", "2.2.0") and dtype == torch.bfloat16:
raise ValueError(
"`StableCascadeCombinedPipeline` requires torch>=2.2.0 when using `torch.bfloat16` dtype."
)
prior_outputs = self.prior_pipe(
prompt=prompt if prompt_embeds is None else None,
images=images,
@@ -147,7 +147,7 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
self.dtype = dtype
if safety_checker is None:
logger.warning(
logger.warn(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
@@ -82,7 +82,7 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline, StableDiffusionMi
super().__init__()
if safety_checker is None and requires_safety_checker:
logger.warning(
logger.warn(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
@@ -2,6 +2,8 @@ import torch
import torch.nn as nn
from ...models.attention_processor import Attention
from ...models.lora import LoRACompatibleConv, LoRACompatibleLinear
from ...utils import USE_PEFT_BACKEND
class WuerstchenLayerNorm(nn.LayerNorm):
@@ -17,7 +19,7 @@ class WuerstchenLayerNorm(nn.LayerNorm):
class TimestepBlock(nn.Module):
def __init__(self, c, c_timestep):
super().__init__()
linear_cls = nn.Linear
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
self.mapper = linear_cls(c_timestep, c * 2)
def forward(self, x, t):
@@ -29,8 +31,8 @@ class ResBlock(nn.Module):
def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0):
super().__init__()
conv_cls = nn.Conv2d
linear_cls = nn.Linear
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
self.depthwise = conv_cls(c + c_skip, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c)
self.norm = WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6)
@@ -64,7 +66,7 @@ class AttnBlock(nn.Module):
def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0):
super().__init__()
linear_cls = nn.Linear
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
self.self_attn = self_attn
self.norm = WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6)
@@ -28,8 +28,9 @@ from ...models.attention_processor import (
AttnAddedKVProcessor,
AttnProcessor,
)
from ...models.lora import LoRACompatibleConv, LoRACompatibleLinear
from ...models.modeling_utils import ModelMixin
from ...utils import is_torch_version
from ...utils import USE_PEFT_BACKEND, is_torch_version
from .modeling_wuerstchen_common import AttnBlock, ResBlock, TimestepBlock, WuerstchenLayerNorm
@@ -40,8 +41,8 @@ class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
@register_to_config
def __init__(self, c_in=16, c=1280, c_cond=1024, c_r=64, depth=16, nhead=16, dropout=0.1):
super().__init__()
conv_cls = nn.Conv2d
linear_cls = nn.Linear
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
self.c_r = c_r
self.projection = conv_cls(c_in, c, kernel_size=1)
@@ -319,13 +319,13 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
self.sample = None
if not self.config.lower_order_final and num_inference_steps % self.config.solver_order != 0:
logger.warning(
logger.warn(
"Changing scheduler {self.config} to have `lower_order_final` set to True to handle uneven amount of inference steps. Please make sure to always use an even number of `num_inference steps when using `lower_order_final=False`."
)
self.register_to_config(lower_order_final=True)
if not self.config.lower_order_final and self.config.final_sigmas_type == "zero":
logger.warning(
logger.warn(
" `last_sigmas_type='zero'` is not supported for `lower_order_final=False`. Changing scheduler {self.config} to have `lower_order_final` set to True."
)
self.register_to_config(lower_order_final=True)
+1 -1
View File
@@ -116,7 +116,7 @@ def export_to_obj(mesh, output_obj_path: str = None):
def export_to_video(
video_frames: Union[List[np.ndarray], List[PIL.Image.Image]], output_video_path: str = None, fps: int = 10
video_frames: Union[List[np.ndarray], List[PIL.Image.Image]], output_video_path: str = None, fps: int = 8
) -> str:
if is_opencv_available():
import cv2
+3 -2
View File
@@ -22,6 +22,7 @@ from torch import nn
from diffusers.models.attention import GEGLU, AdaLayerNorm, ApproximateGELU
from diffusers.models.embeddings import get_timestep_embedding
from diffusers.models.lora import LoRACompatibleLinear
from diffusers.models.resnet import Downsample2D, ResnetBlock2D, Upsample2D
from diffusers.models.transformers.transformer_2d import Transformer2DModel
from diffusers.utils.testing_utils import (
@@ -481,7 +482,7 @@ class Transformer2DModelTests(unittest.TestCase):
assert spatial_transformer_block.transformer_blocks[0].ff.net[0].__class__ == GEGLU
assert spatial_transformer_block.transformer_blocks[0].ff.net[1].__class__ == nn.Dropout
assert spatial_transformer_block.transformer_blocks[0].ff.net[2].__class__ == nn.Linear
assert spatial_transformer_block.transformer_blocks[0].ff.net[2].__class__ == LoRACompatibleLinear
dim = 32
inner_dim = 128
@@ -505,7 +506,7 @@ class Transformer2DModelTests(unittest.TestCase):
assert spatial_transformer_block.transformer_blocks[0].ff.net[0].__class__ == ApproximateGELU
assert spatial_transformer_block.transformer_blocks[0].ff.net[1].__class__ == nn.Dropout
assert spatial_transformer_block.transformer_blocks[0].ff.net[2].__class__ == nn.Linear
assert spatial_transformer_block.transformer_blocks[0].ff.net[2].__class__ == LoRACompatibleLinear
dim = 32
inner_dim = 128
@@ -1,195 +0,0 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import unittest
import torch
from diffusers import StableCascadeUNet
from diffusers.utils import logging
from diffusers.utils.testing_utils import (
enable_full_determinism,
numpy_cosine_similarity_distance,
require_torch_gpu,
slow,
)
from diffusers.utils.torch_utils import randn_tensor
logger = logging.get_logger(__name__)
enable_full_determinism()
@slow
class StableCascadeUNetModelSlowTests(unittest.TestCase):
def tearDown(self) -> None:
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def test_stable_cascade_unet_prior_single_file_components(self):
single_file_url = "https://huggingface.co/stabilityai/stable-cascade/blob/main/stage_c_bf16.safetensors"
single_file_unet = StableCascadeUNet.from_single_file(single_file_url)
single_file_unet_config = single_file_unet.config
del single_file_unet
gc.collect()
torch.cuda.empty_cache()
unet = StableCascadeUNet.from_pretrained(
"stabilityai/stable-cascade-prior", subfolder="prior", revision="refs/pr/2", variant="bf16"
)
unet_config = unet.config
del unet
gc.collect()
torch.cuda.empty_cache()
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values"]
for param_name, param_value in single_file_unet_config.items():
if param_name in PARAMS_TO_IGNORE:
continue
assert unet_config[param_name] == param_value
def test_stable_cascade_unet_decoder_single_file_components(self):
single_file_url = "https://huggingface.co/stabilityai/stable-cascade/blob/main/stage_b_bf16.safetensors"
single_file_unet = StableCascadeUNet.from_single_file(single_file_url)
single_file_unet_config = single_file_unet.config
del single_file_unet
gc.collect()
torch.cuda.empty_cache()
unet = StableCascadeUNet.from_pretrained(
"stabilityai/stable-cascade", subfolder="decoder", revision="refs/pr/44", variant="bf16"
)
unet_config = unet.config
del unet
gc.collect()
torch.cuda.empty_cache()
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values"]
for param_name, param_value in single_file_unet_config.items():
if param_name in PARAMS_TO_IGNORE:
continue
assert unet_config[param_name] == param_value
def test_stable_cascade_unet_config_loading(self):
config = StableCascadeUNet.load_config(
pretrained_model_name_or_path="diffusers/stable-cascade-configs", subfolder="prior"
)
single_file_url = "https://huggingface.co/stabilityai/stable-cascade/blob/main/stage_c_bf16.safetensors"
single_file_unet = StableCascadeUNet.from_single_file(single_file_url, config=config)
single_file_unet_config = single_file_unet.config
del single_file_unet
gc.collect()
torch.cuda.empty_cache()
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values"]
for param_name, param_value in config.items():
if param_name in PARAMS_TO_IGNORE:
continue
assert single_file_unet_config[param_name] == param_value
@require_torch_gpu
def test_stable_cascade_unet_single_file_prior_forward_pass(self):
dtype = torch.bfloat16
generator = torch.Generator("cpu")
model_inputs = {
"sample": randn_tensor((1, 16, 24, 24), generator=generator.manual_seed(0)).to("cuda", dtype),
"timestep_ratio": torch.tensor([1]).to("cuda", dtype),
"clip_text_pooled": randn_tensor((1, 1, 1280), generator=generator.manual_seed(0)).to("cuda", dtype),
"clip_text": randn_tensor((1, 77, 1280), generator=generator.manual_seed(0)).to("cuda", dtype),
"clip_img": randn_tensor((1, 1, 768), generator=generator.manual_seed(0)).to("cuda", dtype),
"pixels": randn_tensor((1, 3, 8, 8), generator=generator.manual_seed(0)).to("cuda", dtype),
}
unet = StableCascadeUNet.from_pretrained(
"stabilityai/stable-cascade-prior",
subfolder="prior",
revision="refs/pr/2",
variant="bf16",
torch_dtype=dtype,
)
unet.to("cuda")
with torch.no_grad():
prior_output = unet(**model_inputs).sample.float().cpu().numpy()
# Remove UNet from GPU memory before loading the single file UNet model
del unet
gc.collect()
torch.cuda.empty_cache()
single_file_url = "https://huggingface.co/stabilityai/stable-cascade/blob/main/stage_c_bf16.safetensors"
single_file_unet = StableCascadeUNet.from_single_file(single_file_url, torch_dtype=dtype)
single_file_unet.to("cuda")
with torch.no_grad():
prior_single_file_output = single_file_unet(**model_inputs).sample.float().cpu().numpy()
# Remove UNet from GPU memory before loading the single file UNet model
del single_file_unet
gc.collect()
torch.cuda.empty_cache()
max_diff = numpy_cosine_similarity_distance(prior_output.flatten(), prior_single_file_output.flatten())
assert max_diff < 8e-3
@require_torch_gpu
def test_stable_cascade_unet_single_file_decoder_forward_pass(self):
dtype = torch.float32
generator = torch.Generator("cpu")
model_inputs = {
"sample": randn_tensor((1, 4, 256, 256), generator=generator.manual_seed(0)).to("cuda", dtype),
"timestep_ratio": torch.tensor([1]).to("cuda", dtype),
"clip_text": randn_tensor((1, 77, 1280), generator=generator.manual_seed(0)).to("cuda", dtype),
"clip_text_pooled": randn_tensor((1, 1, 1280), generator=generator.manual_seed(0)).to("cuda", dtype),
"pixels": randn_tensor((1, 3, 8, 8), generator=generator.manual_seed(0)).to("cuda", dtype),
}
unet = StableCascadeUNet.from_pretrained(
"stabilityai/stable-cascade",
subfolder="decoder",
revision="refs/pr/44",
torch_dtype=dtype,
)
unet.to("cuda")
with torch.no_grad():
prior_output = unet(**model_inputs).sample.float().cpu().numpy()
# Remove UNet from GPU memory before loading the single file UNet model
del unet
gc.collect()
torch.cuda.empty_cache()
single_file_url = "https://huggingface.co/stabilityai/stable-cascade/blob/main/stage_b.safetensors"
single_file_unet = StableCascadeUNet.from_single_file(single_file_url, torch_dtype=dtype)
single_file_unet.to("cuda")
with torch.no_grad():
prior_single_file_output = single_file_unet(**model_inputs).sample.float().cpu().numpy()
# Remove UNet from GPU memory before loading the single file UNet model
del single_file_unet
gc.collect()
torch.cuda.empty_cache()
max_diff = numpy_cosine_similarity_distance(prior_output.flatten(), prior_single_file_output.flatten())
assert max_diff < 1e-4
@@ -838,11 +838,9 @@ class StableDiffusionXLImg2ImgIntegrationTests(unittest.TestCase):
for param_name, param_value in single_file_pipe.unet.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
if param_name == "upcast_attention" and pipe.unet.config[param_name] is None:
pipe.unet.config[param_name] = False
assert (
pipe.unet.config[param_name] == param_value
), f"{param_name} is differs between single file loading and pretrained loading"
), f"{param_name} differs between single file loading and pretrained loading"
for param_name, param_value in single_file_pipe.vae.config.items():
if param_name in PARAMS_TO_IGNORE: