Compare commits

..

4 Commits

Author SHA1 Message Date
DN6 4838b5b884 Merge branch 'main' into torch-regression-tests 2025-01-07 11:07:42 +05:30
DN6 240ba165fd Merge branch 'main' into torch-regression-tests 2025-01-07 11:04:10 +05:30
DN6 64ae481eb8 update 2025-01-07 08:59:19 +05:30
DN6 be75474ad9 update 2025-01-07 08:43:50 +05:30
323 changed files with 1751 additions and 4282 deletions
+1 -1
View File
@@ -272,7 +272,7 @@ jobs:
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx" \
--make-reports=tests_torch_minimum_version_cuda \
tests/models/test_modeling_common.py \
tests/models/test_modelling_common.py \
tests/pipelines/test_pipelines_common.py \
tests/pipelines/test_pipeline_utils.py \
tests/pipelines/test_pipelines.py \
-1
View File
@@ -266,7 +266,6 @@ jobs:
# TODO (sayakpaul, DN6): revisit `--no-deps`
python -m pip install -U peft@git+https://github.com/huggingface/peft.git --no-deps
python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git --no-deps
python -m uv pip install -U tokenizers
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
- name: Environment
+1 -1
View File
@@ -83,7 +83,7 @@ jobs:
python utils/print_env.py
- name: PyTorch CUDA checkpoint tests on Ubuntu
env:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
HF_TOKEN: ${{ secrets.HF_TOKEN }}
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
+1 -1
View File
@@ -193,7 +193,7 @@ jobs:
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx" \
--make-reports=tests_torch_minimum_cuda \
tests/models/test_modeling_common.py \
tests/models/test_modelling_common.py \
tests/pipelines/test_pipelines_common.py \
tests/pipelines/test_pipeline_utils.py \
tests/pipelines/test_pipelines.py \
-27
View File
@@ -62,33 +62,6 @@ image = pipeline(prompt).images[0]
image.save("auraflow.png")
```
Loading [GGUF checkpoints](https://huggingface.co/docs/diffusers/quantization/gguf) are also supported:
```py
import torch
from diffusers import (
AuraFlowPipeline,
GGUFQuantizationConfig,
AuraFlowTransformer2DModel,
)
transformer = AuraFlowTransformer2DModel.from_single_file(
"https://huggingface.co/city96/AuraFlow-v0.3-gguf/blob/main/aura_flow_0.3-Q2_K.gguf",
quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
torch_dtype=torch.bfloat16,
)
pipeline = AuraFlowPipeline.from_pretrained(
"fal/AuraFlow-v0.3",
transformer=transformer,
torch_dtype=torch.bfloat16,
)
prompt = "a cute pony in a field of flowers"
image = pipeline(prompt).images[0]
image.save("auraflow.png")
```
## AuraFlowPipeline
[[autodoc]] AuraFlowPipeline
+1 -1
View File
@@ -367,7 +367,7 @@ transformer_8bit = FluxTransformer2DModel.from_pretrained(
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
text_encoder_2=text_encoder_8bit,
text_encoder=text_encoder_8bit,
transformer=transformer_8bit,
torch_dtype=torch.float16,
device_map="balanced",
@@ -16,7 +16,7 @@
[HunyuanVideo](https://www.arxiv.org/abs/2412.03603) by Tencent.
*Recent advancements in video generation have significantly impacted daily life for both individuals and industries. However, the leading video generation models remain closed-source, resulting in a notable performance gap between industry capabilities and those available to the public. In this report, we introduce HunyuanVideo, an innovative open-source video foundation model that demonstrates performance in video generation comparable to, or even surpassing, that of leading closed-source models. HunyuanVideo encompasses a comprehensive framework that integrates several key elements, including data curation, advanced architectural design, progressive model scaling and training, and an efficient infrastructure tailored for large-scale model training and inference. As a result, we successfully trained a video generative model with over 13 billion parameters, making it the largest among all open-source models. We conducted extensive experiments and implemented a series of targeted designs to ensure high visual quality, motion dynamics, text-video alignment, and advanced filming techniques. According to evaluations by professionals, HunyuanVideo outperforms previous state-of-the-art models, including Runway Gen-3, Luma 1.6, and three top-performing Chinese video generative models. By releasing the code for the foundation model and its applications, we aim to bridge the gap between closed-source and open-source communities. This initiative will empower individuals within the community to experiment with their ideas, fostering a more dynamic and vibrant video generation ecosystem. The code is publicly available at [this https URL](https://github.com/tencent/HunyuanVideo).*
*Recent advancements in video generation have significantly impacted daily life for both individuals and industries. However, the leading video generation models remain closed-source, resulting in a notable performance gap between industry capabilities and those available to the public. In this report, we introduce HunyuanVideo, an innovative open-source video foundation model that demonstrates performance in video generation comparable to, or even surpassing, that of leading closed-source models. HunyuanVideo encompasses a comprehensive framework that integrates several key elements, including data curation, advanced architectural design, progressive model scaling and training, and an efficient infrastructure tailored for large-scale model training and inference. As a result, we successfully trained a video generative model with over 13 billion parameters, making it the largest among all open-source models. We conducted extensive experiments and implemented a series of targeted designs to ensure high visual quality, motion dynamics, text-video alignment, and advanced filming techniques. According to evaluations by professionals, HunyuanVideo outperforms previous state-of-the-art models, including Runway Gen-3, Luma 1.6, and three top-performing Chinese video generative models. By releasing the code for the foundation model and its applications, we aim to bridge the gap between closed-source and open-source communities. This initiative will empower individuals within the community to experiment with their ideas, fostering a more dynamic and vibrant video generation ecosystem. The code is publicly available at [this https URL](https://github.com/Tencent/HunyuanVideo).*
<Tip>
@@ -45,14 +45,14 @@ from diffusers.utils import export_to_video
quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
transformer_8bit = HunyuanVideoTransformer3DModel.from_pretrained(
"hunyuanvideo-community/HunyuanVideo",
"tencent/HunyuanVideo",
subfolder="transformer",
quantization_config=quant_config,
torch_dtype=torch.bfloat16,
torch_dtype=torch.float16,
)
pipeline = HunyuanVideoPipeline.from_pretrained(
"hunyuanvideo-community/HunyuanVideo",
"tencent/HunyuanVideo",
transformer=transformer_8bit,
torch_dtype=torch.float16,
device_map="balanced",
+2 -2
View File
@@ -59,10 +59,10 @@ Refer to the [Quantization](../../quantization/overview) overview to learn more
```py
import torch
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, SanaTransformer2DModel, SanaPipeline
from transformers import BitsAndBytesConfig as BitsAndBytesConfig, AutoModel
from transformers import BitsAndBytesConfig as BitsAndBytesConfig, AutoModelForCausalLM
quant_config = BitsAndBytesConfig(load_in_8bit=True)
text_encoder_8bit = AutoModel.from_pretrained(
text_encoder_8bit = AutoModelForCausalLM.from_pretrained(
"Efficient-Large-Model/Sana_1600M_1024px_diffusers",
subfolder="text_encoder",
quantization_config=quant_config,
@@ -78,10 +78,10 @@ from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel
from diffusers.utils import export_to_video
transformer = HunyuanVideoTransformer3DModel.from_pretrained(
"hunyuanvideo-community/HunyuanVideo", subfolder="transformer", torch_dtype=torch.bfloat16
"tencent/HunyuanVideo", subfolder="transformer", torch_dtype=torch.bfloat16
)
pipe = HunyuanVideoPipeline.from_pretrained(
"hunyuanvideo-community/HunyuanVideo", transformer=transformer, torch_dtype=torch.float16
"tencent/HunyuanVideo", transformer=transformer, torch_dtype=torch.float16
)
# reduce memory requirements
@@ -67,17 +67,6 @@ write_basic_config()
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
Lastly, we recommend logging into your HF account so that your trained LoRA is automatically uploaded to the hub:
```bash
huggingface-cli login
```
This command will prompt you for a token. Copy-paste yours from your [settings/tokens](https://huggingface.co/settings/tokens),and press Enter.
> [!NOTE]
> In the examples below we use `wandb` to document the training runs. To do the same, make sure to install `wandb`:
> `pip install wandb`
> Alternatively, you can use other tools / train without reporting by modifying the flag `--report_to="wandb"`.
### Pivotal Tuning
**Training with text encoder(s)**
@@ -65,17 +65,6 @@ write_basic_config()
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
Lastly, we recommend logging into your HF account so that your trained LoRA is automatically uploaded to the hub:
```bash
huggingface-cli login
```
This command will prompt you for a token. Copy-paste yours from your [settings/tokens](https://huggingface.co/settings/tokens),and press Enter.
> [!NOTE]
> In the examples below we use `wandb` to document the training runs. To do the same, make sure to install `wandb`:
> `pip install wandb`
> Alternatively, you can use other tools / train without reporting by modifying the flag `--report_to="wandb"`.
### Target Modules
When LoRA was first adapted from language models to diffusion models, it was applied to the cross-attention layers in the Unet that relate the image representations with the prompts that describe them.
More recently, SOTA text-to-image diffusion models replaced the Unet with a diffusion Transformer(DiT). With this change, we may also want to explore
+5 -5
View File
@@ -33,12 +33,12 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif
| Bit Diffusion | Diffusion on discrete data | [Bit Diffusion](#bit-diffusion) | - | [Stuti R.](https://github.com/kingstut) |
| K-Diffusion Stable Diffusion | Run Stable Diffusion with any of [K-Diffusion's samplers](https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py) | [Stable Diffusion with K Diffusion](#stable-diffusion-with-k-diffusion) | - | [Patrick von Platen](https://github.com/patrickvonplaten/) |
| Checkpoint Merger Pipeline | Diffusion Pipeline that enables merging of saved model checkpoints | [Checkpoint Merger Pipeline](#checkpoint-merger-pipeline) | - | [Naga Sai Abhinay Devarinti](https://github.com/Abhinay1997/) |
| Stable Diffusion v1.1-1.4 Comparison | Run all 4 model checkpoints for Stable Diffusion and compare their results together | [Stable Diffusion Comparison](#stable-diffusion-comparisons) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/stable_diffusion_comparison.ipynb) | [Suvaditya Mukherjee](https://github.com/suvadityamuk) |
| Stable Diffusion v1.1-1.4 Comparison | Run all 4 model checkpoints for Stable Diffusion and compare their results together | [Stable Diffusion Comparison](#stable-diffusion-comparisons) | - | [Suvaditya Mukherjee](https://github.com/suvadityamuk) |
| MagicMix | Diffusion Pipeline for semantic mixing of an image and a text prompt | [MagicMix](#magic-mix) | - | [Partho Das](https://github.com/daspartho) |
| Stable UnCLIP | Diffusion Pipeline for combining prior model (generate clip image embedding from text, UnCLIPPipeline `"kakaobrain/karlo-v1-alpha"`) and decoder pipeline (decode clip image embedding to image, StableDiffusionImageVariationPipeline `"lambdalabs/sd-image-variations-diffusers"` ). | [Stable UnCLIP](#stable-unclip) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/stable_unclip.ipynb) | [Ray Wang](https://wrong.wang) |
| UnCLIP Text Interpolation Pipeline | Diffusion Pipeline that allows passing two prompts and produces images while interpolating between the text-embeddings of the two prompts | [UnCLIP Text Interpolation Pipeline](#unclip-text-interpolation-pipeline) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/unclip_text_interpolation.ipynb)| [Naga Sai Abhinay Devarinti](https://github.com/Abhinay1997/) |
| Stable UnCLIP | Diffusion Pipeline for combining prior model (generate clip image embedding from text, UnCLIPPipeline `"kakaobrain/karlo-v1-alpha"`) and decoder pipeline (decode clip image embedding to image, StableDiffusionImageVariationPipeline `"lambdalabs/sd-image-variations-diffusers"` ). | [Stable UnCLIP](#stable-unclip) | - | [Ray Wang](https://wrong.wang) |
| UnCLIP Text Interpolation Pipeline | Diffusion Pipeline that allows passing two prompts and produces images while interpolating between the text-embeddings of the two prompts | [UnCLIP Text Interpolation Pipeline](#unclip-text-interpolation-pipeline) | - | [Naga Sai Abhinay Devarinti](https://github.com/Abhinay1997/) |
| UnCLIP Image Interpolation Pipeline | Diffusion Pipeline that allows passing two images/image_embeddings and produces images while interpolating between their image-embeddings | [UnCLIP Image Interpolation Pipeline](#unclip-image-interpolation-pipeline) | - | [Naga Sai Abhinay Devarinti](https://github.com/Abhinay1997/) |
| DDIM Noise Comparative Analysis Pipeline | Investigating how the diffusion models learn visual concepts from each noise level (which is a contribution of [P2 weighting (CVPR 2022)](https://arxiv.org/abs/2204.00227)) | [DDIM Noise Comparative Analysis Pipeline](#ddim-noise-comparative-analysis-pipeline) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/ddim_noise_comparative_analysis.ipynb)| [Aengus (Duc-Anh)](https://github.com/aengusng8) |
| DDIM Noise Comparative Analysis Pipeline | Investigating how the diffusion models learn visual concepts from each noise level (which is a contribution of [P2 weighting (CVPR 2022)](https://arxiv.org/abs/2204.00227)) | [DDIM Noise Comparative Analysis Pipeline](#ddim-noise-comparative-analysis-pipeline) | - | [Aengus (Duc-Anh)](https://github.com/aengusng8) |
| CLIP Guided Img2Img Stable Diffusion Pipeline | Doing CLIP guidance for image to image generation with Stable Diffusion | [CLIP Guided Img2Img Stable Diffusion](#clip-guided-img2img-stable-diffusion) | - | [Nipun Jindal](https://github.com/nipunjindal/) |
| TensorRT Stable Diffusion Text to Image Pipeline | Accelerates the Stable Diffusion Text2Image Pipeline using TensorRT | [TensorRT Stable Diffusion Text to Image Pipeline](#tensorrt-text2image-stable-diffusion-pipeline) | - | [Asfiya Baig](https://github.com/asfiyab-nvidia) |
| EDICT Image Editing Pipeline | Diffusion pipeline for text-guided image editing | [EDICT Image Editing Pipeline](#edict-image-editing-pipeline) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/edict_image_pipeline.ipynb) | [Joqsan Azocar](https://github.com/Joqsan) |
@@ -50,7 +50,7 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif
| IADB Pipeline | Implementation of [Iterative α-(de)Blending: a Minimalist Deterministic Diffusion Model](https://arxiv.org/abs/2305.03486) | [IADB Pipeline](#iadb-pipeline) | - | [Thomas Chambon](https://github.com/tchambon)
| Zero1to3 Pipeline | Implementation of [Zero-1-to-3: Zero-shot One Image to 3D Object](https://arxiv.org/abs/2303.11328) | [Zero1to3 Pipeline](#zero1to3-pipeline) | - | [Xin Kong](https://github.com/kxhit) |
| Stable Diffusion XL Long Weighted Prompt Pipeline | A pipeline support unlimited length of prompt and negative prompt, use A1111 style of prompt weighting | [Stable Diffusion XL Long Weighted Prompt Pipeline](#stable-diffusion-xl-long-weighted-prompt-pipeline) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1LsqilswLR40XLLcp6XFOl5nKb_wOe26W?usp=sharing) | [Andrew Zhu](https://xhinker.medium.com/) |
| FABRIC - Stable Diffusion with feedback Pipeline | pipeline supports feedback from liked and disliked images | [Stable Diffusion Fabric Pipeline](#stable-diffusion-fabric-pipeline) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/stable_diffusion_fabric.ipynb)| [Shauray Singh](https://shauray8.github.io/about_shauray/) |
| FABRIC - Stable Diffusion with feedback Pipeline | pipeline supports feedback from liked and disliked images | [Stable Diffusion Fabric Pipeline](#stable-diffusion-fabric-pipeline) | - | [Shauray Singh](https://shauray8.github.io/about_shauray/) |
| sketch inpaint - Inpainting with non-inpaint Stable Diffusion | sketch inpaint much like in automatic1111 | [Masked Im2Im Stable Diffusion Pipeline](#stable-diffusion-masked-im2im) | - | [Anatoly Belikov](https://github.com/noskill) |
| sketch inpaint xl - Inpainting with non-inpaint Stable Diffusion | sketch inpaint much like in automatic1111 | [Masked Im2Im Stable Diffusion XL Pipeline](#stable-diffusion-xl-masked-im2im) | - | [Anatoly Belikov](https://github.com/noskill) |
| prompt-to-prompt | change parts of a prompt and retain image structure (see [paper page](https://prompt-to-prompt.github.io/)) | [Prompt2Prompt Pipeline](#prompt2prompt-pipeline) | - | [Umer H. Adil](https://twitter.com/UmerHAdil) |
+8 -12
View File
@@ -372,7 +372,7 @@ class AdaptiveMaskInpaintPipeline(
self.register_adaptive_mask_model()
self.register_adaptive_mask_settings()
if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -386,7 +386,7 @@ class AdaptiveMaskInpaintPipeline(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
if scheduler is not None and getattr(scheduler.config, "skip_prk_steps", True) is False:
if hasattr(scheduler.config, "skip_prk_steps") and scheduler.config.skip_prk_steps is False:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration"
" `skip_prk_steps`. `skip_prk_steps` should be set to True in the configuration file. Please make"
@@ -416,14 +416,10 @@ class AdaptiveMaskInpaintPipeline(
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
is_unet_version_less_0_9_0 = (
unet is not None
and hasattr(unet.config, "_diffusers_version")
and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
)
is_unet_sample_size_less_64 = (
unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
)
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
@@ -442,7 +438,7 @@ class AdaptiveMaskInpaintPipeline(
unet._internal_dict = FrozenDict(new_config)
# Check shapes, assume num_channels_latents == 4, num_channels_mask == 1, num_channels_masked == 4
if unet is not None and unet.config.in_channels != 9:
if unet.config.in_channels != 9:
logger.info(f"You have loaded a UNet with {unet.config.in_channels} input channels which.")
self.register_modules(
@@ -454,7 +450,7 @@ class AdaptiveMaskInpaintPipeline(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
@@ -89,7 +89,7 @@ class ComposableStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin)
):
super().__init__()
if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -103,7 +103,7 @@ class ComposableStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin)
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
@@ -132,14 +132,10 @@ class ComposableStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin)
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
is_unet_version_less_0_9_0 = (
unet is not None
and hasattr(unet.config, "_diffusers_version")
and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
)
is_unet_sample_size_less_64 = (
unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
)
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
@@ -166,7 +162,7 @@ class ComposableStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin)
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.register_to_config(requires_safety_checker=requires_safety_checker)
def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
+1 -1
View File
@@ -35,7 +35,7 @@ class EDICTPipeline(DiffusionPipeline):
scheduler=scheduler,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
def _encode_prompt(
+1 -1
View File
@@ -1342,7 +1342,7 @@ class FrescoV2VPipeline(StableDiffusionControlNetImg2ImgPipeline):
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
self.control_image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
+1 -1
View File
@@ -221,7 +221,7 @@ class GlueGenStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin, St
language_adapter=language_adapter,
tensor_norm=tensor_norm,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
+1 -1
View File
@@ -95,7 +95,7 @@ class ImageToImageInpaintingPipeline(DiffusionPipeline):
):
super().__init__()
if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
+7 -11
View File
@@ -109,7 +109,7 @@ class InstaFlowPipeline(
):
super().__init__()
if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -123,7 +123,7 @@ class InstaFlowPipeline(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
@@ -152,14 +152,10 @@ class InstaFlowPipeline(
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
is_unet_version_less_0_9_0 = (
unet is not None
and hasattr(unet.config, "_diffusers_version")
and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
)
is_unet_sample_size_less_64 = (
unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
)
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
@@ -186,7 +182,7 @@ class InstaFlowPipeline(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
@@ -86,7 +86,7 @@ class StableDiffusionWalkPipeline(DiffusionPipeline, StableDiffusionMixin):
):
super().__init__()
if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
+7 -11
View File
@@ -191,7 +191,7 @@ class IPAdapterFaceIDStableDiffusionPipeline(
):
super().__init__()
if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -205,7 +205,7 @@ class IPAdapterFaceIDStableDiffusionPipeline(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
@@ -234,14 +234,10 @@ class IPAdapterFaceIDStableDiffusionPipeline(
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
is_unet_version_less_0_9_0 = (
unet is not None
and hasattr(unet.config, "_diffusers_version")
and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
)
is_unet_sample_size_less_64 = (
unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
)
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
@@ -269,7 +265,7 @@ class IPAdapterFaceIDStableDiffusionPipeline(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
+1 -1
View File
@@ -463,6 +463,6 @@ class StableDiffusionHighResFixPipeline(StableDiffusionPipeline):
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
@@ -69,7 +69,7 @@ class LatentConsistencyModelImg2ImgPipeline(DiffusionPipeline):
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
def _encode_prompt(
@@ -273,7 +273,7 @@ class LatentConsistencyModelWalkPipeline(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
@@ -67,7 +67,7 @@ class LatentConsistencyModelPipeline(DiffusionPipeline):
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
def _encode_prompt(
+7 -11
View File
@@ -336,7 +336,7 @@ class LLMGroundedDiffusionPipeline(
# This is copied from StableDiffusionPipeline, with hook initizations for LMD+.
super().__init__()
if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -350,7 +350,7 @@ class LLMGroundedDiffusionPipeline(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
@@ -379,14 +379,10 @@ class LLMGroundedDiffusionPipeline(
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
is_unet_version_less_0_9_0 = (
unet is not None
and hasattr(unet.config, "_diffusers_version")
and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
)
is_unet_sample_size_less_64 = (
unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
)
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
@@ -414,7 +410,7 @@ class LLMGroundedDiffusionPipeline(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
+7 -11
View File
@@ -496,7 +496,7 @@ class StableDiffusionLongPromptWeightingPipeline(
):
super().__init__()
if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -510,7 +510,7 @@ class StableDiffusionLongPromptWeightingPipeline(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
@@ -539,14 +539,10 @@ class StableDiffusionLongPromptWeightingPipeline(
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
is_unet_version_less_0_9_0 = (
unet is not None
and hasattr(unet.config, "_diffusers_version")
and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
)
is_unet_sample_size_less_64 = (
unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
)
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
@@ -572,7 +568,7 @@ class StableDiffusionLongPromptWeightingPipeline(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(
+4 -11
View File
@@ -673,16 +673,12 @@ class SDXLLongPromptWeightingPipeline(
image_encoder=image_encoder,
)
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.mask_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
)
self.default_sample_size = (
self.unet.config.sample_size
if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size")
else 128
)
self.default_sample_size = self.unet.config.sample_size
add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
@@ -831,9 +827,7 @@ class SDXLLongPromptWeightingPipeline(
)
# We are only ALWAYS interested in the pooled output of the final text encoder
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
pooled_prompt_embeds = prompt_embeds[0]
pooled_prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.hidden_states[-2]
prompt_embeds_list.append(prompt_embeds)
@@ -885,8 +879,7 @@ class SDXLLongPromptWeightingPipeline(
output_hidden_states=True,
)
# We are only ALWAYS interested in the pooled output of the final text encoder
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
+6 -10
View File
@@ -3766,7 +3766,7 @@ class MatryoshkaPipeline(
else:
raise ValueError("Currently, nesting levels 0, 1, and 2 are supported.")
if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -3780,7 +3780,7 @@ class MatryoshkaPipeline(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
# if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
# if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
# deprecation_message = (
# f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
# " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
@@ -3793,14 +3793,10 @@ class MatryoshkaPipeline(
# new_config["clip_sample"] = False
# scheduler._internal_dict = FrozenDict(new_config)
is_unet_version_less_0_9_0 = (
unet is not None
and hasattr(unet.config, "_diffusers_version")
and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
)
is_unet_sample_size_less_64 = (
unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
)
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
@@ -98,7 +98,7 @@ class MultilingualStableDiffusion(DiffusionPipeline, StableDiffusionMixin):
):
super().__init__()
if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -188,7 +188,7 @@ class AnimateDiffControlNetPipeline(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.control_image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
@@ -308,7 +308,7 @@ class AnimateDiffImgToVideoPipeline(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt
@@ -162,7 +162,7 @@ class AnimateDiffPipelineIpex(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt
+4 -11
View File
@@ -166,13 +166,9 @@ class DemoFusionSDXLPipeline(
scheduler=scheduler,
)
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.default_sample_size = (
self.unet.config.sample_size
if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size")
else 128
)
self.default_sample_size = self.unet.config.sample_size
add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
@@ -294,9 +290,7 @@ class DemoFusionSDXLPipeline(
)
# We are only ALWAYS interested in the pooled output of the final text encoder
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
pooled_prompt_embeds = prompt_embeds[0]
pooled_prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.hidden_states[-2]
prompt_embeds_list.append(prompt_embeds)
@@ -348,8 +342,7 @@ class DemoFusionSDXLPipeline(
output_hidden_states=True,
)
# We are only ALWAYS interested in the pooled output of the final text encoder
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
+5 -9
View File
@@ -150,14 +150,10 @@ class FabricPipeline(DiffusionPipeline):
):
super().__init__()
is_unet_version_less_0_9_0 = (
unet is not None
and hasattr(unet.config, "_diffusers_version")
and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
)
is_unet_sample_size_less_64 = (
unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
)
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
@@ -183,7 +179,7 @@ class FabricPipeline(DiffusionPipeline):
tokenizer=tokenizer,
scheduler=scheduler,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
@@ -221,12 +221,13 @@ class FluxDifferentialImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
transformer=transformer,
scheduler=scheduler,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels)) if getattr(self, "vae", None) else 16
self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
self.mask_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor,
vae_latent_channels=latent_channels,
vae_latent_channels=self.vae.config.latent_channels,
do_normalize=False,
do_binarize=False,
do_convert_grayscale=True,
@@ -875,10 +876,10 @@ class FluxDifferentialImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
image_seq_len = (int(height) // self.vae_scale_factor) * (int(width) // self.vae_scale_factor)
mu = calculate_shift(
image_seq_len,
self.scheduler.config.get("base_image_seq_len", 256),
self.scheduler.config.get("max_image_seq_len", 4096),
self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.get("max_shift", 1.16),
self.scheduler.config.base_image_seq_len,
self.scheduler.config.max_image_seq_len,
self.scheduler.config.base_shift,
self.scheduler.config.max_shift,
)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
@@ -219,7 +219,9 @@ class RFInversionFluxPipeline(
transformer=transformer,
scheduler=scheduler,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.tokenizer_max_length = (
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
@@ -417,7 +419,7 @@ class RFInversionFluxPipeline(
)
image = image.to(dtype)
x0 = self.vae.encode(image.to(self._execution_device)).latent_dist.sample()
x0 = self.vae.encode(image.to(self.device)).latent_dist.sample()
x0 = (x0 - self.vae.config.shift_factor) * self.vae.config.scaling_factor
x0 = x0.to(dtype)
return x0, resized
@@ -820,10 +822,10 @@ class RFInversionFluxPipeline(
image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2)
mu = calculate_shift(
image_seq_len,
self.scheduler.config.get("base_image_seq_len", 256),
self.scheduler.config.get("max_image_seq_len", 4096),
self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.get("max_shift", 1.16),
self.scheduler.config.base_image_seq_len,
self.scheduler.config.max_image_seq_len,
self.scheduler.config.base_shift,
self.scheduler.config.max_shift,
)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
@@ -990,10 +992,10 @@ class RFInversionFluxPipeline(
image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2)
mu = calculate_shift(
image_seq_len,
self.scheduler.config.get("base_image_seq_len", 256),
self.scheduler.config.get("max_image_seq_len", 4096),
self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.get("max_shift", 1.16),
self.scheduler.config.base_image_seq_len,
self.scheduler.config.max_image_seq_len,
self.scheduler.config.base_shift,
self.scheduler.config.max_shift,
)
timesteps, num_inversion_steps = retrieve_timesteps(
self.scheduler,
+7 -6
View File
@@ -64,7 +64,6 @@ EXAMPLE_DOC_STRING = """
"""
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
def calculate_shift(
image_seq_len,
base_seq_len: int = 256,
@@ -190,7 +189,9 @@ class FluxCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixi
transformer=transformer,
scheduler=scheduler,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels)) if getattr(self, "vae", None) else 16
self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.tokenizer_max_length = (
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
@@ -756,10 +757,10 @@ class FluxCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixi
image_seq_len = latents.shape[1]
mu = calculate_shift(
image_seq_len,
self.scheduler.config.get("base_image_seq_len", 256),
self.scheduler.config.get("max_image_seq_len", 4096),
self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.get("max_shift", 1.16),
self.scheduler.config.base_image_seq_len,
self.scheduler.config.max_image_seq_len,
self.scheduler.config.base_shift,
self.scheduler.config.max_shift,
)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
@@ -327,7 +327,9 @@ class HunyuanDiTDifferentialImg2ImgPipeline(DiffusionPipeline):
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.mask_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor,
@@ -209,18 +209,16 @@ class KolorsDifferentialImg2ImgPipeline(
feature_extractor=feature_extractor,
)
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.mask_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_convert_grayscale=True
)
self.default_sample_size = (
self.unet.config.sample_size
if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size")
else 128
)
self.default_sample_size = self.unet.config.sample_size
# Copied from diffusers.pipelines.kolors.pipeline_kolors.KolorsPipeline.encode_prompt
def encode_prompt(
+7 -11
View File
@@ -131,7 +131,7 @@ class Prompt2PromptPipeline(
):
super().__init__()
if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -145,7 +145,7 @@ class Prompt2PromptPipeline(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
@@ -174,14 +174,10 @@ class Prompt2PromptPipeline(
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
is_unet_version_less_0_9_0 = (
unet is not None
and hasattr(unet.config, "_diffusers_version")
and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
)
is_unet_sample_size_less_64 = (
unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
)
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
@@ -209,7 +205,7 @@ class Prompt2PromptPipeline(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
@@ -488,17 +488,13 @@ class StyleAlignedSDXLPipeline(
feature_extractor=feature_extractor,
)
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.mask_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
)
self.default_sample_size = (
self.unet.config.sample_size
if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size")
else 128
)
self.default_sample_size = self.unet.config.sample_size
add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
@@ -632,9 +628,7 @@ class StyleAlignedSDXLPipeline(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
pooled_prompt_embeds = prompt_embeds[0]
pooled_prompt_embeds = prompt_embeds[0]
if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
@@ -694,8 +688,7 @@ class StyleAlignedSDXLPipeline(
output_hidden_states=True,
)
# We are only ALWAYS interested in the pooled output of the final text encoder
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
@@ -207,7 +207,7 @@ class StableDiffusion3DifferentialImg2ImgPipeline(DiffusionPipeline):
transformer=transformer,
scheduler=scheduler,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, vae_latent_channels=self.vae.config.latent_channels
)
@@ -417,7 +417,7 @@ class StableDiffusionBoxDiffPipeline(
):
super().__init__()
if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -431,7 +431,7 @@ class StableDiffusionBoxDiffPipeline(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
@@ -460,14 +460,10 @@ class StableDiffusionBoxDiffPipeline(
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
is_unet_version_less_0_9_0 = (
unet is not None
and hasattr(unet.config, "_diffusers_version")
and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
)
is_unet_sample_size_less_64 = (
unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
)
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
@@ -495,7 +491,7 @@ class StableDiffusionBoxDiffPipeline(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
@@ -384,7 +384,7 @@ class StableDiffusionPAGPipeline(
):
super().__init__()
if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -398,7 +398,7 @@ class StableDiffusionPAGPipeline(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
@@ -427,14 +427,10 @@ class StableDiffusionPAGPipeline(
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
is_unet_version_less_0_9_0 = (
unet is not None
and hasattr(unet.config, "_diffusers_version")
and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
)
is_unet_sample_size_less_64 = (
unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
)
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
@@ -462,7 +458,7 @@ class StableDiffusionPAGPipeline(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
@@ -151,7 +151,7 @@ class StableDiffusionUpscaleLDM3DPipeline(
watermarker=watermarker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessorLDM3D(vae_scale_factor=self.vae_scale_factor, resample="bilinear")
# self.register_to_config(requires_safety_checker=requires_safety_checker)
self.register_to_config(max_noise_level=max_noise_level)
@@ -226,16 +226,12 @@ class StableDiffusionXLControlNetAdapterPipeline(
scheduler=scheduler,
)
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.control_image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
)
self.default_sample_size = (
self.unet.config.sample_size
if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size")
else 128
)
self.default_sample_size = self.unet.config.sample_size
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
def encode_prompt(
@@ -363,9 +359,7 @@ class StableDiffusionXLControlNetAdapterPipeline(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
pooled_prompt_embeds = prompt_embeds[0]
pooled_prompt_embeds = prompt_embeds[0]
if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
@@ -425,8 +419,7 @@ class StableDiffusionXLControlNetAdapterPipeline(
output_hidden_states=True,
)
# We are only ALWAYS interested in the pooled output of the final text encoder
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
@@ -374,16 +374,12 @@ class StableDiffusionXLControlNetAdapterInpaintPipeline(
)
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
self.register_to_config(requires_aesthetics_score=requires_aesthetics_score)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.control_image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
)
self.default_sample_size = (
self.unet.config.sample_size
if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size")
else 128
)
self.default_sample_size = self.unet.config.sample_size
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
def encode_prompt(
@@ -511,9 +507,7 @@ class StableDiffusionXLControlNetAdapterInpaintPipeline(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
pooled_prompt_embeds = prompt_embeds[0]
pooled_prompt_embeds = prompt_embeds[0]
if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
@@ -573,8 +567,7 @@ class StableDiffusionXLControlNetAdapterInpaintPipeline(
output_hidden_states=True,
)
# We are only ALWAYS interested in the pooled output of the final text encoder
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
@@ -258,7 +258,7 @@ class StableDiffusionXLDifferentialImg2ImgPipeline(
)
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
self.register_to_config(requires_aesthetics_score=requires_aesthetics_score)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
@@ -394,9 +394,7 @@ class StableDiffusionXLDifferentialImg2ImgPipeline(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
pooled_prompt_embeds = prompt_embeds[0]
pooled_prompt_embeds = prompt_embeds[0]
if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
@@ -456,8 +454,7 @@ class StableDiffusionXLDifferentialImg2ImgPipeline(
output_hidden_states=True,
)
# We are only ALWAYS interested in the pooled output of the final text encoder
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
@@ -253,14 +253,10 @@ class StableDiffusionXLPipelineIpex(
feature_extractor=feature_extractor,
)
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.default_sample_size = (
self.unet.config.sample_size
if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size")
else 128
)
self.default_sample_size = self.unet.config.sample_size
add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
@@ -394,9 +390,7 @@ class StableDiffusionXLPipelineIpex(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
pooled_prompt_embeds = prompt_embeds[0]
pooled_prompt_embeds = prompt_embeds[0]
if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
@@ -456,8 +450,7 @@ class StableDiffusionXLPipelineIpex(
output_hidden_states=True,
)
# We are only ALWAYS interested in the pooled output of the final text encoder
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
+7 -11
View File
@@ -108,7 +108,7 @@ class Zero1to3StableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
):
super().__init__()
if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -122,7 +122,7 @@ class Zero1to3StableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
@@ -151,14 +151,10 @@ class Zero1to3StableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
is_unet_version_less_0_9_0 = (
unet is not None
and hasattr(unet.config, "_diffusers_version")
and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
)
is_unet_sample_size_less_64 = (
unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
)
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
@@ -185,7 +181,7 @@ class Zero1to3StableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
feature_extractor=feature_extractor,
cc_projection=cc_projection,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.register_to_config(requires_safety_checker=requires_safety_checker)
# self.model_mode = None
+4 -7
View File
@@ -352,7 +352,7 @@ class RerenderAVideoPipeline(StableDiffusionControlNetImg2ImgPipeline):
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
self.control_image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
@@ -632,7 +632,7 @@ class RerenderAVideoPipeline(StableDiffusionControlNetImg2ImgPipeline):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
frames (`List[np.ndarray]` or `torch.Tensor`): The input images to be used as the starting point for the image generation process.
control_frames (`List[np.ndarray]` or `torch.Tensor` or `Callable`): The ControlNet input images condition to provide guidance to the `unet` for generation or any callable object to convert frame to control_frame.
control_frames (`List[np.ndarray]` or `torch.Tensor`): The ControlNet input images condition to provide guidance to the `unet` for generation.
strength ('float'): SDEdit strength.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
@@ -789,7 +789,7 @@ class RerenderAVideoPipeline(StableDiffusionControlNetImg2ImgPipeline):
# Currently we only support single control
if isinstance(controlnet, ControlNetModel):
control_image = self.prepare_control_image(
image=control_frames(frames[0]) if callable(control_frames) else control_frames[0],
image=control_frames[0],
width=width,
height=height,
batch_size=batch_size,
@@ -908,9 +908,6 @@ class RerenderAVideoPipeline(StableDiffusionControlNetImg2ImgPipeline):
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
else:
@@ -927,7 +924,7 @@ class RerenderAVideoPipeline(StableDiffusionControlNetImg2ImgPipeline):
for idx in range(1, len(frames)):
image = frames[idx]
prev_image = frames[idx - 1]
control_image = control_frames(image) if callable(control_frames) else control_frames[idx]
control_image = control_frames[idx]
# 5.1 prepare frames
image = self.image_processor.preprocess(image).to(dtype=self.dtype)
prev_image = self.image_processor.preprocess(prev_image).to(dtype=self.dtype)
@@ -179,7 +179,7 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, StableDiffusio
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.register_to_config(requires_safety_checker=requires_safety_checker)
def _encode_prompt(
@@ -278,7 +278,7 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, StableDiffusio
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.register_to_config(requires_safety_checker=requires_safety_checker)
def _encode_prompt(
@@ -263,7 +263,7 @@ class StableDiffusionControlNetInpaintImg2ImgPipeline(DiffusionPipeline, StableD
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.register_to_config(requires_safety_checker=requires_safety_checker)
def _encode_prompt(
+7 -11
View File
@@ -105,7 +105,7 @@ class StableDiffusionIPEXPipeline(
):
super().__init__()
if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -119,7 +119,7 @@ class StableDiffusionIPEXPipeline(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
@@ -148,14 +148,10 @@ class StableDiffusionIPEXPipeline(
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
is_unet_version_less_0_9_0 = (
unet is not None
and hasattr(unet.config, "_diffusers_version")
and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
)
is_unet_sample_size_less_64 = (
unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
)
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
@@ -182,7 +178,7 @@ class StableDiffusionIPEXPipeline(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.register_to_config(requires_safety_checker=requires_safety_checker)
def get_input_example(self, prompt, height=None, width=None, guidance_scale=7.5, num_images_per_prompt=1):
+1 -1
View File
@@ -66,7 +66,7 @@ class StableDiffusionMegaPipeline(DiffusionPipeline, StableDiffusionMixin):
requires_safety_checker: bool = True,
):
super().__init__()
if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -132,7 +132,7 @@ class StableDiffusionReferencePipeline(
):
super().__init__()
if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -146,7 +146,7 @@ class StableDiffusionReferencePipeline(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
if scheduler is not None and getattr(scheduler.config, "skip_prk_steps", True) is False:
if hasattr(scheduler.config, "skip_prk_steps") and scheduler.config.skip_prk_steps is False:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration"
" `skip_prk_steps`. `skip_prk_steps` should be set to True in the configuration file. Please make"
@@ -181,14 +181,10 @@ class StableDiffusionReferencePipeline(
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
is_unet_version_less_0_9_0 = (
unet is not None
and hasattr(unet.config, "_diffusers_version")
and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
)
is_unet_sample_size_less_64 = (
unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
)
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
@@ -206,7 +202,7 @@ class StableDiffusionReferencePipeline(
new_config["sample_size"] = 64
unet._internal_dict = FrozenDict(new_config)
# Check shapes, assume num_channels_latents == 4, num_channels_mask == 1, num_channels_masked == 4
if unet is not None and unet.config.in_channels != 4:
if unet.config.in_channels != 4:
logger.warning(
f"You have loaded a UNet with {unet.config.in_channels} input channels, whereas by default,"
f" {self.__class__} assumes that `pipeline.unet` has 4 input channels: 4 for `num_channels_latents`,"
@@ -223,7 +219,7 @@ class StableDiffusionReferencePipeline(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
+8 -12
View File
@@ -187,7 +187,7 @@ class StableDiffusionRepaintPipeline(
):
super().__init__()
if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -201,7 +201,7 @@ class StableDiffusionRepaintPipeline(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
if scheduler is not None and getattr(scheduler.config, "skip_prk_steps", True) is False:
if hasattr(scheduler.config, "skip_prk_steps") and scheduler.config.skip_prk_steps is False:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration"
" `skip_prk_steps`. `skip_prk_steps` should be set to True in the configuration file. Please make"
@@ -236,14 +236,10 @@ class StableDiffusionRepaintPipeline(
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
is_unet_version_less_0_9_0 = (
unet is not None
and hasattr(unet.config, "_diffusers_version")
and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
)
is_unet_sample_size_less_64 = (
unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
)
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
@@ -261,7 +257,7 @@ class StableDiffusionRepaintPipeline(
new_config["sample_size"] = 64
unet._internal_dict = FrozenDict(new_config)
# Check shapes, assume num_channels_latents == 4, num_channels_mask == 1, num_channels_masked == 4
if unet is not None and unet.config.in_channels != 4:
if unet.config.in_channels != 4:
logger.warning(
f"You have loaded a UNet with {unet.config.in_channels} input channels, whereas by default,"
f" {self.__class__} assumes that `pipeline.unet` has 4 input channels: 4 for `num_channels_latents`,"
@@ -278,7 +274,7 @@ class StableDiffusionRepaintPipeline(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.register_to_config(requires_safety_checker=requires_safety_checker)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
@@ -710,7 +710,7 @@ class TensorRTStableDiffusionImg2ImgPipeline(DiffusionPipeline):
):
super().__init__()
if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -724,7 +724,7 @@ class TensorRTStableDiffusionImg2ImgPipeline(DiffusionPipeline):
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
@@ -753,14 +753,10 @@ class TensorRTStableDiffusionImg2ImgPipeline(DiffusionPipeline):
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
is_unet_version_less_0_9_0 = (
unet is not None
and hasattr(unet.config, "_diffusers_version")
and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
)
is_unet_sample_size_less_64 = (
unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
)
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
@@ -810,7 +806,7 @@ class TensorRTStableDiffusionImg2ImgPipeline(DiffusionPipeline):
self.engine = {} # loaded in build_engines()
self.vae.forward = self.vae.decode
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
@@ -714,7 +714,7 @@ class TensorRTStableDiffusionInpaintPipeline(DiffusionPipeline):
):
super().__init__()
if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -728,7 +728,7 @@ class TensorRTStableDiffusionInpaintPipeline(DiffusionPipeline):
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
@@ -757,14 +757,10 @@ class TensorRTStableDiffusionInpaintPipeline(DiffusionPipeline):
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
is_unet_version_less_0_9_0 = (
unet is not None
and hasattr(unet.config, "_diffusers_version")
and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
)
is_unet_sample_size_less_64 = (
unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
)
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
@@ -814,7 +810,7 @@ class TensorRTStableDiffusionInpaintPipeline(DiffusionPipeline):
self.engine = {} # loaded in build_engines()
self.vae.forward = self.vae.decode
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
@@ -626,7 +626,7 @@ class TensorRTStableDiffusionPipeline(DiffusionPipeline):
):
super().__init__()
if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -640,7 +640,7 @@ class TensorRTStableDiffusionPipeline(DiffusionPipeline):
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
@@ -669,14 +669,10 @@ class TensorRTStableDiffusionPipeline(DiffusionPipeline):
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
is_unet_version_less_0_9_0 = (
unet is not None
and hasattr(unet.config, "_diffusers_version")
and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
)
is_unet_sample_size_less_64 = (
unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
)
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
@@ -726,7 +722,7 @@ class TensorRTStableDiffusionPipeline(DiffusionPipeline):
self.engine = {} # loaded in build_engines()
self.vae.forward = self.vae.decode
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
+2 -2
View File
@@ -71,7 +71,7 @@ class TextInpainting(DiffusionPipeline, StableDiffusionMixin):
):
super().__init__()
if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -85,7 +85,7 @@ class TextInpainting(DiffusionPipeline, StableDiffusionMixin):
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
if scheduler is not None and getattr(scheduler.config, "skip_prk_steps", True) is False:
if hasattr(scheduler.config, "skip_prk_steps") and scheduler.config.skip_prk_steps is False:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration"
" `skip_prk_steps`. `skip_prk_steps` should be set to True in the configuration file. Please make"
@@ -120,7 +120,7 @@ class WildcardStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
):
super().__init__()
if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
+3 -3
View File
@@ -121,7 +121,7 @@ prompt = "A couple, 4k photo, highly detailed"
gen_images = pipe(
prompt=prompt,
control_image=image,
condition_image=image,
num_inference_steps=50,
joint_attention_kwargs={"scale": 0.9},
guidance_scale=25.,
@@ -190,7 +190,7 @@ prompt = "A couple, 4k photo, highly detailed"
gen_images = pipe(
prompt=prompt,
control_image=image,
condition_image=image,
num_inference_steps=50,
guidance_scale=25.,
).images[0]
@@ -200,5 +200,5 @@ gen_images.save("output.png")
## Things to note
* The scripts provided in this directory are experimental and educational. This means we may have to tweak things around to get good results on a given condition. We believe this is best done with the community 🤗
* The scripts are not memory-optimized but we offload the VAE and the text encoders to CPU when they are not used if `--offload` is specified.
* The scripts are not memory-optimized but we offload the VAE and the text encoders to CPU when they are not used.
* We can extract LoRAs from the fully fine-tuned model. While we currently don't provide any utilities for that, users are welcome to refer to [this script](https://github.com/Stability-AI/stability-ComfyUI-nodes/blob/master/control_lora_create.py) that provides a similar functionality.
+9 -51
View File
@@ -122,6 +122,7 @@ def log_validation(flux_transformer, args, accelerator, weight_dtype, step, is_f
for _ in range(args.num_validation_images):
with autocast_ctx:
# need to fix in pipeline_flux_controlnet
image = pipeline(
prompt=validation_prompt,
control_image=validation_image,
@@ -158,7 +159,7 @@ def log_validation(flux_transformer, args, accelerator, weight_dtype, step, is_f
images = log["images"]
validation_prompt = log["validation_prompt"]
validation_image = log["validation_image"]
formatted_images.append(wandb.Image(validation_image, caption="Conditioning"))
formatted_images.append(wandb.Image(validation_image, caption="Controlnet conditioning"))
for image in images:
image = wandb.Image(image, caption=validation_prompt)
formatted_images.append(image)
@@ -187,7 +188,7 @@ def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=N
img_str += f"![images_{i})](./images_{i}.png)\n"
model_description = f"""
# flux-control-{repo_id}
# control-lora-{repo_id}
These are Control weights trained on {base_model} with new type of conditioning.
{img_str}
@@ -433,7 +434,7 @@ def parse_args(input_args=None):
"--conditioning_image_column",
type=str,
default="conditioning_image",
help="The column of the dataset containing the control conditioning image.",
help="The column of the dataset containing the controlnet conditioning image.",
)
parser.add_argument(
"--caption_column",
@@ -441,7 +442,6 @@ def parse_args(input_args=None):
default="text",
help="The column of the dataset containing a caption or a list of captions.",
)
parser.add_argument("--log_dataset_samples", action="store_true", help="Whether to log somple dataset samples.")
parser.add_argument(
"--max_train_samples",
type=int,
@@ -468,7 +468,7 @@ def parse_args(input_args=None):
default=None,
nargs="+",
help=(
"A set of paths to the control conditioning image be evaluated every `--validation_steps`"
"A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`"
" and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a"
" a single `--validation_prompt` to be used with all `--validation_image`s, or a single"
" `--validation_image` that will be used with all `--validation_prompt`s."
@@ -505,11 +505,7 @@ def parse_args(input_args=None):
default=None,
help="Path to the jsonl file containing the training data.",
)
parser.add_argument(
"--only_target_transformer_blocks",
action="store_true",
help="If we should only target the transformer blocks to train along with the input layer (`x_embedder`).",
)
parser.add_argument(
"--guidance_scale",
type=float,
@@ -585,7 +581,7 @@ def parse_args(input_args=None):
if args.resolution % 8 != 0:
raise ValueError(
"`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the Flux transformer."
"`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder."
)
return args
@@ -669,12 +665,7 @@ def prepare_train_dataset(dataset, accelerator):
conditioning_images = [image_transforms(image) for image in conditioning_images]
examples["pixel_values"] = images
examples["conditioning_pixel_values"] = conditioning_images
is_caption_list = isinstance(examples[args.caption_column][0], list)
if is_caption_list:
examples["captions"] = [max(example, key=len) for example in examples[args.caption_column]]
else:
examples["captions"] = list(examples[args.caption_column])
examples["captions"] = list(examples[args.caption_column])
return examples
@@ -774,8 +765,7 @@ def main(args):
subfolder="scheduler",
)
noise_scheduler_copy = copy.deepcopy(noise_scheduler)
if not args.only_target_transformer_blocks:
flux_transformer.requires_grad_(True)
flux_transformer.requires_grad_(True)
vae.requires_grad_(False)
# cast down and move to the CPU
@@ -807,12 +797,6 @@ def main(args):
assert torch.all(flux_transformer.x_embedder.weight[:, initial_input_channels:].data == 0)
flux_transformer.register_to_config(in_channels=initial_input_channels * 2, out_channels=initial_input_channels)
if args.only_target_transformer_blocks:
flux_transformer.x_embedder.requires_grad_(True)
for name, module in flux_transformer.named_modules():
if "transformer_blocks" in name:
module.requires_grad_(True)
def unwrap_model(model):
model = accelerator.unwrap_model(model)
model = model._orig_mod if is_compiled_module(model) else model
@@ -990,32 +974,6 @@ def main(args):
else:
initial_global_step = 0
if accelerator.is_main_process and args.report_to == "wandb" and args.log_dataset_samples:
logger.info("Logging some dataset samples.")
formatted_images = []
formatted_control_images = []
all_prompts = []
for i, batch in enumerate(train_dataloader):
images = (batch["pixel_values"] + 1) / 2
control_images = (batch["conditioning_pixel_values"] + 1) / 2
prompts = batch["captions"]
if len(formatted_images) > 10:
break
for img, control_img, prompt in zip(images, control_images, prompts):
formatted_images.append(img)
formatted_control_images.append(control_img)
all_prompts.append(prompt)
logged_artifacts = []
for img, control_img, prompt in zip(formatted_images, formatted_control_images, all_prompts):
logged_artifacts.append(wandb.Image(control_img, caption="Conditioning"))
logged_artifacts.append(wandb.Image(img, caption=prompt))
wandb_tracker = [tracker for tracker in accelerator.trackers if tracker.name == "wandb"]
wandb_tracker[0].log({"dataset_samples": logged_artifacts})
progress_bar = tqdm(
range(0, args.max_train_steps),
initial=initial_global_step,
@@ -132,6 +132,7 @@ def log_validation(flux_transformer, args, accelerator, weight_dtype, step, is_f
for _ in range(args.num_validation_images):
with autocast_ctx:
# need to fix in pipeline_flux_controlnet
image = pipeline(
prompt=validation_prompt,
control_image=validation_image,
@@ -168,7 +169,7 @@ def log_validation(flux_transformer, args, accelerator, weight_dtype, step, is_f
images = log["images"]
validation_prompt = log["validation_prompt"]
validation_image = log["validation_image"]
formatted_images.append(wandb.Image(validation_image, caption="Conditioning"))
formatted_images.append(wandb.Image(validation_image, caption="Controlnet conditioning"))
for image in images:
image = wandb.Image(image, caption=validation_prompt)
formatted_images.append(image)
@@ -197,7 +198,7 @@ def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=N
img_str += f"![images_{i})](./images_{i}.png)\n"
model_description = f"""
# control-lora-{repo_id}
# controlnet-lora-{repo_id}
These are Control LoRA weights trained on {base_model} with new type of conditioning.
{img_str}
@@ -255,7 +256,7 @@ def parse_args(input_args=None):
parser.add_argument(
"--output_dir",
type=str,
default="control-lora",
default="controlnet-lora",
help="The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument(
@@ -465,7 +466,7 @@ def parse_args(input_args=None):
"--conditioning_image_column",
type=str,
default="conditioning_image",
help="The column of the dataset containing the control conditioning image.",
help="The column of the dataset containing the controlnet conditioning image.",
)
parser.add_argument(
"--caption_column",
@@ -473,7 +474,6 @@ def parse_args(input_args=None):
default="text",
help="The column of the dataset containing a caption or a list of captions.",
)
parser.add_argument("--log_dataset_samples", action="store_true", help="Whether to log somple dataset samples.")
parser.add_argument(
"--max_train_samples",
type=int,
@@ -500,7 +500,7 @@ def parse_args(input_args=None):
default=None,
nargs="+",
help=(
"A set of paths to the control conditioning image be evaluated every `--validation_steps`"
"A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`"
" and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a"
" a single `--validation_prompt` to be used with all `--validation_image`s, or a single"
" `--validation_image` that will be used with all `--validation_prompt`s."
@@ -613,7 +613,7 @@ def parse_args(input_args=None):
if args.resolution % 8 != 0:
raise ValueError(
"`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the Flux transformer."
"`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder."
)
return args
@@ -697,12 +697,7 @@ def prepare_train_dataset(dataset, accelerator):
conditioning_images = [image_transforms(image) for image in conditioning_images]
examples["pixel_values"] = images
examples["conditioning_pixel_values"] = conditioning_images
is_caption_list = isinstance(examples[args.caption_column][0], list)
if is_caption_list:
examples["captions"] = [max(example, key=len) for example in examples[args.caption_column]]
else:
examples["captions"] = list(examples[args.caption_column])
examples["captions"] = list(examples[args.caption_column])
return examples
@@ -1137,32 +1132,6 @@ def main(args):
else:
initial_global_step = 0
if accelerator.is_main_process and args.report_to == "wandb" and args.log_dataset_samples:
logger.info("Logging some dataset samples.")
formatted_images = []
formatted_control_images = []
all_prompts = []
for i, batch in enumerate(train_dataloader):
images = (batch["pixel_values"] + 1) / 2
control_images = (batch["conditioning_pixel_values"] + 1) / 2
prompts = batch["captions"]
if len(formatted_images) > 10:
break
for img, control_img, prompt in zip(images, control_images, prompts):
formatted_images.append(img)
formatted_control_images.append(control_img)
all_prompts.append(prompt)
logged_artifacts = []
for img, control_img, prompt in zip(formatted_images, formatted_control_images, all_prompts):
logged_artifacts.append(wandb.Image(control_img, caption="Conditioning"))
logged_artifacts.append(wandb.Image(img, caption=prompt))
wandb_tracker = [tracker for tracker in accelerator.trackers if tracker.name == "wandb"]
wandb_tracker[0].log({"dataset_samples": logged_artifacts})
progress_bar = tqdm(
range(0, args.max_train_steps),
initial=initial_global_step,
@@ -2,34 +2,6 @@
This extended LoRA training script was authored by [Aiden-Frost](https://github.com/Aiden-Frost).
This is an experimental LoRA extension of [this example](https://github.com/huggingface/diffusers/blob/main/examples/instruct_pix2pix/train_instruct_pix2pix.py). This script provides further support add LoRA layers for unet model.
## Running locally with PyTorch
### Installing the dependencies
Before running the scripts, make sure to install the library's training dependencies:
**Important**
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
```bash
git clone https://github.com/huggingface/diffusers
cd diffusers
pip install .
```
Then cd in the example folder and run
```bash
pip install -r requirements.txt
```
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
```bash
accelerate config
```
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
## Training script example
```bash
@@ -37,7 +9,7 @@ export MODEL_ID="timbrooks/instruct-pix2pix"
export DATASET_ID="instruction-tuning-sd/cartoonization"
export OUTPUT_DIR="instructPix2Pix-cartoonization"
accelerate launch train_instruct_pix2pix_lora.py \
accelerate launch finetune_instruct_pix2pix.py \
--pretrained_model_name_or_path=$MODEL_ID \
--dataset_name=$DATASET_ID \
--enable_xformers_memory_efficient_attention \
@@ -52,10 +24,7 @@ accelerate launch train_instruct_pix2pix_lora.py \
--rank=4 \
--output_dir=$OUTPUT_DIR \
--report_to=wandb \
--push_to_hub \
--original_image_column="original_image" \
--edited_image_column="cartoonized_image" \
--edit_prompt_column="edit_prompt"
--push_to_hub
```
## Inference
@@ -14,10 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Script to fine-tune Stable Diffusion for LORA InstructPix2Pix.
Base code referred from: https://github.com/huggingface/diffusers/blob/main/examples/instruct_pix2pix/train_instruct_pix2pix.py
"""
"""Script to fine-tune Stable Diffusion for InstructPix2Pix."""
import argparse
import logging
@@ -33,7 +30,6 @@ import numpy as np
import PIL
import requests
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
import transformers
@@ -43,28 +39,21 @@ from accelerate.utils import ProjectConfiguration, set_seed
from datasets import load_dataset
from huggingface_hub import create_repo, upload_folder
from packaging import version
from peft import LoraConfig
from peft.utils import get_peft_model_state_dict
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer
import diffusers
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionInstructPix2PixPipeline, UNet2DConditionModel
from diffusers.models.lora import LoRALinearLayer
from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel, cast_training_params
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, deprecate, is_wandb_available
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.training_utils import EMAModel
from diffusers.utils import check_min_version, deprecate, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import is_compiled_module
if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.32.0.dev0")
check_min_version("0.26.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -74,92 +63,6 @@ DATASET_NAME_MAPPING = {
WANDB_TABLE_COL_NAMES = ["original_image", "edited_image", "edit_prompt"]
def save_model_card(
repo_id: str,
images: list = None,
base_model: str = None,
dataset_name: str = None,
repo_folder: str = None,
):
img_str = ""
if images is not None:
for i, image in enumerate(images):
image.save(os.path.join(repo_folder, f"image_{i}.png"))
img_str += f"![img_{i}](./image_{i}.png)\n"
model_description = f"""
# LoRA text2image fine-tuning - {repo_id}
These are LoRA adaption weights for {base_model}. The weights were fine-tuned on the {dataset_name} dataset. You can find some example images in the following. \n
{img_str}
"""
model_card = load_or_create_model_card(
repo_id_or_path=repo_id,
from_training=True,
license="creativeml-openrail-m",
base_model=base_model,
model_description=model_description,
inference=True,
)
tags = [
"stable-diffusion",
"stable-diffusion-diffusers",
"text-to-image",
"instruct-pix2pix",
"diffusers",
"diffusers-training",
"lora",
]
model_card = populate_model_card(model_card, tags=tags)
model_card.save(os.path.join(repo_folder, "README.md"))
def log_validation(
pipeline,
args,
accelerator,
generator,
):
logger.info(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
)
pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True)
# run inference
original_image = download_image(args.val_image_url)
edited_images = []
if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type)
with autocast_ctx:
for _ in range(args.num_validation_images):
edited_images.append(
pipeline(
args.validation_prompt,
image=original_image,
num_inference_steps=20,
image_guidance_scale=1.5,
guidance_scale=7,
generator=generator,
).images[0]
)
for tracker in accelerator.trackers:
if tracker.name == "wandb":
wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES)
for edited_image in edited_images:
wandb_table.add_data(wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt)
tracker.log({"validation": wandb_table})
return edited_images
def parse_args():
parser = argparse.ArgumentParser(description="Simple example of a training script for InstructPix2Pix.")
parser.add_argument(
@@ -514,6 +417,11 @@ def main():
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
if args.report_to == "wandb":
if not is_wandb_available():
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
import wandb
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
@@ -559,58 +467,49 @@ def main():
args.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision
)
# InstructPix2Pix uses an additional image for conditioning. To accommodate that,
# it uses 8 channels (instead of 4) in the first (conv) layer of the UNet. This UNet is
# then fine-tuned on the custom InstructPix2Pix dataset. This modified UNet is initialized
# from the pre-trained checkpoints. For the extra channels added to the first layer, they are
# initialized to zero.
logger.info("Initializing the InstructPix2Pix UNet from the pretrained UNet.")
in_channels = 8
out_channels = unet.conv_in.out_channels
unet.register_to_config(in_channels=in_channels)
with torch.no_grad():
new_conv_in = nn.Conv2d(
in_channels, out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding
)
new_conv_in.weight.zero_()
new_conv_in.weight[:, :in_channels, :, :].copy_(unet.conv_in.weight)
unet.conv_in = new_conv_in
# Freeze vae, text_encoder and unet
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
unet.requires_grad_(False)
# referred to https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora.py
# For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
# as these weights are only used for inference, keeping weights in full precision is not required.
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
unet_lora_parameters = []
for attn_processor_name, attn_processor in unet.attn_processors.items():
# Parse the attention module.
attn_module = unet
for n in attn_processor_name.split(".")[:-1]:
attn_module = getattr(attn_module, n)
# Freeze the unet parameters before adding adapters
unet.requires_grad_(False)
# Set the `lora_layer` attribute of the attention-related matrices.
attn_module.to_q.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=args.rank
)
)
attn_module.to_k.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=args.rank
)
)
unet_lora_config = LoraConfig(
r=args.rank,
lora_alpha=args.rank,
init_lora_weights="gaussian",
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
)
attn_module.to_v.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=args.rank
)
)
attn_module.to_out[0].set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_out[0].in_features,
out_features=attn_module.to_out[0].out_features,
rank=args.rank,
)
)
# Move unet, vae and text_encoder to device and cast to weight_dtype
unet.to(accelerator.device, dtype=weight_dtype)
vae.to(accelerator.device, dtype=weight_dtype)
text_encoder.to(accelerator.device, dtype=weight_dtype)
# Add adapter and make sure the trainable params are in float32.
unet.add_adapter(unet_lora_config)
if args.mixed_precision == "fp16":
# only upcast trainable parameters (LoRA) into fp32
cast_training_params(unet, dtype=torch.float32)
# Accumulate the LoRA params to optimize.
unet_lora_parameters.extend(attn_module.to_q.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.to_k.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters())
# Create EMA for the unet.
if args.use_ema:
@@ -629,13 +528,6 @@ def main():
else:
raise ValueError("xformers is not available. Make sure it is installed correctly")
trainable_params = filter(lambda p: p.requires_grad, unet.parameters())
def unwrap_model(model):
model = accelerator.unwrap_model(model)
model = model._orig_mod if is_compiled_module(model) else model
return model
# `accelerate` 0.16.0 will have better support for customized saving
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
@@ -648,8 +540,7 @@ def main():
model.save_pretrained(os.path.join(output_dir, "unet"))
# make sure to pop weight so that corresponding model is not saved again
if weights:
weights.pop()
weights.pop()
def load_model_hook(models, input_dir):
if args.use_ema:
@@ -698,9 +589,9 @@ def main():
else:
optimizer_cls = torch.optim.AdamW
# train on only lora_layers
# train on only unet_lora_parameters
optimizer = optimizer_cls(
trainable_params,
unet_lora_parameters,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
@@ -839,27 +730,22 @@ def main():
)
# Scheduler and math around the number of training steps.
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
num_training_steps_for_scheduler = (
args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes
)
else:
num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=num_warmup_steps_for_scheduler,
num_training_steps=num_training_steps_for_scheduler,
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps * accelerator.num_processes,
)
# Prepare everything with our `accelerator`.
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, optimizer, train_dataloader, lr_scheduler
unet, unet_lora_parameters, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, unet_lora_parameters, optimizer, train_dataloader, lr_scheduler
)
if args.use_ema:
@@ -879,14 +765,8 @@ def main():
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
if overrode_max_train_steps:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:
logger.warning(
f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
f"This inconsistency may result in the learning rate scheduler not functioning properly."
)
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
@@ -1005,7 +885,7 @@ def main():
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
# Predict the noise residual and compute loss
model_pred = unet(concatenated_noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0]
model_pred = unet(concatenated_noisy_latents, timesteps, encoder_hidden_states).sample
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
# Gather the losses across all processes for logging (if we use distributed training).
@@ -1015,7 +895,7 @@ def main():
# Backpropagate
accelerator.backward(loss)
if accelerator.sync_gradients:
accelerator.clip_grad_norm_(trainable_params, args.max_grad_norm)
accelerator.clip_grad_norm_(unet_lora_parameters, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
@@ -1023,7 +903,7 @@ def main():
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
if args.use_ema:
ema_unet.step(trainable_params)
ema_unet.step(unet_lora_parameters)
progress_bar.update(1)
global_step += 1
accelerator.log({"train_loss": train_loss}, step=global_step)
@@ -1053,16 +933,6 @@ def main():
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path)
unwrapped_unet = unwrap_model(unet)
unet_lora_state_dict = convert_state_dict_to_diffusers(
get_peft_model_state_dict(unwrapped_unet)
)
StableDiffusionInstructPix2PixPipeline.save_lora_weights(
save_directory=save_path,
unet_lora_layers=unet_lora_state_dict,
safe_serialization=True,
)
logger.info(f"Saved state to {save_path}")
logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
@@ -1089,22 +959,45 @@ def main():
# The models need unwrapping because for compatibility in distributed training mode.
pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained(
args.pretrained_model_name_or_path,
unet=unwrap_model(unet),
text_encoder=unwrap_model(text_encoder),
vae=unwrap_model(vae),
unet=accelerator.unwrap_model(unet),
text_encoder=accelerator.unwrap_model(text_encoder),
vae=accelerator.unwrap_model(vae),
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True)
# run inference
log_validation(
pipeline,
args,
accelerator,
generator,
)
original_image = download_image(args.val_image_url)
edited_images = []
if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type)
with autocast_ctx:
for _ in range(args.num_validation_images):
edited_images.append(
pipeline(
args.validation_prompt,
image=original_image,
num_inference_steps=20,
image_guidance_scale=1.5,
guidance_scale=7,
generator=generator,
).images[0]
)
for tracker in accelerator.trackers:
if tracker.name == "wandb":
wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES)
for edited_image in edited_images:
wandb_table.add_data(
wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt
)
tracker.log({"validation": wandb_table})
if args.use_ema:
# Switch back to the original UNet parameters.
ema_unet.restore(unet.parameters())
@@ -1115,47 +1008,22 @@ def main():
# Create the pipeline using the trained modules and save it.
accelerator.wait_for_everyone()
if accelerator.is_main_process:
unet = accelerator.unwrap_model(unet)
if args.use_ema:
ema_unet.copy_to(unet.parameters())
# store only LORA layers
unet = unet.to(torch.float32)
unwrapped_unet = unwrap_model(unet)
unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unwrapped_unet))
StableDiffusionInstructPix2PixPipeline.save_lora_weights(
save_directory=args.output_dir,
unet_lora_layers=unet_lora_state_dict,
safe_serialization=True,
)
pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained(
args.pretrained_model_name_or_path,
text_encoder=unwrap_model(text_encoder),
vae=unwrap_model(vae),
unet=unwrap_model(unet),
text_encoder=accelerator.unwrap_model(text_encoder),
vae=accelerator.unwrap_model(vae),
unet=unet,
revision=args.revision,
variant=args.variant,
)
pipeline.load_lora_weights(args.output_dir)
images = None
if (args.val_image_url is not None) and (args.validation_prompt is not None):
images = log_validation(
pipeline,
args,
accelerator,
generator,
)
# store only LORA layers
unet.save_attn_procs(args.output_dir)
if args.push_to_hub:
save_model_card(
repo_id,
images=images,
base_model=args.pretrained_model_name_or_path,
dataset_name=args.dataset_name,
repo_folder=args.output_dir,
)
upload_folder(
repo_id=repo_id,
folder_path=args.output_dir,
@@ -1163,6 +1031,31 @@ def main():
ignore_patterns=["step_*", "epoch_*"],
)
if args.validation_prompt is not None:
edited_images = []
pipeline = pipeline.to(accelerator.device)
with torch.autocast(str(accelerator.device).replace(":0", "")):
for _ in range(args.num_validation_images):
edited_images.append(
pipeline(
args.validation_prompt,
image=original_image,
num_inference_steps=20,
image_guidance_scale=1.5,
guidance_scale=7,
generator=generator,
).images[0]
)
for tracker in accelerator.trackers:
if tracker.name == "wandb":
wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES)
for edited_image in edited_images:
wandb_table.add_data(
wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt
)
tracker.log({"test": wandb_table})
accelerator.end_training()
@@ -310,7 +310,7 @@ class PixArtAlphaControlnetPipeline(DiffusionPipeline):
controlnet=controlnet,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.control_image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor)
@@ -233,7 +233,7 @@ class PromptDiffusionPipeline(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
self.control_image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
@@ -78,7 +78,7 @@ class RDMPipeline(DiffusionPipeline, StableDiffusionMixin):
feature_extractor=feature_extractor,
)
# Copy from statement here and all the methods we take from stable_diffusion_pipeline
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.retriever = retriever
@@ -6,4 +6,4 @@ torch==2.2.0
torchvision>=0.16
ftfy==6.1.1
tensorboard==2.14.0
Jinja2==3.1.5
Jinja2==3.1.4
@@ -765,7 +765,7 @@ def main(args):
lora_state_dict = StableDiffusion3Pipeline.lora_state_dict(input_dir)
transformer_state_dict = {
f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")
}
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
+4 -8
View File
@@ -25,7 +25,6 @@ from diffusers.utils.import_utils import is_accelerate_available
CTX = init_empty_weights if is_accelerate_available else nullcontext
ckpt_ids = [
"Efficient-Large-Model/Sana_1600M_4Kpx_BF16/checkpoints/Sana_1600M_4Kpx_BF16.pth",
"Efficient-Large-Model/Sana_1600M_2Kpx_BF16/checkpoints/Sana_1600M_2Kpx_BF16.pth",
"Efficient-Large-Model/Sana_1600M_1024px_MultiLing/checkpoints/Sana_1600M_1024px_MultiLing.pth",
"Efficient-Large-Model/Sana_1600M_1024px_BF16/checkpoints/Sana_1600M_1024px_BF16.pth",
@@ -90,10 +89,7 @@ def main(args):
converted_state_dict["caption_norm.weight"] = state_dict.pop("attention_y_norm.weight")
# scheduler
if args.image_size == 4096:
flow_shift = 6.0
else:
flow_shift = 3.0
flow_shift = 3.0
# model config
if args.model_type == "SanaMS_1600M_P1_D20":
@@ -103,7 +99,7 @@ def main(args):
else:
raise ValueError(f"{args.model_type} is not supported.")
# Positional embedding interpolation scale.
interpolation_scale = {512: None, 1024: None, 2048: 1.0, 4096: 2.0}
interpolation_scale = {512: None, 1024: None, 2048: 1.0}
for depth in range(layer_num):
# Transformer blocks.
@@ -276,9 +272,9 @@ if __name__ == "__main__":
"--image_size",
default=1024,
type=int,
choices=[512, 1024, 2048, 4096],
choices=[512, 1024, 2048],
required=False,
help="Image size of pretrained model, 512, 1024, 2048 or 4096.",
help="Image size of pretrained model, 512, 1024 or 2048.",
)
parser.add_argument(
"--model_type", default="SanaMS_1600M_P1_D20", type=str, choices=["SanaMS_1600M_P1_D20", "SanaMS_600M_P1_D28"]
-2
View File
@@ -135,7 +135,6 @@ _deps = [
"transformers>=4.41.2",
"urllib3<=2.0.0",
"black",
"phonemizer",
]
# this is a lookup table with items like:
@@ -228,7 +227,6 @@ extras["test"] = deps_list(
"scipy",
"torchvision",
"transformers",
"phonemizer",
)
extras["torch"] = deps_list("torch", "accelerate")
@@ -43,5 +43,4 @@ deps = {
"transformers": "transformers>=4.41.2",
"urllib3": "urllib3<=2.0.0",
"black": "black",
"phonemizer": "phonemizer",
}
+21 -156
View File
@@ -28,20 +28,13 @@ from ..models.modeling_utils import ModelMixin, load_state_dict
from ..utils import (
USE_PEFT_BACKEND,
_get_model_file,
convert_state_dict_to_diffusers,
convert_state_dict_to_peft,
delete_adapter_layers,
deprecate,
get_adapter_name,
get_peft_kwargs,
is_accelerate_available,
is_peft_available,
is_peft_version,
is_transformers_available,
is_transformers_version,
logging,
recurse_remove_peft_layers,
scale_lora_layers,
set_adapter_layers,
set_weights_and_activate_adapters,
)
@@ -50,8 +43,6 @@ from ..utils import (
if is_transformers_available():
from transformers import PreTrainedModel
from ..models.lora import text_encoder_attn_modules, text_encoder_mlp_modules
if is_peft_available():
from peft.tuners.tuners_utils import BaseTunerLayer
@@ -306,152 +297,6 @@ def _best_guess_weight_name(
return weight_name
def _load_lora_into_text_encoder(
state_dict,
network_alphas,
text_encoder,
prefix=None,
lora_scale=1.0,
text_encoder_name="text_encoder",
adapter_name=None,
_pipeline=None,
low_cpu_mem_usage=False,
):
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
peft_kwargs = {}
if low_cpu_mem_usage:
if not is_peft_version(">=", "0.13.1"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)
if not is_transformers_version(">", "4.45.2"):
# Note from sayakpaul: It's not in `transformers` stable yet.
# https://github.com/huggingface/transformers/pull/33725/
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
)
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
from peft import LoraConfig
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
# then the `state_dict` keys should have `unet_name` and/or `text_encoder_name` as
# their prefixes.
keys = list(state_dict.keys())
prefix = text_encoder_name if prefix is None else prefix
# Safe prefix to check with.
if any(text_encoder_name in key for key in keys):
# Load the layers corresponding to text encoder and make necessary adjustments.
text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix]
text_encoder_lora_state_dict = {
k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
}
if len(text_encoder_lora_state_dict) > 0:
logger.info(f"Loading {prefix}.")
rank = {}
text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
# convert state dict
text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
for name, _ in text_encoder_attn_modules(text_encoder):
for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
rank_key = f"{name}.{module}.lora_B.weight"
if rank_key not in text_encoder_lora_state_dict:
continue
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
for name, _ in text_encoder_mlp_modules(text_encoder):
for module in ("fc1", "fc2"):
rank_key = f"{name}.{module}.lora_B.weight"
if rank_key not in text_encoder_lora_state_dict:
continue
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
if network_alphas is not None:
alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"]:
if is_peft_version("<", "0.9.0"):
raise ValueError(
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
)
else:
if is_peft_version("<", "0.9.0"):
lora_config_kwargs.pop("use_dora")
if "lora_bias" in lora_config_kwargs:
if lora_config_kwargs["lora_bias"]:
if is_peft_version("<=", "0.13.2"):
raise ValueError(
"You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
)
else:
if is_peft_version("<=", "0.13.2"):
lora_config_kwargs.pop("lora_bias")
lora_config = LoraConfig(**lora_config_kwargs)
# adapter_name
if adapter_name is None:
adapter_name = get_adapter_name(text_encoder)
is_model_cpu_offload, is_sequential_cpu_offload = _func_optionally_disable_offloading(_pipeline)
# inject LoRA layers and load the state dict
# in transformers we automatically check whether the adapter name is already in use or not
text_encoder.load_adapter(
adapter_name=adapter_name,
adapter_state_dict=text_encoder_lora_state_dict,
peft_config=lora_config,
**peft_kwargs,
)
# scale LoRA layers with `lora_scale`
scale_lora_layers(text_encoder, weight=lora_scale)
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
# Offload back.
if is_model_cpu_offload:
_pipeline.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
_pipeline.enable_sequential_cpu_offload()
# Unsafe code />
def _func_optionally_disable_offloading(_pipeline):
is_model_cpu_offload = False
is_sequential_cpu_offload = False
if _pipeline is not None and _pipeline.hf_device_map is None:
for _, component in _pipeline.components.items():
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
if not is_model_cpu_offload:
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
if not is_sequential_cpu_offload:
is_sequential_cpu_offload = (
isinstance(component._hf_hook, AlignDevicesHook)
or hasattr(component._hf_hook, "hooks")
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
)
logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
return (is_model_cpu_offload, is_sequential_cpu_offload)
class LoraBaseMixin:
"""Utility class for handling LoRAs."""
@@ -482,7 +327,27 @@ class LoraBaseMixin:
tuple:
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
"""
return _func_optionally_disable_offloading(_pipeline=_pipeline)
is_model_cpu_offload = False
is_sequential_cpu_offload = False
if _pipeline is not None and _pipeline.hf_device_map is None:
for _, component in _pipeline.components.items():
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
if not is_model_cpu_offload:
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
if not is_sequential_cpu_offload:
is_sequential_cpu_offload = (
isinstance(component._hf_hook, AlignDevicesHook)
or hasattr(component._hf_hook, "hooks")
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
)
logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
return (is_model_cpu_offload, is_sequential_cpu_offload)
@classmethod
def _fetch_state_dict(cls, *args, **kwargs):
@@ -973,178 +973,3 @@ def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict):
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
return converted_state_dict
def _convert_hunyuan_video_lora_to_diffusers(original_state_dict):
converted_state_dict = {k: original_state_dict.pop(k) for k in list(original_state_dict.keys())}
def remap_norm_scale_shift_(key, state_dict):
weight = state_dict.pop(key)
shift, scale = weight.chunk(2, dim=0)
new_weight = torch.cat([scale, shift], dim=0)
state_dict[key.replace("final_layer.adaLN_modulation.1", "norm_out.linear")] = new_weight
def remap_txt_in_(key, state_dict):
def rename_key(key):
new_key = key.replace("individual_token_refiner.blocks", "token_refiner.refiner_blocks")
new_key = new_key.replace("adaLN_modulation.1", "norm_out.linear")
new_key = new_key.replace("txt_in", "context_embedder")
new_key = new_key.replace("t_embedder.mlp.0", "time_text_embed.timestep_embedder.linear_1")
new_key = new_key.replace("t_embedder.mlp.2", "time_text_embed.timestep_embedder.linear_2")
new_key = new_key.replace("c_embedder", "time_text_embed.text_embedder")
new_key = new_key.replace("mlp", "ff")
return new_key
if "self_attn_qkv" in key:
weight = state_dict.pop(key)
to_q, to_k, to_v = weight.chunk(3, dim=0)
state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_q"))] = to_q
state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_k"))] = to_k
state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_v"))] = to_v
else:
state_dict[rename_key(key)] = state_dict.pop(key)
def remap_img_attn_qkv_(key, state_dict):
weight = state_dict.pop(key)
if "lora_A" in key:
state_dict[key.replace("img_attn_qkv", "attn.to_q")] = weight
state_dict[key.replace("img_attn_qkv", "attn.to_k")] = weight
state_dict[key.replace("img_attn_qkv", "attn.to_v")] = weight
else:
to_q, to_k, to_v = weight.chunk(3, dim=0)
state_dict[key.replace("img_attn_qkv", "attn.to_q")] = to_q
state_dict[key.replace("img_attn_qkv", "attn.to_k")] = to_k
state_dict[key.replace("img_attn_qkv", "attn.to_v")] = to_v
def remap_txt_attn_qkv_(key, state_dict):
weight = state_dict.pop(key)
if "lora_A" in key:
state_dict[key.replace("txt_attn_qkv", "attn.add_q_proj")] = weight
state_dict[key.replace("txt_attn_qkv", "attn.add_k_proj")] = weight
state_dict[key.replace("txt_attn_qkv", "attn.add_v_proj")] = weight
else:
to_q, to_k, to_v = weight.chunk(3, dim=0)
state_dict[key.replace("txt_attn_qkv", "attn.add_q_proj")] = to_q
state_dict[key.replace("txt_attn_qkv", "attn.add_k_proj")] = to_k
state_dict[key.replace("txt_attn_qkv", "attn.add_v_proj")] = to_v
def remap_single_transformer_blocks_(key, state_dict):
hidden_size = 3072
if "linear1.lora_A.weight" in key or "linear1.lora_B.weight" in key:
linear1_weight = state_dict.pop(key)
if "lora_A" in key:
new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(
".linear1.lora_A.weight"
)
state_dict[f"{new_key}.attn.to_q.lora_A.weight"] = linear1_weight
state_dict[f"{new_key}.attn.to_k.lora_A.weight"] = linear1_weight
state_dict[f"{new_key}.attn.to_v.lora_A.weight"] = linear1_weight
state_dict[f"{new_key}.proj_mlp.lora_A.weight"] = linear1_weight
else:
split_size = (hidden_size, hidden_size, hidden_size, linear1_weight.size(0) - 3 * hidden_size)
q, k, v, mlp = torch.split(linear1_weight, split_size, dim=0)
new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(
".linear1.lora_B.weight"
)
state_dict[f"{new_key}.attn.to_q.lora_B.weight"] = q
state_dict[f"{new_key}.attn.to_k.lora_B.weight"] = k
state_dict[f"{new_key}.attn.to_v.lora_B.weight"] = v
state_dict[f"{new_key}.proj_mlp.lora_B.weight"] = mlp
elif "linear1.lora_A.bias" in key or "linear1.lora_B.bias" in key:
linear1_bias = state_dict.pop(key)
if "lora_A" in key:
new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(
".linear1.lora_A.bias"
)
state_dict[f"{new_key}.attn.to_q.lora_A.bias"] = linear1_bias
state_dict[f"{new_key}.attn.to_k.lora_A.bias"] = linear1_bias
state_dict[f"{new_key}.attn.to_v.lora_A.bias"] = linear1_bias
state_dict[f"{new_key}.proj_mlp.lora_A.bias"] = linear1_bias
else:
split_size = (hidden_size, hidden_size, hidden_size, linear1_bias.size(0) - 3 * hidden_size)
q_bias, k_bias, v_bias, mlp_bias = torch.split(linear1_bias, split_size, dim=0)
new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(
".linear1.lora_B.bias"
)
state_dict[f"{new_key}.attn.to_q.lora_B.bias"] = q_bias
state_dict[f"{new_key}.attn.to_k.lora_B.bias"] = k_bias
state_dict[f"{new_key}.attn.to_v.lora_B.bias"] = v_bias
state_dict[f"{new_key}.proj_mlp.lora_B.bias"] = mlp_bias
else:
new_key = key.replace("single_blocks", "single_transformer_blocks")
new_key = new_key.replace("linear2", "proj_out")
new_key = new_key.replace("q_norm", "attn.norm_q")
new_key = new_key.replace("k_norm", "attn.norm_k")
state_dict[new_key] = state_dict.pop(key)
TRANSFORMER_KEYS_RENAME_DICT = {
"img_in": "x_embedder",
"time_in.mlp.0": "time_text_embed.timestep_embedder.linear_1",
"time_in.mlp.2": "time_text_embed.timestep_embedder.linear_2",
"guidance_in.mlp.0": "time_text_embed.guidance_embedder.linear_1",
"guidance_in.mlp.2": "time_text_embed.guidance_embedder.linear_2",
"vector_in.in_layer": "time_text_embed.text_embedder.linear_1",
"vector_in.out_layer": "time_text_embed.text_embedder.linear_2",
"double_blocks": "transformer_blocks",
"img_attn_q_norm": "attn.norm_q",
"img_attn_k_norm": "attn.norm_k",
"img_attn_proj": "attn.to_out.0",
"txt_attn_q_norm": "attn.norm_added_q",
"txt_attn_k_norm": "attn.norm_added_k",
"txt_attn_proj": "attn.to_add_out",
"img_mod.linear": "norm1.linear",
"img_norm1": "norm1.norm",
"img_norm2": "norm2",
"img_mlp": "ff",
"txt_mod.linear": "norm1_context.linear",
"txt_norm1": "norm1.norm",
"txt_norm2": "norm2_context",
"txt_mlp": "ff_context",
"self_attn_proj": "attn.to_out.0",
"modulation.linear": "norm.linear",
"pre_norm": "norm.norm",
"final_layer.norm_final": "norm_out.norm",
"final_layer.linear": "proj_out",
"fc1": "net.0.proj",
"fc2": "net.2",
"input_embedder": "proj_in",
}
TRANSFORMER_SPECIAL_KEYS_REMAP = {
"txt_in": remap_txt_in_,
"img_attn_qkv": remap_img_attn_qkv_,
"txt_attn_qkv": remap_txt_attn_qkv_,
"single_blocks": remap_single_transformer_blocks_,
"final_layer.adaLN_modulation.1": remap_norm_scale_shift_,
}
# Some folks attempt to make their state dict compatible with diffusers by adding "transformer." prefix to all keys
# and use their custom code. To make sure both "original" and "attempted diffusers" loras work as expected, we make
# sure that both follow the same initial format by stripping off the "transformer." prefix.
for key in list(converted_state_dict.keys()):
if key.startswith("transformer."):
converted_state_dict[key[len("transformer.") :]] = converted_state_dict.pop(key)
if key.startswith("diffusion_model."):
converted_state_dict[key[len("diffusion_model.") :]] = converted_state_dict.pop(key)
# Rename and remap the state dict keys
for key in list(converted_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
converted_state_dict[new_key] = converted_state_dict.pop(key)
for key in list(converted_state_dict.keys()):
for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
if special_key not in key:
continue
handler_fn_inplace(key, converted_state_dict)
# Add back the "transformer." prefix
for key in list(converted_state_dict.keys()):
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
return converted_state_dict
+612 -79
View File
@@ -20,24 +20,22 @@ from huggingface_hub.utils import validate_hf_hub_args
from ..utils import (
USE_PEFT_BACKEND,
convert_state_dict_to_diffusers,
convert_state_dict_to_peft,
deprecate,
get_adapter_name,
get_peft_kwargs,
is_peft_available,
is_peft_version,
is_torch_version,
is_transformers_available,
is_transformers_version,
logging,
scale_lora_layers,
)
from .lora_base import ( # noqa
LORA_WEIGHT_NAME,
LORA_WEIGHT_NAME_SAFE,
LoraBaseMixin,
_fetch_state_dict,
_load_lora_into_text_encoder,
)
from .lora_base import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, LoraBaseMixin, _fetch_state_dict # noqa
from .lora_conversion_utils import (
_convert_bfl_flux_control_lora_to_diffusers,
_convert_hunyuan_video_lora_to_diffusers,
_convert_kohya_flux_lora_to_diffusers,
_convert_non_diffusers_lora_to_diffusers,
_convert_xlabs_flux_lora_to_diffusers,
@@ -56,6 +54,9 @@ if is_torch_version(">=", "1.9.0"):
_LOW_CPU_MEM_USAGE_DEFAULT_LORA = True
if is_transformers_available():
from ..models.lora import text_encoder_attn_modules, text_encoder_mlp_modules
logger = logging.get_logger(__name__)
TEXT_ENCODER_NAME = "text_encoder"
@@ -347,17 +348,119 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
"""
_load_lora_into_text_encoder(
state_dict=state_dict,
network_alphas=network_alphas,
lora_scale=lora_scale,
text_encoder=text_encoder,
prefix=prefix,
text_encoder_name=cls.text_encoder_name,
adapter_name=adapter_name,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
)
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
peft_kwargs = {}
if low_cpu_mem_usage:
if not is_peft_version(">=", "0.13.1"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)
if not is_transformers_version(">", "4.45.2"):
# Note from sayakpaul: It's not in `transformers` stable yet.
# https://github.com/huggingface/transformers/pull/33725/
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
)
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
from peft import LoraConfig
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
# then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
# their prefixes.
keys = list(state_dict.keys())
prefix = cls.text_encoder_name if prefix is None else prefix
# Safe prefix to check with.
if any(cls.text_encoder_name in key for key in keys):
# Load the layers corresponding to text encoder and make necessary adjustments.
text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix]
text_encoder_lora_state_dict = {
k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
}
if len(text_encoder_lora_state_dict) > 0:
logger.info(f"Loading {prefix}.")
rank = {}
text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
# convert state dict
text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
for name, _ in text_encoder_attn_modules(text_encoder):
for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
rank_key = f"{name}.{module}.lora_B.weight"
if rank_key not in text_encoder_lora_state_dict:
continue
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
for name, _ in text_encoder_mlp_modules(text_encoder):
for module in ("fc1", "fc2"):
rank_key = f"{name}.{module}.lora_B.weight"
if rank_key not in text_encoder_lora_state_dict:
continue
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
if network_alphas is not None:
alpha_keys = [
k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix
]
network_alphas = {
k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
}
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"]:
if is_peft_version("<", "0.9.0"):
raise ValueError(
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
)
else:
if is_peft_version("<", "0.9.0"):
lora_config_kwargs.pop("use_dora")
if "lora_bias" in lora_config_kwargs:
if lora_config_kwargs["lora_bias"]:
if is_peft_version("<=", "0.13.2"):
raise ValueError(
"You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
)
else:
if is_peft_version("<=", "0.13.2"):
lora_config_kwargs.pop("lora_bias")
lora_config = LoraConfig(**lora_config_kwargs)
# adapter_name
if adapter_name is None:
adapter_name = get_adapter_name(text_encoder)
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
# inject LoRA layers and load the state dict
# in transformers we automatically check whether the adapter name is already in use or not
text_encoder.load_adapter(
adapter_name=adapter_name,
adapter_state_dict=text_encoder_lora_state_dict,
peft_config=lora_config,
**peft_kwargs,
)
# scale LoRA layers with `lora_scale`
scale_lora_layers(text_encoder, weight=lora_scale)
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
# Offload back.
if is_model_cpu_offload:
_pipeline.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
_pipeline.enable_sequential_cpu_offload()
# Unsafe code />
@classmethod
def save_lora_weights(
@@ -788,17 +891,119 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
"""
_load_lora_into_text_encoder(
state_dict=state_dict,
network_alphas=network_alphas,
lora_scale=lora_scale,
text_encoder=text_encoder,
prefix=prefix,
text_encoder_name=cls.text_encoder_name,
adapter_name=adapter_name,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
)
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
peft_kwargs = {}
if low_cpu_mem_usage:
if not is_peft_version(">=", "0.13.1"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)
if not is_transformers_version(">", "4.45.2"):
# Note from sayakpaul: It's not in `transformers` stable yet.
# https://github.com/huggingface/transformers/pull/33725/
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
)
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
from peft import LoraConfig
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
# then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
# their prefixes.
keys = list(state_dict.keys())
prefix = cls.text_encoder_name if prefix is None else prefix
# Safe prefix to check with.
if any(cls.text_encoder_name in key for key in keys):
# Load the layers corresponding to text encoder and make necessary adjustments.
text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix]
text_encoder_lora_state_dict = {
k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
}
if len(text_encoder_lora_state_dict) > 0:
logger.info(f"Loading {prefix}.")
rank = {}
text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
# convert state dict
text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
for name, _ in text_encoder_attn_modules(text_encoder):
for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
rank_key = f"{name}.{module}.lora_B.weight"
if rank_key not in text_encoder_lora_state_dict:
continue
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
for name, _ in text_encoder_mlp_modules(text_encoder):
for module in ("fc1", "fc2"):
rank_key = f"{name}.{module}.lora_B.weight"
if rank_key not in text_encoder_lora_state_dict:
continue
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
if network_alphas is not None:
alpha_keys = [
k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix
]
network_alphas = {
k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
}
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"]:
if is_peft_version("<", "0.9.0"):
raise ValueError(
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
)
else:
if is_peft_version("<", "0.9.0"):
lora_config_kwargs.pop("use_dora")
if "lora_bias" in lora_config_kwargs:
if lora_config_kwargs["lora_bias"]:
if is_peft_version("<=", "0.13.2"):
raise ValueError(
"You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
)
else:
if is_peft_version("<=", "0.13.2"):
lora_config_kwargs.pop("lora_bias")
lora_config = LoraConfig(**lora_config_kwargs)
# adapter_name
if adapter_name is None:
adapter_name = get_adapter_name(text_encoder)
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
# inject LoRA layers and load the state dict
# in transformers we automatically check whether the adapter name is already in use or not
text_encoder.load_adapter(
adapter_name=adapter_name,
adapter_state_dict=text_encoder_lora_state_dict,
peft_config=lora_config,
**peft_kwargs,
)
# scale LoRA layers with `lora_scale`
scale_lora_layers(text_encoder, weight=lora_scale)
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
# Offload back.
if is_model_cpu_offload:
_pipeline.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
_pipeline.enable_sequential_cpu_offload()
# Unsafe code />
@classmethod
def save_lora_weights(
@@ -1195,17 +1400,119 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
"""
_load_lora_into_text_encoder(
state_dict=state_dict,
network_alphas=network_alphas,
lora_scale=lora_scale,
text_encoder=text_encoder,
prefix=prefix,
text_encoder_name=cls.text_encoder_name,
adapter_name=adapter_name,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
)
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
peft_kwargs = {}
if low_cpu_mem_usage:
if not is_peft_version(">=", "0.13.1"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)
if not is_transformers_version(">", "4.45.2"):
# Note from sayakpaul: It's not in `transformers` stable yet.
# https://github.com/huggingface/transformers/pull/33725/
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
)
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
from peft import LoraConfig
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
# then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
# their prefixes.
keys = list(state_dict.keys())
prefix = cls.text_encoder_name if prefix is None else prefix
# Safe prefix to check with.
if any(cls.text_encoder_name in key for key in keys):
# Load the layers corresponding to text encoder and make necessary adjustments.
text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix]
text_encoder_lora_state_dict = {
k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
}
if len(text_encoder_lora_state_dict) > 0:
logger.info(f"Loading {prefix}.")
rank = {}
text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
# convert state dict
text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
for name, _ in text_encoder_attn_modules(text_encoder):
for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
rank_key = f"{name}.{module}.lora_B.weight"
if rank_key not in text_encoder_lora_state_dict:
continue
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
for name, _ in text_encoder_mlp_modules(text_encoder):
for module in ("fc1", "fc2"):
rank_key = f"{name}.{module}.lora_B.weight"
if rank_key not in text_encoder_lora_state_dict:
continue
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
if network_alphas is not None:
alpha_keys = [
k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix
]
network_alphas = {
k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
}
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"]:
if is_peft_version("<", "0.9.0"):
raise ValueError(
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
)
else:
if is_peft_version("<", "0.9.0"):
lora_config_kwargs.pop("use_dora")
if "lora_bias" in lora_config_kwargs:
if lora_config_kwargs["lora_bias"]:
if is_peft_version("<=", "0.13.2"):
raise ValueError(
"You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
)
else:
if is_peft_version("<=", "0.13.2"):
lora_config_kwargs.pop("lora_bias")
lora_config = LoraConfig(**lora_config_kwargs)
# adapter_name
if adapter_name is None:
adapter_name = get_adapter_name(text_encoder)
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
# inject LoRA layers and load the state dict
# in transformers we automatically check whether the adapter name is already in use or not
text_encoder.load_adapter(
adapter_name=adapter_name,
adapter_state_dict=text_encoder_lora_state_dict,
peft_config=lora_config,
**peft_kwargs,
)
# scale LoRA layers with `lora_scale`
scale_lora_layers(text_encoder, weight=lora_scale)
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
# Offload back.
if is_model_cpu_offload:
_pipeline.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
_pipeline.enable_sequential_cpu_offload()
# Unsafe code />
@classmethod
def save_lora_weights(
@@ -1725,17 +2032,119 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
"""
_load_lora_into_text_encoder(
state_dict=state_dict,
network_alphas=network_alphas,
lora_scale=lora_scale,
text_encoder=text_encoder,
prefix=prefix,
text_encoder_name=cls.text_encoder_name,
adapter_name=adapter_name,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
)
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
peft_kwargs = {}
if low_cpu_mem_usage:
if not is_peft_version(">=", "0.13.1"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)
if not is_transformers_version(">", "4.45.2"):
# Note from sayakpaul: It's not in `transformers` stable yet.
# https://github.com/huggingface/transformers/pull/33725/
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
)
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
from peft import LoraConfig
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
# then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
# their prefixes.
keys = list(state_dict.keys())
prefix = cls.text_encoder_name if prefix is None else prefix
# Safe prefix to check with.
if any(cls.text_encoder_name in key for key in keys):
# Load the layers corresponding to text encoder and make necessary adjustments.
text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix]
text_encoder_lora_state_dict = {
k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
}
if len(text_encoder_lora_state_dict) > 0:
logger.info(f"Loading {prefix}.")
rank = {}
text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
# convert state dict
text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
for name, _ in text_encoder_attn_modules(text_encoder):
for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
rank_key = f"{name}.{module}.lora_B.weight"
if rank_key not in text_encoder_lora_state_dict:
continue
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
for name, _ in text_encoder_mlp_modules(text_encoder):
for module in ("fc1", "fc2"):
rank_key = f"{name}.{module}.lora_B.weight"
if rank_key not in text_encoder_lora_state_dict:
continue
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
if network_alphas is not None:
alpha_keys = [
k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix
]
network_alphas = {
k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
}
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"]:
if is_peft_version("<", "0.9.0"):
raise ValueError(
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
)
else:
if is_peft_version("<", "0.9.0"):
lora_config_kwargs.pop("use_dora")
if "lora_bias" in lora_config_kwargs:
if lora_config_kwargs["lora_bias"]:
if is_peft_version("<=", "0.13.2"):
raise ValueError(
"You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
)
else:
if is_peft_version("<=", "0.13.2"):
lora_config_kwargs.pop("lora_bias")
lora_config = LoraConfig(**lora_config_kwargs)
# adapter_name
if adapter_name is None:
adapter_name = get_adapter_name(text_encoder)
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
# inject LoRA layers and load the state dict
# in transformers we automatically check whether the adapter name is already in use or not
text_encoder.load_adapter(
adapter_name=adapter_name,
adapter_state_dict=text_encoder_lora_state_dict,
peft_config=lora_config,
**peft_kwargs,
)
# scale LoRA layers with `lora_scale`
scale_lora_layers(text_encoder, weight=lora_scale)
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
# Offload back.
if is_model_cpu_offload:
_pipeline.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
_pipeline.enable_sequential_cpu_offload()
# Unsafe code />
@classmethod
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.save_lora_weights with unet->transformer
@@ -1794,7 +2203,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
def fuse_lora(
self,
components: List[str] = ["transformer"],
components: List[str] = ["transformer", "text_encoder"],
lora_scale: float = 1.0,
safe_fusing: bool = False,
adapter_names: Optional[List[str]] = None,
@@ -2188,17 +2597,119 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
"""
_load_lora_into_text_encoder(
state_dict=state_dict,
network_alphas=network_alphas,
lora_scale=lora_scale,
text_encoder=text_encoder,
prefix=prefix,
text_encoder_name=cls.text_encoder_name,
adapter_name=adapter_name,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
)
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
peft_kwargs = {}
if low_cpu_mem_usage:
if not is_peft_version(">=", "0.13.1"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)
if not is_transformers_version(">", "4.45.2"):
# Note from sayakpaul: It's not in `transformers` stable yet.
# https://github.com/huggingface/transformers/pull/33725/
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
)
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
from peft import LoraConfig
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
# then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
# their prefixes.
keys = list(state_dict.keys())
prefix = cls.text_encoder_name if prefix is None else prefix
# Safe prefix to check with.
if any(cls.text_encoder_name in key for key in keys):
# Load the layers corresponding to text encoder and make necessary adjustments.
text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix]
text_encoder_lora_state_dict = {
k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
}
if len(text_encoder_lora_state_dict) > 0:
logger.info(f"Loading {prefix}.")
rank = {}
text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
# convert state dict
text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
for name, _ in text_encoder_attn_modules(text_encoder):
for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
rank_key = f"{name}.{module}.lora_B.weight"
if rank_key not in text_encoder_lora_state_dict:
continue
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
for name, _ in text_encoder_mlp_modules(text_encoder):
for module in ("fc1", "fc2"):
rank_key = f"{name}.{module}.lora_B.weight"
if rank_key not in text_encoder_lora_state_dict:
continue
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
if network_alphas is not None:
alpha_keys = [
k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix
]
network_alphas = {
k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
}
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"]:
if is_peft_version("<", "0.9.0"):
raise ValueError(
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
)
else:
if is_peft_version("<", "0.9.0"):
lora_config_kwargs.pop("use_dora")
if "lora_bias" in lora_config_kwargs:
if lora_config_kwargs["lora_bias"]:
if is_peft_version("<=", "0.13.2"):
raise ValueError(
"You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
)
else:
if is_peft_version("<=", "0.13.2"):
lora_config_kwargs.pop("lora_bias")
lora_config = LoraConfig(**lora_config_kwargs)
# adapter_name
if adapter_name is None:
adapter_name = get_adapter_name(text_encoder)
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
# inject LoRA layers and load the state dict
# in transformers we automatically check whether the adapter name is already in use or not
text_encoder.load_adapter(
adapter_name=adapter_name,
adapter_state_dict=text_encoder_lora_state_dict,
peft_config=lora_config,
**peft_kwargs,
)
# scale LoRA layers with `lora_scale`
scale_lora_layers(text_encoder, weight=lora_scale)
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
# Offload back.
if is_model_cpu_offload:
_pipeline.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
_pipeline.enable_sequential_cpu_offload()
# Unsafe code />
@classmethod
def save_lora_weights(
@@ -2496,9 +3007,10 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
safe_serialization=safe_serialization,
)
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer
def fuse_lora(
self,
components: List[str] = ["transformer"],
components: List[str] = ["transformer", "text_encoder"],
lora_scale: float = 1.0,
safe_fusing: bool = False,
adapter_names: Optional[List[str]] = None,
@@ -2539,7 +3051,8 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
)
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.unfuse_lora with unet->transformer
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
r"""
Reverses the effect of
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
@@ -2553,6 +3066,9 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
Args:
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
unfuse_text_encoder (`bool`, defaults to `True`):
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
LoRA parameters then it won't have any effect.
"""
super().unfuse_lora(components=components)
@@ -2799,9 +3315,10 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
safe_serialization=safe_serialization,
)
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer
def fuse_lora(
self,
components: List[str] = ["transformer"],
components: List[str] = ["transformer", "text_encoder"],
lora_scale: float = 1.0,
safe_fusing: bool = False,
adapter_names: Optional[List[str]] = None,
@@ -2842,7 +3359,8 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
)
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.unfuse_lora with unet->transformer
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
r"""
Reverses the effect of
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
@@ -2856,6 +3374,9 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
Args:
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
unfuse_text_encoder (`bool`, defaults to `True`):
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
LoRA parameters then it won't have any effect.
"""
super().unfuse_lora(components=components)
@@ -3102,9 +3623,10 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
safe_serialization=safe_serialization,
)
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer
def fuse_lora(
self,
components: List[str] = ["transformer"],
components: List[str] = ["transformer", "text_encoder"],
lora_scale: float = 1.0,
safe_fusing: bool = False,
adapter_names: Optional[List[str]] = None,
@@ -3145,7 +3667,8 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
)
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.unfuse_lora with unet->transformer
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
r"""
Reverses the effect of
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
@@ -3159,6 +3682,9 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
Args:
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
unfuse_text_encoder (`bool`, defaults to `True`):
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
LoRA parameters then it won't have any effect.
"""
super().unfuse_lora(components=components)
@@ -3405,9 +3931,10 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
safe_serialization=safe_serialization,
)
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer
def fuse_lora(
self,
components: List[str] = ["transformer"],
components: List[str] = ["transformer", "text_encoder"],
lora_scale: float = 1.0,
safe_fusing: bool = False,
adapter_names: Optional[List[str]] = None,
@@ -3448,7 +3975,8 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
)
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.unfuse_lora with unet->transformer
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
r"""
Reverses the effect of
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
@@ -3462,6 +3990,9 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
Args:
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
unfuse_text_encoder (`bool`, defaults to `True`):
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
LoRA parameters then it won't have any effect.
"""
super().unfuse_lora(components=components)
@@ -3476,6 +4007,7 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
@classmethod
@validate_hf_hub_args
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
def lora_state_dict(
cls,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
@@ -3486,7 +4018,7 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
<Tip warning={true}>
We support loading original format HunyuanVideo LoRA checkpoints.
We support loading A1111 formatted LoRA checkpoints in a limited capacity.
This function is experimental and might change in the future.
@@ -3569,10 +4101,6 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
logger.warning(warn_msg)
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
is_original_hunyuan_video = any("img_attn_qkv" in k for k in state_dict)
if is_original_hunyuan_video:
state_dict = _convert_hunyuan_video_lora_to_diffusers(state_dict)
return state_dict
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
@@ -3711,9 +4239,10 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
safe_serialization=safe_serialization,
)
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer
def fuse_lora(
self,
components: List[str] = ["transformer"],
components: List[str] = ["transformer", "text_encoder"],
lora_scale: float = 1.0,
safe_fusing: bool = False,
adapter_names: Optional[List[str]] = None,
@@ -3754,7 +4283,8 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
)
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.unfuse_lora with unet->transformer
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
r"""
Reverses the effect of
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
@@ -3768,6 +4298,9 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
Args:
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
unfuse_text_encoder (`bool`, defaults to `True`):
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
LoRA parameters then it won't have any effect.
"""
super().unfuse_lora(components=components)
+35 -12
View File
@@ -20,6 +20,7 @@ from typing import Dict, List, Optional, Union
import safetensors
import torch
import torch.nn as nn
from ..utils import (
MIN_PEFT_VERSION,
@@ -29,16 +30,20 @@ from ..utils import (
delete_adapter_layers,
get_adapter_name,
get_peft_kwargs,
is_accelerate_available,
is_peft_available,
is_peft_version,
logging,
set_adapter_layers,
set_weights_and_activate_adapters,
)
from .lora_base import _fetch_state_dict, _func_optionally_disable_offloading
from .lora_base import _fetch_state_dict
from .unet_loader_utils import _maybe_expand_lora_scales
if is_accelerate_available():
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
logger = logging.get_logger(__name__)
_SET_ADAPTER_SCALE_FN_MAPPING = {
@@ -135,7 +140,27 @@ class PeftAdapterMixin:
tuple:
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
"""
return _func_optionally_disable_offloading(_pipeline=_pipeline)
is_model_cpu_offload = False
is_sequential_cpu_offload = False
if _pipeline is not None and _pipeline.hf_device_map is None:
for _, component in _pipeline.components.items():
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
if not is_model_cpu_offload:
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
if not is_sequential_cpu_offload:
is_sequential_cpu_offload = (
isinstance(component._hf_hook, AlignDevicesHook)
or hasattr(component._hf_hook, "hooks")
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
)
logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
return (is_model_cpu_offload, is_sequential_cpu_offload)
def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="transformer", **kwargs):
r"""
@@ -300,17 +325,15 @@ class PeftAdapterMixin:
try:
inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs)
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs)
except Exception as e:
# In case `inject_adapter_in_model()` was unsuccessful even before injecting the `peft_config`.
if hasattr(self, "peft_config"):
for module in self.modules():
if isinstance(module, BaseTunerLayer):
active_adapters = module.active_adapters
for active_adapter in active_adapters:
if adapter_name in active_adapter:
module.delete_adapter(adapter_name)
except RuntimeError as e:
for module in self.modules():
if isinstance(module, BaseTunerLayer):
active_adapters = module.active_adapters
for active_adapter in active_adapters:
if adapter_name in active_adapter:
module.delete_adapter(adapter_name)
self.peft_config.pop(adapter_name)
self.peft_config.pop(adapter_name)
logger.error(f"Loading {adapter_name} was unsucessful with the following error: \n{e}")
raise
-8
View File
@@ -60,7 +60,6 @@ def load_single_file_sub_model(
local_files_only=False,
torch_dtype=None,
is_legacy_loading=False,
disable_mmap=False,
**kwargs,
):
if is_pipeline_module:
@@ -107,7 +106,6 @@ def load_single_file_sub_model(
subfolder=name,
torch_dtype=torch_dtype,
local_files_only=local_files_only,
disable_mmap=disable_mmap,
**kwargs,
)
@@ -310,9 +308,6 @@ class FromSingleFileMixin:
hosted on the Hub.
- A path to a *directory* (for example `./my_pipeline_directory/`) containing the pipeline
component configs in Diffusers format.
disable_mmap ('bool', *optional*, defaults to 'False'):
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
is on a network mount or hard drive.
kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline
class). The overwritten components are passed directly to the pipelines `__init__` method. See example
@@ -360,7 +355,6 @@ class FromSingleFileMixin:
local_files_only = kwargs.pop("local_files_only", False)
revision = kwargs.pop("revision", None)
torch_dtype = kwargs.pop("torch_dtype", None)
disable_mmap = kwargs.pop("disable_mmap", False)
is_legacy_loading = False
@@ -389,7 +383,6 @@ class FromSingleFileMixin:
cache_dir=cache_dir,
local_files_only=local_files_only,
revision=revision,
disable_mmap=disable_mmap,
)
if config is None:
@@ -511,7 +504,6 @@ class FromSingleFileMixin:
original_config=original_config,
local_files_only=local_files_only,
is_legacy_loading=is_legacy_loading,
disable_mmap=disable_mmap,
**kwargs,
)
except SingleFileComponentError as e:
@@ -25,7 +25,6 @@ from ..utils import deprecate, is_accelerate_available, logging
from .single_file_utils import (
SingleFileComponentError,
convert_animatediff_checkpoint_to_diffusers,
convert_auraflow_transformer_checkpoint_to_diffusers,
convert_autoencoder_dc_checkpoint_to_diffusers,
convert_controlnet_checkpoint,
convert_flux_transformer_checkpoint_to_diffusers,
@@ -107,10 +106,6 @@ SINGLE_FILE_LOADABLE_CLASSES = {
"checkpoint_mapping_fn": convert_hunyuan_video_transformer_to_diffusers,
"default_subfolder": "transformer",
},
"AuraFlowTransformer2DModel": {
"checkpoint_mapping_fn": convert_auraflow_transformer_checkpoint_to_diffusers,
"default_subfolder": "transformer",
},
}
@@ -187,9 +182,6 @@ class FromOriginalModelMixin:
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
allowed by Git.
disable_mmap ('bool', *optional*, defaults to 'False'):
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to overwrite load and saveable variables (for example the pipeline components of the
specific pipeline class). The overwritten components are directly passed to the pipelines `__init__`
@@ -237,7 +229,6 @@ class FromOriginalModelMixin:
torch_dtype = kwargs.pop("torch_dtype", None)
quantization_config = kwargs.pop("quantization_config", None)
device = kwargs.pop("device", None)
disable_mmap = kwargs.pop("disable_mmap", False)
if isinstance(pretrained_model_link_or_path_or_dict, dict):
checkpoint = pretrained_model_link_or_path_or_dict
@@ -250,7 +241,6 @@ class FromOriginalModelMixin:
cache_dir=cache_dir,
local_files_only=local_files_only,
revision=revision,
disable_mmap=disable_mmap,
)
if quantization_config is not None:
hf_quantizer = DiffusersAutoQuantizer.from_config(quantization_config)
+4 -113
View File
@@ -94,12 +94,6 @@ CHECKPOINT_KEY_NAMES = {
"animatediff_sdxl_beta": "up_blocks.2.motion_modules.0.temporal_transformer.norm.weight",
"animatediff_scribble": "controlnet_cond_embedding.conv_in.weight",
"animatediff_rgb": "controlnet_cond_embedding.weight",
"auraflow": [
"double_layers.0.attn.w2q.weight",
"double_layers.0.attn.w1q.weight",
"cond_seq_linear.weight",
"t_embedder.mlp.0.weight",
],
"flux": [
"double_blocks.0.img_attn.norm.key_norm.scale",
"model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale",
@@ -160,7 +154,6 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
"animatediff_sdxl_beta": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-sdxl-beta"},
"animatediff_scribble": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-scribble"},
"animatediff_rgb": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-rgb"},
"auraflow": {"pretrained_model_name_or_path": "fal/AuraFlow-v0.3"},
"flux-dev": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-dev"},
"flux-fill": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Fill-dev"},
"flux-depth": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Depth-dev"},
@@ -186,7 +179,6 @@ DIFFUSERS_TO_LDM_DEFAULT_IMAGE_SIZE_MAP = {
"inpainting": 512,
"inpainting_v2": 512,
"controlnet": 512,
"instruct-pix2pix": 512,
"v2": 768,
"v1": 512,
}
@@ -388,7 +380,6 @@ def load_single_file_checkpoint(
cache_dir=None,
local_files_only=None,
revision=None,
disable_mmap=False,
):
if os.path.isfile(pretrained_model_link_or_path):
pretrained_model_link_or_path = pretrained_model_link_or_path
@@ -406,7 +397,7 @@ def load_single_file_checkpoint(
revision=revision,
)
checkpoint = load_state_dict(pretrained_model_link_or_path, disable_mmap=disable_mmap)
checkpoint = load_state_dict(pretrained_model_link_or_path)
# some checkpoints contain the model state dict under a "state_dict" key
while "state_dict" in checkpoint:
@@ -606,14 +597,10 @@ def infer_diffusers_model_type(checkpoint):
if any(
g in checkpoint for g in ["guidance_in.in_layer.bias", "model.diffusion_model.guidance_in.in_layer.bias"]
):
if "model.diffusion_model.img_in.weight" in checkpoint:
key = "model.diffusion_model.img_in.weight"
else:
key = "img_in.weight"
if checkpoint[key].shape[1] == 384:
if checkpoint["img_in.weight"].shape[1] == 384:
model_type = "flux-fill"
elif checkpoint[key].shape[1] == 128:
elif checkpoint["img_in.weight"].shape[1] == 128:
model_type = "flux-depth"
else:
model_type = "flux-dev"
@@ -648,9 +635,6 @@ def infer_diffusers_model_type(checkpoint):
elif CHECKPOINT_KEY_NAMES["hunyuan-video"] in checkpoint:
model_type = "hunyuan-video"
elif all(key in checkpoint for key in CHECKPOINT_KEY_NAMES["auraflow"]):
model_type = "auraflow"
elif (
CHECKPOINT_KEY_NAMES["instruct-pix2pix"] in checkpoint
and checkpoint[CHECKPOINT_KEY_NAMES["instruct-pix2pix"]].shape[1] == 8
@@ -2106,7 +2090,6 @@ def convert_animatediff_checkpoint_to_diffusers(checkpoint, **kwargs):
def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
converted_state_dict = {}
keys = list(checkpoint.keys())
for k in keys:
if "model.diffusion_model." in k:
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
@@ -2706,95 +2689,3 @@ def convert_hunyuan_video_transformer_to_diffusers(checkpoint, **kwargs):
handler_fn_inplace(key, checkpoint)
return checkpoint
def convert_auraflow_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
converted_state_dict = {}
state_dict_keys = list(checkpoint.keys())
# Handle register tokens and positional embeddings
converted_state_dict["register_tokens"] = checkpoint.pop("register_tokens", None)
# Handle time step projection
converted_state_dict["time_step_proj.linear_1.weight"] = checkpoint.pop("t_embedder.mlp.0.weight", None)
converted_state_dict["time_step_proj.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias", None)
converted_state_dict["time_step_proj.linear_2.weight"] = checkpoint.pop("t_embedder.mlp.2.weight", None)
converted_state_dict["time_step_proj.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias", None)
# Handle context embedder
converted_state_dict["context_embedder.weight"] = checkpoint.pop("cond_seq_linear.weight", None)
# Calculate the number of layers
def calculate_layers(keys, key_prefix):
layers = set()
for k in keys:
if key_prefix in k:
layer_num = int(k.split(".")[1]) # get the layer number
layers.add(layer_num)
return len(layers)
mmdit_layers = calculate_layers(state_dict_keys, key_prefix="double_layers")
single_dit_layers = calculate_layers(state_dict_keys, key_prefix="single_layers")
# MMDiT blocks
for i in range(mmdit_layers):
# Feed-forward
path_mapping = {"mlpX": "ff", "mlpC": "ff_context"}
weight_mapping = {"c_fc1": "linear_1", "c_fc2": "linear_2", "c_proj": "out_projection"}
for orig_k, diffuser_k in path_mapping.items():
for k, v in weight_mapping.items():
converted_state_dict[f"joint_transformer_blocks.{i}.{diffuser_k}.{v}.weight"] = checkpoint.pop(
f"double_layers.{i}.{orig_k}.{k}.weight", None
)
# Norms
path_mapping = {"modX": "norm1", "modC": "norm1_context"}
for orig_k, diffuser_k in path_mapping.items():
converted_state_dict[f"joint_transformer_blocks.{i}.{diffuser_k}.linear.weight"] = checkpoint.pop(
f"double_layers.{i}.{orig_k}.1.weight", None
)
# Attentions
x_attn_mapping = {"w2q": "to_q", "w2k": "to_k", "w2v": "to_v", "w2o": "to_out.0"}
context_attn_mapping = {"w1q": "add_q_proj", "w1k": "add_k_proj", "w1v": "add_v_proj", "w1o": "to_add_out"}
for attn_mapping in [x_attn_mapping, context_attn_mapping]:
for k, v in attn_mapping.items():
converted_state_dict[f"joint_transformer_blocks.{i}.attn.{v}.weight"] = checkpoint.pop(
f"double_layers.{i}.attn.{k}.weight", None
)
# Single-DiT blocks
for i in range(single_dit_layers):
# Feed-forward
mapping = {"c_fc1": "linear_1", "c_fc2": "linear_2", "c_proj": "out_projection"}
for k, v in mapping.items():
converted_state_dict[f"single_transformer_blocks.{i}.ff.{v}.weight"] = checkpoint.pop(
f"single_layers.{i}.mlp.{k}.weight", None
)
# Norms
converted_state_dict[f"single_transformer_blocks.{i}.norm1.linear.weight"] = checkpoint.pop(
f"single_layers.{i}.modCX.1.weight", None
)
# Attentions
x_attn_mapping = {"w1q": "to_q", "w1k": "to_k", "w1v": "to_v", "w1o": "to_out.0"}
for k, v in x_attn_mapping.items():
converted_state_dict[f"single_transformer_blocks.{i}.attn.{v}.weight"] = checkpoint.pop(
f"single_layers.{i}.attn.{k}.weight", None
)
# Final blocks
converted_state_dict["proj_out.weight"] = checkpoint.pop("final_linear.weight", None)
# Handle the final norm layer
norm_weight = checkpoint.pop("modF.1.weight", None)
if norm_weight is not None:
converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(norm_weight, dim=None)
else:
converted_state_dict["norm_out.linear.weight"] = None
converted_state_dict["pos_embed.pos_embed"] = checkpoint.pop("positional_encoding")
converted_state_dict["pos_embed.proj.weight"] = checkpoint.pop("init_x_linear.weight")
converted_state_dict["pos_embed.proj.bias"] = checkpoint.pop("init_x_linear.bias")
return converted_state_dict
+25 -2
View File
@@ -21,6 +21,7 @@ import safetensors
import torch
import torch.nn.functional as F
from huggingface_hub.utils import validate_hf_hub_args
from torch import nn
from ..models.embeddings import (
ImageProjection,
@@ -43,11 +44,13 @@ from ..utils import (
is_torch_version,
logging,
)
from .lora_base import _func_optionally_disable_offloading
from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, TEXT_ENCODER_NAME, UNET_NAME
from .utils import AttnProcsLayers
if is_accelerate_available():
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
logger = logging.get_logger(__name__)
@@ -408,7 +411,27 @@ class UNet2DConditionLoadersMixin:
tuple:
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
"""
return _func_optionally_disable_offloading(_pipeline=_pipeline)
is_model_cpu_offload = False
is_sequential_cpu_offload = False
if _pipeline is not None and _pipeline.hf_device_map is None:
for _, component in _pipeline.components.items():
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
if not is_model_cpu_offload:
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
if not is_sequential_cpu_offload:
is_sequential_cpu_offload = (
isinstance(component._hf_hook, AlignDevicesHook)
or hasattr(component._hf_hook, "hooks")
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
)
logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
return (is_model_cpu_offload, is_sequential_cpu_offload)
def save_attn_procs(
self,
@@ -486,9 +486,6 @@ class AutoencoderDC(ModelMixin, ConfigMixin, FromOriginalModelMixin):
self.tile_sample_stride_height = 448
self.tile_sample_stride_width = 448
self.tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
self.tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
def enable_tiling(
self,
tile_sample_min_height: Optional[int] = None,
@@ -518,8 +515,6 @@ class AutoencoderDC(ModelMixin, ConfigMixin, FromOriginalModelMixin):
self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
self.tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
self.tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
def disable_tiling(self) -> None:
r"""
@@ -611,106 +606,11 @@ class AutoencoderDC(ModelMixin, ConfigMixin, FromOriginalModelMixin):
return (decoded,)
return DecoderOutput(sample=decoded)
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[2], b.shape[2], blend_extent)
for y in range(blend_extent):
b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
return b
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
for x in range(blend_extent):
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
return b
def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> torch.Tensor:
batch_size, num_channels, height, width = x.shape
latent_height = height // self.spatial_compression_ratio
latent_width = width // self.spatial_compression_ratio
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
blend_height = tile_latent_min_height - tile_latent_stride_height
blend_width = tile_latent_min_width - tile_latent_stride_width
# Split x into overlapping tiles and encode them separately.
# The tiles have an overlap to avoid seams between tiles.
rows = []
for i in range(0, x.shape[2], self.tile_sample_stride_height):
row = []
for j in range(0, x.shape[3], self.tile_sample_stride_width):
tile = x[:, :, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width]
if (
tile.shape[2] % self.spatial_compression_ratio != 0
or tile.shape[3] % self.spatial_compression_ratio != 0
):
pad_h = (self.spatial_compression_ratio - tile.shape[2]) % self.spatial_compression_ratio
pad_w = (self.spatial_compression_ratio - tile.shape[3]) % self.spatial_compression_ratio
tile = F.pad(tile, (0, pad_w, 0, pad_h))
tile = self.encoder(tile)
row.append(tile)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_width)
result_row.append(tile[:, :, :tile_latent_stride_height, :tile_latent_stride_width])
result_rows.append(torch.cat(result_row, dim=3))
encoded = torch.cat(result_rows, dim=2)[:, :, :latent_height, :latent_width]
if not return_dict:
return (encoded,)
return EncoderOutput(latent=encoded)
raise NotImplementedError("`tiled_encode` has not been implemented for AutoencoderDC.")
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
batch_size, num_channels, height, width = z.shape
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
blend_height = self.tile_sample_min_height - self.tile_sample_stride_height
blend_width = self.tile_sample_min_width - self.tile_sample_stride_width
# Split z into overlapping tiles and decode them separately.
# The tiles have an overlap to avoid seams between tiles.
rows = []
for i in range(0, height, tile_latent_stride_height):
row = []
for j in range(0, width, tile_latent_stride_width):
tile = z[:, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width]
decoded = self.decoder(tile)
row.append(decoded)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_width)
result_row.append(tile[:, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width])
result_rows.append(torch.cat(result_row, dim=3))
decoded = torch.cat(result_rows, dim=2)
if not return_dict:
return (decoded,)
return DecoderOutput(sample=decoded)
raise NotImplementedError("`tiled_decode` has not been implemented for AutoencoderDC.")
def forward(self, sample: torch.Tensor, return_dict: bool = True) -> torch.Tensor:
encoded = self.encode(sample, return_dict=False)[0]
@@ -1010,12 +1010,10 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
# The minimal tile height and width for spatial tiling to be used
self.tile_sample_min_height = 512
self.tile_sample_min_width = 512
self.tile_sample_min_num_frames = 16
# The minimal distance between two spatial tiles
self.tile_sample_stride_height = 448
self.tile_sample_stride_width = 448
self.tile_sample_stride_num_frames = 8
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (LTXVideoEncoder3d, LTXVideoDecoder3d)):
@@ -1025,10 +1023,8 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
self,
tile_sample_min_height: Optional[int] = None,
tile_sample_min_width: Optional[int] = None,
tile_sample_min_num_frames: Optional[int] = None,
tile_sample_stride_height: Optional[float] = None,
tile_sample_stride_width: Optional[float] = None,
tile_sample_stride_num_frames: Optional[float] = None,
) -> None:
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
@@ -1050,10 +1046,8 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
self.use_tiling = True
self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
self.tile_sample_min_num_frames = tile_sample_min_num_frames or self.tile_sample_min_num_frames
self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames
def disable_tiling(self) -> None:
r"""
@@ -1079,13 +1073,18 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
def _encode(self, x: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, num_frames, height, width = x.shape
if self.use_framewise_decoding and num_frames > self.tile_sample_min_num_frames:
return self._temporal_tiled_encode(x)
if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
return self.tiled_encode(x)
enc = self.encoder(x)
if self.use_framewise_encoding:
# TODO(aryan): requires investigation
raise NotImplementedError(
"Frame-wise encoding has not been implemented for AutoencoderKLLTXVideo, at the moment, due to "
"quality issues caused by splitting inference across frame dimension. If you believe this "
"should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls."
)
else:
enc = self.encoder(x)
return enc
@@ -1122,15 +1121,19 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
batch_size, num_channels, num_frames, height, width = z.shape
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
tile_latent_min_width = self.tile_sample_stride_width // self.spatial_compression_ratio
tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio
if self.use_framewise_decoding and num_frames > tile_latent_min_num_frames:
return self._temporal_tiled_decode(z, temb, return_dict=return_dict)
if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height):
return self.tiled_decode(z, temb, return_dict=return_dict)
dec = self.decoder(z, temb)
if self.use_framewise_decoding:
# TODO(aryan): requires investigation
raise NotImplementedError(
"Frame-wise decoding has not been implemented for AutoencoderKLLTXVideo, at the moment, due to "
"quality issues caused by splitting inference across frame dimension. If you believe this "
"should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls."
)
else:
dec = self.decoder(z, temb)
if not return_dict:
return (dec,)
@@ -1186,14 +1189,6 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
)
return b
def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[-3], b.shape[-3], blend_extent)
for x in range(blend_extent):
b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * (1 - x / blend_extent) + b[:, :, x, :, :] * (
x / blend_extent
)
return b
def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
r"""Encode a batch of images using a tiled encoder.
@@ -1222,9 +1217,17 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
for i in range(0, height, self.tile_sample_stride_height):
row = []
for j in range(0, width, self.tile_sample_stride_width):
time = self.encoder(
x[:, :, :, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width]
)
if self.use_framewise_encoding:
# TODO(aryan): requires investigation
raise NotImplementedError(
"Frame-wise encoding has not been implemented for AutoencoderKLLTXVideo, at the moment, due to "
"quality issues caused by splitting inference across frame dimension. If you believe this "
"should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls."
)
else:
time = self.encoder(
x[:, :, :, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width]
)
row.append(time)
rows.append(row)
@@ -1280,7 +1283,17 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
for i in range(0, height, tile_latent_stride_height):
row = []
for j in range(0, width, tile_latent_stride_width):
time = self.decoder(z[:, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width], temb)
if self.use_framewise_decoding:
# TODO(aryan): requires investigation
raise NotImplementedError(
"Frame-wise decoding has not been implemented for AutoencoderKLLTXVideo, at the moment, due to "
"quality issues caused by splitting inference across frame dimension. If you believe this "
"should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls."
)
else:
time = self.decoder(
z[:, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width], temb
)
row.append(time)
rows.append(row)
@@ -1305,74 +1318,6 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
return DecoderOutput(sample=dec)
def _temporal_tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput:
batch_size, num_channels, num_frames, height, width = x.shape
latent_num_frames = (num_frames - 1) // self.temporal_compression_ratio + 1
tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio
tile_latent_stride_num_frames = self.tile_sample_stride_num_frames // self.temporal_compression_ratio
blend_num_frames = tile_latent_min_num_frames - tile_latent_stride_num_frames
row = []
for i in range(0, num_frames, self.tile_sample_stride_num_frames):
tile = x[:, :, i : i + self.tile_sample_min_num_frames + 1, :, :]
if self.use_tiling and (height > self.tile_sample_min_height or width > self.tile_sample_min_width):
tile = self.tiled_encode(tile)
else:
tile = self.encoder(tile)
if i > 0:
tile = tile[:, :, 1:, :, :]
row.append(tile)
result_row = []
for i, tile in enumerate(row):
if i > 0:
tile = self.blend_t(row[i - 1], tile, blend_num_frames)
result_row.append(tile[:, :, :tile_latent_stride_num_frames, :, :])
else:
result_row.append(tile[:, :, : tile_latent_stride_num_frames + 1, :, :])
enc = torch.cat(result_row, dim=2)[:, :, :latent_num_frames]
return enc
def _temporal_tiled_decode(
self, z: torch.Tensor, temb: Optional[torch.Tensor], return_dict: bool = True
) -> Union[DecoderOutput, torch.Tensor]:
batch_size, num_channels, num_frames, height, width = z.shape
num_sample_frames = (num_frames - 1) * self.temporal_compression_ratio + 1
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio
tile_latent_stride_num_frames = self.tile_sample_stride_num_frames // self.temporal_compression_ratio
blend_num_frames = self.tile_sample_min_num_frames - self.tile_sample_stride_num_frames
row = []
for i in range(0, num_frames, tile_latent_stride_num_frames):
tile = z[:, :, i : i + tile_latent_min_num_frames + 1, :, :]
if self.use_tiling and (tile.shape[-1] > tile_latent_min_width or tile.shape[-2] > tile_latent_min_height):
decoded = self.tiled_decode(tile, temb, return_dict=True).sample
else:
decoded = self.decoder(tile, temb)
if i > 0:
decoded = decoded[:, :, :-1, :, :]
row.append(decoded)
result_row = []
for i, tile in enumerate(row):
if i > 0:
tile = self.blend_t(row[i - 1], tile, blend_num_frames)
tile = tile[:, :, : self.tile_sample_stride_num_frames, :, :]
result_row.append(tile)
else:
result_row.append(tile[:, :, : self.tile_sample_stride_num_frames + 1, :, :])
dec = torch.cat(result_row, dim=2)[:, :, :num_sample_frames]
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
def forward(
self,
sample: torch.Tensor,
@@ -1389,5 +1334,5 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
z = posterior.mode()
dec = self.decode(z, temb)
if not return_dict:
return (dec.sample,)
return (dec,)
return dec
+2 -7
View File
@@ -131,9 +131,7 @@ def _fetch_remapped_cls_from_config(config, old_class):
return old_class
def load_state_dict(
checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None, disable_mmap: bool = False
):
def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None):
"""
Reads a checkpoint file, returning properly formatted errors if they arise.
"""
@@ -144,10 +142,7 @@ def load_state_dict(
try:
file_extension = os.path.basename(checkpoint_file).split(".")[-1]
if file_extension == SAFETENSORS_FILE_EXTENSION:
if disable_mmap:
return safetensors.torch.load(open(checkpoint_file, "rb").read())
else:
return safetensors.torch.load_file(checkpoint_file, device="cpu")
return safetensors.torch.load_file(checkpoint_file, device="cpu")
elif file_extension == GGUF_FILE_EXTENSION:
return load_gguf_checkpoint(checkpoint_file)
else:
+7 -7
View File
@@ -559,9 +559,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the
`safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors`
weights. If set to `False`, `safetensors` weights are not loaded.
disable_mmap ('bool', *optional*, defaults to 'False'):
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
<Tip>
@@ -607,7 +604,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
variant = kwargs.pop("variant", None)
use_safetensors = kwargs.pop("use_safetensors", None)
quantization_config = kwargs.pop("quantization_config", None)
disable_mmap = kwargs.pop("disable_mmap", False)
allow_pickle = False
if use_safetensors is None:
@@ -887,7 +883,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
# TODO (sayakpaul, SunMarc): remove this after model loading refactor
else:
param_device = torch.device(torch.cuda.current_device())
state_dict = load_state_dict(model_file, variant=variant, disable_mmap=disable_mmap)
state_dict = load_state_dict(model_file, variant=variant)
model._convert_deprecated_attention_blocks(state_dict)
# move the params from meta device to cpu
@@ -924,12 +920,14 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
else: # else let accelerate handle loading and dispatching.
# Load weights and dispatch according to the device_map
# by default the device_map is None and the weights are loaded on the CPU
force_hook = True
device_map = _determine_device_map(
model, device_map, max_memory, torch_dtype, keep_in_fp32_modules, hf_quantizer
)
if device_map is None and is_sharded:
# we load the parameters on the cpu
device_map = {"": "cpu"}
force_hook = False
try:
accelerate.load_checkpoint_and_dispatch(
model,
@@ -939,6 +937,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
offload_folder=offload_folder,
offload_state_dict=offload_state_dict,
dtype=torch_dtype,
force_hooks=force_hook,
strict=True,
)
except AttributeError as e:
@@ -968,6 +967,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
offload_folder=offload_folder,
offload_state_dict=offload_state_dict,
dtype=torch_dtype,
force_hooks=force_hook,
strict=True,
)
model._undo_temp_convert_self_to_deprecated_attention_blocks()
@@ -983,7 +983,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
else:
model = cls.from_config(config, **unused_kwargs)
state_dict = load_state_dict(model_file, variant=variant, disable_mmap=disable_mmap)
state_dict = load_state_dict(model_file, variant=variant)
model._convert_deprecated_attention_blocks(state_dict)
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
@@ -1214,7 +1214,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
# Adapted from `transformers` modeling_utils.py
def _get_no_split_modules(self, device_map: str):
"""
Get the modules of the model that should not be split when using device_map. We iterate through the modules to
Get the modules of the model that should not be spit when using device_map. We iterate through the modules to
get the underlying `_no_split_modules`.
Args:
@@ -20,7 +20,6 @@ import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin
from ...utils import is_torch_version, logging
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention_processor import (
@@ -254,7 +253,7 @@ class AuraFlowJointTransformerBlock(nn.Module):
return encoder_hidden_states, hidden_states
class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin):
r"""
A 2D Transformer model as introduced in AuraFlow (https://blog.fal.ai/auraflow/).
@@ -120,10 +120,8 @@ class CogVideoXBlock(nn.Module):
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
) -> torch.Tensor:
text_seq_length = encoder_hidden_states.size(1)
attention_kwargs = attention_kwargs or {}
# norm & modulate
norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
@@ -135,7 +133,6 @@ class CogVideoXBlock(nn.Module):
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
**attention_kwargs,
)
hidden_states = hidden_states + gate_msa * attn_hidden_states
@@ -213,7 +210,6 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
"""
_supports_gradient_checkpointing = True
_no_split_modules = ["CogVideoXBlock", "CogVideoXPatchEmbed"]
@register_to_config
def __init__(
@@ -501,7 +497,6 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
encoder_hidden_states,
emb,
image_rotary_emb,
attention_kwargs,
**ckpt_kwargs,
)
else:
@@ -510,7 +505,6 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
encoder_hidden_states=encoder_hidden_states,
temb=emb,
image_rotary_emb=image_rotary_emb,
attention_kwargs=attention_kwargs,
)
if not self.config.use_rotary_positional_embeddings:
@@ -221,8 +221,6 @@ class AllegroTransformer3DModel(ModelMixin, ConfigMixin):
Scaling factor to apply in 3D positional embeddings across time dimension.
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
@@ -166,7 +166,6 @@ class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin):
"""
_supports_gradient_checkpointing = True
_no_split_modules = ["CogView3PlusTransformerBlock", "CogView3PlusPatchEmbed"]
@register_to_config
def __init__(
@@ -542,12 +542,6 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
"""
_supports_gradient_checkpointing = True
_no_split_modules = [
"HunyuanVideoTransformerBlock",
"HunyuanVideoSingleTransformerBlock",
"HunyuanVideoPatchEmbed",
"HunyuanVideoTokenRefiner",
]
@register_to_config
def __init__(
@@ -719,16 +713,15 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
condition_sequence_length = encoder_hidden_states.shape[1]
sequence_length = latent_sequence_length + condition_sequence_length
attention_mask = torch.zeros(
batch_size, sequence_length, device=hidden_states.device, dtype=torch.bool
) # [B, N]
batch_size, sequence_length, sequence_length, device=hidden_states.device, dtype=torch.bool
) # [B, N, N]
effective_condition_sequence_length = encoder_attention_mask.sum(dim=1, dtype=torch.int) # [B,]
effective_sequence_length = latent_sequence_length + effective_condition_sequence_length
for i in range(batch_size):
attention_mask[i, : effective_sequence_length[i]] = True
# [B, 1, 1, N], for broadcasting across attention heads
attention_mask = attention_mask.unsqueeze(1).unsqueeze(1)
attention_mask[i, : effective_sequence_length[i], : effective_sequence_length[i]] = True
attention_mask = attention_mask.unsqueeze(1) # [B, 1, N, N], for broadcasting across attention heads
# 4. Transformer blocks
if torch.is_grad_enabled() and self.gradient_checkpointing:
+15 -20
View File
@@ -58,7 +58,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`):
Tuple of downsample block types.
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2D"`):
Block type for middle of UNet, it can be either `UNetMidBlock2D` or `None`.
Block type for middle of UNet, it can be either `UNetMidBlock2D` or `UnCLIPUNetMidBlock2D`.
up_block_types (`Tuple[str]`, *optional*, defaults to `("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`):
Tuple of upsample block types.
block_out_channels (`Tuple[int]`, *optional*, defaults to `(224, 448, 672, 896)`):
@@ -103,7 +103,6 @@ class UNet2DModel(ModelMixin, ConfigMixin):
freq_shift: int = 0,
flip_sin_to_cos: bool = True,
down_block_types: Tuple[str, ...] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
mid_block_type: Optional[str] = "UNetMidBlock2D",
up_block_types: Tuple[str, ...] = ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"),
block_out_channels: Tuple[int, ...] = (224, 448, 672, 896),
layers_per_block: int = 2,
@@ -195,22 +194,19 @@ class UNet2DModel(ModelMixin, ConfigMixin):
self.down_blocks.append(down_block)
# mid
if mid_block_type is None:
self.mid_block = None
else:
self.mid_block = UNetMidBlock2D(
in_channels=block_out_channels[-1],
temb_channels=time_embed_dim,
dropout=dropout,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
resnet_time_scale_shift=resnet_time_scale_shift,
attention_head_dim=attention_head_dim if attention_head_dim is not None else block_out_channels[-1],
resnet_groups=norm_num_groups,
attn_groups=attn_norm_num_groups,
add_attention=add_attention,
)
self.mid_block = UNetMidBlock2D(
in_channels=block_out_channels[-1],
temb_channels=time_embed_dim,
dropout=dropout,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
resnet_time_scale_shift=resnet_time_scale_shift,
attention_head_dim=attention_head_dim if attention_head_dim is not None else block_out_channels[-1],
resnet_groups=norm_num_groups,
attn_groups=attn_norm_num_groups,
add_attention=add_attention,
)
# up
reversed_block_out_channels = list(reversed(block_out_channels))
@@ -326,8 +322,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
down_block_res_samples += res_samples
# 4. mid
if self.mid_block is not None:
sample = self.mid_block(sample, emb)
sample = self.mid_block(sample, emb)
# 5. up
skip_sample = None
@@ -33,7 +33,6 @@ from ...utils import (
deprecate,
is_bs4_available,
is_ftfy_available,
is_torch_xla_available,
logging,
replace_example_docstring,
)
@@ -42,14 +41,6 @@ from ...video_processor import VideoProcessor
from .pipeline_output import AllegroPipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__)
if is_bs4_available():
@@ -203,10 +194,10 @@ class AllegroPipeline(DiffusionPipeline):
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
)
self.vae_scale_factor_spatial = (
2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
)
self.vae_scale_factor_temporal = (
self.vae.config.temporal_compression_ratio if getattr(self, "vae", None) else 4
self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
)
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
@@ -930,9 +921,6 @@ class AllegroPipeline(DiffusionPipeline):
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
if not output_type == "latent":
latents = latents.to(self.vae.dtype)
video = self.decode_latents(latents)
@@ -20,18 +20,10 @@ from transformers import CLIPTextModelWithProjection, CLIPTokenizer
from ...image_processor import VaeImageProcessor
from ...models import UVit2DModel, VQModel
from ...schedulers import AmusedScheduler
from ...utils import is_torch_xla_available, replace_example_docstring
from ...utils import replace_example_docstring
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -74,9 +66,7 @@ class AmusedPipeline(DiffusionPipeline):
transformer=transformer,
scheduler=scheduler,
)
self.vae_scale_factor = (
2 ** (len(self.vqvae.config.block_out_channels) - 1) if getattr(self, "vqvae", None) else 8
)
self.vae_scale_factor = 2 ** (len(self.vqvae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False)
@torch.no_grad()
@@ -307,9 +297,6 @@ class AmusedPipeline(DiffusionPipeline):
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, timestep, latents)
if XLA_AVAILABLE:
xm.mark_step()
if output_type == "latent":
output = latents
else:
@@ -20,18 +20,10 @@ from transformers import CLIPTextModelWithProjection, CLIPTokenizer
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...models import UVit2DModel, VQModel
from ...schedulers import AmusedScheduler
from ...utils import is_torch_xla_available, replace_example_docstring
from ...utils import replace_example_docstring
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -89,9 +81,7 @@ class AmusedImg2ImgPipeline(DiffusionPipeline):
transformer=transformer,
scheduler=scheduler,
)
self.vae_scale_factor = (
2 ** (len(self.vqvae.config.block_out_channels) - 1) if getattr(self, "vqvae", None) else 8
)
self.vae_scale_factor = 2 ** (len(self.vqvae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False)
@torch.no_grad()
@@ -333,9 +323,6 @@ class AmusedImg2ImgPipeline(DiffusionPipeline):
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, timestep, latents)
if XLA_AVAILABLE:
xm.mark_step()
if output_type == "latent":
output = latents
else:
@@ -21,18 +21,10 @@ from transformers import CLIPTextModelWithProjection, CLIPTokenizer
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...models import UVit2DModel, VQModel
from ...schedulers import AmusedScheduler
from ...utils import is_torch_xla_available, replace_example_docstring
from ...utils import replace_example_docstring
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -97,9 +89,7 @@ class AmusedInpaintPipeline(DiffusionPipeline):
transformer=transformer,
scheduler=scheduler,
)
self.vae_scale_factor = (
2 ** (len(self.vqvae.config.block_out_channels) - 1) if getattr(self, "vqvae", None) else 8
)
self.vae_scale_factor = 2 ** (len(self.vqvae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False)
self.mask_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor,
@@ -364,9 +354,6 @@ class AmusedInpaintPipeline(DiffusionPipeline):
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, timestep, latents)
if XLA_AVAILABLE:
xm.mark_step()
if output_type == "latent":
output = latents
else:
@@ -34,7 +34,6 @@ from ...schedulers import (
from ...utils import (
USE_PEFT_BACKEND,
deprecate,
is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
@@ -48,16 +47,8 @@ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from .pipeline_output import AnimateDiffPipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -148,7 +139,7 @@ class AnimateDiffPipeline(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt
@@ -853,9 +844,6 @@ class AnimateDiffPipeline(
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
# 9. Post processing
if output_type == "latent":
video = latents

Some files were not shown because too many files have changed in this diff Show More