Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 4b682be902 | |||
| 3d0bb51d53 | |||
| 4b72aae0cd | |||
| 33bbe58ea7 |
@@ -12,13 +12,13 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
# Stable Cascade
|
||||
|
||||
This model is built upon the [Würstchen](https://openreview.net/forum?id=gU58d5QeGv) architecture and its main
|
||||
difference to other models like Stable Diffusion is that it is working at a much smaller latent space. Why is this
|
||||
important? The smaller the latent space, the **faster** you can run inference and the **cheaper** the training becomes.
|
||||
How small is the latent space? Stable Diffusion uses a compression factor of 8, resulting in a 1024x1024 image being
|
||||
encoded to 128x128. Stable Cascade achieves a compression factor of 42, meaning that it is possible to encode a
|
||||
1024x1024 image to 24x24, while maintaining crisp reconstructions. The text-conditional model is then trained in the
|
||||
highly compressed latent space. Previous versions of this architecture, achieved a 16x cost reduction over Stable
|
||||
This model is built upon the [Würstchen](https://openreview.net/forum?id=gU58d5QeGv) architecture and its main
|
||||
difference to other models like Stable Diffusion is that it is working at a much smaller latent space. Why is this
|
||||
important? The smaller the latent space, the **faster** you can run inference and the **cheaper** the training becomes.
|
||||
How small is the latent space? Stable Diffusion uses a compression factor of 8, resulting in a 1024x1024 image being
|
||||
encoded to 128x128. Stable Cascade achieves a compression factor of 42, meaning that it is possible to encode a
|
||||
1024x1024 image to 24x24, while maintaining crisp reconstructions. The text-conditional model is then trained in the
|
||||
highly compressed latent space. Previous versions of this architecture, achieved a 16x cost reduction over Stable
|
||||
Diffusion 1.5.
|
||||
|
||||
Therefore, this kind of model is well suited for usages where efficiency is important. Furthermore, all known extensions
|
||||
@@ -30,154 +30,13 @@ The original codebase can be found at [Stability-AI/StableCascade](https://githu
|
||||
Stable Cascade consists of three models: Stage A, Stage B and Stage C, representing a cascade to generate images,
|
||||
hence the name "Stable Cascade".
|
||||
|
||||
Stage A & B are used to compress images, similar to what the job of the VAE is in Stable Diffusion.
|
||||
However, with this setup, a much higher compression of images can be achieved. While the Stable Diffusion models use a
|
||||
spatial compression factor of 8, encoding an image with resolution of 1024 x 1024 to 128 x 128, Stable Cascade achieves
|
||||
a compression factor of 42. This encodes a 1024 x 1024 image to 24 x 24, while being able to accurately decode the
|
||||
image. This comes with the great benefit of cheaper training and inference. Furthermore, Stage C is responsible
|
||||
Stage A & B are used to compress images, similar to what the job of the VAE is in Stable Diffusion.
|
||||
However, with this setup, a much higher compression of images can be achieved. While the Stable Diffusion models use a
|
||||
spatial compression factor of 8, encoding an image with resolution of 1024 x 1024 to 128 x 128, Stable Cascade achieves
|
||||
a compression factor of 42. This encodes a 1024 x 1024 image to 24 x 24, while being able to accurately decode the
|
||||
image. This comes with the great benefit of cheaper training and inference. Furthermore, Stage C is responsible
|
||||
for generating the small 24 x 24 latents given a text prompt.
|
||||
|
||||
The Stage C model operates on the small 24 x 24 latents and denoises the latents conditioned on text prompts. The model is also the largest component in the Cascade pipeline and is meant to be used with the `StableCascadePriorPipeline`
|
||||
|
||||
The Stage B and Stage A models are used with the `StableCascadeDecoderPipeline` and are responsible for generating the final image given the small 24 x 24 latents.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
There are some restrictions on data types that can be used with the Stable Cascade models. The official checkpoints for the `StableCascadePriorPipeline` do not support the `torch.float16` data type. Please use `torch.bfloat16` instead.
|
||||
|
||||
In order to use the `torch.bfloat16` data type with the `StableCascadeDecoderPipeline` you need to have PyTorch 2.2.0 or higher installed. This also means that using the `StableCascadeCombinedPipeline` with `torch.bfloat16` requires PyTorch 2.2.0 or higher, since it calls the `StableCascadeDecoderPipeline` internally.
|
||||
|
||||
If it is not possible to install PyTorch 2.2.0 or higher in your environment, the `StableCascadeDecoderPipeline` can be used on its own with the `torch.float16` data type. You can download the full precision or `bf16` variant weights for the pipeline and cast the weights to `torch.float16`.
|
||||
|
||||
</Tip>
|
||||
|
||||
## Usage example
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline
|
||||
|
||||
prompt = "an image of a shiba inu, donning a spacesuit and helmet"
|
||||
negative_prompt = ""
|
||||
|
||||
prior = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", variant="bf16", torch_dtype=torch.bfloat16)
|
||||
decoder = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", variant="bf16", torch_dtype=torch.float16)
|
||||
|
||||
prior.enable_model_cpu_offload()
|
||||
prior_output = prior(
|
||||
prompt=prompt,
|
||||
height=1024,
|
||||
width=1024,
|
||||
negative_prompt=negative_prompt,
|
||||
guidance_scale=4.0,
|
||||
num_images_per_prompt=1,
|
||||
num_inference_steps=20
|
||||
)
|
||||
|
||||
decoder.enable_model_cpu_offload()
|
||||
decoder_output = decoder(
|
||||
image_embeddings=prior_output.image_embeddings.to(torch.float16),
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
guidance_scale=0.0,
|
||||
output_type="pil",
|
||||
num_inference_steps=10
|
||||
).images[0]
|
||||
decoder_output.save("cascade.png")
|
||||
```
|
||||
|
||||
## Using the Lite Versions of the Stage B and Stage C models
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import (
|
||||
StableCascadeDecoderPipeline,
|
||||
StableCascadePriorPipeline,
|
||||
StableCascadeUNet,
|
||||
)
|
||||
|
||||
prompt = "an image of a shiba inu, donning a spacesuit and helmet"
|
||||
negative_prompt = ""
|
||||
|
||||
prior_unet = StableCascadeUNet.from_pretrained("stabilityai/stable-cascade-prior", subfolder="prior_lite")
|
||||
decoder_unet = StableCascadeUNet.from_pretrained("stabilityai/stable-cascade", subfolder="decoder_lite")
|
||||
|
||||
prior = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", prior=prior_unet)
|
||||
decoder = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", decoder=decoder_unet)
|
||||
|
||||
prior.enable_model_cpu_offload()
|
||||
prior_output = prior(
|
||||
prompt=prompt,
|
||||
height=1024,
|
||||
width=1024,
|
||||
negative_prompt=negative_prompt,
|
||||
guidance_scale=4.0,
|
||||
num_images_per_prompt=1,
|
||||
num_inference_steps=20
|
||||
)
|
||||
|
||||
decoder.enable_model_cpu_offload()
|
||||
decoder_output = decoder(
|
||||
image_embeddings=prior_output.image_embeddings,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
guidance_scale=0.0,
|
||||
output_type="pil",
|
||||
num_inference_steps=10
|
||||
).images[0]
|
||||
decoder_output.save("cascade.png")
|
||||
```
|
||||
|
||||
## Loading original checkpoints with `from_single_file`
|
||||
|
||||
Loading the original format checkpoints is supported via `from_single_file` method in the StableCascadeUNet.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import (
|
||||
StableCascadeDecoderPipeline,
|
||||
StableCascadePriorPipeline,
|
||||
StableCascadeUNet,
|
||||
)
|
||||
|
||||
prompt = "an image of a shiba inu, donning a spacesuit and helmet"
|
||||
negative_prompt = ""
|
||||
|
||||
prior_unet = StableCascadeUNet.from_single_file(
|
||||
"https://huggingface.co/stabilityai/stable-cascade/resolve/main/stage_c_bf16.safetensors",
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
decoder_unet = StableCascadeUNet.from_single_file(
|
||||
"https://huggingface.co/stabilityai/stable-cascade/blob/main/stage_b_bf16.safetensors",
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
prior = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", prior=prior_unet, torch_dtype=torch.bfloat16)
|
||||
decoder = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", decoder=decoder_unet, torch_dtype=torch.bfloat16)
|
||||
|
||||
prior.enable_model_cpu_offload()
|
||||
prior_output = prior(
|
||||
prompt=prompt,
|
||||
height=1024,
|
||||
width=1024,
|
||||
negative_prompt=negative_prompt,
|
||||
guidance_scale=4.0,
|
||||
num_images_per_prompt=1,
|
||||
num_inference_steps=20
|
||||
)
|
||||
|
||||
decoder.enable_model_cpu_offload()
|
||||
decoder_output = decoder(
|
||||
image_embeddings=prior_output.image_embeddings,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
guidance_scale=0.0,
|
||||
output_type="pil",
|
||||
num_inference_steps=10
|
||||
).images[0]
|
||||
decoder_output.save("cascade-single-file.png")
|
||||
```
|
||||
|
||||
## Uses
|
||||
|
||||
### Direct Use
|
||||
@@ -194,7 +53,7 @@ Excluded uses are described below.
|
||||
|
||||
### Out-of-Scope Use
|
||||
|
||||
The model was not trained to be factual or true representations of people or events,
|
||||
The model was not trained to be factual or true representations of people or events,
|
||||
and therefore using the model to generate such content is out-of-scope for the abilities of this model.
|
||||
The model should not be used in any way that violates Stability AI's [Acceptable Use Policy](https://stability.ai/use-policy).
|
||||
|
||||
|
||||
@@ -1215,7 +1215,7 @@ def main(args):
|
||||
|
||||
xformers_version = version.parse(xformers.__version__)
|
||||
if xformers_version == version.parse("0.0.16"):
|
||||
logger.warning(
|
||||
logger.warn(
|
||||
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, "
|
||||
"please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
||||
)
|
||||
@@ -1366,14 +1366,14 @@ def main(args):
|
||||
|
||||
# Optimizer creation
|
||||
if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"):
|
||||
logger.warning(
|
||||
logger.warn(
|
||||
f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]."
|
||||
"Defaulting to adamW"
|
||||
)
|
||||
args.optimizer = "adamw"
|
||||
|
||||
if args.use_8bit_adam and not args.optimizer.lower() == "adamw":
|
||||
logger.warning(
|
||||
logger.warn(
|
||||
f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was "
|
||||
f"set to {args.optimizer.lower()}"
|
||||
)
|
||||
@@ -1407,11 +1407,11 @@ def main(args):
|
||||
optimizer_class = prodigyopt.Prodigy
|
||||
|
||||
if args.learning_rate <= 0.1:
|
||||
logger.warning(
|
||||
logger.warn(
|
||||
"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
|
||||
)
|
||||
if args.train_text_encoder and args.text_encoder_lr:
|
||||
logger.warning(
|
||||
logger.warn(
|
||||
f"Learning rates were provided both for the unet and the text encoder- e.g. text_encoder_lr:"
|
||||
f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. "
|
||||
f"When using prodigy only learning_rate is used as the initial learning rate."
|
||||
|
||||
@@ -1317,7 +1317,7 @@ def main(args):
|
||||
|
||||
xformers_version = version.parse(xformers.__version__)
|
||||
if xformers_version == version.parse("0.0.16"):
|
||||
logger.warning(
|
||||
logger.warn(
|
||||
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, "
|
||||
"please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
||||
)
|
||||
@@ -1522,14 +1522,14 @@ def main(args):
|
||||
|
||||
# Optimizer creation
|
||||
if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"):
|
||||
logger.warning(
|
||||
logger.warn(
|
||||
f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]."
|
||||
"Defaulting to adamW"
|
||||
)
|
||||
args.optimizer = "adamw"
|
||||
|
||||
if args.use_8bit_adam and not args.optimizer.lower() == "adamw":
|
||||
logger.warning(
|
||||
logger.warn(
|
||||
f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was "
|
||||
f"set to {args.optimizer.lower()}"
|
||||
)
|
||||
@@ -1563,11 +1563,11 @@ def main(args):
|
||||
optimizer_class = prodigyopt.Prodigy
|
||||
|
||||
if args.learning_rate <= 0.1:
|
||||
logger.warning(
|
||||
logger.warn(
|
||||
"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
|
||||
)
|
||||
if args.train_text_encoder and args.text_encoder_lr:
|
||||
logger.warning(
|
||||
logger.warn(
|
||||
f"Learning rates were provided both for the unet and the text encoder- e.g. text_encoder_lr:"
|
||||
f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. "
|
||||
f"When using prodigy only learning_rate is used as the initial learning rate."
|
||||
|
||||
@@ -452,7 +452,7 @@ class StableDiffusionXLInstantIDPipeline(StableDiffusionXLControlNetPipeline):
|
||||
|
||||
xformers_version = version.parse(xformers.__version__)
|
||||
if xformers_version == version.parse("0.0.16"):
|
||||
logger.warning(
|
||||
logger.warn(
|
||||
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
||||
)
|
||||
self.enable_xformers_memory_efficient_attention()
|
||||
|
||||
@@ -308,7 +308,7 @@ def log_validation(vae, unet, args, accelerator, weight_dtype, step):
|
||||
|
||||
tracker.log({"validation": formatted_images})
|
||||
else:
|
||||
logger.warning(f"image logging not implemented for {tracker.name}")
|
||||
logger.warn(f"image logging not implemented for {tracker.name}")
|
||||
|
||||
del pipeline
|
||||
gc.collect()
|
||||
@@ -1068,7 +1068,7 @@ def main(args):
|
||||
|
||||
xformers_version = version.parse(xformers.__version__)
|
||||
if xformers_version == version.parse("0.0.16"):
|
||||
logger.warning(
|
||||
logger.warn(
|
||||
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
||||
)
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
|
||||
@@ -180,7 +180,7 @@ def log_validation(vae, args, accelerator, weight_dtype, step, unet=None, is_fin
|
||||
logger_name = "test" if is_final_validation else "validation"
|
||||
tracker.log({logger_name: formatted_images})
|
||||
else:
|
||||
logger.warning(f"image logging not implemented for {tracker.name}")
|
||||
logger.warn(f"image logging not implemented for {tracker.name}")
|
||||
|
||||
del pipeline
|
||||
gc.collect()
|
||||
@@ -928,7 +928,7 @@ def main(args):
|
||||
|
||||
xformers_version = version.parse(xformers.__version__)
|
||||
if xformers_version == version.parse("0.0.16"):
|
||||
logger.warning(
|
||||
logger.warn(
|
||||
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
||||
)
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
|
||||
@@ -325,7 +325,7 @@ def log_validation(vae, unet, args, accelerator, weight_dtype, step):
|
||||
|
||||
tracker.log({"validation": formatted_images})
|
||||
else:
|
||||
logger.warning(f"image logging not implemented for {tracker.name}")
|
||||
logger.warn(f"image logging not implemented for {tracker.name}")
|
||||
|
||||
del pipeline
|
||||
gc.collect()
|
||||
@@ -1083,7 +1083,7 @@ def main(args):
|
||||
|
||||
xformers_version = version.parse(xformers.__version__)
|
||||
if xformers_version == version.parse("0.0.16"):
|
||||
logger.warning(
|
||||
logger.warn(
|
||||
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
||||
)
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
|
||||
@@ -285,7 +285,7 @@ def log_validation(vae, unet, args, accelerator, weight_dtype, step, name="targe
|
||||
|
||||
tracker.log({f"validation/{name}": formatted_images})
|
||||
else:
|
||||
logger.warning(f"image logging not implemented for {tracker.name}")
|
||||
logger.warn(f"image logging not implemented for {tracker.name}")
|
||||
|
||||
del pipeline
|
||||
gc.collect()
|
||||
@@ -1023,7 +1023,7 @@ def main(args):
|
||||
|
||||
xformers_version = version.parse(xformers.__version__)
|
||||
if xformers_version == version.parse("0.0.16"):
|
||||
logger.warning(
|
||||
logger.warn(
|
||||
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
||||
)
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
|
||||
@@ -303,7 +303,7 @@ def log_validation(vae, unet, args, accelerator, weight_dtype, step, name="targe
|
||||
|
||||
tracker.log({f"validation/{name}": formatted_images})
|
||||
else:
|
||||
logger.warning(f"image logging not implemented for {tracker.name}")
|
||||
logger.warn(f"image logging not implemented for {tracker.name}")
|
||||
|
||||
del pipeline
|
||||
gc.collect()
|
||||
@@ -1083,7 +1083,7 @@ def main(args):
|
||||
|
||||
xformers_version = version.parse(xformers.__version__)
|
||||
if xformers_version == version.parse("0.0.16"):
|
||||
logger.warning(
|
||||
logger.warn(
|
||||
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
||||
)
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
|
||||
@@ -178,7 +178,7 @@ def log_validation(
|
||||
|
||||
tracker.log({tracker_key: formatted_images})
|
||||
else:
|
||||
logger.warning(f"image logging not implemented for {tracker.name}")
|
||||
logger.warn(f"image logging not implemented for {tracker.name}")
|
||||
|
||||
del pipeline
|
||||
gc.collect()
|
||||
@@ -861,7 +861,7 @@ def main(args):
|
||||
|
||||
xformers_version = version.parse(xformers.__version__)
|
||||
if xformers_version == version.parse("0.0.16"):
|
||||
logger.warning(
|
||||
logger.warn(
|
||||
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
||||
)
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
|
||||
@@ -128,7 +128,7 @@ def log_validation(pipeline, pipeline_params, controlnet_params, tokenizer, args
|
||||
|
||||
wandb.log({"validation": formatted_images})
|
||||
else:
|
||||
logger.warning(f"image logging not implemented for {args.report_to}")
|
||||
logger.warn(f"image logging not implemented for {args.report_to}")
|
||||
|
||||
return image_logs
|
||||
|
||||
|
||||
@@ -178,7 +178,7 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step,
|
||||
|
||||
tracker.log({tracker_key: formatted_images})
|
||||
else:
|
||||
logger.warning(f"image logging not implemented for {tracker.name}")
|
||||
logger.warn(f"image logging not implemented for {tracker.name}")
|
||||
|
||||
del pipeline
|
||||
gc.collect()
|
||||
@@ -929,7 +929,7 @@ def main(args):
|
||||
|
||||
xformers_version = version.parse(xformers.__version__)
|
||||
if xformers_version == version.parse("0.0.16"):
|
||||
logger.warning(
|
||||
logger.warn(
|
||||
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
||||
)
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
|
||||
@@ -904,7 +904,7 @@ def main(args):
|
||||
|
||||
xformers_version = version.parse(xformers.__version__)
|
||||
if xformers_version == version.parse("0.0.16"):
|
||||
logger.warning(
|
||||
logger.warn(
|
||||
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
||||
)
|
||||
attention_class = CustomDiffusionXFormersAttnProcessor
|
||||
|
||||
@@ -987,7 +987,7 @@ def main(args):
|
||||
|
||||
xformers_version = version.parse(xformers.__version__)
|
||||
if xformers_version == version.parse("0.0.16"):
|
||||
logger.warning(
|
||||
logger.warn(
|
||||
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
||||
)
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
|
||||
@@ -895,7 +895,7 @@ def main(args):
|
||||
|
||||
xformers_version = version.parse(xformers.__version__)
|
||||
if xformers_version == version.parse("0.0.16"):
|
||||
logger.warning(
|
||||
logger.warn(
|
||||
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
||||
)
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
|
||||
@@ -1141,7 +1141,7 @@ def main(args):
|
||||
|
||||
xformers_version = version.parse(xformers.__version__)
|
||||
if xformers_version == version.parse("0.0.16"):
|
||||
logger.warning(
|
||||
logger.warn(
|
||||
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, "
|
||||
"please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
||||
)
|
||||
@@ -1317,14 +1317,14 @@ def main(args):
|
||||
|
||||
# Optimizer creation
|
||||
if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"):
|
||||
logger.warning(
|
||||
logger.warn(
|
||||
f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]."
|
||||
"Defaulting to adamW"
|
||||
)
|
||||
args.optimizer = "adamw"
|
||||
|
||||
if args.use_8bit_adam and not args.optimizer.lower() == "adamw":
|
||||
logger.warning(
|
||||
logger.warn(
|
||||
f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was "
|
||||
f"set to {args.optimizer.lower()}"
|
||||
)
|
||||
@@ -1358,11 +1358,11 @@ def main(args):
|
||||
optimizer_class = prodigyopt.Prodigy
|
||||
|
||||
if args.learning_rate <= 0.1:
|
||||
logger.warning(
|
||||
logger.warn(
|
||||
"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
|
||||
)
|
||||
if args.train_text_encoder and args.text_encoder_lr:
|
||||
logger.warning(
|
||||
logger.warn(
|
||||
f"Learning rates were provided both for the unet and the text encoder- e.g. text_encoder_lr:"
|
||||
f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. "
|
||||
f"When using prodigy only learning_rate is used as the initial learning rate."
|
||||
|
||||
@@ -488,7 +488,7 @@ def main():
|
||||
|
||||
xformers_version = version.parse(xformers.__version__)
|
||||
if xformers_version == version.parse("0.0.16"):
|
||||
logger.warning(
|
||||
logger.warn(
|
||||
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
||||
)
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
|
||||
@@ -580,7 +580,7 @@ def main():
|
||||
|
||||
xformers_version = version.parse(xformers.__version__)
|
||||
if xformers_version == version.parse("0.0.16"):
|
||||
logger.warning(
|
||||
logger.warn(
|
||||
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
||||
)
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
|
||||
@@ -177,7 +177,7 @@ def log_validation(vae, image_encoder, image_processor, unet, args, accelerator,
|
||||
}
|
||||
)
|
||||
else:
|
||||
logger.warning(f"image logging not implemented for {tracker.name}")
|
||||
logger.warn(f"image logging not implemented for {tracker.name}")
|
||||
|
||||
del pipeline
|
||||
torch.cuda.empty_cache()
|
||||
@@ -534,7 +534,7 @@ def main():
|
||||
|
||||
xformers_version = version.parse(xformers.__version__)
|
||||
if xformers_version == version.parse("0.0.16"):
|
||||
logger.warning(
|
||||
logger.warn(
|
||||
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
||||
)
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
|
||||
@@ -180,7 +180,7 @@ def log_validation(
|
||||
}
|
||||
)
|
||||
else:
|
||||
logger.warning(f"image logging not implemented for {tracker.name}")
|
||||
logger.warn(f"image logging not implemented for {tracker.name}")
|
||||
|
||||
del pipeline
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@@ -219,7 +219,7 @@ def log_validation(unet, scheduler, args, accelerator, weight_dtype, step, name=
|
||||
if args.num_classes is not None:
|
||||
class_labels = list(range(args.num_classes))
|
||||
else:
|
||||
logger.warning(
|
||||
logger.warn(
|
||||
"The model is class-conditional but the number of classes is not set. The generated images will be"
|
||||
" unconditional rather than class-conditional."
|
||||
)
|
||||
@@ -266,7 +266,7 @@ def log_validation(unet, scheduler, args, accelerator, weight_dtype, step, name=
|
||||
|
||||
tracker.log({f"validation/{name}": formatted_images})
|
||||
else:
|
||||
logger.warning(f"image logging not implemented for {tracker.name}")
|
||||
logger.warn(f"image logging not implemented for {tracker.name}")
|
||||
|
||||
del pipeline
|
||||
gc.collect()
|
||||
@@ -863,14 +863,14 @@ def main(args):
|
||||
elif args.model_config_name_or_path is None:
|
||||
# TODO: use default architectures from iCT paper
|
||||
if not args.class_conditional and (args.num_classes is not None or args.class_embed_type is not None):
|
||||
logger.warning(
|
||||
logger.warn(
|
||||
f"`--class_conditional` is set to `False` but `--num_classes` is set to {args.num_classes} and"
|
||||
f" `--class_embed_type` is set to {args.class_embed_type}. These values will be overridden to `None`."
|
||||
)
|
||||
args.num_classes = None
|
||||
args.class_embed_type = None
|
||||
elif args.class_conditional and args.num_classes is None and args.class_embed_type is None:
|
||||
logger.warning(
|
||||
logger.warn(
|
||||
"`--class_conditional` is set to `True` but neither `--num_classes` nor `--class_embed_type` is set."
|
||||
"`class_conditional` will be overridden to `False`."
|
||||
)
|
||||
@@ -996,7 +996,7 @@ def main(args):
|
||||
|
||||
xformers_version = version.parse(xformers.__version__)
|
||||
if xformers_version == version.parse("0.0.16"):
|
||||
logger.warning(
|
||||
logger.warn(
|
||||
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
||||
)
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
|
||||
@@ -407,7 +407,7 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step)
|
||||
|
||||
tracker.log({"validation": formatted_images})
|
||||
else:
|
||||
logger.warning(f"image logging not implemented for {tracker.name}")
|
||||
logger.warn(f"image logging not implemented for {tracker.name}")
|
||||
|
||||
del pipeline
|
||||
gc.collect()
|
||||
@@ -1057,7 +1057,7 @@ def main(args):
|
||||
|
||||
xformers_version = version.parse(xformers.__version__)
|
||||
if xformers_version == version.parse("0.0.16"):
|
||||
logger.warning(
|
||||
logger.warn(
|
||||
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
||||
)
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
|
||||
@@ -574,7 +574,7 @@ def main(args):
|
||||
|
||||
xformers_version = version.parse(xformers.__version__)
|
||||
if xformers_version == version.parse("0.0.16"):
|
||||
logger.warning(
|
||||
logger.warn(
|
||||
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
||||
)
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
|
||||
@@ -672,7 +672,7 @@ def main(args):
|
||||
|
||||
xformers_version = version.parse(xformers.__version__)
|
||||
if xformers_version == version.parse("0.0.16"):
|
||||
logger.warning(
|
||||
logger.warn(
|
||||
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
||||
)
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
|
||||
@@ -516,7 +516,7 @@ def main():
|
||||
|
||||
xformers_version = version.parse(xformers.__version__)
|
||||
if xformers_version == version.parse("0.0.16"):
|
||||
logger.warning(
|
||||
logger.warn(
|
||||
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
||||
)
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
|
||||
@@ -608,7 +608,7 @@ def main():
|
||||
# Create the pipeline using using the trained modules and save it.
|
||||
if accelerator.is_main_process:
|
||||
if args.push_to_hub and args.only_save_embeds:
|
||||
logger.warning("Enabling full model saving because --push_to_hub=True was specified.")
|
||||
logger.warn("Enabling full model saving because --push_to_hub=True was specified.")
|
||||
save_full_model = True
|
||||
else:
|
||||
save_full_model = not args.only_save_embeds
|
||||
|
||||
@@ -541,7 +541,7 @@ def main():
|
||||
|
||||
xformers_version = version.parse(xformers.__version__)
|
||||
if xformers_version == version.parse("0.0.16"):
|
||||
logger.warning(
|
||||
logger.warn(
|
||||
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
||||
)
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
|
||||
@@ -645,7 +645,7 @@ def main():
|
||||
|
||||
xformers_version = version.parse(xformers.__version__)
|
||||
if xformers_version == version.parse("0.0.16"):
|
||||
logger.warning(
|
||||
logger.warn(
|
||||
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
||||
)
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
@@ -901,7 +901,7 @@ def main():
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
if args.push_to_hub and args.only_save_embeds:
|
||||
logger.warning("Enabling full model saving because --push_to_hub=True was specified.")
|
||||
logger.warn("Enabling full model saving because --push_to_hub=True was specified.")
|
||||
save_full_model = True
|
||||
else:
|
||||
save_full_model = not args.only_save_embeds
|
||||
|
||||
@@ -108,7 +108,7 @@ def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight
|
||||
}
|
||||
)
|
||||
else:
|
||||
logger.warning(f"image logging not implemented for {tracker.name}")
|
||||
logger.warn(f"image logging not implemented for {tracker.name}")
|
||||
|
||||
del pipeline
|
||||
torch.cuda.empty_cache()
|
||||
@@ -523,7 +523,7 @@ def main():
|
||||
|
||||
xformers_version = version.parse(xformers.__version__)
|
||||
if xformers_version == version.parse("0.0.16"):
|
||||
logger.warning(
|
||||
logger.warn(
|
||||
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
||||
)
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
|
||||
@@ -687,7 +687,7 @@ def main():
|
||||
|
||||
xformers_version = version.parse(xformers.__version__)
|
||||
if xformers_version == version.parse("0.0.16"):
|
||||
logger.warning(
|
||||
logger.warn(
|
||||
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
||||
)
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
@@ -916,7 +916,7 @@ def main():
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
if args.push_to_hub and not args.save_as_full_pipeline:
|
||||
logger.warning("Enabling full model saving because --push_to_hub=True was specified.")
|
||||
logger.warn("Enabling full model saving because --push_to_hub=True was specified.")
|
||||
save_full_model = True
|
||||
else:
|
||||
save_full_model = args.save_as_full_pipeline
|
||||
|
||||
+1
-1
@@ -410,7 +410,7 @@ def main(args):
|
||||
|
||||
xformers_version = version.parse(xformers.__version__)
|
||||
if xformers_version == version.parse("0.0.16"):
|
||||
logger.warning(
|
||||
logger.warn(
|
||||
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
||||
)
|
||||
model.enable_xformers_memory_efficient_attention()
|
||||
|
||||
@@ -629,7 +629,7 @@ def main(args):
|
||||
|
||||
xformers_version = version.parse(xformers.__version__)
|
||||
if xformers_version == version.parse("0.0.16"):
|
||||
logger.warning(
|
||||
logger.warn(
|
||||
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
||||
)
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
|
||||
@@ -167,7 +167,7 @@ def log_validation(vae, unet, adapter, args, accelerator, weight_dtype, step):
|
||||
|
||||
tracker.log({"validation": formatted_images})
|
||||
else:
|
||||
logger.warning(f"image logging not implemented for {tracker.name}")
|
||||
logger.warn(f"image logging not implemented for {tracker.name}")
|
||||
|
||||
del pipeline
|
||||
gc.collect()
|
||||
@@ -932,7 +932,7 @@ def main(args):
|
||||
|
||||
xformers_version = version.parse(xformers.__version__)
|
||||
if xformers_version == version.parse("0.0.16"):
|
||||
logger.warning(
|
||||
logger.warn(
|
||||
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
||||
)
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
|
||||
@@ -183,7 +183,7 @@ def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight
|
||||
}
|
||||
)
|
||||
else:
|
||||
logger.warning(f"image logging not implemented for {tracker.name}")
|
||||
logger.warn(f"image logging not implemented for {tracker.name}")
|
||||
|
||||
del pipeline
|
||||
torch.cuda.empty_cache()
|
||||
@@ -608,7 +608,7 @@ def main():
|
||||
|
||||
xformers_version = version.parse(xformers.__version__)
|
||||
if xformers_version == version.parse("0.0.16"):
|
||||
logger.warning(
|
||||
logger.warn(
|
||||
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
||||
)
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
|
||||
@@ -497,7 +497,7 @@ def main():
|
||||
|
||||
xformers_version = version.parse(xformers.__version__)
|
||||
if xformers_version == version.parse("0.0.16"):
|
||||
logger.warning(
|
||||
logger.warn(
|
||||
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
||||
)
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
|
||||
@@ -616,7 +616,7 @@ def main(args):
|
||||
|
||||
xformers_version = version.parse(xformers.__version__)
|
||||
if xformers_version == version.parse("0.0.16"):
|
||||
logger.warning(
|
||||
logger.warn(
|
||||
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
||||
)
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
|
||||
@@ -712,7 +712,7 @@ def main(args):
|
||||
|
||||
xformers_version = version.parse(xformers.__version__)
|
||||
if xformers_version == version.parse("0.0.16"):
|
||||
logger.warning(
|
||||
logger.warn(
|
||||
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
||||
)
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
|
||||
@@ -708,7 +708,7 @@ def main():
|
||||
|
||||
xformers_version = version.parse(xformers.__version__)
|
||||
if xformers_version == version.parse("0.0.16"):
|
||||
logger.warning(
|
||||
logger.warn(
|
||||
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
||||
)
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
@@ -966,7 +966,7 @@ def main():
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
if args.push_to_hub and not args.save_as_full_pipeline:
|
||||
logger.warning("Enabling full model saving because --push_to_hub=True was specified.")
|
||||
logger.warn("Enabling full model saving because --push_to_hub=True was specified.")
|
||||
save_full_model = True
|
||||
else:
|
||||
save_full_model = args.save_as_full_pipeline
|
||||
|
||||
@@ -711,7 +711,7 @@ def main():
|
||||
|
||||
xformers_version = version.parse(xformers.__version__)
|
||||
if xformers_version == version.parse("0.0.16"):
|
||||
logger.warning(
|
||||
logger.warn(
|
||||
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
||||
)
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
@@ -1022,7 +1022,7 @@ def main():
|
||||
)
|
||||
|
||||
if args.push_to_hub and not args.save_as_full_pipeline:
|
||||
logger.warning("Enabling full model saving because --push_to_hub=True was specified.")
|
||||
logger.warn("Enabling full model saving because --push_to_hub=True was specified.")
|
||||
save_full_model = True
|
||||
else:
|
||||
save_full_model = args.save_as_full_pipeline
|
||||
|
||||
@@ -408,7 +408,7 @@ def main(args):
|
||||
|
||||
xformers_version = version.parse(xformers.__version__)
|
||||
if xformers_version == version.parse("0.0.16"):
|
||||
logger.warning(
|
||||
logger.warn(
|
||||
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
||||
)
|
||||
model.enable_xformers_memory_efficient_attention()
|
||||
|
||||
@@ -184,7 +184,7 @@ def log_validation(text_encoder, tokenizer, prior, args, accelerator, weight_dty
|
||||
}
|
||||
)
|
||||
else:
|
||||
logger.warning(f"image logging not implemented for {tracker.name}")
|
||||
logger.warn(f"image logging not implemented for {tracker.name}")
|
||||
|
||||
del pipeline
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@@ -182,7 +182,7 @@ def log_validation(text_encoder, tokenizer, prior, args, accelerator, weight_dty
|
||||
}
|
||||
)
|
||||
else:
|
||||
logger.warning(f"image logging not implemented for {tracker.name}")
|
||||
logger.warn(f"image logging not implemented for {tracker.name}")
|
||||
|
||||
del pipeline
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
+156
-159
@@ -1,7 +1,7 @@
|
||||
# Run this script to convert the Stable Cascade model weights to a diffusers pipeline.
|
||||
import argparse
|
||||
from contextlib import nullcontext
|
||||
|
||||
import accelerate
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
from transformers import (
|
||||
@@ -18,56 +18,23 @@ from diffusers import (
|
||||
StableCascadeDecoderPipeline,
|
||||
StableCascadePriorPipeline,
|
||||
)
|
||||
from diffusers.loaders.single_file_utils import convert_stable_cascade_unet_single_file_to_diffusers
|
||||
from diffusers.models import StableCascadeUNet
|
||||
from diffusers.models.modeling_utils import load_model_dict_into_meta
|
||||
from diffusers.pipelines.wuerstchen import PaellaVQModel
|
||||
from diffusers.utils import is_accelerate_available
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
parser = argparse.ArgumentParser(description="Convert Stable Cascade model weights to a diffusers pipeline")
|
||||
parser.add_argument("--model_path", type=str, help="Location of Stable Cascade weights")
|
||||
parser.add_argument("--model_path", type=str, default="../StableCascade", help="Location of Stable Cascade weights")
|
||||
parser.add_argument("--stage_c_name", type=str, default="stage_c.safetensors", help="Name of stage c checkpoint file")
|
||||
parser.add_argument("--stage_b_name", type=str, default="stage_b.safetensors", help="Name of stage b checkpoint file")
|
||||
parser.add_argument("--skip_stage_c", action="store_true", help="Skip converting stage c")
|
||||
parser.add_argument("--skip_stage_b", action="store_true", help="Skip converting stage b")
|
||||
parser.add_argument("--use_safetensors", action="store_true", help="Use SafeTensors for conversion")
|
||||
parser.add_argument(
|
||||
"--prior_output_path", default="stable-cascade-prior", type=str, help="Hub organization to save the pipelines to"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--decoder_output_path",
|
||||
type=str,
|
||||
default="stable-cascade-decoder",
|
||||
help="Hub organization to save the pipelines to",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--combined_output_path",
|
||||
type=str,
|
||||
default="stable-cascade-combined",
|
||||
help="Hub organization to save the pipelines to",
|
||||
)
|
||||
parser.add_argument("--save_combined", action="store_true")
|
||||
parser.add_argument("--save_org", type=str, default="diffusers", help="Hub organization to save the pipelines to")
|
||||
parser.add_argument("--push_to_hub", action="store_true", help="Push to hub")
|
||||
parser.add_argument("--variant", type=str, help="Set to bf16 to save bfloat16 weights")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.skip_stage_b and args.skip_stage_c:
|
||||
raise ValueError("At least one stage should be converted")
|
||||
if (args.skip_stage_b or args.skip_stage_c) and args.save_combined:
|
||||
raise ValueError("Cannot skip stages when creating a combined pipeline")
|
||||
|
||||
model_path = args.model_path
|
||||
|
||||
device = "cpu"
|
||||
if args.variant == "bf16":
|
||||
dtype = torch.bfloat16
|
||||
else:
|
||||
dtype = torch.float32
|
||||
|
||||
# set paths to model weights
|
||||
prior_checkpoint_path = f"{model_path}/{args.stage_c_name}"
|
||||
@@ -85,134 +52,164 @@ tokenizer = AutoTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b1
|
||||
feature_extractor = CLIPImageProcessor()
|
||||
image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14")
|
||||
|
||||
# Prior
|
||||
if args.use_safetensors:
|
||||
orig_state_dict = load_file(prior_checkpoint_path, device=device)
|
||||
else:
|
||||
orig_state_dict = torch.load(prior_checkpoint_path, map_location=device)
|
||||
|
||||
state_dict = {}
|
||||
for key in orig_state_dict.keys():
|
||||
if key.endswith("in_proj_weight"):
|
||||
weights = orig_state_dict[key].chunk(3, 0)
|
||||
state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0]
|
||||
state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1]
|
||||
state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2]
|
||||
elif key.endswith("in_proj_bias"):
|
||||
weights = orig_state_dict[key].chunk(3, 0)
|
||||
state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0]
|
||||
state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1]
|
||||
state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2]
|
||||
elif key.endswith("out_proj.weight"):
|
||||
weights = orig_state_dict[key]
|
||||
state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights
|
||||
elif key.endswith("out_proj.bias"):
|
||||
weights = orig_state_dict[key]
|
||||
state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights
|
||||
else:
|
||||
state_dict[key] = orig_state_dict[key]
|
||||
|
||||
|
||||
with accelerate.init_empty_weights():
|
||||
prior_model = StableCascadeUNet(
|
||||
in_channels=16,
|
||||
out_channels=16,
|
||||
timestep_ratio_embedding_dim=64,
|
||||
patch_size=1,
|
||||
conditioning_dim=2048,
|
||||
block_out_channels=[2048, 2048],
|
||||
num_attention_heads=[32, 32],
|
||||
down_num_layers_per_block=[8, 24],
|
||||
up_num_layers_per_block=[24, 8],
|
||||
down_blocks_repeat_mappers=[1, 1],
|
||||
up_blocks_repeat_mappers=[1, 1],
|
||||
block_types_per_layer=[
|
||||
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
|
||||
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
|
||||
],
|
||||
clip_text_in_channels=1280,
|
||||
clip_text_pooled_in_channels=1280,
|
||||
clip_image_in_channels=768,
|
||||
clip_seq=4,
|
||||
kernel_size=3,
|
||||
dropout=[0.1, 0.1],
|
||||
self_attn=True,
|
||||
timestep_conditioning_type=["sca", "crp"],
|
||||
switch_level=[False],
|
||||
)
|
||||
load_model_dict_into_meta(prior_model, state_dict)
|
||||
|
||||
# scheduler for prior and decoder
|
||||
scheduler = DDPMWuerstchenScheduler()
|
||||
ctx = init_empty_weights if is_accelerate_available() else nullcontext
|
||||
|
||||
if not args.skip_stage_c:
|
||||
# Prior
|
||||
if args.use_safetensors:
|
||||
prior_orig_state_dict = load_file(prior_checkpoint_path, device=device)
|
||||
# Prior pipeline
|
||||
prior_pipeline = StableCascadePriorPipeline(
|
||||
prior=prior_model,
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
image_encoder=image_encoder,
|
||||
scheduler=scheduler,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
prior_pipeline.save_pretrained(f"{args.save_org}/StableCascade-prior", push_to_hub=args.push_to_hub)
|
||||
|
||||
# Decoder
|
||||
if args.use_safetensors:
|
||||
orig_state_dict = load_file(decoder_checkpoint_path, device=device)
|
||||
else:
|
||||
orig_state_dict = torch.load(decoder_checkpoint_path, map_location=device)
|
||||
|
||||
state_dict = {}
|
||||
for key in orig_state_dict.keys():
|
||||
if key.endswith("in_proj_weight"):
|
||||
weights = orig_state_dict[key].chunk(3, 0)
|
||||
state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0]
|
||||
state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1]
|
||||
state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2]
|
||||
elif key.endswith("in_proj_bias"):
|
||||
weights = orig_state_dict[key].chunk(3, 0)
|
||||
state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0]
|
||||
state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1]
|
||||
state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2]
|
||||
elif key.endswith("out_proj.weight"):
|
||||
weights = orig_state_dict[key]
|
||||
state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights
|
||||
elif key.endswith("out_proj.bias"):
|
||||
weights = orig_state_dict[key]
|
||||
state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights
|
||||
# rename clip_mapper to clip_txt_pooled_mapper
|
||||
elif key.endswith("clip_mapper.weight"):
|
||||
weights = orig_state_dict[key]
|
||||
state_dict[key.replace("clip_mapper.weight", "clip_txt_pooled_mapper.weight")] = weights
|
||||
elif key.endswith("clip_mapper.bias"):
|
||||
weights = orig_state_dict[key]
|
||||
state_dict[key.replace("clip_mapper.bias", "clip_txt_pooled_mapper.bias")] = weights
|
||||
else:
|
||||
prior_orig_state_dict = torch.load(prior_checkpoint_path, map_location=device)
|
||||
state_dict[key] = orig_state_dict[key]
|
||||
|
||||
prior_state_dict = convert_stable_cascade_unet_single_file_to_diffusers(prior_orig_state_dict)
|
||||
|
||||
with ctx():
|
||||
prior_model = StableCascadeUNet(
|
||||
in_channels=16,
|
||||
out_channels=16,
|
||||
timestep_ratio_embedding_dim=64,
|
||||
patch_size=1,
|
||||
conditioning_dim=2048,
|
||||
block_out_channels=[2048, 2048],
|
||||
num_attention_heads=[32, 32],
|
||||
down_num_layers_per_block=[8, 24],
|
||||
up_num_layers_per_block=[24, 8],
|
||||
down_blocks_repeat_mappers=[1, 1],
|
||||
up_blocks_repeat_mappers=[1, 1],
|
||||
block_types_per_layer=[
|
||||
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
|
||||
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
|
||||
],
|
||||
clip_text_in_channels=1280,
|
||||
clip_text_pooled_in_channels=1280,
|
||||
clip_image_in_channels=768,
|
||||
clip_seq=4,
|
||||
kernel_size=3,
|
||||
dropout=[0.1, 0.1],
|
||||
self_attn=True,
|
||||
timestep_conditioning_type=["sca", "crp"],
|
||||
switch_level=[False],
|
||||
)
|
||||
if is_accelerate_available():
|
||||
load_model_dict_into_meta(prior_model, prior_state_dict)
|
||||
else:
|
||||
prior_model.load_state_dict(prior_state_dict)
|
||||
|
||||
# Prior pipeline
|
||||
prior_pipeline = StableCascadePriorPipeline(
|
||||
prior=prior_model,
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
image_encoder=image_encoder,
|
||||
scheduler=scheduler,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
prior_pipeline.to(dtype).save_pretrained(
|
||||
args.prior_output_path, push_to_hub=args.push_to_hub, variant=args.variant
|
||||
with accelerate.init_empty_weights():
|
||||
decoder = StableCascadeUNet(
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
timestep_ratio_embedding_dim=64,
|
||||
patch_size=2,
|
||||
conditioning_dim=1280,
|
||||
block_out_channels=[320, 640, 1280, 1280],
|
||||
down_num_layers_per_block=[2, 6, 28, 6],
|
||||
up_num_layers_per_block=[6, 28, 6, 2],
|
||||
down_blocks_repeat_mappers=[1, 1, 1, 1],
|
||||
up_blocks_repeat_mappers=[3, 3, 2, 2],
|
||||
num_attention_heads=[0, 0, 20, 20],
|
||||
block_types_per_layer=[
|
||||
["SDCascadeResBlock", "SDCascadeTimestepBlock"],
|
||||
["SDCascadeResBlock", "SDCascadeTimestepBlock"],
|
||||
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
|
||||
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
|
||||
],
|
||||
clip_text_pooled_in_channels=1280,
|
||||
clip_seq=4,
|
||||
effnet_in_channels=16,
|
||||
pixel_mapper_in_channels=3,
|
||||
kernel_size=3,
|
||||
dropout=[0, 0, 0.1, 0.1],
|
||||
self_attn=True,
|
||||
timestep_conditioning_type=["sca"],
|
||||
)
|
||||
load_model_dict_into_meta(decoder, state_dict)
|
||||
|
||||
if not args.skip_stage_b:
|
||||
# VQGAN from Wuerstchen-V2
|
||||
vqmodel = PaellaVQModel.from_pretrained("warp-ai/wuerstchen", subfolder="vqgan")
|
||||
|
||||
# Decoder pipeline
|
||||
decoder_pipeline = StableCascadeDecoderPipeline(
|
||||
decoder=decoder, text_encoder=text_encoder, tokenizer=tokenizer, vqgan=vqmodel, scheduler=scheduler
|
||||
)
|
||||
decoder_pipeline.save_pretrained(f"{args.save_org}/StableCascade-decoder", push_to_hub=args.push_to_hub)
|
||||
|
||||
# Stable Cascade combined pipeline
|
||||
stable_cascade_pipeline = StableCascadeCombinedPipeline(
|
||||
# Decoder
|
||||
if args.use_safetensors:
|
||||
decoder_orig_state_dict = load_file(decoder_checkpoint_path, device=device)
|
||||
else:
|
||||
decoder_orig_state_dict = torch.load(decoder_checkpoint_path, map_location=device)
|
||||
|
||||
decoder_state_dict = convert_stable_cascade_unet_single_file_to_diffusers(decoder_orig_state_dict)
|
||||
with ctx():
|
||||
decoder = StableCascadeUNet(
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
timestep_ratio_embedding_dim=64,
|
||||
patch_size=2,
|
||||
conditioning_dim=1280,
|
||||
block_out_channels=[320, 640, 1280, 1280],
|
||||
down_num_layers_per_block=[2, 6, 28, 6],
|
||||
up_num_layers_per_block=[6, 28, 6, 2],
|
||||
down_blocks_repeat_mappers=[1, 1, 1, 1],
|
||||
up_blocks_repeat_mappers=[3, 3, 2, 2],
|
||||
num_attention_heads=[0, 0, 20, 20],
|
||||
block_types_per_layer=[
|
||||
["SDCascadeResBlock", "SDCascadeTimestepBlock"],
|
||||
["SDCascadeResBlock", "SDCascadeTimestepBlock"],
|
||||
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
|
||||
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
|
||||
],
|
||||
clip_text_pooled_in_channels=1280,
|
||||
clip_seq=4,
|
||||
effnet_in_channels=16,
|
||||
pixel_mapper_in_channels=3,
|
||||
kernel_size=3,
|
||||
dropout=[0, 0, 0.1, 0.1],
|
||||
self_attn=True,
|
||||
timestep_conditioning_type=["sca"],
|
||||
)
|
||||
|
||||
if is_accelerate_available():
|
||||
load_model_dict_into_meta(decoder, decoder_state_dict)
|
||||
else:
|
||||
decoder.load_state_dict(decoder_state_dict)
|
||||
|
||||
# VQGAN from Wuerstchen-V2
|
||||
vqmodel = PaellaVQModel.from_pretrained("warp-ai/wuerstchen", subfolder="vqgan")
|
||||
|
||||
# Decoder pipeline
|
||||
decoder_pipeline = StableCascadeDecoderPipeline(
|
||||
decoder=decoder, text_encoder=text_encoder, tokenizer=tokenizer, vqgan=vqmodel, scheduler=scheduler
|
||||
)
|
||||
decoder_pipeline.to(dtype).save_pretrained(
|
||||
args.decoder_output_path, push_to_hub=args.push_to_hub, variant=args.variant
|
||||
)
|
||||
|
||||
if args.save_combined:
|
||||
# Stable Cascade combined pipeline
|
||||
stable_cascade_pipeline = StableCascadeCombinedPipeline(
|
||||
# Decoder
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
decoder=decoder,
|
||||
scheduler=scheduler,
|
||||
vqgan=vqmodel,
|
||||
# Prior
|
||||
prior_text_encoder=text_encoder,
|
||||
prior_tokenizer=tokenizer,
|
||||
prior_prior=prior_model,
|
||||
prior_scheduler=scheduler,
|
||||
prior_image_encoder=image_encoder,
|
||||
prior_feature_extractor=feature_extractor,
|
||||
)
|
||||
stable_cascade_pipeline.to(dtype).save_pretrained(
|
||||
args.combined_output_path, push_to_hub=args.push_to_hub, variant=args.variant
|
||||
)
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
decoder=decoder,
|
||||
scheduler=scheduler,
|
||||
vqgan=vqmodel,
|
||||
# Prior
|
||||
prior_text_encoder=text_encoder,
|
||||
prior_tokenizer=tokenizer,
|
||||
prior_prior=prior_model,
|
||||
prior_scheduler=scheduler,
|
||||
prior_image_encoder=image_encoder,
|
||||
prior_feature_extractor=feature_extractor,
|
||||
)
|
||||
stable_cascade_pipeline.save_pretrained(f"{args.save_org}/StableCascade", push_to_hub=args.push_to_hub)
|
||||
|
||||
@@ -1,226 +0,0 @@
|
||||
# Run this script to convert the Stable Cascade model weights to a diffusers pipeline.
|
||||
import argparse
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
CLIPConfig,
|
||||
CLIPImageProcessor,
|
||||
CLIPTextModelWithProjection,
|
||||
CLIPVisionModelWithProjection,
|
||||
)
|
||||
|
||||
from diffusers import (
|
||||
DDPMWuerstchenScheduler,
|
||||
StableCascadeCombinedPipeline,
|
||||
StableCascadeDecoderPipeline,
|
||||
StableCascadePriorPipeline,
|
||||
)
|
||||
from diffusers.loaders.single_file_utils import convert_stable_cascade_unet_single_file_to_diffusers
|
||||
from diffusers.models import StableCascadeUNet
|
||||
from diffusers.models.modeling_utils import load_model_dict_into_meta
|
||||
from diffusers.pipelines.wuerstchen import PaellaVQModel
|
||||
from diffusers.utils import is_accelerate_available
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
parser = argparse.ArgumentParser(description="Convert Stable Cascade model weights to a diffusers pipeline")
|
||||
parser.add_argument("--model_path", type=str, help="Location of Stable Cascade weights")
|
||||
parser.add_argument(
|
||||
"--stage_c_name", type=str, default="stage_c_lite.safetensors", help="Name of stage c checkpoint file"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--stage_b_name", type=str, default="stage_b_lite.safetensors", help="Name of stage b checkpoint file"
|
||||
)
|
||||
parser.add_argument("--skip_stage_c", action="store_true", help="Skip converting stage c")
|
||||
parser.add_argument("--skip_stage_b", action="store_true", help="Skip converting stage b")
|
||||
parser.add_argument("--use_safetensors", action="store_true", help="Use SafeTensors for conversion")
|
||||
parser.add_argument(
|
||||
"--prior_output_path",
|
||||
default="stable-cascade-prior-lite",
|
||||
type=str,
|
||||
help="Hub organization to save the pipelines to",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--decoder_output_path",
|
||||
type=str,
|
||||
default="stable-cascade-decoder-lite",
|
||||
help="Hub organization to save the pipelines to",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--combined_output_path",
|
||||
type=str,
|
||||
default="stable-cascade-combined-lite",
|
||||
help="Hub organization to save the pipelines to",
|
||||
)
|
||||
parser.add_argument("--save_combined", action="store_true")
|
||||
parser.add_argument("--push_to_hub", action="store_true", help="Push to hub")
|
||||
parser.add_argument("--variant", type=str, help="Set to bf16 to save bfloat16 weights")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.skip_stage_b and args.skip_stage_c:
|
||||
raise ValueError("At least one stage should be converted")
|
||||
if (args.skip_stage_b or args.skip_stage_c) and args.save_combined:
|
||||
raise ValueError("Cannot skip stages when creating a combined pipeline")
|
||||
|
||||
model_path = args.model_path
|
||||
|
||||
device = "cpu"
|
||||
if args.variant == "bf16":
|
||||
dtype = torch.bfloat16
|
||||
else:
|
||||
dtype = torch.float32
|
||||
|
||||
# set paths to model weights
|
||||
prior_checkpoint_path = f"{model_path}/{args.stage_c_name}"
|
||||
decoder_checkpoint_path = f"{model_path}/{args.stage_b_name}"
|
||||
|
||||
# Clip Text encoder and tokenizer
|
||||
config = CLIPConfig.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")
|
||||
config.text_config.projection_dim = config.projection_dim
|
||||
text_encoder = CLIPTextModelWithProjection.from_pretrained(
|
||||
"laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", config=config.text_config
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")
|
||||
|
||||
# image processor
|
||||
feature_extractor = CLIPImageProcessor()
|
||||
image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14")
|
||||
# scheduler for prior and decoder
|
||||
scheduler = DDPMWuerstchenScheduler()
|
||||
|
||||
ctx = init_empty_weights if is_accelerate_available() else nullcontext
|
||||
|
||||
if not args.skip_stage_c:
|
||||
# Prior
|
||||
if args.use_safetensors:
|
||||
prior_orig_state_dict = load_file(prior_checkpoint_path, device=device)
|
||||
else:
|
||||
prior_orig_state_dict = torch.load(prior_checkpoint_path, map_location=device)
|
||||
|
||||
prior_state_dict = convert_stable_cascade_unet_single_file_to_diffusers(prior_orig_state_dict)
|
||||
with ctx():
|
||||
prior_model = StableCascadeUNet(
|
||||
in_channels=16,
|
||||
out_channels=16,
|
||||
timestep_ratio_embedding_dim=64,
|
||||
patch_size=1,
|
||||
conditioning_dim=1536,
|
||||
block_out_channels=[1536, 1536],
|
||||
num_attention_heads=[24, 24],
|
||||
down_num_layers_per_block=[4, 12],
|
||||
up_num_layers_per_block=[12, 4],
|
||||
down_blocks_repeat_mappers=[1, 1],
|
||||
up_blocks_repeat_mappers=[1, 1],
|
||||
block_types_per_layer=[
|
||||
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
|
||||
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
|
||||
],
|
||||
clip_text_in_channels=1280,
|
||||
clip_text_pooled_in_channels=1280,
|
||||
clip_image_in_channels=768,
|
||||
clip_seq=4,
|
||||
kernel_size=3,
|
||||
dropout=[0.1, 0.1],
|
||||
self_attn=True,
|
||||
timestep_conditioning_type=["sca", "crp"],
|
||||
switch_level=[False],
|
||||
)
|
||||
|
||||
if is_accelerate_available():
|
||||
load_model_dict_into_meta(prior_model, prior_state_dict)
|
||||
else:
|
||||
prior_model.load_state_dict(prior_state_dict)
|
||||
|
||||
# Prior pipeline
|
||||
prior_pipeline = StableCascadePriorPipeline(
|
||||
prior=prior_model,
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
image_encoder=image_encoder,
|
||||
scheduler=scheduler,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
prior_pipeline.to(dtype).save_pretrained(
|
||||
args.prior_output_path, push_to_hub=args.push_to_hub, variant=args.variant
|
||||
)
|
||||
|
||||
if not args.skip_stage_b:
|
||||
# Decoder
|
||||
if args.use_safetensors:
|
||||
decoder_orig_state_dict = load_file(decoder_checkpoint_path, device=device)
|
||||
else:
|
||||
decoder_orig_state_dict = torch.load(decoder_checkpoint_path, map_location=device)
|
||||
|
||||
decoder_state_dict = convert_stable_cascade_unet_single_file_to_diffusers(decoder_orig_state_dict)
|
||||
|
||||
with ctx():
|
||||
decoder = StableCascadeUNet(
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
timestep_ratio_embedding_dim=64,
|
||||
patch_size=2,
|
||||
conditioning_dim=1280,
|
||||
block_out_channels=[320, 576, 1152, 1152],
|
||||
down_num_layers_per_block=[2, 4, 14, 4],
|
||||
up_num_layers_per_block=[4, 14, 4, 2],
|
||||
down_blocks_repeat_mappers=[1, 1, 1, 1],
|
||||
up_blocks_repeat_mappers=[2, 2, 2, 2],
|
||||
num_attention_heads=[0, 9, 18, 18],
|
||||
block_types_per_layer=[
|
||||
["SDCascadeResBlock", "SDCascadeTimestepBlock"],
|
||||
["SDCascadeResBlock", "SDCascadeTimestepBlock"],
|
||||
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
|
||||
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
|
||||
],
|
||||
clip_text_pooled_in_channels=1280,
|
||||
clip_seq=4,
|
||||
effnet_in_channels=16,
|
||||
pixel_mapper_in_channels=3,
|
||||
kernel_size=3,
|
||||
dropout=[0, 0, 0.1, 0.1],
|
||||
self_attn=True,
|
||||
timestep_conditioning_type=["sca"],
|
||||
)
|
||||
|
||||
if is_accelerate_available():
|
||||
load_model_dict_into_meta(decoder, decoder_state_dict)
|
||||
else:
|
||||
decoder.load_state_dict(decoder_state_dict)
|
||||
|
||||
# VQGAN from Wuerstchen-V2
|
||||
vqmodel = PaellaVQModel.from_pretrained("warp-ai/wuerstchen", subfolder="vqgan")
|
||||
|
||||
# Decoder pipeline
|
||||
decoder_pipeline = StableCascadeDecoderPipeline(
|
||||
decoder=decoder, text_encoder=text_encoder, tokenizer=tokenizer, vqgan=vqmodel, scheduler=scheduler
|
||||
)
|
||||
decoder_pipeline.to(dtype).save_pretrained(
|
||||
args.decoder_output_path, push_to_hub=args.push_to_hub, variant=args.variant
|
||||
)
|
||||
|
||||
if args.save_combined:
|
||||
# Stable Cascade combined pipeline
|
||||
stable_cascade_pipeline = StableCascadeCombinedPipeline(
|
||||
# Decoder
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
decoder=decoder,
|
||||
scheduler=scheduler,
|
||||
vqgan=vqmodel,
|
||||
# Prior
|
||||
prior_text_encoder=text_encoder,
|
||||
prior_tokenizer=tokenizer,
|
||||
prior_prior=prior_model,
|
||||
prior_scheduler=scheduler,
|
||||
prior_image_encoder=image_encoder,
|
||||
prior_feature_extractor=feature_extractor,
|
||||
)
|
||||
stable_cascade_pipeline.to(dtype).save_pretrained(
|
||||
args.combined_output_path, push_to_hub=args.push_to_hub, variant=args.variant
|
||||
)
|
||||
@@ -430,7 +430,7 @@ class LoraLoaderMixin:
|
||||
# contain the module names of the `unet` as its keys WITHOUT any prefix.
|
||||
if not USE_PEFT_BACKEND:
|
||||
warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet.{module_name}': params for module_name, params in old_state_dict.items()}`."
|
||||
logger.warning(warn_message)
|
||||
logger.warn(warn_message)
|
||||
|
||||
if len(state_dict.keys()) > 0:
|
||||
if adapter_name in getattr(unet, "peft_config", {}):
|
||||
@@ -882,7 +882,7 @@ class LoraLoaderMixin:
|
||||
if fuse_unet or fuse_text_encoder:
|
||||
self.num_fused_loras += 1
|
||||
if self.num_fused_loras > 1:
|
||||
logger.warning(
|
||||
logger.warn(
|
||||
"The current API is supported for operating with a single LoRA file. You are trying to load and fuse more than one LoRA which is not well-supported.",
|
||||
)
|
||||
|
||||
|
||||
@@ -56,8 +56,6 @@ def build_sub_model_components(
|
||||
|
||||
if component_name == "unet":
|
||||
num_in_channels = kwargs.pop("num_in_channels", None)
|
||||
upcast_attention = kwargs.pop("upcast_attention", None)
|
||||
|
||||
unet_components = create_diffusers_unet_model_from_ldm(
|
||||
pipeline_class_name,
|
||||
original_config,
|
||||
@@ -66,7 +64,6 @@ def build_sub_model_components(
|
||||
image_size=image_size,
|
||||
torch_dtype=torch_dtype,
|
||||
model_type=model_type,
|
||||
upcast_attention=upcast_attention,
|
||||
)
|
||||
return unet_components
|
||||
|
||||
@@ -303,9 +300,7 @@ class FromSingleFileMixin:
|
||||
continue
|
||||
init_kwargs.update(components)
|
||||
|
||||
additional_components = set_additional_components(
|
||||
class_name, original_config, checkpoint=checkpoint, model_type=model_type
|
||||
)
|
||||
additional_components = set_additional_components(class_name, original_config, model_type=model_type)
|
||||
if additional_components:
|
||||
init_kwargs.update(additional_components)
|
||||
|
||||
|
||||
@@ -81,87 +81,6 @@ SCHEDULER_DEFAULT_CONFIG = {
|
||||
"timestep_spacing": "leading",
|
||||
}
|
||||
|
||||
|
||||
STABLE_CASCADE_DEFAULT_CONFIGS = {
|
||||
"stage_c": {"pretrained_model_name_or_path": "diffusers/stable-cascade-configs", "subfolder": "prior"},
|
||||
"stage_c_lite": {"pretrained_model_name_or_path": "diffusers/stable-cascade-configs", "subfolder": "prior_lite"},
|
||||
"stage_b": {"pretrained_model_name_or_path": "diffusers/stable-cascade-configs", "subfolder": "decoder"},
|
||||
"stage_b_lite": {"pretrained_model_name_or_path": "diffusers/stable-cascade-configs", "subfolder": "decoder_lite"},
|
||||
}
|
||||
|
||||
|
||||
def convert_stable_cascade_unet_single_file_to_diffusers(original_state_dict):
|
||||
is_stage_c = "clip_txt_mapper.weight" in original_state_dict
|
||||
|
||||
if is_stage_c:
|
||||
state_dict = {}
|
||||
for key in original_state_dict.keys():
|
||||
if key.endswith("in_proj_weight"):
|
||||
weights = original_state_dict[key].chunk(3, 0)
|
||||
state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0]
|
||||
state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1]
|
||||
state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2]
|
||||
elif key.endswith("in_proj_bias"):
|
||||
weights = original_state_dict[key].chunk(3, 0)
|
||||
state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0]
|
||||
state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1]
|
||||
state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2]
|
||||
elif key.endswith("out_proj.weight"):
|
||||
weights = original_state_dict[key]
|
||||
state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights
|
||||
elif key.endswith("out_proj.bias"):
|
||||
weights = original_state_dict[key]
|
||||
state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights
|
||||
else:
|
||||
state_dict[key] = original_state_dict[key]
|
||||
else:
|
||||
state_dict = {}
|
||||
for key in original_state_dict.keys():
|
||||
if key.endswith("in_proj_weight"):
|
||||
weights = original_state_dict[key].chunk(3, 0)
|
||||
state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0]
|
||||
state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1]
|
||||
state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2]
|
||||
elif key.endswith("in_proj_bias"):
|
||||
weights = original_state_dict[key].chunk(3, 0)
|
||||
state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0]
|
||||
state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1]
|
||||
state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2]
|
||||
elif key.endswith("out_proj.weight"):
|
||||
weights = original_state_dict[key]
|
||||
state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights
|
||||
elif key.endswith("out_proj.bias"):
|
||||
weights = original_state_dict[key]
|
||||
state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights
|
||||
# rename clip_mapper to clip_txt_pooled_mapper
|
||||
elif key.endswith("clip_mapper.weight"):
|
||||
weights = original_state_dict[key]
|
||||
state_dict[key.replace("clip_mapper.weight", "clip_txt_pooled_mapper.weight")] = weights
|
||||
elif key.endswith("clip_mapper.bias"):
|
||||
weights = original_state_dict[key]
|
||||
state_dict[key.replace("clip_mapper.bias", "clip_txt_pooled_mapper.bias")] = weights
|
||||
else:
|
||||
state_dict[key] = original_state_dict[key]
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def infer_stable_cascade_single_file_config(checkpoint):
|
||||
is_stage_c = "clip_txt_mapper.weight" in checkpoint
|
||||
is_stage_b = "down_blocks.1.0.channelwise.0.weight" in checkpoint
|
||||
|
||||
if is_stage_c and (checkpoint["clip_txt_mapper.weight"].shape[0] == 1536):
|
||||
config_type = "stage_c_lite"
|
||||
elif is_stage_c and (checkpoint["clip_txt_mapper.weight"].shape[0] == 2048):
|
||||
config_type = "stage_c"
|
||||
elif is_stage_b and checkpoint["down_blocks.1.0.channelwise.0.weight"].shape[-1] == 576:
|
||||
config_type = "stage_b_lite"
|
||||
elif is_stage_b and checkpoint["down_blocks.1.0.channelwise.0.weight"].shape[-1] == 640:
|
||||
config_type = "stage_b"
|
||||
|
||||
return STABLE_CASCADE_DEFAULT_CONFIGS[config_type]
|
||||
|
||||
|
||||
DIFFUSERS_TO_LDM_MAPPING = {
|
||||
"unet": {
|
||||
"layers": {
|
||||
@@ -310,34 +229,10 @@ def fetch_ldm_config_and_checkpoint(
|
||||
cache_dir=None,
|
||||
local_files_only=None,
|
||||
revision=None,
|
||||
):
|
||||
checkpoint = load_single_file_model_checkpoint(
|
||||
pretrained_model_link_or_path,
|
||||
resume_download=resume_download,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=local_files_only,
|
||||
revision=revision,
|
||||
)
|
||||
original_config = fetch_original_config(class_name, checkpoint, original_config_file)
|
||||
|
||||
return original_config, checkpoint
|
||||
|
||||
|
||||
def load_single_file_model_checkpoint(
|
||||
pretrained_model_link_or_path,
|
||||
resume_download=False,
|
||||
force_download=False,
|
||||
proxies=None,
|
||||
token=None,
|
||||
cache_dir=None,
|
||||
local_files_only=None,
|
||||
revision=None,
|
||||
):
|
||||
if os.path.isfile(pretrained_model_link_or_path):
|
||||
checkpoint = load_state_dict(pretrained_model_link_or_path)
|
||||
|
||||
else:
|
||||
repo_id, weights_name = _extract_repo_id_and_weights_name(pretrained_model_link_or_path)
|
||||
checkpoint_path = _get_model_file(
|
||||
@@ -357,7 +252,9 @@ def load_single_file_model_checkpoint(
|
||||
while "state_dict" in checkpoint:
|
||||
checkpoint = checkpoint["state_dict"]
|
||||
|
||||
return checkpoint
|
||||
original_config = fetch_original_config(class_name, checkpoint, original_config_file)
|
||||
|
||||
return original_config, checkpoint
|
||||
|
||||
|
||||
def infer_original_config_file(class_name, checkpoint):
|
||||
@@ -410,7 +307,7 @@ def fetch_original_config(pipeline_class_name, checkpoint, original_config_file=
|
||||
return original_config
|
||||
|
||||
|
||||
def infer_model_type(original_config, checkpoint, model_type=None):
|
||||
def infer_model_type(original_config, checkpoint=None, model_type=None):
|
||||
if model_type is not None:
|
||||
return model_type
|
||||
|
||||
@@ -987,7 +884,7 @@ def create_diffusers_controlnet_model_from_ldm(
|
||||
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
||||
|
||||
if len(unexpected_keys) > 0:
|
||||
logger.warning(
|
||||
logger.warn(
|
||||
f"Some weights of the model checkpoint were not used when initializing {controlnet.__name__}: \n {[', '.join(unexpected_keys)]}"
|
||||
)
|
||||
else:
|
||||
@@ -1163,7 +1060,7 @@ def create_text_encoder_from_ldm_clip_checkpoint(config_name, checkpoint, local_
|
||||
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
||||
|
||||
if len(unexpected_keys) > 0:
|
||||
logger.warning(
|
||||
logger.warn(
|
||||
f"Some weights of the model checkpoint were not used when initializing {text_model.__class__.__name__}: \n {[', '.join(unexpected_keys)]}"
|
||||
)
|
||||
else:
|
||||
@@ -1258,7 +1155,7 @@ def create_text_encoder_from_open_clip_checkpoint(
|
||||
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
||||
|
||||
if len(unexpected_keys) > 0:
|
||||
logger.warning(
|
||||
logger.warn(
|
||||
f"Some weights of the model checkpoint were not used when initializing {text_model.__class__.__name__}: \n {[', '.join(unexpected_keys)]}"
|
||||
)
|
||||
|
||||
@@ -1279,7 +1176,7 @@ def create_diffusers_unet_model_from_ldm(
|
||||
original_config,
|
||||
checkpoint,
|
||||
num_in_channels=None,
|
||||
upcast_attention=None,
|
||||
upcast_attention=False,
|
||||
extract_ema=False,
|
||||
image_size=None,
|
||||
torch_dtype=None,
|
||||
@@ -1307,8 +1204,7 @@ def create_diffusers_unet_model_from_ldm(
|
||||
)
|
||||
unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
|
||||
unet_config["in_channels"] = num_in_channels
|
||||
if upcast_attention is not None:
|
||||
unet_config["upcast_attention"] = upcast_attention
|
||||
unet_config["upcast_attention"] = upcast_attention
|
||||
|
||||
diffusers_format_unet_checkpoint = convert_ldm_unet_checkpoint(checkpoint, unet_config, extract_ema=extract_ema)
|
||||
ctx = init_empty_weights if is_accelerate_available() else nullcontext
|
||||
@@ -1325,7 +1221,7 @@ def create_diffusers_unet_model_from_ldm(
|
||||
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
||||
|
||||
if len(unexpected_keys) > 0:
|
||||
logger.warning(
|
||||
logger.warn(
|
||||
f"Some weights of the model checkpoint were not used when initializing {unet.__name__}: \n {[', '.join(unexpected_keys)]}"
|
||||
)
|
||||
else:
|
||||
@@ -1387,7 +1283,7 @@ def create_diffusers_vae_model_from_ldm(
|
||||
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
||||
|
||||
if len(unexpected_keys) > 0:
|
||||
logger.warning(
|
||||
logger.warn(
|
||||
f"Some weights of the model checkpoint were not used when initializing {vae.__name__}: \n {[', '.join(unexpected_keys)]}"
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -42,11 +42,6 @@ from ..utils import (
|
||||
set_adapter_layers,
|
||||
set_weights_and_activate_adapters,
|
||||
)
|
||||
from .single_file_utils import (
|
||||
convert_stable_cascade_unet_single_file_to_diffusers,
|
||||
infer_stable_cascade_single_file_config,
|
||||
load_single_file_model_checkpoint,
|
||||
)
|
||||
from .utils import AttnProcsLayers
|
||||
|
||||
|
||||
@@ -350,7 +345,7 @@ class UNet2DConditionLoadersMixin:
|
||||
is_model_cpu_offload = False
|
||||
is_sequential_cpu_offload = False
|
||||
|
||||
# For PEFT backend the Unet is already offloaded at this stage as it is handled inside `load_lora_weights_into_unet`
|
||||
# For PEFT backend the Unet is already offloaded at this stage as it is handled inside `lora_lora_weights_into_unet`
|
||||
if not USE_PEFT_BACKEND:
|
||||
if _pipeline is not None:
|
||||
for _, component in _pipeline.components.items():
|
||||
@@ -389,7 +384,7 @@ class UNet2DConditionLoadersMixin:
|
||||
is_text_encoder_present = any(key.startswith(self.text_encoder_name) for key in state_dict.keys())
|
||||
if is_text_encoder_present:
|
||||
warn_message = "The state_dict contains LoRA params corresponding to the text encoder which are not being used here. To use both UNet and text encoder related LoRA params, use [`pipe.load_lora_weights()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraLoaderMixin.load_lora_weights)."
|
||||
logger.warning(warn_message)
|
||||
logger.warn(warn_message)
|
||||
unet_keys = [k for k in state_dict.keys() if k.startswith(self.unet_name)]
|
||||
state_dict = {k.replace(f"{self.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys}
|
||||
|
||||
@@ -901,103 +896,3 @@ class UNet2DConditionLoadersMixin:
|
||||
self.config.encoder_hid_dim_type = "ip_image_proj"
|
||||
|
||||
self.to(dtype=self.dtype, device=self.device)
|
||||
|
||||
|
||||
class FromOriginalUNetMixin:
|
||||
"""
|
||||
Load pretrained UNet model weights saved in the `.ckpt` or `.safetensors` format into a [`StableCascadeUNet`].
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
|
||||
r"""
|
||||
Instantiate a [`StableCascadeUNet`] from pretrained StableCascadeUNet weights saved in the original `.ckpt` or
|
||||
`.safetensors` format. The pipeline is set in evaluation mode (`model.eval()`) by default.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*):
|
||||
Can be either:
|
||||
- A link to the `.ckpt` file (for example
|
||||
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.ckpt"`) on the Hub.
|
||||
- A path to a *file* containing all pipeline weights.
|
||||
config: (`dict`, *optional*):
|
||||
Dictionary containing the configuration of the model:
|
||||
torch_dtype (`str` or `torch.dtype`, *optional*):
|
||||
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
|
||||
dtype is automatically derived from the model's weights.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
resume_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
|
||||
incompletely downloaded files are deleted.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether to only load local model weights and configuration files or not. If set to True, the model
|
||||
won't be downloaded from the Hub.
|
||||
token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
||||
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
||||
allowed by Git.
|
||||
kwargs (remaining dictionary of keyword arguments, *optional*):
|
||||
Can be used to overwrite load and saveable variables of the model.
|
||||
|
||||
"""
|
||||
class_name = cls.__name__
|
||||
if class_name != "StableCascadeUNet":
|
||||
raise ValueError("FromOriginalUNetMixin is currently only compatible with StableCascadeUNet")
|
||||
|
||||
config = kwargs.pop("config", None)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
token = kwargs.pop("token", None)
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
|
||||
checkpoint = load_single_file_model_checkpoint(
|
||||
pretrained_model_link_or_path,
|
||||
resume_download=resume_download,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=local_files_only,
|
||||
revision=revision,
|
||||
)
|
||||
|
||||
if config is None:
|
||||
config = infer_stable_cascade_single_file_config(checkpoint)
|
||||
model_config = cls.load_config(**config, **kwargs)
|
||||
else:
|
||||
model_config = config
|
||||
|
||||
ctx = init_empty_weights if is_accelerate_available() else nullcontext
|
||||
with ctx():
|
||||
model = cls.from_config(model_config, **kwargs)
|
||||
|
||||
diffusers_format_checkpoint = convert_stable_cascade_unet_single_file_to_diffusers(checkpoint)
|
||||
if is_accelerate_available():
|
||||
unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
|
||||
if len(unexpected_keys) > 0:
|
||||
logger.warn(
|
||||
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
|
||||
)
|
||||
|
||||
else:
|
||||
model.load_state_dict(diffusers_format_checkpoint)
|
||||
|
||||
if torch_dtype is not None:
|
||||
model.to(torch_dtype)
|
||||
|
||||
return model
|
||||
|
||||
@@ -17,7 +17,8 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ..utils import deprecate
|
||||
from ..utils import USE_PEFT_BACKEND
|
||||
from .lora import LoRACompatibleLinear
|
||||
|
||||
|
||||
ACTIVATION_FUNCTIONS = {
|
||||
@@ -86,7 +87,9 @@ class GEGLU(nn.Module):
|
||||
|
||||
def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
|
||||
super().__init__()
|
||||
self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias)
|
||||
linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear
|
||||
|
||||
self.proj = linear_cls(dim_in, dim_out * 2, bias=bias)
|
||||
|
||||
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
|
||||
if gate.device.type != "mps":
|
||||
@@ -94,12 +97,9 @@ class GEGLU(nn.Module):
|
||||
# mps: gelu is not implemented for float16
|
||||
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
|
||||
|
||||
def forward(self, hidden_states, *args, **kwargs):
|
||||
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
||||
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
||||
deprecate("scale", "1.0.0", deprecation_message)
|
||||
|
||||
hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
|
||||
def forward(self, hidden_states, scale: float = 1.0):
|
||||
args = () if USE_PEFT_BACKEND else (scale,)
|
||||
hidden_states, gate = self.proj(hidden_states, *args).chunk(2, dim=-1)
|
||||
return hidden_states * self.gelu(gate)
|
||||
|
||||
|
||||
|
||||
@@ -17,18 +17,18 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ..utils import deprecate, logging
|
||||
from ..utils import USE_PEFT_BACKEND
|
||||
from ..utils.torch_utils import maybe_allow_in_graph
|
||||
from .activations import GEGLU, GELU, ApproximateGELU
|
||||
from .attention_processor import Attention
|
||||
from .embeddings import SinusoidalPositionalEmbedding
|
||||
from .lora import LoRACompatibleLinear
|
||||
from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int):
|
||||
def _chunked_feed_forward(
|
||||
ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int, lora_scale: Optional[float] = None
|
||||
):
|
||||
# "feed_forward_chunk_size" can be used to save memory
|
||||
if hidden_states.shape[chunk_dim] % chunk_size != 0:
|
||||
raise ValueError(
|
||||
@@ -36,10 +36,18 @@ def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim:
|
||||
)
|
||||
|
||||
num_chunks = hidden_states.shape[chunk_dim] // chunk_size
|
||||
ff_output = torch.cat(
|
||||
[ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
|
||||
dim=chunk_dim,
|
||||
)
|
||||
if lora_scale is None:
|
||||
ff_output = torch.cat(
|
||||
[ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
|
||||
dim=chunk_dim,
|
||||
)
|
||||
else:
|
||||
# TOOD(Patrick): LoRA scale can be removed once PEFT refactor is complete
|
||||
ff_output = torch.cat(
|
||||
[ff(hid_slice, scale=lora_scale) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
|
||||
dim=chunk_dim,
|
||||
)
|
||||
|
||||
return ff_output
|
||||
|
||||
|
||||
@@ -291,10 +299,6 @@ class BasicTransformerBlock(nn.Module):
|
||||
class_labels: Optional[torch.LongTensor] = None,
|
||||
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
||||
) -> torch.FloatTensor:
|
||||
if cross_attention_kwargs is not None:
|
||||
if cross_attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
|
||||
|
||||
# Notice that normalization is always applied before the real computation in the following blocks.
|
||||
# 0. Self-Attention
|
||||
batch_size = hidden_states.shape[0]
|
||||
@@ -322,7 +326,10 @@ class BasicTransformerBlock(nn.Module):
|
||||
if self.pos_embed is not None:
|
||||
norm_hidden_states = self.pos_embed(norm_hidden_states)
|
||||
|
||||
# 1. Prepare GLIGEN inputs
|
||||
# 1. Retrieve lora scale.
|
||||
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
|
||||
|
||||
# 2. Prepare GLIGEN inputs
|
||||
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
|
||||
gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
|
||||
|
||||
@@ -341,7 +348,7 @@ class BasicTransformerBlock(nn.Module):
|
||||
if hidden_states.ndim == 4:
|
||||
hidden_states = hidden_states.squeeze(1)
|
||||
|
||||
# 1.2 GLIGEN Control
|
||||
# 2.5 GLIGEN Control
|
||||
if gligen_kwargs is not None:
|
||||
hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
|
||||
|
||||
@@ -387,9 +394,11 @@ class BasicTransformerBlock(nn.Module):
|
||||
|
||||
if self._chunk_size is not None:
|
||||
# "feed_forward_chunk_size" can be used to save memory
|
||||
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
|
||||
ff_output = _chunked_feed_forward(
|
||||
self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size, lora_scale=lora_scale
|
||||
)
|
||||
else:
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
ff_output = self.ff(norm_hidden_states, scale=lora_scale)
|
||||
|
||||
if self.norm_type == "ada_norm_zero":
|
||||
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
||||
@@ -634,7 +643,7 @@ class FeedForward(nn.Module):
|
||||
if inner_dim is None:
|
||||
inner_dim = int(dim * mult)
|
||||
dim_out = dim_out if dim_out is not None else dim
|
||||
linear_cls = nn.Linear
|
||||
linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear
|
||||
|
||||
if activation_fn == "gelu":
|
||||
act_fn = GELU(dim, inner_dim, bias=bias)
|
||||
@@ -656,10 +665,11 @@ class FeedForward(nn.Module):
|
||||
if final_dropout:
|
||||
self.net.append(nn.Dropout(dropout))
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
||||
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
||||
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
||||
deprecate("scale", "1.0.0", deprecation_message)
|
||||
def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
|
||||
compatible_cls = (GEGLU,) if USE_PEFT_BACKEND else (GEGLU, LoRACompatibleLinear)
|
||||
for module in self.net:
|
||||
hidden_states = module(hidden_states)
|
||||
if isinstance(module, compatible_cls):
|
||||
hidden_states = module(hidden_states, scale)
|
||||
else:
|
||||
hidden_states = module(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
@@ -20,10 +20,10 @@ import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ..image_processor import IPAdapterMaskProcessor
|
||||
from ..utils import deprecate, logging
|
||||
from ..utils import USE_PEFT_BACKEND, deprecate, logging
|
||||
from ..utils.import_utils import is_xformers_available
|
||||
from ..utils.torch_utils import maybe_allow_in_graph
|
||||
from .lora import LoRALinearLayer
|
||||
from .lora import LoRACompatibleLinear, LoRALinearLayer
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -181,7 +181,10 @@ class Attention(nn.Module):
|
||||
f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
|
||||
)
|
||||
|
||||
linear_cls = nn.Linear
|
||||
if USE_PEFT_BACKEND:
|
||||
linear_cls = nn.Linear
|
||||
else:
|
||||
linear_cls = LoRACompatibleLinear
|
||||
|
||||
self.linear_cls = linear_cls
|
||||
self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias)
|
||||
@@ -738,15 +741,12 @@ class AttnProcessor:
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
scale: float = 1.0,
|
||||
) -> torch.Tensor:
|
||||
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
||||
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
||||
deprecate("scale", "1.0.0", deprecation_message)
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
args = () if USE_PEFT_BACKEND else (scale,)
|
||||
|
||||
if attn.spatial_norm is not None:
|
||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||
|
||||
@@ -764,15 +764,15 @@ class AttnProcessor:
|
||||
if attn.group_norm is not None:
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
query = attn.to_q(hidden_states, *args)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
key = attn.to_k(encoder_hidden_states, *args)
|
||||
value = attn.to_v(encoder_hidden_states, *args)
|
||||
|
||||
query = attn.head_to_batch_dim(query)
|
||||
key = attn.head_to_batch_dim(key)
|
||||
@@ -783,7 +783,7 @@ class AttnProcessor:
|
||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
hidden_states = attn.to_out[0](hidden_states, *args)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
@@ -914,15 +914,12 @@ class AttnAddedKVProcessor:
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
scale: float = 1.0,
|
||||
) -> torch.Tensor:
|
||||
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
||||
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
||||
deprecate("scale", "1.0.0", deprecation_message)
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
args = () if USE_PEFT_BACKEND else (scale,)
|
||||
|
||||
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
|
||||
batch_size, sequence_length, _ = hidden_states.shape
|
||||
|
||||
@@ -935,17 +932,17 @@ class AttnAddedKVProcessor:
|
||||
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
query = attn.to_q(hidden_states, *args)
|
||||
query = attn.head_to_batch_dim(query)
|
||||
|
||||
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
||||
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
||||
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states, *args)
|
||||
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states, *args)
|
||||
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
|
||||
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
|
||||
|
||||
if not attn.only_cross_attention:
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
key = attn.to_k(hidden_states, *args)
|
||||
value = attn.to_v(hidden_states, *args)
|
||||
key = attn.head_to_batch_dim(key)
|
||||
value = attn.head_to_batch_dim(value)
|
||||
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
|
||||
@@ -959,7 +956,7 @@ class AttnAddedKVProcessor:
|
||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
hidden_states = attn.to_out[0](hidden_states, *args)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
@@ -987,15 +984,12 @@ class AttnAddedKVProcessor2_0:
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
scale: float = 1.0,
|
||||
) -> torch.Tensor:
|
||||
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
||||
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
||||
deprecate("scale", "1.0.0", deprecation_message)
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
args = () if USE_PEFT_BACKEND else (scale,)
|
||||
|
||||
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
|
||||
batch_size, sequence_length, _ = hidden_states.shape
|
||||
|
||||
@@ -1008,7 +1002,7 @@ class AttnAddedKVProcessor2_0:
|
||||
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
query = attn.to_q(hidden_states, *args)
|
||||
query = attn.head_to_batch_dim(query, out_dim=4)
|
||||
|
||||
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
||||
@@ -1017,8 +1011,8 @@ class AttnAddedKVProcessor2_0:
|
||||
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj, out_dim=4)
|
||||
|
||||
if not attn.only_cross_attention:
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
key = attn.to_k(hidden_states, *args)
|
||||
value = attn.to_v(hidden_states, *args)
|
||||
key = attn.head_to_batch_dim(key, out_dim=4)
|
||||
value = attn.head_to_batch_dim(value, out_dim=4)
|
||||
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
|
||||
@@ -1035,7 +1029,7 @@ class AttnAddedKVProcessor2_0:
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, residual.shape[1])
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
hidden_states = attn.to_out[0](hidden_states, *args)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
@@ -1138,15 +1132,12 @@ class XFormersAttnProcessor:
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
scale: float = 1.0,
|
||||
) -> torch.FloatTensor:
|
||||
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
||||
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
||||
deprecate("scale", "1.0.0", deprecation_message)
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
args = () if USE_PEFT_BACKEND else (scale,)
|
||||
|
||||
if attn.spatial_norm is not None:
|
||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||
|
||||
@@ -1174,15 +1165,15 @@ class XFormersAttnProcessor:
|
||||
if attn.group_norm is not None:
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
query = attn.to_q(hidden_states, *args)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
key = attn.to_k(encoder_hidden_states, *args)
|
||||
value = attn.to_v(encoder_hidden_states, *args)
|
||||
|
||||
query = attn.head_to_batch_dim(query).contiguous()
|
||||
key = attn.head_to_batch_dim(key).contiguous()
|
||||
@@ -1195,7 +1186,7 @@ class XFormersAttnProcessor:
|
||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
hidden_states = attn.to_out[0](hidden_states, *args)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
@@ -1226,13 +1217,8 @@ class AttnProcessor2_0:
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
scale: float = 1.0,
|
||||
) -> torch.FloatTensor:
|
||||
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
||||
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
||||
deprecate("scale", "1.0.0", deprecation_message)
|
||||
|
||||
residual = hidden_states
|
||||
if attn.spatial_norm is not None:
|
||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||
@@ -1256,15 +1242,16 @@ class AttnProcessor2_0:
|
||||
if attn.group_norm is not None:
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
args = () if USE_PEFT_BACKEND else (scale,)
|
||||
query = attn.to_q(hidden_states, *args)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
key = attn.to_k(encoder_hidden_states, *args)
|
||||
value = attn.to_v(encoder_hidden_states, *args)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
@@ -1284,7 +1271,7 @@ class AttnProcessor2_0:
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
hidden_states = attn.to_out[0](hidden_states, *args)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
@@ -1325,13 +1312,8 @@ class FusedAttnProcessor2_0:
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
scale: float = 1.0,
|
||||
) -> torch.FloatTensor:
|
||||
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
||||
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
||||
deprecate("scale", "1.0.0", deprecation_message)
|
||||
|
||||
residual = hidden_states
|
||||
if attn.spatial_norm is not None:
|
||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||
@@ -1355,16 +1337,17 @@ class FusedAttnProcessor2_0:
|
||||
if attn.group_norm is not None:
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
args = () if USE_PEFT_BACKEND else (scale,)
|
||||
if encoder_hidden_states is None:
|
||||
qkv = attn.to_qkv(hidden_states)
|
||||
qkv = attn.to_qkv(hidden_states, *args)
|
||||
split_size = qkv.shape[-1] // 3
|
||||
query, key, value = torch.split(qkv, split_size, dim=-1)
|
||||
else:
|
||||
if attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
query = attn.to_q(hidden_states)
|
||||
query = attn.to_q(hidden_states, *args)
|
||||
|
||||
kv = attn.to_kv(encoder_hidden_states)
|
||||
kv = attn.to_kv(encoder_hidden_states, *args)
|
||||
split_size = kv.shape[-1] // 2
|
||||
key, value = torch.split(kv, split_size, dim=-1)
|
||||
|
||||
@@ -1385,7 +1368,7 @@ class FusedAttnProcessor2_0:
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
hidden_states = attn.to_out[0](hidden_states, *args)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
@@ -1876,7 +1859,7 @@ class LoRAAttnProcessor(nn.Module):
|
||||
self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
|
||||
self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
|
||||
|
||||
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
|
||||
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
|
||||
self_cls_name = self.__class__.__name__
|
||||
deprecate(
|
||||
self_cls_name,
|
||||
@@ -1894,7 +1877,7 @@ class LoRAAttnProcessor(nn.Module):
|
||||
|
||||
attn._modules.pop("processor")
|
||||
attn.processor = AttnProcessor()
|
||||
return attn.processor(attn, hidden_states, **kwargs)
|
||||
return attn.processor(attn, hidden_states, *args, **kwargs)
|
||||
|
||||
|
||||
class LoRAAttnProcessor2_0(nn.Module):
|
||||
@@ -1937,7 +1920,7 @@ class LoRAAttnProcessor2_0(nn.Module):
|
||||
self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
|
||||
self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
|
||||
|
||||
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
|
||||
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
|
||||
self_cls_name = self.__class__.__name__
|
||||
deprecate(
|
||||
self_cls_name,
|
||||
@@ -1955,7 +1938,7 @@ class LoRAAttnProcessor2_0(nn.Module):
|
||||
|
||||
attn._modules.pop("processor")
|
||||
attn.processor = AttnProcessor2_0()
|
||||
return attn.processor(attn, hidden_states, **kwargs)
|
||||
return attn.processor(attn, hidden_states, *args, **kwargs)
|
||||
|
||||
|
||||
class LoRAXFormersAttnProcessor(nn.Module):
|
||||
@@ -2016,7 +1999,7 @@ class LoRAXFormersAttnProcessor(nn.Module):
|
||||
self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
|
||||
self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
|
||||
|
||||
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
|
||||
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
|
||||
self_cls_name = self.__class__.__name__
|
||||
deprecate(
|
||||
self_cls_name,
|
||||
@@ -2034,7 +2017,7 @@ class LoRAXFormersAttnProcessor(nn.Module):
|
||||
|
||||
attn._modules.pop("processor")
|
||||
attn.processor = XFormersAttnProcessor()
|
||||
return attn.processor(attn, hidden_states, **kwargs)
|
||||
return attn.processor(attn, hidden_states, *args, **kwargs)
|
||||
|
||||
|
||||
class LoRAAttnAddedKVProcessor(nn.Module):
|
||||
@@ -2075,7 +2058,7 @@ class LoRAAttnAddedKVProcessor(nn.Module):
|
||||
self.to_v_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
||||
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
||||
|
||||
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
|
||||
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
|
||||
self_cls_name = self.__class__.__name__
|
||||
deprecate(
|
||||
self_cls_name,
|
||||
@@ -2093,7 +2076,7 @@ class LoRAAttnAddedKVProcessor(nn.Module):
|
||||
|
||||
attn._modules.pop("processor")
|
||||
attn.processor = AttnAddedKVProcessor()
|
||||
return attn.processor(attn, hidden_states, **kwargs)
|
||||
return attn.processor(attn, hidden_states, *args, **kwargs)
|
||||
|
||||
|
||||
class IPAdapterAttnProcessor(nn.Module):
|
||||
|
||||
@@ -18,7 +18,8 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..utils import deprecate
|
||||
from ..utils import USE_PEFT_BACKEND
|
||||
from .lora import LoRACompatibleConv
|
||||
from .normalization import RMSNorm
|
||||
from .upsampling import upfirdn2d_native
|
||||
|
||||
@@ -102,7 +103,7 @@ class Downsample2D(nn.Module):
|
||||
self.padding = padding
|
||||
stride = 2
|
||||
self.name = name
|
||||
conv_cls = nn.Conv2d
|
||||
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
|
||||
|
||||
if norm_type == "ln_norm":
|
||||
self.norm = nn.LayerNorm(channels, eps, elementwise_affine)
|
||||
@@ -130,10 +131,7 @@ class Downsample2D(nn.Module):
|
||||
else:
|
||||
self.conv = conv
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
|
||||
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
||||
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
||||
deprecate("scale", "1.0.0", deprecation_message)
|
||||
def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor:
|
||||
assert hidden_states.shape[1] == self.channels
|
||||
|
||||
if self.norm is not None:
|
||||
@@ -145,7 +143,13 @@ class Downsample2D(nn.Module):
|
||||
|
||||
assert hidden_states.shape[1] == self.channels
|
||||
|
||||
hidden_states = self.conv(hidden_states)
|
||||
if not USE_PEFT_BACKEND:
|
||||
if isinstance(self.conv, LoRACompatibleConv):
|
||||
hidden_states = self.conv(hidden_states, scale)
|
||||
else:
|
||||
hidden_states = self.conv(hidden_states)
|
||||
else:
|
||||
hidden_states = self.conv(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
@@ -18,9 +18,10 @@ import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ..utils import deprecate
|
||||
from ..utils import USE_PEFT_BACKEND, deprecate
|
||||
from .activations import get_activation
|
||||
from .attention_processor import Attention
|
||||
from .lora import LoRACompatibleLinear
|
||||
|
||||
|
||||
def get_timestep_embedding(
|
||||
@@ -199,7 +200,7 @@ class TimestepEmbedding(nn.Module):
|
||||
sample_proj_bias=True,
|
||||
):
|
||||
super().__init__()
|
||||
linear_cls = nn.Linear
|
||||
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
|
||||
|
||||
self.linear_1 = linear_cls(in_channels, time_embed_dim, sample_proj_bias)
|
||||
|
||||
|
||||
@@ -204,9 +204,6 @@ class LoRALinearLayer(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
deprecation_message = "Use of `LoRALinearLayer` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`."
|
||||
deprecate("LoRALinearLayer", "1.0.0", deprecation_message)
|
||||
|
||||
self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype)
|
||||
self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype)
|
||||
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
|
||||
@@ -267,9 +264,6 @@ class LoRAConv2dLayer(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
deprecation_message = "Use of `LoRAConv2dLayer` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`."
|
||||
deprecate("LoRAConv2dLayer", "1.0.0", deprecation_message)
|
||||
|
||||
self.down = nn.Conv2d(in_features, rank, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)
|
||||
# according to the official kohya_ss trainer kernel_size are always fixed for the up layer
|
||||
# # see: https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L129
|
||||
|
||||
@@ -677,7 +677,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
||||
|
||||
if len(unexpected_keys) > 0:
|
||||
logger.warning(
|
||||
logger.warn(
|
||||
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
|
||||
)
|
||||
|
||||
@@ -705,7 +705,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
# the weights so we don't have to do this again.
|
||||
|
||||
if "'Attention' object has no attribute" in str(e):
|
||||
logger.warning(
|
||||
logger.warn(
|
||||
f"Taking `{str(e)}` while using `accelerate.load_checkpoint_and_dispatch` to mean {pretrained_model_name_or_path}"
|
||||
" was saved with deprecated attention block weight names. We will load it with the deprecated attention block"
|
||||
" names and convert them on the fly to the new attention block format. Please re-save the model after this conversion,"
|
||||
|
||||
@@ -20,7 +20,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..utils import deprecate
|
||||
from ..utils import USE_PEFT_BACKEND
|
||||
from .activations import get_activation
|
||||
from .attention_processor import SpatialNorm
|
||||
from .downsampling import ( # noqa
|
||||
@@ -30,6 +30,7 @@ from .downsampling import ( # noqa
|
||||
KDownsample2D,
|
||||
downsample_2d,
|
||||
)
|
||||
from .lora import LoRACompatibleConv, LoRACompatibleLinear
|
||||
from .normalization import AdaGroupNorm
|
||||
from .upsampling import ( # noqa
|
||||
FirUpsample2D,
|
||||
@@ -101,7 +102,7 @@ class ResnetBlockCondNorm2D(nn.Module):
|
||||
self.output_scale_factor = output_scale_factor
|
||||
self.time_embedding_norm = time_embedding_norm
|
||||
|
||||
conv_cls = nn.Conv2d
|
||||
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
|
||||
|
||||
if groups_out is None:
|
||||
groups_out = groups
|
||||
@@ -148,11 +149,12 @@ class ResnetBlockCondNorm2D(nn.Module):
|
||||
bias=conv_shortcut_bias,
|
||||
)
|
||||
|
||||
def forward(self, input_tensor: torch.FloatTensor, temb: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
|
||||
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
||||
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
||||
deprecate("scale", "1.0.0", deprecation_message)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_tensor: torch.FloatTensor,
|
||||
temb: torch.FloatTensor,
|
||||
scale: float = 1.0,
|
||||
) -> torch.FloatTensor:
|
||||
hidden_states = input_tensor
|
||||
|
||||
hidden_states = self.norm1(hidden_states, temb)
|
||||
@@ -164,24 +166,26 @@ class ResnetBlockCondNorm2D(nn.Module):
|
||||
if hidden_states.shape[0] >= 64:
|
||||
input_tensor = input_tensor.contiguous()
|
||||
hidden_states = hidden_states.contiguous()
|
||||
input_tensor = self.upsample(input_tensor)
|
||||
hidden_states = self.upsample(hidden_states)
|
||||
input_tensor = self.upsample(input_tensor, scale=scale)
|
||||
hidden_states = self.upsample(hidden_states, scale=scale)
|
||||
|
||||
elif self.downsample is not None:
|
||||
input_tensor = self.downsample(input_tensor)
|
||||
hidden_states = self.downsample(hidden_states)
|
||||
input_tensor = self.downsample(input_tensor, scale=scale)
|
||||
hidden_states = self.downsample(hidden_states, scale=scale)
|
||||
|
||||
hidden_states = self.conv1(hidden_states)
|
||||
hidden_states = self.conv1(hidden_states, scale) if not USE_PEFT_BACKEND else self.conv1(hidden_states)
|
||||
|
||||
hidden_states = self.norm2(hidden_states, temb)
|
||||
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = self.conv2(hidden_states)
|
||||
hidden_states = self.conv2(hidden_states, scale) if not USE_PEFT_BACKEND else self.conv2(hidden_states)
|
||||
|
||||
if self.conv_shortcut is not None:
|
||||
input_tensor = self.conv_shortcut(input_tensor)
|
||||
input_tensor = (
|
||||
self.conv_shortcut(input_tensor, scale) if not USE_PEFT_BACKEND else self.conv_shortcut(input_tensor)
|
||||
)
|
||||
|
||||
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
|
||||
|
||||
@@ -263,8 +267,8 @@ class ResnetBlock2D(nn.Module):
|
||||
self.time_embedding_norm = time_embedding_norm
|
||||
self.skip_time_act = skip_time_act
|
||||
|
||||
linear_cls = nn.Linear
|
||||
conv_cls = nn.Conv2d
|
||||
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
|
||||
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
|
||||
|
||||
if groups_out is None:
|
||||
groups_out = groups
|
||||
@@ -322,11 +326,12 @@ class ResnetBlock2D(nn.Module):
|
||||
bias=conv_shortcut_bias,
|
||||
)
|
||||
|
||||
def forward(self, input_tensor: torch.FloatTensor, temb: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
|
||||
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
||||
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
||||
deprecate("scale", "1.0.0", deprecation_message)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_tensor: torch.FloatTensor,
|
||||
temb: torch.FloatTensor,
|
||||
scale: float = 1.0,
|
||||
) -> torch.FloatTensor:
|
||||
hidden_states = input_tensor
|
||||
|
||||
hidden_states = self.norm1(hidden_states)
|
||||
@@ -337,18 +342,38 @@ class ResnetBlock2D(nn.Module):
|
||||
if hidden_states.shape[0] >= 64:
|
||||
input_tensor = input_tensor.contiguous()
|
||||
hidden_states = hidden_states.contiguous()
|
||||
input_tensor = self.upsample(input_tensor)
|
||||
hidden_states = self.upsample(hidden_states)
|
||||
input_tensor = (
|
||||
self.upsample(input_tensor, scale=scale)
|
||||
if isinstance(self.upsample, Upsample2D)
|
||||
else self.upsample(input_tensor)
|
||||
)
|
||||
hidden_states = (
|
||||
self.upsample(hidden_states, scale=scale)
|
||||
if isinstance(self.upsample, Upsample2D)
|
||||
else self.upsample(hidden_states)
|
||||
)
|
||||
elif self.downsample is not None:
|
||||
input_tensor = self.downsample(input_tensor)
|
||||
hidden_states = self.downsample(hidden_states)
|
||||
input_tensor = (
|
||||
self.downsample(input_tensor, scale=scale)
|
||||
if isinstance(self.downsample, Downsample2D)
|
||||
else self.downsample(input_tensor)
|
||||
)
|
||||
hidden_states = (
|
||||
self.downsample(hidden_states, scale=scale)
|
||||
if isinstance(self.downsample, Downsample2D)
|
||||
else self.downsample(hidden_states)
|
||||
)
|
||||
|
||||
hidden_states = self.conv1(hidden_states)
|
||||
hidden_states = self.conv1(hidden_states, scale) if not USE_PEFT_BACKEND else self.conv1(hidden_states)
|
||||
|
||||
if self.time_emb_proj is not None:
|
||||
if not self.skip_time_act:
|
||||
temb = self.nonlinearity(temb)
|
||||
temb = self.time_emb_proj(temb)[:, :, None, None]
|
||||
temb = (
|
||||
self.time_emb_proj(temb, scale)[:, :, None, None]
|
||||
if not USE_PEFT_BACKEND
|
||||
else self.time_emb_proj(temb)[:, :, None, None]
|
||||
)
|
||||
|
||||
if self.time_embedding_norm == "default":
|
||||
if temb is not None:
|
||||
@@ -368,10 +393,12 @@ class ResnetBlock2D(nn.Module):
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = self.conv2(hidden_states)
|
||||
hidden_states = self.conv2(hidden_states, scale) if not USE_PEFT_BACKEND else self.conv2(hidden_states)
|
||||
|
||||
if self.conv_shortcut is not None:
|
||||
input_tensor = self.conv_shortcut(input_tensor)
|
||||
input_tensor = (
|
||||
self.conv_shortcut(input_tensor, scale) if not USE_PEFT_BACKEND else self.conv_shortcut(input_tensor)
|
||||
)
|
||||
|
||||
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
|
||||
|
||||
|
||||
@@ -19,16 +19,14 @@ import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import BaseOutput, deprecate, is_torch_version, logging
|
||||
from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_torch_version
|
||||
from ..attention import BasicTransformerBlock
|
||||
from ..embeddings import ImagePositionalEmbeddings, PatchEmbed, PixArtAlphaTextProjection
|
||||
from ..lora import LoRACompatibleConv, LoRACompatibleLinear
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNormSingle
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
@dataclass
|
||||
class Transformer2DModelOutput(BaseOutput):
|
||||
"""
|
||||
@@ -117,8 +115,8 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
self.attention_head_dim = attention_head_dim
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
conv_cls = nn.Conv2d
|
||||
linear_cls = nn.Linear
|
||||
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
|
||||
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
|
||||
|
||||
# 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
|
||||
# Define whether input is continuous or discrete depending on configuration
|
||||
@@ -306,9 +304,6 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
||||
`tuple` where the first element is the sample tensor.
|
||||
"""
|
||||
if cross_attention_kwargs is not None:
|
||||
if cross_attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
|
||||
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
|
||||
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
|
||||
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
|
||||
@@ -332,6 +327,9 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
|
||||
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
||||
|
||||
# Retrieve lora scale.
|
||||
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
|
||||
|
||||
# 1. Input
|
||||
if self.is_input_continuous:
|
||||
batch, _, height, width = hidden_states.shape
|
||||
@@ -339,13 +337,21 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
if not self.use_linear_projection:
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
hidden_states = (
|
||||
self.proj_in(hidden_states, scale=lora_scale)
|
||||
if not USE_PEFT_BACKEND
|
||||
else self.proj_in(hidden_states)
|
||||
)
|
||||
inner_dim = hidden_states.shape[1]
|
||||
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
|
||||
else:
|
||||
inner_dim = hidden_states.shape[1]
|
||||
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
hidden_states = (
|
||||
self.proj_in(hidden_states, scale=lora_scale)
|
||||
if not USE_PEFT_BACKEND
|
||||
else self.proj_in(hidden_states)
|
||||
)
|
||||
|
||||
elif self.is_input_vectorized:
|
||||
hidden_states = self.latent_image_embedding(hidden_states)
|
||||
@@ -408,9 +414,17 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
if self.is_input_continuous:
|
||||
if not self.use_linear_projection:
|
||||
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
hidden_states = (
|
||||
self.proj_out(hidden_states, scale=lora_scale)
|
||||
if not USE_PEFT_BACKEND
|
||||
else self.proj_out(hidden_states)
|
||||
)
|
||||
else:
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
hidden_states = (
|
||||
self.proj_out(hidden_states, scale=lora_scale)
|
||||
if not USE_PEFT_BACKEND
|
||||
else self.proj_out(hidden_states)
|
||||
)
|
||||
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
||||
|
||||
output = hidden_states + residual
|
||||
|
||||
@@ -18,7 +18,7 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ...utils import deprecate, is_torch_version, logging
|
||||
from ...utils import is_torch_version, logging
|
||||
from ...utils.torch_utils import apply_freeu
|
||||
from ..activations import get_activation
|
||||
from ..attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0
|
||||
@@ -69,7 +69,7 @@ def get_down_block(
|
||||
):
|
||||
# If attn head dim is not defined, we default it to the number of heads
|
||||
if attention_head_dim is None:
|
||||
logger.warning(
|
||||
logger.warn(
|
||||
f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
|
||||
)
|
||||
attention_head_dim = num_attention_heads
|
||||
@@ -354,7 +354,7 @@ def get_up_block(
|
||||
) -> nn.Module:
|
||||
# If attn head dim is not defined, we default it to the number of heads
|
||||
if attention_head_dim is None:
|
||||
logger.warning(
|
||||
logger.warn(
|
||||
f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
|
||||
)
|
||||
attention_head_dim = num_attention_heads
|
||||
@@ -673,7 +673,7 @@ class UNetMidBlock2D(nn.Module):
|
||||
attentions = []
|
||||
|
||||
if attention_head_dim is None:
|
||||
logger.warning(
|
||||
logger.warn(
|
||||
f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}."
|
||||
)
|
||||
attention_head_dim = in_channels
|
||||
@@ -844,11 +844,8 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
if cross_attention_kwargs is not None:
|
||||
if cross_attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
|
||||
|
||||
hidden_states = self.resnets[0](hidden_states, temb)
|
||||
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
|
||||
hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale)
|
||||
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
@@ -885,7 +882,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@@ -985,8 +982,7 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
|
||||
if cross_attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
|
||||
lora_scale = cross_attention_kwargs.get("scale", 1.0)
|
||||
|
||||
if attention_mask is None:
|
||||
# if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask.
|
||||
@@ -999,7 +995,7 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
|
||||
# mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask
|
||||
mask = attention_mask
|
||||
|
||||
hidden_states = self.resnets[0](hidden_states, temb)
|
||||
hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale)
|
||||
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
||||
# attn
|
||||
hidden_states = attn(
|
||||
@@ -1010,7 +1006,7 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
|
||||
)
|
||||
|
||||
# resnet
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@@ -1039,7 +1035,7 @@ class AttnDownBlock2D(nn.Module):
|
||||
self.downsample_type = downsample_type
|
||||
|
||||
if attention_head_dim is None:
|
||||
logger.warning(
|
||||
logger.warn(
|
||||
f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}."
|
||||
)
|
||||
attention_head_dim = out_channels
|
||||
@@ -1115,22 +1111,23 @@ class AttnDownBlock2D(nn.Module):
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
|
||||
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
|
||||
if cross_attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
|
||||
|
||||
lora_scale = cross_attention_kwargs.get("scale", 1.0)
|
||||
|
||||
output_states = ()
|
||||
|
||||
for resnet, attn in zip(self.resnets, self.attentions):
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
cross_attention_kwargs.update({"scale": lora_scale})
|
||||
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
|
||||
hidden_states = attn(hidden_states, **cross_attention_kwargs)
|
||||
output_states = output_states + (hidden_states,)
|
||||
|
||||
if self.downsamplers is not None:
|
||||
for downsampler in self.downsamplers:
|
||||
if self.downsample_type == "resnet":
|
||||
hidden_states = downsampler(hidden_states, temb=temb)
|
||||
hidden_states = downsampler(hidden_states, temb=temb, scale=lora_scale)
|
||||
else:
|
||||
hidden_states = downsampler(hidden_states)
|
||||
hidden_states = downsampler(hidden_states, scale=lora_scale)
|
||||
|
||||
output_states += (hidden_states,)
|
||||
|
||||
@@ -1239,12 +1236,10 @@ class CrossAttnDownBlock2D(nn.Module):
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
additional_residuals: Optional[torch.FloatTensor] = None,
|
||||
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
|
||||
if cross_attention_kwargs is not None:
|
||||
if cross_attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
|
||||
|
||||
output_states = ()
|
||||
|
||||
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
|
||||
|
||||
blocks = list(zip(self.resnets, self.attentions))
|
||||
|
||||
for i, (resnet, attn) in enumerate(blocks):
|
||||
@@ -1275,7 +1270,7 @@ class CrossAttnDownBlock2D(nn.Module):
|
||||
return_dict=False,
|
||||
)[0]
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
@@ -1293,7 +1288,7 @@ class CrossAttnDownBlock2D(nn.Module):
|
||||
|
||||
if self.downsamplers is not None:
|
||||
for downsampler in self.downsamplers:
|
||||
hidden_states = downsampler(hidden_states)
|
||||
hidden_states = downsampler(hidden_states, scale=lora_scale)
|
||||
|
||||
output_states = output_states + (hidden_states,)
|
||||
|
||||
@@ -1353,12 +1348,8 @@ class DownBlock2D(nn.Module):
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, *args, **kwargs
|
||||
self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0
|
||||
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
|
||||
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
||||
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
||||
deprecate("scale", "1.0.0", deprecation_message)
|
||||
|
||||
output_states = ()
|
||||
|
||||
for resnet in self.resnets:
|
||||
@@ -1379,13 +1370,13 @@ class DownBlock2D(nn.Module):
|
||||
create_custom_forward(resnet), hidden_states, temb
|
||||
)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = resnet(hidden_states, temb, scale=scale)
|
||||
|
||||
output_states = output_states + (hidden_states,)
|
||||
|
||||
if self.downsamplers is not None:
|
||||
for downsampler in self.downsamplers:
|
||||
hidden_states = downsampler(hidden_states)
|
||||
hidden_states = downsampler(hidden_states, scale=scale)
|
||||
|
||||
output_states = output_states + (hidden_states,)
|
||||
|
||||
@@ -1456,17 +1447,13 @@ class DownEncoderBlock2D(nn.Module):
|
||||
else:
|
||||
self.downsamplers = None
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
|
||||
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
||||
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
||||
deprecate("scale", "1.0.0", deprecation_message)
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor:
|
||||
for resnet in self.resnets:
|
||||
hidden_states = resnet(hidden_states, temb=None)
|
||||
hidden_states = resnet(hidden_states, temb=None, scale=scale)
|
||||
|
||||
if self.downsamplers is not None:
|
||||
for downsampler in self.downsamplers:
|
||||
hidden_states = downsampler(hidden_states)
|
||||
hidden_states = downsampler(hidden_states, scale)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@@ -1493,7 +1480,7 @@ class AttnDownEncoderBlock2D(nn.Module):
|
||||
attentions = []
|
||||
|
||||
if attention_head_dim is None:
|
||||
logger.warning(
|
||||
logger.warn(
|
||||
f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}."
|
||||
)
|
||||
attention_head_dim = out_channels
|
||||
@@ -1558,18 +1545,15 @@ class AttnDownEncoderBlock2D(nn.Module):
|
||||
else:
|
||||
self.downsamplers = None
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
|
||||
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
||||
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
||||
deprecate("scale", "1.0.0", deprecation_message)
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor:
|
||||
for resnet, attn in zip(self.resnets, self.attentions):
|
||||
hidden_states = resnet(hidden_states, temb=None)
|
||||
hidden_states = attn(hidden_states)
|
||||
hidden_states = resnet(hidden_states, temb=None, scale=scale)
|
||||
cross_attention_kwargs = {"scale": scale}
|
||||
hidden_states = attn(hidden_states, **cross_attention_kwargs)
|
||||
|
||||
if self.downsamplers is not None:
|
||||
for downsampler in self.downsamplers:
|
||||
hidden_states = downsampler(hidden_states)
|
||||
hidden_states = downsampler(hidden_states, scale)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@@ -1595,7 +1579,7 @@ class AttnSkipDownBlock2D(nn.Module):
|
||||
self.resnets = nn.ModuleList([])
|
||||
|
||||
if attention_head_dim is None:
|
||||
logger.warning(
|
||||
logger.warn(
|
||||
f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}."
|
||||
)
|
||||
attention_head_dim = out_channels
|
||||
@@ -1660,22 +1644,18 @@ class AttnSkipDownBlock2D(nn.Module):
|
||||
hidden_states: torch.FloatTensor,
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
skip_sample: Optional[torch.FloatTensor] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
scale: float = 1.0,
|
||||
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...], torch.FloatTensor]:
|
||||
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
||||
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
||||
deprecate("scale", "1.0.0", deprecation_message)
|
||||
|
||||
output_states = ()
|
||||
|
||||
for resnet, attn in zip(self.resnets, self.attentions):
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = attn(hidden_states)
|
||||
hidden_states = resnet(hidden_states, temb, scale=scale)
|
||||
cross_attention_kwargs = {"scale": scale}
|
||||
hidden_states = attn(hidden_states, **cross_attention_kwargs)
|
||||
output_states += (hidden_states,)
|
||||
|
||||
if self.downsamplers is not None:
|
||||
hidden_states = self.resnet_down(hidden_states, temb)
|
||||
hidden_states = self.resnet_down(hidden_states, temb, scale=scale)
|
||||
for downsampler in self.downsamplers:
|
||||
skip_sample = downsampler(skip_sample)
|
||||
|
||||
@@ -1751,21 +1731,16 @@ class SkipDownBlock2D(nn.Module):
|
||||
hidden_states: torch.FloatTensor,
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
skip_sample: Optional[torch.FloatTensor] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
scale: float = 1.0,
|
||||
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...], torch.FloatTensor]:
|
||||
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
||||
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
||||
deprecate("scale", "1.0.0", deprecation_message)
|
||||
|
||||
output_states = ()
|
||||
|
||||
for resnet in self.resnets:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = resnet(hidden_states, temb, scale)
|
||||
output_states += (hidden_states,)
|
||||
|
||||
if self.downsamplers is not None:
|
||||
hidden_states = self.resnet_down(hidden_states, temb)
|
||||
hidden_states = self.resnet_down(hidden_states, temb, scale)
|
||||
for downsampler in self.downsamplers:
|
||||
skip_sample = downsampler(skip_sample)
|
||||
|
||||
@@ -1841,12 +1816,8 @@ class ResnetDownsampleBlock2D(nn.Module):
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, *args, **kwargs
|
||||
self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0
|
||||
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
|
||||
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
||||
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
||||
deprecate("scale", "1.0.0", deprecation_message)
|
||||
|
||||
output_states = ()
|
||||
|
||||
for resnet in self.resnets:
|
||||
@@ -1867,13 +1838,13 @@ class ResnetDownsampleBlock2D(nn.Module):
|
||||
create_custom_forward(resnet), hidden_states, temb
|
||||
)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = resnet(hidden_states, temb, scale)
|
||||
|
||||
output_states = output_states + (hidden_states,)
|
||||
|
||||
if self.downsamplers is not None:
|
||||
for downsampler in self.downsamplers:
|
||||
hidden_states = downsampler(hidden_states, temb)
|
||||
hidden_states = downsampler(hidden_states, temb, scale)
|
||||
|
||||
output_states = output_states + (hidden_states,)
|
||||
|
||||
@@ -1984,11 +1955,10 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
|
||||
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
|
||||
if cross_attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
|
||||
|
||||
output_states = ()
|
||||
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
|
||||
|
||||
lora_scale = cross_attention_kwargs.get("scale", 1.0)
|
||||
|
||||
if attention_mask is None:
|
||||
# if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask.
|
||||
@@ -2021,7 +1991,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
|
||||
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
@@ -2034,7 +2004,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
|
||||
|
||||
if self.downsamplers is not None:
|
||||
for downsampler in self.downsamplers:
|
||||
hidden_states = downsampler(hidden_states, temb)
|
||||
hidden_states = downsampler(hidden_states, temb, scale=lora_scale)
|
||||
|
||||
output_states = output_states + (hidden_states,)
|
||||
|
||||
@@ -2088,12 +2058,8 @@ class KDownBlock2D(nn.Module):
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, *args, **kwargs
|
||||
self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0
|
||||
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
|
||||
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
||||
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
||||
deprecate("scale", "1.0.0", deprecation_message)
|
||||
|
||||
output_states = ()
|
||||
|
||||
for resnet in self.resnets:
|
||||
@@ -2114,7 +2080,7 @@ class KDownBlock2D(nn.Module):
|
||||
create_custom_forward(resnet), hidden_states, temb
|
||||
)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = resnet(hidden_states, temb, scale)
|
||||
|
||||
output_states += (hidden_states,)
|
||||
|
||||
@@ -2199,11 +2165,8 @@ class KCrossAttnDownBlock2D(nn.Module):
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
|
||||
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
|
||||
if cross_attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
|
||||
|
||||
output_states = ()
|
||||
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
|
||||
|
||||
for resnet, attn in zip(self.resnets, self.attentions):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
@@ -2233,7 +2196,7 @@ class KCrossAttnDownBlock2D(nn.Module):
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
@@ -2281,7 +2244,7 @@ class AttnUpBlock2D(nn.Module):
|
||||
self.upsample_type = upsample_type
|
||||
|
||||
if attention_head_dim is None:
|
||||
logger.warning(
|
||||
logger.warn(
|
||||
f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}."
|
||||
)
|
||||
attention_head_dim = out_channels
|
||||
@@ -2353,28 +2316,24 @@ class AttnUpBlock2D(nn.Module):
|
||||
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
upsample_size: Optional[int] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
scale: float = 1.0,
|
||||
) -> torch.FloatTensor:
|
||||
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
||||
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
||||
deprecate("scale", "1.0.0", deprecation_message)
|
||||
|
||||
for resnet, attn in zip(self.resnets, self.attentions):
|
||||
# pop res hidden states
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = attn(hidden_states)
|
||||
hidden_states = resnet(hidden_states, temb, scale=scale)
|
||||
cross_attention_kwargs = {"scale": scale}
|
||||
hidden_states = attn(hidden_states, **cross_attention_kwargs)
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
if self.upsample_type == "resnet":
|
||||
hidden_states = upsampler(hidden_states, temb=temb)
|
||||
hidden_states = upsampler(hidden_states, temb=temb, scale=scale)
|
||||
else:
|
||||
hidden_states = upsampler(hidden_states)
|
||||
hidden_states = upsampler(hidden_states, scale=scale)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@@ -2481,10 +2440,7 @@ class CrossAttnUpBlock2D(nn.Module):
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
if cross_attention_kwargs is not None:
|
||||
if cross_attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
|
||||
|
||||
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
|
||||
is_freeu_enabled = (
|
||||
getattr(self, "s1", None)
|
||||
and getattr(self, "s2", None)
|
||||
@@ -2538,7 +2494,7 @@ class CrossAttnUpBlock2D(nn.Module):
|
||||
return_dict=False,
|
||||
)[0]
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
@@ -2550,7 +2506,7 @@ class CrossAttnUpBlock2D(nn.Module):
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
hidden_states = upsampler(hidden_states, upsample_size)
|
||||
hidden_states = upsampler(hidden_states, upsample_size, scale=lora_scale)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@@ -2611,13 +2567,8 @@ class UpBlock2D(nn.Module):
|
||||
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
upsample_size: Optional[int] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
scale: float = 1.0,
|
||||
) -> torch.FloatTensor:
|
||||
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
||||
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
||||
deprecate("scale", "1.0.0", deprecation_message)
|
||||
|
||||
is_freeu_enabled = (
|
||||
getattr(self, "s1", None)
|
||||
and getattr(self, "s2", None)
|
||||
@@ -2661,11 +2612,11 @@ class UpBlock2D(nn.Module):
|
||||
create_custom_forward(resnet), hidden_states, temb
|
||||
)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = resnet(hidden_states, temb, scale=scale)
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
hidden_states = upsampler(hidden_states, upsample_size)
|
||||
hidden_states = upsampler(hidden_states, upsample_size, scale=scale)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@@ -2732,9 +2683,11 @@ class UpDecoderBlock2D(nn.Module):
|
||||
|
||||
self.resolution_idx = resolution_idx
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
|
||||
def forward(
|
||||
self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0
|
||||
) -> torch.FloatTensor:
|
||||
for resnet in self.resnets:
|
||||
hidden_states = resnet(hidden_states, temb=temb)
|
||||
hidden_states = resnet(hidden_states, temb=temb, scale=scale)
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
@@ -2766,7 +2719,7 @@ class AttnUpDecoderBlock2D(nn.Module):
|
||||
attentions = []
|
||||
|
||||
if attention_head_dim is None:
|
||||
logger.warning(
|
||||
logger.warn(
|
||||
f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `out_channels`: {out_channels}."
|
||||
)
|
||||
attention_head_dim = out_channels
|
||||
@@ -2830,14 +2783,17 @@ class AttnUpDecoderBlock2D(nn.Module):
|
||||
|
||||
self.resolution_idx = resolution_idx
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
|
||||
def forward(
|
||||
self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0
|
||||
) -> torch.FloatTensor:
|
||||
for resnet, attn in zip(self.resnets, self.attentions):
|
||||
hidden_states = resnet(hidden_states, temb=temb)
|
||||
hidden_states = attn(hidden_states, temb=temb)
|
||||
hidden_states = resnet(hidden_states, temb=temb, scale=scale)
|
||||
cross_attention_kwargs = {"scale": scale}
|
||||
hidden_states = attn(hidden_states, temb=temb, **cross_attention_kwargs)
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
hidden_states = upsampler(hidden_states)
|
||||
hidden_states = upsampler(hidden_states, scale=scale)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@@ -2885,7 +2841,7 @@ class AttnSkipUpBlock2D(nn.Module):
|
||||
)
|
||||
|
||||
if attention_head_dim is None:
|
||||
logger.warning(
|
||||
logger.warn(
|
||||
f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `out_channels`: {out_channels}."
|
||||
)
|
||||
attention_head_dim = out_channels
|
||||
@@ -2942,22 +2898,18 @@ class AttnSkipUpBlock2D(nn.Module):
|
||||
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
skip_sample=None,
|
||||
*args,
|
||||
**kwargs,
|
||||
scale: float = 1.0,
|
||||
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
||||
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
||||
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
||||
deprecate("scale", "1.0.0", deprecation_message)
|
||||
|
||||
for resnet in self.resnets:
|
||||
# pop res hidden states
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = resnet(hidden_states, temb, scale=scale)
|
||||
|
||||
hidden_states = self.attentions[0](hidden_states)
|
||||
cross_attention_kwargs = {"scale": scale}
|
||||
hidden_states = self.attentions[0](hidden_states, **cross_attention_kwargs)
|
||||
|
||||
if skip_sample is not None:
|
||||
skip_sample = self.upsampler(skip_sample)
|
||||
@@ -2971,7 +2923,7 @@ class AttnSkipUpBlock2D(nn.Module):
|
||||
|
||||
skip_sample = skip_sample + skip_sample_states
|
||||
|
||||
hidden_states = self.resnet_up(hidden_states, temb)
|
||||
hidden_states = self.resnet_up(hidden_states, temb, scale=scale)
|
||||
|
||||
return hidden_states, skip_sample
|
||||
|
||||
@@ -3054,20 +3006,15 @@ class SkipUpBlock2D(nn.Module):
|
||||
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
skip_sample=None,
|
||||
*args,
|
||||
**kwargs,
|
||||
scale: float = 1.0,
|
||||
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
||||
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
||||
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
||||
deprecate("scale", "1.0.0", deprecation_message)
|
||||
|
||||
for resnet in self.resnets:
|
||||
# pop res hidden states
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = resnet(hidden_states, temb, scale=scale)
|
||||
|
||||
if skip_sample is not None:
|
||||
skip_sample = self.upsampler(skip_sample)
|
||||
@@ -3081,7 +3028,7 @@ class SkipUpBlock2D(nn.Module):
|
||||
|
||||
skip_sample = skip_sample + skip_sample_states
|
||||
|
||||
hidden_states = self.resnet_up(hidden_states, temb)
|
||||
hidden_states = self.resnet_up(hidden_states, temb, scale=scale)
|
||||
|
||||
return hidden_states, skip_sample
|
||||
|
||||
@@ -3161,13 +3108,8 @@ class ResnetUpsampleBlock2D(nn.Module):
|
||||
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
upsample_size: Optional[int] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
scale: float = 1.0,
|
||||
) -> torch.FloatTensor:
|
||||
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
||||
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
||||
deprecate("scale", "1.0.0", deprecation_message)
|
||||
|
||||
for resnet in self.resnets:
|
||||
# pop res hidden states
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
@@ -3191,11 +3133,11 @@ class ResnetUpsampleBlock2D(nn.Module):
|
||||
create_custom_forward(resnet), hidden_states, temb
|
||||
)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = resnet(hidden_states, temb, scale=scale)
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
hidden_states = upsampler(hidden_states, temb)
|
||||
hidden_states = upsampler(hidden_states, temb, scale=scale)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@@ -3311,9 +3253,8 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
|
||||
if cross_attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
|
||||
|
||||
lora_scale = cross_attention_kwargs.get("scale", 1.0)
|
||||
if attention_mask is None:
|
||||
# if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask.
|
||||
mask = None if encoder_hidden_states is None else encoder_attention_mask
|
||||
@@ -3351,7 +3292,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
|
||||
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
@@ -3362,7 +3303,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
hidden_states = upsampler(hidden_states, temb)
|
||||
hidden_states = upsampler(hidden_states, temb, scale=lora_scale)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@@ -3423,13 +3364,8 @@ class KUpBlock2D(nn.Module):
|
||||
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
upsample_size: Optional[int] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
scale: float = 1.0,
|
||||
) -> torch.FloatTensor:
|
||||
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
||||
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
||||
deprecate("scale", "1.0.0", deprecation_message)
|
||||
|
||||
res_hidden_states_tuple = res_hidden_states_tuple[-1]
|
||||
if res_hidden_states_tuple is not None:
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1)
|
||||
@@ -3452,7 +3388,7 @@ class KUpBlock2D(nn.Module):
|
||||
create_custom_forward(resnet), hidden_states, temb
|
||||
)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = resnet(hidden_states, temb, scale=scale)
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
@@ -3562,6 +3498,7 @@ class KCrossAttnUpBlock2D(nn.Module):
|
||||
if res_hidden_states_tuple is not None:
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1)
|
||||
|
||||
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
|
||||
for resnet, attn in zip(self.resnets, self.attentions):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
@@ -3590,7 +3527,7 @@ class KCrossAttnUpBlock2D(nn.Module):
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
@@ -3693,8 +3630,6 @@ class KAttentionBlock(nn.Module):
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
|
||||
if cross_attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
|
||||
|
||||
# 1. Self-Attention
|
||||
if self.add_self_attention:
|
||||
|
||||
@@ -1226,7 +1226,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
|
||||
**additional_residuals,
|
||||
)
|
||||
else:
|
||||
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
||||
sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale)
|
||||
if is_adapter and len(down_intrablock_additional_residuals) > 0:
|
||||
sample += down_intrablock_additional_residuals.pop(0)
|
||||
|
||||
@@ -1297,6 +1297,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
|
||||
temb=emb,
|
||||
res_hidden_states_tuple=res_samples,
|
||||
upsample_size=upsample_size,
|
||||
scale=lora_scale,
|
||||
)
|
||||
|
||||
# 6. post-process
|
||||
|
||||
@@ -17,7 +17,7 @@ from typing import Any, Dict, Optional, Tuple, Union
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ...utils import deprecate, is_torch_version, logging
|
||||
from ...utils import is_torch_version
|
||||
from ...utils.torch_utils import apply_freeu
|
||||
from ..attention import Attention
|
||||
from ..resnet import (
|
||||
@@ -35,9 +35,6 @@ from ..transformers.transformer_temporal import (
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def get_down_block(
|
||||
down_block_type: str,
|
||||
num_layers: int,
|
||||
@@ -1008,14 +1005,9 @@ class DownBlockMotion(nn.Module):
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
scale: float = 1.0,
|
||||
num_frames: int = 1,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
|
||||
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
||||
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
||||
deprecate("scale", "1.0.0", deprecation_message)
|
||||
|
||||
output_states = ()
|
||||
|
||||
blocks = zip(self.resnets, self.motion_modules)
|
||||
@@ -1037,18 +1029,18 @@ class DownBlockMotion(nn.Module):
|
||||
)
|
||||
else:
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb
|
||||
create_custom_forward(resnet), hidden_states, temb, scale
|
||||
)
|
||||
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = resnet(hidden_states, temb, scale=scale)
|
||||
hidden_states = motion_module(hidden_states, num_frames=num_frames)[0]
|
||||
|
||||
output_states = output_states + (hidden_states,)
|
||||
|
||||
if self.downsamplers is not None:
|
||||
for downsampler in self.downsamplers:
|
||||
hidden_states = downsampler(hidden_states)
|
||||
hidden_states = downsampler(hidden_states, scale=scale)
|
||||
|
||||
output_states = output_states + (hidden_states,)
|
||||
|
||||
@@ -1181,12 +1173,10 @@ class CrossAttnDownBlockMotion(nn.Module):
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
additional_residuals: Optional[torch.FloatTensor] = None,
|
||||
):
|
||||
if cross_attention_kwargs is not None:
|
||||
if cross_attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
|
||||
|
||||
output_states = ()
|
||||
|
||||
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
|
||||
|
||||
blocks = list(zip(self.resnets, self.attentions, self.motion_modules))
|
||||
for i, (resnet, attn, motion_module) in enumerate(blocks):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
@@ -1216,7 +1206,7 @@ class CrossAttnDownBlockMotion(nn.Module):
|
||||
return_dict=False,
|
||||
)[0]
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
@@ -1238,7 +1228,7 @@ class CrossAttnDownBlockMotion(nn.Module):
|
||||
|
||||
if self.downsamplers is not None:
|
||||
for downsampler in self.downsamplers:
|
||||
hidden_states = downsampler(hidden_states)
|
||||
hidden_states = downsampler(hidden_states, scale=lora_scale)
|
||||
|
||||
output_states = output_states + (hidden_states,)
|
||||
|
||||
@@ -1365,10 +1355,7 @@ class CrossAttnUpBlockMotion(nn.Module):
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
num_frames: int = 1,
|
||||
) -> torch.FloatTensor:
|
||||
if cross_attention_kwargs is not None:
|
||||
if cross_attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
|
||||
|
||||
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
|
||||
is_freeu_enabled = (
|
||||
getattr(self, "s1", None)
|
||||
and getattr(self, "s2", None)
|
||||
@@ -1423,7 +1410,7 @@ class CrossAttnUpBlockMotion(nn.Module):
|
||||
return_dict=False,
|
||||
)[0]
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
@@ -1439,7 +1426,7 @@ class CrossAttnUpBlockMotion(nn.Module):
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
hidden_states = upsampler(hidden_states, upsample_size)
|
||||
hidden_states = upsampler(hidden_states, upsample_size, scale=lora_scale)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@@ -1520,14 +1507,9 @@ class UpBlockMotion(nn.Module):
|
||||
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
upsample_size=None,
|
||||
scale: float = 1.0,
|
||||
num_frames: int = 1,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
||||
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
||||
deprecate("scale", "1.0.0", deprecation_message)
|
||||
|
||||
is_freeu_enabled = (
|
||||
getattr(self, "s1", None)
|
||||
and getattr(self, "s2", None)
|
||||
@@ -1577,12 +1559,12 @@ class UpBlockMotion(nn.Module):
|
||||
)
|
||||
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = resnet(hidden_states, temb, scale=scale)
|
||||
hidden_states = motion_module(hidden_states, num_frames=num_frames)[0]
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
hidden_states = upsampler(hidden_states, upsample_size)
|
||||
hidden_states = upsampler(hidden_states, upsample_size, scale=scale)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@@ -1705,11 +1687,8 @@ class UNetMidBlockCrossAttnMotion(nn.Module):
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
num_frames: int = 1,
|
||||
) -> torch.FloatTensor:
|
||||
if cross_attention_kwargs is not None:
|
||||
if cross_attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
|
||||
|
||||
hidden_states = self.resnets[0](hidden_states, temb)
|
||||
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
|
||||
hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale)
|
||||
|
||||
blocks = zip(self.attentions, self.resnets[1:], self.motion_modules)
|
||||
for attn, resnet, motion_module in blocks:
|
||||
@@ -1758,7 +1737,7 @@ class UNetMidBlockCrossAttnMotion(nn.Module):
|
||||
hidden_states,
|
||||
num_frames=num_frames,
|
||||
)[0]
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
@@ -89,7 +89,7 @@ class I2VGenXLTransformerTemporalEncoder(nn.Module):
|
||||
if hidden_states.ndim == 4:
|
||||
hidden_states = hidden_states.squeeze(1)
|
||||
|
||||
ff_output = self.ff(hidden_states)
|
||||
ff_output = self.ff(hidden_states, scale=1.0)
|
||||
hidden_states = ff_output + hidden_states
|
||||
if hidden_states.ndim == 4:
|
||||
hidden_states = hidden_states.squeeze(1)
|
||||
|
||||
@@ -21,7 +21,6 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders.unet import FromOriginalUNetMixin
|
||||
from ...utils import BaseOutput
|
||||
from ..attention_processor import Attention
|
||||
from ..modeling_utils import ModelMixin
|
||||
@@ -135,7 +134,7 @@ class StableCascadeUNetOutput(BaseOutput):
|
||||
sample: torch.FloatTensor = None
|
||||
|
||||
|
||||
class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalUNetMixin):
|
||||
class StableCascadeUNet(ModelMixin, ConfigMixin):
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
|
||||
@@ -18,7 +18,8 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..utils import deprecate
|
||||
from ..utils import USE_PEFT_BACKEND
|
||||
from .lora import LoRACompatibleConv
|
||||
from .normalization import RMSNorm
|
||||
|
||||
|
||||
@@ -110,7 +111,7 @@ class Upsample2D(nn.Module):
|
||||
self.use_conv_transpose = use_conv_transpose
|
||||
self.name = name
|
||||
self.interpolate = interpolate
|
||||
conv_cls = nn.Conv2d
|
||||
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
|
||||
|
||||
if norm_type == "ln_norm":
|
||||
self.norm = nn.LayerNorm(channels, eps, elementwise_affine)
|
||||
@@ -140,12 +141,11 @@ class Upsample2D(nn.Module):
|
||||
self.Conv2d_0 = conv
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.FloatTensor, output_size: Optional[int] = None, *args, **kwargs
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
output_size: Optional[int] = None,
|
||||
scale: float = 1.0,
|
||||
) -> torch.FloatTensor:
|
||||
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
||||
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
||||
deprecate("scale", "1.0.0", deprecation_message)
|
||||
|
||||
assert hidden_states.shape[1] == self.channels
|
||||
|
||||
if self.norm is not None:
|
||||
@@ -180,9 +180,15 @@ class Upsample2D(nn.Module):
|
||||
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
|
||||
if self.use_conv:
|
||||
if self.name == "conv":
|
||||
hidden_states = self.conv(hidden_states)
|
||||
if isinstance(self.conv, LoRACompatibleConv) and not USE_PEFT_BACKEND:
|
||||
hidden_states = self.conv(hidden_states, scale)
|
||||
else:
|
||||
hidden_states = self.conv(hidden_states)
|
||||
else:
|
||||
hidden_states = self.Conv2d_0(hidden_states)
|
||||
if isinstance(self.Conv2d_0, LoRACompatibleConv) and not USE_PEFT_BACKEND:
|
||||
hidden_states = self.Conv2d_0(hidden_states, scale)
|
||||
else:
|
||||
hidden_states = self.Conv2d_0(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
@@ -156,7 +156,7 @@ class FlaxStableDiffusionControlNetPipeline(FlaxDiffusionPipeline):
|
||||
self.dtype = dtype
|
||||
|
||||
if safety_checker is None:
|
||||
logger.warning(
|
||||
logger.warn(
|
||||
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
||||
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
||||
|
||||
@@ -416,13 +416,13 @@ class IFPipeline(DiffusionPipeline, LoraLoaderMixin):
|
||||
|
||||
def _text_preprocessing(self, text, clean_caption=False):
|
||||
if clean_caption and not is_bs4_available():
|
||||
logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
|
||||
logger.warning("Setting `clean_caption` to False...")
|
||||
logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
|
||||
logger.warn("Setting `clean_caption` to False...")
|
||||
clean_caption = False
|
||||
|
||||
if clean_caption and not is_ftfy_available():
|
||||
logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
|
||||
logger.warning("Setting `clean_caption` to False...")
|
||||
logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
|
||||
logger.warn("Setting `clean_caption` to False...")
|
||||
clean_caption = False
|
||||
|
||||
if not isinstance(text, (tuple, list)):
|
||||
|
||||
@@ -460,13 +460,13 @@ class IFImg2ImgPipeline(DiffusionPipeline, LoraLoaderMixin):
|
||||
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
|
||||
def _text_preprocessing(self, text, clean_caption=False):
|
||||
if clean_caption and not is_bs4_available():
|
||||
logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
|
||||
logger.warning("Setting `clean_caption` to False...")
|
||||
logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
|
||||
logger.warn("Setting `clean_caption` to False...")
|
||||
clean_caption = False
|
||||
|
||||
if clean_caption and not is_ftfy_available():
|
||||
logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
|
||||
logger.warning("Setting `clean_caption` to False...")
|
||||
logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
|
||||
logger.warn("Setting `clean_caption` to False...")
|
||||
clean_caption = False
|
||||
|
||||
if not isinstance(text, (tuple, list)):
|
||||
|
||||
@@ -175,7 +175,7 @@ class IFImg2ImgSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
|
||||
)
|
||||
|
||||
if unet.config.in_channels != 6:
|
||||
logger.warning(
|
||||
logger.warn(
|
||||
"It seems like you have loaded a checkpoint that shall not be used for super resolution from {unet.config._name_or_path} as it accepts {unet.config.in_channels} input channels instead of 6. Please make sure to pass a super resolution checkpoint as the `'unet'`: IFSuperResolutionPipeline.from_pretrained(unet=super_resolution_unet, ...)`."
|
||||
)
|
||||
|
||||
@@ -209,13 +209,13 @@ class IFImg2ImgSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
|
||||
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
|
||||
def _text_preprocessing(self, text, clean_caption=False):
|
||||
if clean_caption and not is_bs4_available():
|
||||
logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
|
||||
logger.warning("Setting `clean_caption` to False...")
|
||||
logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
|
||||
logger.warn("Setting `clean_caption` to False...")
|
||||
clean_caption = False
|
||||
|
||||
if clean_caption and not is_ftfy_available():
|
||||
logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
|
||||
logger.warning("Setting `clean_caption` to False...")
|
||||
logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
|
||||
logger.warn("Setting `clean_caption` to False...")
|
||||
clean_caption = False
|
||||
|
||||
if not isinstance(text, (tuple, list)):
|
||||
|
||||
@@ -500,13 +500,13 @@ class IFInpaintingPipeline(DiffusionPipeline, LoraLoaderMixin):
|
||||
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
|
||||
def _text_preprocessing(self, text, clean_caption=False):
|
||||
if clean_caption and not is_bs4_available():
|
||||
logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
|
||||
logger.warning("Setting `clean_caption` to False...")
|
||||
logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
|
||||
logger.warn("Setting `clean_caption` to False...")
|
||||
clean_caption = False
|
||||
|
||||
if clean_caption and not is_ftfy_available():
|
||||
logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
|
||||
logger.warning("Setting `clean_caption` to False...")
|
||||
logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
|
||||
logger.warn("Setting `clean_caption` to False...")
|
||||
clean_caption = False
|
||||
|
||||
if not isinstance(text, (tuple, list)):
|
||||
|
||||
@@ -177,7 +177,7 @@ class IFInpaintingSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
|
||||
)
|
||||
|
||||
if unet.config.in_channels != 6:
|
||||
logger.warning(
|
||||
logger.warn(
|
||||
"It seems like you have loaded a checkpoint that shall not be used for super resolution from {unet.config._name_or_path} as it accepts {unet.config.in_channels} input channels instead of 6. Please make sure to pass a super resolution checkpoint as the `'unet'`: IFSuperResolutionPipeline.from_pretrained(unet=super_resolution_unet, ...)`."
|
||||
)
|
||||
|
||||
@@ -211,13 +211,13 @@ class IFInpaintingSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
|
||||
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
|
||||
def _text_preprocessing(self, text, clean_caption=False):
|
||||
if clean_caption and not is_bs4_available():
|
||||
logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
|
||||
logger.warning("Setting `clean_caption` to False...")
|
||||
logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
|
||||
logger.warn("Setting `clean_caption` to False...")
|
||||
clean_caption = False
|
||||
|
||||
if clean_caption and not is_ftfy_available():
|
||||
logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
|
||||
logger.warning("Setting `clean_caption` to False...")
|
||||
logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
|
||||
logger.warn("Setting `clean_caption` to False...")
|
||||
clean_caption = False
|
||||
|
||||
if not isinstance(text, (tuple, list)):
|
||||
|
||||
@@ -133,7 +133,7 @@ class IFSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
|
||||
)
|
||||
|
||||
if unet.config.in_channels != 6:
|
||||
logger.warning(
|
||||
logger.warn(
|
||||
"It seems like you have loaded a checkpoint that shall not be used for super resolution from {unet.config._name_or_path} as it accepts {unet.config.in_channels} input channels instead of 6. Please make sure to pass a super resolution checkpoint as the `'unet'`: IFSuperResolutionPipeline.from_pretrained(unet=super_resolution_unet, ...)`."
|
||||
)
|
||||
|
||||
@@ -167,13 +167,13 @@ class IFSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
|
||||
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
|
||||
def _text_preprocessing(self, text, clean_caption=False):
|
||||
if clean_caption and not is_bs4_available():
|
||||
logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
|
||||
logger.warning("Setting `clean_caption` to False...")
|
||||
logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
|
||||
logger.warn("Setting `clean_caption` to False...")
|
||||
clean_caption = False
|
||||
|
||||
if clean_caption and not is_ftfy_available():
|
||||
logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
|
||||
logger.warning("Setting `clean_caption` to False...")
|
||||
logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
|
||||
logger.warn("Setting `clean_caption` to False...")
|
||||
clean_caption = False
|
||||
|
||||
if not isinstance(text, (tuple, list)):
|
||||
|
||||
@@ -1333,7 +1333,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
**additional_residuals,
|
||||
)
|
||||
else:
|
||||
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
||||
sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale)
|
||||
if is_adapter and len(down_intrablock_additional_residuals) > 0:
|
||||
sample += down_intrablock_additional_residuals.pop(0)
|
||||
|
||||
@@ -1589,7 +1589,7 @@ class DownBlockFlat(nn.Module):
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None
|
||||
self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0
|
||||
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
|
||||
output_states = ()
|
||||
|
||||
@@ -1611,13 +1611,13 @@ class DownBlockFlat(nn.Module):
|
||||
create_custom_forward(resnet), hidden_states, temb
|
||||
)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = resnet(hidden_states, temb, scale=scale)
|
||||
|
||||
output_states = output_states + (hidden_states,)
|
||||
|
||||
if self.downsamplers is not None:
|
||||
for downsampler in self.downsamplers:
|
||||
hidden_states = downsampler(hidden_states)
|
||||
hidden_states = downsampler(hidden_states, scale=scale)
|
||||
|
||||
output_states = output_states + (hidden_states,)
|
||||
|
||||
@@ -1728,6 +1728,8 @@ class CrossAttnDownBlockFlat(nn.Module):
|
||||
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
|
||||
output_states = ()
|
||||
|
||||
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
|
||||
|
||||
blocks = list(zip(self.resnets, self.attentions))
|
||||
|
||||
for i, (resnet, attn) in enumerate(blocks):
|
||||
@@ -1758,7 +1760,7 @@ class CrossAttnDownBlockFlat(nn.Module):
|
||||
return_dict=False,
|
||||
)[0]
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
@@ -1776,7 +1778,7 @@ class CrossAttnDownBlockFlat(nn.Module):
|
||||
|
||||
if self.downsamplers is not None:
|
||||
for downsampler in self.downsamplers:
|
||||
hidden_states = downsampler(hidden_states)
|
||||
hidden_states = downsampler(hidden_states, scale=lora_scale)
|
||||
|
||||
output_states = output_states + (hidden_states,)
|
||||
|
||||
@@ -1840,13 +1842,8 @@ class UpBlockFlat(nn.Module):
|
||||
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
upsample_size: Optional[int] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
scale: float = 1.0,
|
||||
) -> torch.FloatTensor:
|
||||
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
||||
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
||||
deprecate("scale", "1.0.0", deprecation_message)
|
||||
|
||||
is_freeu_enabled = (
|
||||
getattr(self, "s1", None)
|
||||
and getattr(self, "s2", None)
|
||||
@@ -1890,11 +1887,11 @@ class UpBlockFlat(nn.Module):
|
||||
create_custom_forward(resnet), hidden_states, temb
|
||||
)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = resnet(hidden_states, temb, scale=scale)
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
hidden_states = upsampler(hidden_states, upsample_size)
|
||||
hidden_states = upsampler(hidden_states, upsample_size, scale=scale)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@@ -2002,10 +1999,7 @@ class CrossAttnUpBlockFlat(nn.Module):
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
if cross_attention_kwargs is not None:
|
||||
if cross_attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
|
||||
|
||||
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
|
||||
is_freeu_enabled = (
|
||||
getattr(self, "s1", None)
|
||||
and getattr(self, "s2", None)
|
||||
@@ -2059,7 +2053,7 @@ class CrossAttnUpBlockFlat(nn.Module):
|
||||
return_dict=False,
|
||||
)[0]
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
@@ -2071,7 +2065,7 @@ class CrossAttnUpBlockFlat(nn.Module):
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
hidden_states = upsampler(hidden_states, upsample_size)
|
||||
hidden_states = upsampler(hidden_states, upsample_size, scale=lora_scale)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@@ -2164,7 +2158,7 @@ class UNetMidBlockFlat(nn.Module):
|
||||
attentions = []
|
||||
|
||||
if attention_head_dim is None:
|
||||
logger.warning(
|
||||
logger.warn(
|
||||
f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}."
|
||||
)
|
||||
attention_head_dim = in_channels
|
||||
@@ -2336,11 +2330,8 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
if cross_attention_kwargs is not None:
|
||||
if cross_attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
|
||||
|
||||
hidden_states = self.resnets[0](hidden_states, temb)
|
||||
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
|
||||
hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale)
|
||||
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
@@ -2377,7 +2368,7 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@@ -2478,8 +2469,7 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
|
||||
if cross_attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
|
||||
lora_scale = cross_attention_kwargs.get("scale", 1.0)
|
||||
|
||||
if attention_mask is None:
|
||||
# if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask.
|
||||
@@ -2492,7 +2482,7 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
|
||||
# mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask
|
||||
mask = attention_mask
|
||||
|
||||
hidden_states = self.resnets[0](hidden_states, temb)
|
||||
hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale)
|
||||
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
||||
# attn
|
||||
hidden_states = attn(
|
||||
@@ -2503,6 +2493,6 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
|
||||
)
|
||||
|
||||
# resnet
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@@ -481,7 +481,7 @@ class KandinskyInpaintPipeline(DiffusionPipeline):
|
||||
if not self._warn_has_been_called and version.parse(version.parse(__version__).base_version) < version.parse(
|
||||
"0.23.0.dev0"
|
||||
):
|
||||
logger.warning(
|
||||
logger.warn(
|
||||
"Please note that the expected format of `mask_image` has recently been changed. "
|
||||
"Before diffusers == 0.19.0, Kandinsky Inpainting pipelines repainted black pixels and preserved black pixels. "
|
||||
"As of diffusers==0.19.0 this behavior has been inverted. Now white pixels are repainted and black pixels are preserved. "
|
||||
|
||||
@@ -372,7 +372,7 @@ class KandinskyV22InpaintPipeline(DiffusionPipeline):
|
||||
if not self._warn_has_been_called and version.parse(version.parse(__version__).base_version) < version.parse(
|
||||
"0.23.0.dev0"
|
||||
):
|
||||
logger.warning(
|
||||
logger.warn(
|
||||
"Please note that the expected format of `mask_image` has recently been changed. "
|
||||
"Before diffusers == 0.19.0, Kandinsky Inpainting pipelines repainted black pixels and preserved black pixels. "
|
||||
"As of diffusers==0.19.0 this behavior has been inverted. Now white pixels are repainted and black pixels are preserved. "
|
||||
|
||||
@@ -256,9 +256,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
break
|
||||
|
||||
if save_method_name is None:
|
||||
logger.warning(
|
||||
f"self.{pipeline_component_name}={sub_model} of type {type(sub_model)} cannot be saved."
|
||||
)
|
||||
logger.warn(f"self.{pipeline_component_name}={sub_model} of type {type(sub_model)} cannot be saved.")
|
||||
# make sure that unsaveable components are not tried to be loaded afterward
|
||||
self.register_to_config(**{pipeline_component_name: (None, None)})
|
||||
continue
|
||||
@@ -1204,7 +1202,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
try:
|
||||
info = model_info(pretrained_model_name, token=token, revision=revision)
|
||||
except (HTTPError, OfflineModeIsEnabled, requests.ConnectionError) as e:
|
||||
logger.warning(f"Couldn't connect to the Hub: {e}.\nWill try to load from local cache.")
|
||||
logger.warn(f"Couldn't connect to the Hub: {e}.\nWill try to load from local cache.")
|
||||
local_files_only = True
|
||||
model_info_call_error = e # save error to reraise it if model is not cached locally
|
||||
|
||||
@@ -1355,7 +1353,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
len(safetensors_variant_filenames) > 0
|
||||
and safetensors_model_filenames != safetensors_variant_filenames
|
||||
):
|
||||
logger.warning(
|
||||
logger.warn(
|
||||
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(safetensors_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(safetensors_model_filenames - safetensors_variant_filenames)}\nIf this behavior is not expected, please check your folder structure."
|
||||
)
|
||||
else:
|
||||
@@ -1368,7 +1366,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
bin_variant_filenames = {f for f in variant_filenames if f.endswith(".bin")}
|
||||
bin_model_filenames = {f for f in model_filenames if f.endswith(".bin")}
|
||||
if len(bin_variant_filenames) > 0 and bin_model_filenames != bin_variant_filenames:
|
||||
logger.warning(
|
||||
logger.warn(
|
||||
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(bin_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(bin_model_filenames - bin_variant_filenames)}\nIf this behavior is not expected, please check your folder structure."
|
||||
)
|
||||
|
||||
|
||||
@@ -514,13 +514,13 @@ class PixArtAlphaPipeline(DiffusionPipeline):
|
||||
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
|
||||
def _text_preprocessing(self, text, clean_caption=False):
|
||||
if clean_caption and not is_bs4_available():
|
||||
logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
|
||||
logger.warning("Setting `clean_caption` to False...")
|
||||
logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
|
||||
logger.warn("Setting `clean_caption` to False...")
|
||||
clean_caption = False
|
||||
|
||||
if clean_caption and not is_ftfy_available():
|
||||
logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
|
||||
logger.warning("Setting `clean_caption` to False...")
|
||||
logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
|
||||
logger.warn("Setting `clean_caption` to False...")
|
||||
clean_caption = False
|
||||
|
||||
if not isinstance(text, (tuple, list)):
|
||||
|
||||
@@ -19,7 +19,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPV
|
||||
|
||||
from ...models import StableCascadeUNet
|
||||
from ...schedulers import DDPMWuerstchenScheduler
|
||||
from ...utils import is_torch_version, replace_example_docstring
|
||||
from ...utils import replace_example_docstring
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from ..wuerstchen.modeling_paella_vq_model import PaellaVQModel
|
||||
from .pipeline_stable_cascade import StableCascadeDecoderPipeline
|
||||
@@ -29,10 +29,11 @@ from .pipeline_stable_cascade_prior import StableCascadePriorPipeline
|
||||
TEXT2IMAGE_EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from diffusers import StableCascadeCombinedPipeline
|
||||
>>> pipe = StableCascadeCombinedPipeline.from_pretrained("stabilityai/stable-cascade", variant="bf16", torch_dtype=torch.bfloat16)
|
||||
>>> pipe.enable_model_cpu_offload()
|
||||
>>> from diffusions import StableCascadeCombinedPipeline
|
||||
|
||||
>>> pipe = StableCascadeCombinedPipeline.from_pretrained("stabilityai/stable-cascade-combined", torch_dtype=torch.bfloat16).to(
|
||||
... "cuda"
|
||||
... )
|
||||
>>> prompt = "an image of a shiba inu, donning a spacesuit and helmet"
|
||||
>>> images = pipe(prompt=prompt)
|
||||
```
|
||||
@@ -259,12 +260,6 @@ class StableCascadeCombinedPipeline(DiffusionPipeline):
|
||||
[`~pipelines.ImagePipelineOutput`] or `tuple` [`~pipelines.ImagePipelineOutput`] if `return_dict` is True,
|
||||
otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images.
|
||||
"""
|
||||
dtype = self.decoder_pipe.decoder.dtype
|
||||
if is_torch_version("<", "2.2.0") and dtype == torch.bfloat16:
|
||||
raise ValueError(
|
||||
"`StableCascadeCombinedPipeline` requires torch>=2.2.0 when using `torch.bfloat16` dtype."
|
||||
)
|
||||
|
||||
prior_outputs = self.prior_pipe(
|
||||
prompt=prompt if prompt_embeds is None else None,
|
||||
images=images,
|
||||
|
||||
@@ -147,7 +147,7 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
|
||||
self.dtype = dtype
|
||||
|
||||
if safety_checker is None:
|
||||
logger.warning(
|
||||
logger.warn(
|
||||
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
||||
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
||||
|
||||
+1
-1
@@ -82,7 +82,7 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline, StableDiffusionMi
|
||||
super().__init__()
|
||||
|
||||
if safety_checker is None and requires_safety_checker:
|
||||
logger.warning(
|
||||
logger.warn(
|
||||
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
||||
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
||||
|
||||
@@ -2,6 +2,8 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ...models.attention_processor import Attention
|
||||
from ...models.lora import LoRACompatibleConv, LoRACompatibleLinear
|
||||
from ...utils import USE_PEFT_BACKEND
|
||||
|
||||
|
||||
class WuerstchenLayerNorm(nn.LayerNorm):
|
||||
@@ -17,7 +19,7 @@ class WuerstchenLayerNorm(nn.LayerNorm):
|
||||
class TimestepBlock(nn.Module):
|
||||
def __init__(self, c, c_timestep):
|
||||
super().__init__()
|
||||
linear_cls = nn.Linear
|
||||
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
|
||||
self.mapper = linear_cls(c_timestep, c * 2)
|
||||
|
||||
def forward(self, x, t):
|
||||
@@ -29,8 +31,8 @@ class ResBlock(nn.Module):
|
||||
def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0):
|
||||
super().__init__()
|
||||
|
||||
conv_cls = nn.Conv2d
|
||||
linear_cls = nn.Linear
|
||||
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
|
||||
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
|
||||
|
||||
self.depthwise = conv_cls(c + c_skip, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c)
|
||||
self.norm = WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6)
|
||||
@@ -64,7 +66,7 @@ class AttnBlock(nn.Module):
|
||||
def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0):
|
||||
super().__init__()
|
||||
|
||||
linear_cls = nn.Linear
|
||||
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
|
||||
|
||||
self.self_attn = self_attn
|
||||
self.norm = WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6)
|
||||
|
||||
@@ -28,8 +28,9 @@ from ...models.attention_processor import (
|
||||
AttnAddedKVProcessor,
|
||||
AttnProcessor,
|
||||
)
|
||||
from ...models.lora import LoRACompatibleConv, LoRACompatibleLinear
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
from ...utils import is_torch_version
|
||||
from ...utils import USE_PEFT_BACKEND, is_torch_version
|
||||
from .modeling_wuerstchen_common import AttnBlock, ResBlock, TimestepBlock, WuerstchenLayerNorm
|
||||
|
||||
|
||||
@@ -40,8 +41,8 @@ class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
|
||||
@register_to_config
|
||||
def __init__(self, c_in=16, c=1280, c_cond=1024, c_r=64, depth=16, nhead=16, dropout=0.1):
|
||||
super().__init__()
|
||||
conv_cls = nn.Conv2d
|
||||
linear_cls = nn.Linear
|
||||
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
|
||||
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
|
||||
|
||||
self.c_r = c_r
|
||||
self.projection = conv_cls(c_in, c, kernel_size=1)
|
||||
|
||||
@@ -319,13 +319,13 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.sample = None
|
||||
|
||||
if not self.config.lower_order_final and num_inference_steps % self.config.solver_order != 0:
|
||||
logger.warning(
|
||||
logger.warn(
|
||||
"Changing scheduler {self.config} to have `lower_order_final` set to True to handle uneven amount of inference steps. Please make sure to always use an even number of `num_inference steps when using `lower_order_final=False`."
|
||||
)
|
||||
self.register_to_config(lower_order_final=True)
|
||||
|
||||
if not self.config.lower_order_final and self.config.final_sigmas_type == "zero":
|
||||
logger.warning(
|
||||
logger.warn(
|
||||
" `last_sigmas_type='zero'` is not supported for `lower_order_final=False`. Changing scheduler {self.config} to have `lower_order_final` set to True."
|
||||
)
|
||||
self.register_to_config(lower_order_final=True)
|
||||
|
||||
@@ -116,7 +116,7 @@ def export_to_obj(mesh, output_obj_path: str = None):
|
||||
|
||||
|
||||
def export_to_video(
|
||||
video_frames: Union[List[np.ndarray], List[PIL.Image.Image]], output_video_path: str = None, fps: int = 10
|
||||
video_frames: Union[List[np.ndarray], List[PIL.Image.Image]], output_video_path: str = None, fps: int = 8
|
||||
) -> str:
|
||||
if is_opencv_available():
|
||||
import cv2
|
||||
|
||||
@@ -22,6 +22,7 @@ from torch import nn
|
||||
|
||||
from diffusers.models.attention import GEGLU, AdaLayerNorm, ApproximateGELU
|
||||
from diffusers.models.embeddings import get_timestep_embedding
|
||||
from diffusers.models.lora import LoRACompatibleLinear
|
||||
from diffusers.models.resnet import Downsample2D, ResnetBlock2D, Upsample2D
|
||||
from diffusers.models.transformers.transformer_2d import Transformer2DModel
|
||||
from diffusers.utils.testing_utils import (
|
||||
@@ -481,7 +482,7 @@ class Transformer2DModelTests(unittest.TestCase):
|
||||
|
||||
assert spatial_transformer_block.transformer_blocks[0].ff.net[0].__class__ == GEGLU
|
||||
assert spatial_transformer_block.transformer_blocks[0].ff.net[1].__class__ == nn.Dropout
|
||||
assert spatial_transformer_block.transformer_blocks[0].ff.net[2].__class__ == nn.Linear
|
||||
assert spatial_transformer_block.transformer_blocks[0].ff.net[2].__class__ == LoRACompatibleLinear
|
||||
|
||||
dim = 32
|
||||
inner_dim = 128
|
||||
@@ -505,7 +506,7 @@ class Transformer2DModelTests(unittest.TestCase):
|
||||
|
||||
assert spatial_transformer_block.transformer_blocks[0].ff.net[0].__class__ == ApproximateGELU
|
||||
assert spatial_transformer_block.transformer_blocks[0].ff.net[1].__class__ == nn.Dropout
|
||||
assert spatial_transformer_block.transformer_blocks[0].ff.net[2].__class__ == nn.Linear
|
||||
assert spatial_transformer_block.transformer_blocks[0].ff.net[2].__class__ == LoRACompatibleLinear
|
||||
|
||||
dim = 32
|
||||
inner_dim = 128
|
||||
|
||||
@@ -1,195 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import StableCascadeUNet
|
||||
from diffusers.utils import logging
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
)
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
@slow
|
||||
class StableCascadeUNetModelSlowTests(unittest.TestCase):
|
||||
def tearDown(self) -> None:
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def test_stable_cascade_unet_prior_single_file_components(self):
|
||||
single_file_url = "https://huggingface.co/stabilityai/stable-cascade/blob/main/stage_c_bf16.safetensors"
|
||||
single_file_unet = StableCascadeUNet.from_single_file(single_file_url)
|
||||
|
||||
single_file_unet_config = single_file_unet.config
|
||||
del single_file_unet
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
unet = StableCascadeUNet.from_pretrained(
|
||||
"stabilityai/stable-cascade-prior", subfolder="prior", revision="refs/pr/2", variant="bf16"
|
||||
)
|
||||
unet_config = unet.config
|
||||
del unet
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values"]
|
||||
for param_name, param_value in single_file_unet_config.items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
continue
|
||||
|
||||
assert unet_config[param_name] == param_value
|
||||
|
||||
def test_stable_cascade_unet_decoder_single_file_components(self):
|
||||
single_file_url = "https://huggingface.co/stabilityai/stable-cascade/blob/main/stage_b_bf16.safetensors"
|
||||
single_file_unet = StableCascadeUNet.from_single_file(single_file_url)
|
||||
|
||||
single_file_unet_config = single_file_unet.config
|
||||
del single_file_unet
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
unet = StableCascadeUNet.from_pretrained(
|
||||
"stabilityai/stable-cascade", subfolder="decoder", revision="refs/pr/44", variant="bf16"
|
||||
)
|
||||
unet_config = unet.config
|
||||
del unet
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values"]
|
||||
for param_name, param_value in single_file_unet_config.items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
continue
|
||||
|
||||
assert unet_config[param_name] == param_value
|
||||
|
||||
def test_stable_cascade_unet_config_loading(self):
|
||||
config = StableCascadeUNet.load_config(
|
||||
pretrained_model_name_or_path="diffusers/stable-cascade-configs", subfolder="prior"
|
||||
)
|
||||
single_file_url = "https://huggingface.co/stabilityai/stable-cascade/blob/main/stage_c_bf16.safetensors"
|
||||
|
||||
single_file_unet = StableCascadeUNet.from_single_file(single_file_url, config=config)
|
||||
single_file_unet_config = single_file_unet.config
|
||||
del single_file_unet
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values"]
|
||||
for param_name, param_value in config.items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
continue
|
||||
|
||||
assert single_file_unet_config[param_name] == param_value
|
||||
|
||||
@require_torch_gpu
|
||||
def test_stable_cascade_unet_single_file_prior_forward_pass(self):
|
||||
dtype = torch.bfloat16
|
||||
generator = torch.Generator("cpu")
|
||||
|
||||
model_inputs = {
|
||||
"sample": randn_tensor((1, 16, 24, 24), generator=generator.manual_seed(0)).to("cuda", dtype),
|
||||
"timestep_ratio": torch.tensor([1]).to("cuda", dtype),
|
||||
"clip_text_pooled": randn_tensor((1, 1, 1280), generator=generator.manual_seed(0)).to("cuda", dtype),
|
||||
"clip_text": randn_tensor((1, 77, 1280), generator=generator.manual_seed(0)).to("cuda", dtype),
|
||||
"clip_img": randn_tensor((1, 1, 768), generator=generator.manual_seed(0)).to("cuda", dtype),
|
||||
"pixels": randn_tensor((1, 3, 8, 8), generator=generator.manual_seed(0)).to("cuda", dtype),
|
||||
}
|
||||
|
||||
unet = StableCascadeUNet.from_pretrained(
|
||||
"stabilityai/stable-cascade-prior",
|
||||
subfolder="prior",
|
||||
revision="refs/pr/2",
|
||||
variant="bf16",
|
||||
torch_dtype=dtype,
|
||||
)
|
||||
unet.to("cuda")
|
||||
with torch.no_grad():
|
||||
prior_output = unet(**model_inputs).sample.float().cpu().numpy()
|
||||
|
||||
# Remove UNet from GPU memory before loading the single file UNet model
|
||||
del unet
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
single_file_url = "https://huggingface.co/stabilityai/stable-cascade/blob/main/stage_c_bf16.safetensors"
|
||||
single_file_unet = StableCascadeUNet.from_single_file(single_file_url, torch_dtype=dtype)
|
||||
single_file_unet.to("cuda")
|
||||
with torch.no_grad():
|
||||
prior_single_file_output = single_file_unet(**model_inputs).sample.float().cpu().numpy()
|
||||
|
||||
# Remove UNet from GPU memory before loading the single file UNet model
|
||||
del single_file_unet
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(prior_output.flatten(), prior_single_file_output.flatten())
|
||||
assert max_diff < 8e-3
|
||||
|
||||
@require_torch_gpu
|
||||
def test_stable_cascade_unet_single_file_decoder_forward_pass(self):
|
||||
dtype = torch.float32
|
||||
generator = torch.Generator("cpu")
|
||||
|
||||
model_inputs = {
|
||||
"sample": randn_tensor((1, 4, 256, 256), generator=generator.manual_seed(0)).to("cuda", dtype),
|
||||
"timestep_ratio": torch.tensor([1]).to("cuda", dtype),
|
||||
"clip_text": randn_tensor((1, 77, 1280), generator=generator.manual_seed(0)).to("cuda", dtype),
|
||||
"clip_text_pooled": randn_tensor((1, 1, 1280), generator=generator.manual_seed(0)).to("cuda", dtype),
|
||||
"pixels": randn_tensor((1, 3, 8, 8), generator=generator.manual_seed(0)).to("cuda", dtype),
|
||||
}
|
||||
|
||||
unet = StableCascadeUNet.from_pretrained(
|
||||
"stabilityai/stable-cascade",
|
||||
subfolder="decoder",
|
||||
revision="refs/pr/44",
|
||||
torch_dtype=dtype,
|
||||
)
|
||||
unet.to("cuda")
|
||||
with torch.no_grad():
|
||||
prior_output = unet(**model_inputs).sample.float().cpu().numpy()
|
||||
|
||||
# Remove UNet from GPU memory before loading the single file UNet model
|
||||
del unet
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
single_file_url = "https://huggingface.co/stabilityai/stable-cascade/blob/main/stage_b.safetensors"
|
||||
single_file_unet = StableCascadeUNet.from_single_file(single_file_url, torch_dtype=dtype)
|
||||
single_file_unet.to("cuda")
|
||||
with torch.no_grad():
|
||||
prior_single_file_output = single_file_unet(**model_inputs).sample.float().cpu().numpy()
|
||||
|
||||
# Remove UNet from GPU memory before loading the single file UNet model
|
||||
del single_file_unet
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(prior_output.flatten(), prior_single_file_output.flatten())
|
||||
assert max_diff < 1e-4
|
||||
@@ -838,11 +838,9 @@ class StableDiffusionXLImg2ImgIntegrationTests(unittest.TestCase):
|
||||
for param_name, param_value in single_file_pipe.unet.config.items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
continue
|
||||
if param_name == "upcast_attention" and pipe.unet.config[param_name] is None:
|
||||
pipe.unet.config[param_name] = False
|
||||
assert (
|
||||
pipe.unet.config[param_name] == param_value
|
||||
), f"{param_name} is differs between single file loading and pretrained loading"
|
||||
), f"{param_name} differs between single file loading and pretrained loading"
|
||||
|
||||
for param_name, param_value in single_file_pipe.vae.config.items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
|
||||
Reference in New Issue
Block a user