Compare commits
52 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 44fb2fd6ae | |||
| 74b67524b5 | |||
| 794f7e49a9 | |||
| 9fc9c6dd71 | |||
| df355ea2c6 | |||
| ae019da9e3 | |||
| 329771e542 | |||
| f7cb595428 | |||
| c3478a42b9 | |||
| 980736b792 | |||
| 50c81df4e7 | |||
| e1c7269720 | |||
| edb8c1bce6 | |||
| 0785dba4df | |||
| 5cda8ea521 | |||
| 36acdd7517 | |||
| e7db062e10 | |||
| 1b0fe63656 | |||
| d6c030fd37 | |||
| 9f06a0d1a4 | |||
| 52c05bd4cd | |||
| a6f043a80f | |||
| 12fbe3f7dc | |||
| 83ba01a38d | |||
| 7116fd24e5 | |||
| 553b13845f | |||
| 7bc8b92384 | |||
| f0c6d9784b | |||
| d006f0769b | |||
| a26d57097a | |||
| daf9d0f119 | |||
| 95c5ce4e6f | |||
| c0964571fc | |||
| b13cdbb294 | |||
| a0acbdc989 | |||
| 5655b22ead | |||
| 4df9d49218 | |||
| 9731773d39 | |||
| e2deb82e69 | |||
| 1288c8560a | |||
| cb342b745a | |||
| 80fd9260bb | |||
| 71ad16b463 | |||
| ee7e141d80 | |||
| 01bd79649e | |||
| 03bcf5aefe | |||
| e0b96ba7b0 | |||
| 854a04659c | |||
| 628f2c544a | |||
| 811560b1d7 | |||
| f1e0c7ce4a | |||
| b94cfd7937 |
@@ -272,7 +272,7 @@ jobs:
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "not Flax and not Onnx" \
|
||||
--make-reports=tests_torch_minimum_version_cuda \
|
||||
tests/models/test_modelling_common.py \
|
||||
tests/models/test_modeling_common.py \
|
||||
tests/pipelines/test_pipelines_common.py \
|
||||
tests/pipelines/test_pipeline_utils.py \
|
||||
tests/pipelines/test_pipelines.py \
|
||||
|
||||
@@ -266,6 +266,7 @@ jobs:
|
||||
# TODO (sayakpaul, DN6): revisit `--no-deps`
|
||||
python -m pip install -U peft@git+https://github.com/huggingface/peft.git --no-deps
|
||||
python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git --no-deps
|
||||
python -m uv pip install -U tokenizers
|
||||
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
|
||||
|
||||
- name: Environment
|
||||
|
||||
@@ -193,7 +193,7 @@ jobs:
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "not Flax and not Onnx" \
|
||||
--make-reports=tests_torch_minimum_cuda \
|
||||
tests/models/test_modelling_common.py \
|
||||
tests/models/test_modeling_common.py \
|
||||
tests/pipelines/test_pipelines_common.py \
|
||||
tests/pipelines/test_pipeline_utils.py \
|
||||
tests/pipelines/test_pipelines.py \
|
||||
|
||||
@@ -62,6 +62,33 @@ image = pipeline(prompt).images[0]
|
||||
image.save("auraflow.png")
|
||||
```
|
||||
|
||||
Loading [GGUF checkpoints](https://huggingface.co/docs/diffusers/quantization/gguf) are also supported:
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import (
|
||||
AuraFlowPipeline,
|
||||
GGUFQuantizationConfig,
|
||||
AuraFlowTransformer2DModel,
|
||||
)
|
||||
|
||||
transformer = AuraFlowTransformer2DModel.from_single_file(
|
||||
"https://huggingface.co/city96/AuraFlow-v0.3-gguf/blob/main/aura_flow_0.3-Q2_K.gguf",
|
||||
quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
pipeline = AuraFlowPipeline.from_pretrained(
|
||||
"fal/AuraFlow-v0.3",
|
||||
transformer=transformer,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
prompt = "a cute pony in a field of flowers"
|
||||
image = pipeline(prompt).images[0]
|
||||
image.save("auraflow.png")
|
||||
```
|
||||
|
||||
## AuraFlowPipeline
|
||||
|
||||
[[autodoc]] AuraFlowPipeline
|
||||
|
||||
@@ -367,7 +367,7 @@ transformer_8bit = FluxTransformer2DModel.from_pretrained(
|
||||
|
||||
pipeline = FluxPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
text_encoder=text_encoder_8bit,
|
||||
text_encoder_2=text_encoder_8bit,
|
||||
transformer=transformer_8bit,
|
||||
torch_dtype=torch.float16,
|
||||
device_map="balanced",
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
|
||||
[HunyuanVideo](https://www.arxiv.org/abs/2412.03603) by Tencent.
|
||||
|
||||
*Recent advancements in video generation have significantly impacted daily life for both individuals and industries. However, the leading video generation models remain closed-source, resulting in a notable performance gap between industry capabilities and those available to the public. In this report, we introduce HunyuanVideo, an innovative open-source video foundation model that demonstrates performance in video generation comparable to, or even surpassing, that of leading closed-source models. HunyuanVideo encompasses a comprehensive framework that integrates several key elements, including data curation, advanced architectural design, progressive model scaling and training, and an efficient infrastructure tailored for large-scale model training and inference. As a result, we successfully trained a video generative model with over 13 billion parameters, making it the largest among all open-source models. We conducted extensive experiments and implemented a series of targeted designs to ensure high visual quality, motion dynamics, text-video alignment, and advanced filming techniques. According to evaluations by professionals, HunyuanVideo outperforms previous state-of-the-art models, including Runway Gen-3, Luma 1.6, and three top-performing Chinese video generative models. By releasing the code for the foundation model and its applications, we aim to bridge the gap between closed-source and open-source communities. This initiative will empower individuals within the community to experiment with their ideas, fostering a more dynamic and vibrant video generation ecosystem. The code is publicly available at [this https URL](https://github.com/Tencent/HunyuanVideo).*
|
||||
*Recent advancements in video generation have significantly impacted daily life for both individuals and industries. However, the leading video generation models remain closed-source, resulting in a notable performance gap between industry capabilities and those available to the public. In this report, we introduce HunyuanVideo, an innovative open-source video foundation model that demonstrates performance in video generation comparable to, or even surpassing, that of leading closed-source models. HunyuanVideo encompasses a comprehensive framework that integrates several key elements, including data curation, advanced architectural design, progressive model scaling and training, and an efficient infrastructure tailored for large-scale model training and inference. As a result, we successfully trained a video generative model with over 13 billion parameters, making it the largest among all open-source models. We conducted extensive experiments and implemented a series of targeted designs to ensure high visual quality, motion dynamics, text-video alignment, and advanced filming techniques. According to evaluations by professionals, HunyuanVideo outperforms previous state-of-the-art models, including Runway Gen-3, Luma 1.6, and three top-performing Chinese video generative models. By releasing the code for the foundation model and its applications, we aim to bridge the gap between closed-source and open-source communities. This initiative will empower individuals within the community to experiment with their ideas, fostering a more dynamic and vibrant video generation ecosystem. The code is publicly available at [this https URL](https://github.com/tencent/HunyuanVideo).*
|
||||
|
||||
<Tip>
|
||||
|
||||
@@ -45,14 +45,14 @@ from diffusers.utils import export_to_video
|
||||
|
||||
quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
|
||||
transformer_8bit = HunyuanVideoTransformer3DModel.from_pretrained(
|
||||
"tencent/HunyuanVideo",
|
||||
"hunyuanvideo-community/HunyuanVideo",
|
||||
subfolder="transformer",
|
||||
quantization_config=quant_config,
|
||||
torch_dtype=torch.float16,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
pipeline = HunyuanVideoPipeline.from_pretrained(
|
||||
"tencent/HunyuanVideo",
|
||||
"hunyuanvideo-community/HunyuanVideo",
|
||||
transformer=transformer_8bit,
|
||||
torch_dtype=torch.float16,
|
||||
device_map="balanced",
|
||||
|
||||
@@ -59,10 +59,10 @@ Refer to the [Quantization](../../quantization/overview) overview to learn more
|
||||
```py
|
||||
import torch
|
||||
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, SanaTransformer2DModel, SanaPipeline
|
||||
from transformers import BitsAndBytesConfig as BitsAndBytesConfig, AutoModelForCausalLM
|
||||
from transformers import BitsAndBytesConfig as BitsAndBytesConfig, AutoModel
|
||||
|
||||
quant_config = BitsAndBytesConfig(load_in_8bit=True)
|
||||
text_encoder_8bit = AutoModelForCausalLM.from_pretrained(
|
||||
text_encoder_8bit = AutoModel.from_pretrained(
|
||||
"Efficient-Large-Model/Sana_1600M_1024px_diffusers",
|
||||
subfolder="text_encoder",
|
||||
quantization_config=quant_config,
|
||||
|
||||
@@ -78,10 +78,10 @@ from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel
|
||||
from diffusers.utils import export_to_video
|
||||
|
||||
transformer = HunyuanVideoTransformer3DModel.from_pretrained(
|
||||
"tencent/HunyuanVideo", subfolder="transformer", torch_dtype=torch.bfloat16
|
||||
"hunyuanvideo-community/HunyuanVideo", subfolder="transformer", torch_dtype=torch.bfloat16
|
||||
)
|
||||
pipe = HunyuanVideoPipeline.from_pretrained(
|
||||
"tencent/HunyuanVideo", transformer=transformer, torch_dtype=torch.float16
|
||||
"hunyuanvideo-community/HunyuanVideo", transformer=transformer, torch_dtype=torch.float16
|
||||
)
|
||||
|
||||
# reduce memory requirements
|
||||
|
||||
@@ -67,6 +67,17 @@ write_basic_config()
|
||||
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
|
||||
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
|
||||
|
||||
Lastly, we recommend logging into your HF account so that your trained LoRA is automatically uploaded to the hub:
|
||||
```bash
|
||||
huggingface-cli login
|
||||
```
|
||||
This command will prompt you for a token. Copy-paste yours from your [settings/tokens](https://huggingface.co/settings/tokens),and press Enter.
|
||||
|
||||
> [!NOTE]
|
||||
> In the examples below we use `wandb` to document the training runs. To do the same, make sure to install `wandb`:
|
||||
> `pip install wandb`
|
||||
> Alternatively, you can use other tools / train without reporting by modifying the flag `--report_to="wandb"`.
|
||||
|
||||
### Pivotal Tuning
|
||||
**Training with text encoder(s)**
|
||||
|
||||
|
||||
@@ -65,6 +65,17 @@ write_basic_config()
|
||||
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
|
||||
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
|
||||
|
||||
Lastly, we recommend logging into your HF account so that your trained LoRA is automatically uploaded to the hub:
|
||||
```bash
|
||||
huggingface-cli login
|
||||
```
|
||||
This command will prompt you for a token. Copy-paste yours from your [settings/tokens](https://huggingface.co/settings/tokens),and press Enter.
|
||||
|
||||
> [!NOTE]
|
||||
> In the examples below we use `wandb` to document the training runs. To do the same, make sure to install `wandb`:
|
||||
> `pip install wandb`
|
||||
> Alternatively, you can use other tools / train without reporting by modifying the flag `--report_to="wandb"`.
|
||||
|
||||
### Target Modules
|
||||
When LoRA was first adapted from language models to diffusion models, it was applied to the cross-attention layers in the Unet that relate the image representations with the prompts that describe them.
|
||||
More recently, SOTA text-to-image diffusion models replaced the Unet with a diffusion Transformer(DiT). With this change, we may also want to explore
|
||||
|
||||
@@ -33,12 +33,12 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif
|
||||
| Bit Diffusion | Diffusion on discrete data | [Bit Diffusion](#bit-diffusion) | - | [Stuti R.](https://github.com/kingstut) |
|
||||
| K-Diffusion Stable Diffusion | Run Stable Diffusion with any of [K-Diffusion's samplers](https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py) | [Stable Diffusion with K Diffusion](#stable-diffusion-with-k-diffusion) | - | [Patrick von Platen](https://github.com/patrickvonplaten/) |
|
||||
| Checkpoint Merger Pipeline | Diffusion Pipeline that enables merging of saved model checkpoints | [Checkpoint Merger Pipeline](#checkpoint-merger-pipeline) | - | [Naga Sai Abhinay Devarinti](https://github.com/Abhinay1997/) |
|
||||
| Stable Diffusion v1.1-1.4 Comparison | Run all 4 model checkpoints for Stable Diffusion and compare their results together | [Stable Diffusion Comparison](#stable-diffusion-comparisons) | - | [Suvaditya Mukherjee](https://github.com/suvadityamuk) |
|
||||
| Stable Diffusion v1.1-1.4 Comparison | Run all 4 model checkpoints for Stable Diffusion and compare their results together | [Stable Diffusion Comparison](#stable-diffusion-comparisons) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/stable_diffusion_comparison.ipynb) | [Suvaditya Mukherjee](https://github.com/suvadityamuk) |
|
||||
| MagicMix | Diffusion Pipeline for semantic mixing of an image and a text prompt | [MagicMix](#magic-mix) | - | [Partho Das](https://github.com/daspartho) |
|
||||
| Stable UnCLIP | Diffusion Pipeline for combining prior model (generate clip image embedding from text, UnCLIPPipeline `"kakaobrain/karlo-v1-alpha"`) and decoder pipeline (decode clip image embedding to image, StableDiffusionImageVariationPipeline `"lambdalabs/sd-image-variations-diffusers"` ). | [Stable UnCLIP](#stable-unclip) | - | [Ray Wang](https://wrong.wang) |
|
||||
| UnCLIP Text Interpolation Pipeline | Diffusion Pipeline that allows passing two prompts and produces images while interpolating between the text-embeddings of the two prompts | [UnCLIP Text Interpolation Pipeline](#unclip-text-interpolation-pipeline) | - | [Naga Sai Abhinay Devarinti](https://github.com/Abhinay1997/) |
|
||||
| Stable UnCLIP | Diffusion Pipeline for combining prior model (generate clip image embedding from text, UnCLIPPipeline `"kakaobrain/karlo-v1-alpha"`) and decoder pipeline (decode clip image embedding to image, StableDiffusionImageVariationPipeline `"lambdalabs/sd-image-variations-diffusers"` ). | [Stable UnCLIP](#stable-unclip) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/stable_unclip.ipynb) | [Ray Wang](https://wrong.wang) |
|
||||
| UnCLIP Text Interpolation Pipeline | Diffusion Pipeline that allows passing two prompts and produces images while interpolating between the text-embeddings of the two prompts | [UnCLIP Text Interpolation Pipeline](#unclip-text-interpolation-pipeline) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/unclip_text_interpolation.ipynb)| [Naga Sai Abhinay Devarinti](https://github.com/Abhinay1997/) |
|
||||
| UnCLIP Image Interpolation Pipeline | Diffusion Pipeline that allows passing two images/image_embeddings and produces images while interpolating between their image-embeddings | [UnCLIP Image Interpolation Pipeline](#unclip-image-interpolation-pipeline) | - | [Naga Sai Abhinay Devarinti](https://github.com/Abhinay1997/) |
|
||||
| DDIM Noise Comparative Analysis Pipeline | Investigating how the diffusion models learn visual concepts from each noise level (which is a contribution of [P2 weighting (CVPR 2022)](https://arxiv.org/abs/2204.00227)) | [DDIM Noise Comparative Analysis Pipeline](#ddim-noise-comparative-analysis-pipeline) | - | [Aengus (Duc-Anh)](https://github.com/aengusng8) |
|
||||
| DDIM Noise Comparative Analysis Pipeline | Investigating how the diffusion models learn visual concepts from each noise level (which is a contribution of [P2 weighting (CVPR 2022)](https://arxiv.org/abs/2204.00227)) | [DDIM Noise Comparative Analysis Pipeline](#ddim-noise-comparative-analysis-pipeline) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/ddim_noise_comparative_analysis.ipynb)| [Aengus (Duc-Anh)](https://github.com/aengusng8) |
|
||||
| CLIP Guided Img2Img Stable Diffusion Pipeline | Doing CLIP guidance for image to image generation with Stable Diffusion | [CLIP Guided Img2Img Stable Diffusion](#clip-guided-img2img-stable-diffusion) | - | [Nipun Jindal](https://github.com/nipunjindal/) |
|
||||
| TensorRT Stable Diffusion Text to Image Pipeline | Accelerates the Stable Diffusion Text2Image Pipeline using TensorRT | [TensorRT Stable Diffusion Text to Image Pipeline](#tensorrt-text2image-stable-diffusion-pipeline) | - | [Asfiya Baig](https://github.com/asfiyab-nvidia) |
|
||||
| EDICT Image Editing Pipeline | Diffusion pipeline for text-guided image editing | [EDICT Image Editing Pipeline](#edict-image-editing-pipeline) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/edict_image_pipeline.ipynb) | [Joqsan Azocar](https://github.com/Joqsan) |
|
||||
@@ -50,7 +50,7 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif
|
||||
| IADB Pipeline | Implementation of [Iterative α-(de)Blending: a Minimalist Deterministic Diffusion Model](https://arxiv.org/abs/2305.03486) | [IADB Pipeline](#iadb-pipeline) | - | [Thomas Chambon](https://github.com/tchambon)
|
||||
| Zero1to3 Pipeline | Implementation of [Zero-1-to-3: Zero-shot One Image to 3D Object](https://arxiv.org/abs/2303.11328) | [Zero1to3 Pipeline](#zero1to3-pipeline) | - | [Xin Kong](https://github.com/kxhit) |
|
||||
| Stable Diffusion XL Long Weighted Prompt Pipeline | A pipeline support unlimited length of prompt and negative prompt, use A1111 style of prompt weighting | [Stable Diffusion XL Long Weighted Prompt Pipeline](#stable-diffusion-xl-long-weighted-prompt-pipeline) | [](https://colab.research.google.com/drive/1LsqilswLR40XLLcp6XFOl5nKb_wOe26W?usp=sharing) | [Andrew Zhu](https://xhinker.medium.com/) |
|
||||
| FABRIC - Stable Diffusion with feedback Pipeline | pipeline supports feedback from liked and disliked images | [Stable Diffusion Fabric Pipeline](#stable-diffusion-fabric-pipeline) | - | [Shauray Singh](https://shauray8.github.io/about_shauray/) |
|
||||
| FABRIC - Stable Diffusion with feedback Pipeline | pipeline supports feedback from liked and disliked images | [Stable Diffusion Fabric Pipeline](#stable-diffusion-fabric-pipeline) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/stable_diffusion_fabric.ipynb)| [Shauray Singh](https://shauray8.github.io/about_shauray/) |
|
||||
| sketch inpaint - Inpainting with non-inpaint Stable Diffusion | sketch inpaint much like in automatic1111 | [Masked Im2Im Stable Diffusion Pipeline](#stable-diffusion-masked-im2im) | - | [Anatoly Belikov](https://github.com/noskill) |
|
||||
| sketch inpaint xl - Inpainting with non-inpaint Stable Diffusion | sketch inpaint much like in automatic1111 | [Masked Im2Im Stable Diffusion XL Pipeline](#stable-diffusion-xl-masked-im2im) | - | [Anatoly Belikov](https://github.com/noskill) |
|
||||
| prompt-to-prompt | change parts of a prompt and retain image structure (see [paper page](https://prompt-to-prompt.github.io/)) | [Prompt2Prompt Pipeline](#prompt2prompt-pipeline) | - | [Umer H. Adil](https://twitter.com/UmerHAdil) |
|
||||
|
||||
@@ -372,7 +372,7 @@ class AdaptiveMaskInpaintPipeline(
|
||||
self.register_adaptive_mask_model()
|
||||
self.register_adaptive_mask_settings()
|
||||
|
||||
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
|
||||
if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
|
||||
deprecation_message = (
|
||||
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
|
||||
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
|
||||
@@ -386,7 +386,7 @@ class AdaptiveMaskInpaintPipeline(
|
||||
new_config["steps_offset"] = 1
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if hasattr(scheduler.config, "skip_prk_steps") and scheduler.config.skip_prk_steps is False:
|
||||
if scheduler is not None and getattr(scheduler.config, "skip_prk_steps", True) is False:
|
||||
deprecation_message = (
|
||||
f"The configuration file of this scheduler: {scheduler} has not set the configuration"
|
||||
" `skip_prk_steps`. `skip_prk_steps` should be set to True in the configuration file. Please make"
|
||||
@@ -416,10 +416,14 @@ class AdaptiveMaskInpaintPipeline(
|
||||
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
||||
)
|
||||
|
||||
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
|
||||
version.parse(unet.config._diffusers_version).base_version
|
||||
) < version.parse("0.9.0.dev0")
|
||||
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
||||
is_unet_version_less_0_9_0 = (
|
||||
unet is not None
|
||||
and hasattr(unet.config, "_diffusers_version")
|
||||
and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
|
||||
)
|
||||
is_unet_sample_size_less_64 = (
|
||||
unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
||||
)
|
||||
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
|
||||
deprecation_message = (
|
||||
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
||||
@@ -438,7 +442,7 @@ class AdaptiveMaskInpaintPipeline(
|
||||
unet._internal_dict = FrozenDict(new_config)
|
||||
|
||||
# Check shapes, assume num_channels_latents == 4, num_channels_mask == 1, num_channels_masked == 4
|
||||
if unet.config.in_channels != 9:
|
||||
if unet is not None and unet.config.in_channels != 9:
|
||||
logger.info(f"You have loaded a UNet with {unet.config.in_channels} input channels which.")
|
||||
|
||||
self.register_modules(
|
||||
@@ -450,7 +454,7 @@ class AdaptiveMaskInpaintPipeline(
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
|
||||
@@ -89,7 +89,7 @@ class ComposableStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin)
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
|
||||
if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
|
||||
deprecation_message = (
|
||||
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
|
||||
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
|
||||
@@ -103,7 +103,7 @@ class ComposableStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin)
|
||||
new_config["steps_offset"] = 1
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
|
||||
if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
|
||||
deprecation_message = (
|
||||
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
|
||||
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
|
||||
@@ -132,10 +132,14 @@ class ComposableStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin)
|
||||
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
||||
)
|
||||
|
||||
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
|
||||
version.parse(unet.config._diffusers_version).base_version
|
||||
) < version.parse("0.9.0.dev0")
|
||||
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
||||
is_unet_version_less_0_9_0 = (
|
||||
unet is not None
|
||||
and hasattr(unet.config, "_diffusers_version")
|
||||
and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
|
||||
)
|
||||
is_unet_sample_size_less_64 = (
|
||||
unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
||||
)
|
||||
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
|
||||
deprecation_message = (
|
||||
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
||||
@@ -162,7 +166,7 @@ class ComposableStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin)
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
|
||||
|
||||
@@ -35,7 +35,7 @@ class EDICTPipeline(DiffusionPipeline):
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
|
||||
def _encode_prompt(
|
||||
|
||||
@@ -1342,7 +1342,7 @@ class FrescoV2VPipeline(StableDiffusionControlNetImg2ImgPipeline):
|
||||
feature_extractor=feature_extractor,
|
||||
image_encoder=image_encoder,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
|
||||
self.control_image_processor = VaeImageProcessor(
|
||||
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
|
||||
|
||||
@@ -221,7 +221,7 @@ class GlueGenStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin, St
|
||||
language_adapter=language_adapter,
|
||||
tensor_norm=tensor_norm,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
|
||||
@@ -95,7 +95,7 @@ class ImageToImageInpaintingPipeline(DiffusionPipeline):
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
|
||||
if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
|
||||
deprecation_message = (
|
||||
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
|
||||
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
|
||||
|
||||
@@ -109,7 +109,7 @@ class InstaFlowPipeline(
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
|
||||
if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
|
||||
deprecation_message = (
|
||||
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
|
||||
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
|
||||
@@ -123,7 +123,7 @@ class InstaFlowPipeline(
|
||||
new_config["steps_offset"] = 1
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
|
||||
if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
|
||||
deprecation_message = (
|
||||
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
|
||||
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
|
||||
@@ -152,10 +152,14 @@ class InstaFlowPipeline(
|
||||
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
||||
)
|
||||
|
||||
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
|
||||
version.parse(unet.config._diffusers_version).base_version
|
||||
) < version.parse("0.9.0.dev0")
|
||||
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
||||
is_unet_version_less_0_9_0 = (
|
||||
unet is not None
|
||||
and hasattr(unet.config, "_diffusers_version")
|
||||
and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
|
||||
)
|
||||
is_unet_sample_size_less_64 = (
|
||||
unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
||||
)
|
||||
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
|
||||
deprecation_message = (
|
||||
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
||||
@@ -182,7 +186,7 @@ class InstaFlowPipeline(
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
|
||||
@@ -86,7 +86,7 @@ class StableDiffusionWalkPipeline(DiffusionPipeline, StableDiffusionMixin):
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
|
||||
if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
|
||||
deprecation_message = (
|
||||
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
|
||||
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
|
||||
|
||||
@@ -191,7 +191,7 @@ class IPAdapterFaceIDStableDiffusionPipeline(
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
|
||||
if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
|
||||
deprecation_message = (
|
||||
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
|
||||
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
|
||||
@@ -205,7 +205,7 @@ class IPAdapterFaceIDStableDiffusionPipeline(
|
||||
new_config["steps_offset"] = 1
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
|
||||
if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
|
||||
deprecation_message = (
|
||||
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
|
||||
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
|
||||
@@ -234,10 +234,14 @@ class IPAdapterFaceIDStableDiffusionPipeline(
|
||||
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
||||
)
|
||||
|
||||
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
|
||||
version.parse(unet.config._diffusers_version).base_version
|
||||
) < version.parse("0.9.0.dev0")
|
||||
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
||||
is_unet_version_less_0_9_0 = (
|
||||
unet is not None
|
||||
and hasattr(unet.config, "_diffusers_version")
|
||||
and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
|
||||
)
|
||||
is_unet_sample_size_less_64 = (
|
||||
unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
||||
)
|
||||
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
|
||||
deprecation_message = (
|
||||
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
||||
@@ -265,7 +269,7 @@ class IPAdapterFaceIDStableDiffusionPipeline(
|
||||
feature_extractor=feature_extractor,
|
||||
image_encoder=image_encoder,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
|
||||
@@ -463,6 +463,6 @@ class StableDiffusionHighResFixPipeline(StableDiffusionPipeline):
|
||||
feature_extractor=feature_extractor,
|
||||
image_encoder=image_encoder,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
@@ -69,7 +69,7 @@ class LatentConsistencyModelImg2ImgPipeline(DiffusionPipeline):
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
|
||||
def _encode_prompt(
|
||||
|
||||
@@ -273,7 +273,7 @@ class LatentConsistencyModelWalkPipeline(
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
|
||||
@@ -67,7 +67,7 @@ class LatentConsistencyModelPipeline(DiffusionPipeline):
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
|
||||
def _encode_prompt(
|
||||
|
||||
@@ -336,7 +336,7 @@ class LLMGroundedDiffusionPipeline(
|
||||
# This is copied from StableDiffusionPipeline, with hook initizations for LMD+.
|
||||
super().__init__()
|
||||
|
||||
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
|
||||
if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
|
||||
deprecation_message = (
|
||||
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
|
||||
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
|
||||
@@ -350,7 +350,7 @@ class LLMGroundedDiffusionPipeline(
|
||||
new_config["steps_offset"] = 1
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
|
||||
if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
|
||||
deprecation_message = (
|
||||
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
|
||||
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
|
||||
@@ -379,10 +379,14 @@ class LLMGroundedDiffusionPipeline(
|
||||
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
||||
)
|
||||
|
||||
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
|
||||
version.parse(unet.config._diffusers_version).base_version
|
||||
) < version.parse("0.9.0.dev0")
|
||||
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
||||
is_unet_version_less_0_9_0 = (
|
||||
unet is not None
|
||||
and hasattr(unet.config, "_diffusers_version")
|
||||
and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
|
||||
)
|
||||
is_unet_sample_size_less_64 = (
|
||||
unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
||||
)
|
||||
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
|
||||
deprecation_message = (
|
||||
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
||||
@@ -410,7 +414,7 @@ class LLMGroundedDiffusionPipeline(
|
||||
feature_extractor=feature_extractor,
|
||||
image_encoder=image_encoder,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
|
||||
@@ -496,7 +496,7 @@ class StableDiffusionLongPromptWeightingPipeline(
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
|
||||
if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
|
||||
deprecation_message = (
|
||||
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
|
||||
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
|
||||
@@ -510,7 +510,7 @@ class StableDiffusionLongPromptWeightingPipeline(
|
||||
new_config["steps_offset"] = 1
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
|
||||
if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
|
||||
deprecation_message = (
|
||||
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
|
||||
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
|
||||
@@ -539,10 +539,14 @@ class StableDiffusionLongPromptWeightingPipeline(
|
||||
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
||||
)
|
||||
|
||||
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
|
||||
version.parse(unet.config._diffusers_version).base_version
|
||||
) < version.parse("0.9.0.dev0")
|
||||
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
||||
is_unet_version_less_0_9_0 = (
|
||||
unet is not None
|
||||
and hasattr(unet.config, "_diffusers_version")
|
||||
and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
|
||||
)
|
||||
is_unet_sample_size_less_64 = (
|
||||
unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
||||
)
|
||||
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
|
||||
deprecation_message = (
|
||||
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
||||
@@ -568,7 +572,7 @@ class StableDiffusionLongPromptWeightingPipeline(
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
||||
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
self.register_to_config(
|
||||
|
||||
@@ -673,12 +673,16 @@ class SDXLLongPromptWeightingPipeline(
|
||||
image_encoder=image_encoder,
|
||||
)
|
||||
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
self.mask_processor = VaeImageProcessor(
|
||||
vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
|
||||
)
|
||||
self.default_sample_size = self.unet.config.sample_size
|
||||
self.default_sample_size = (
|
||||
self.unet.config.sample_size
|
||||
if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size")
|
||||
else 128
|
||||
)
|
||||
|
||||
add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
|
||||
|
||||
@@ -827,7 +831,9 @@ class SDXLLongPromptWeightingPipeline(
|
||||
)
|
||||
|
||||
# We are only ALWAYS interested in the pooled output of the final text encoder
|
||||
pooled_prompt_embeds = prompt_embeds[0]
|
||||
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
|
||||
pooled_prompt_embeds = prompt_embeds[0]
|
||||
|
||||
prompt_embeds = prompt_embeds.hidden_states[-2]
|
||||
|
||||
prompt_embeds_list.append(prompt_embeds)
|
||||
@@ -879,7 +885,8 @@ class SDXLLongPromptWeightingPipeline(
|
||||
output_hidden_states=True,
|
||||
)
|
||||
# We are only ALWAYS interested in the pooled output of the final text encoder
|
||||
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
|
||||
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
|
||||
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
|
||||
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
|
||||
|
||||
negative_prompt_embeds_list.append(negative_prompt_embeds)
|
||||
|
||||
@@ -3766,7 +3766,7 @@ class MatryoshkaPipeline(
|
||||
else:
|
||||
raise ValueError("Currently, nesting levels 0, 1, and 2 are supported.")
|
||||
|
||||
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
|
||||
if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
|
||||
deprecation_message = (
|
||||
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
|
||||
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
|
||||
@@ -3780,7 +3780,7 @@ class MatryoshkaPipeline(
|
||||
new_config["steps_offset"] = 1
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
# if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
|
||||
# if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
|
||||
# deprecation_message = (
|
||||
# f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
|
||||
# " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
|
||||
@@ -3793,10 +3793,14 @@ class MatryoshkaPipeline(
|
||||
# new_config["clip_sample"] = False
|
||||
# scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
|
||||
version.parse(unet.config._diffusers_version).base_version
|
||||
) < version.parse("0.9.0.dev0")
|
||||
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
||||
is_unet_version_less_0_9_0 = (
|
||||
unet is not None
|
||||
and hasattr(unet.config, "_diffusers_version")
|
||||
and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
|
||||
)
|
||||
is_unet_sample_size_less_64 = (
|
||||
unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
||||
)
|
||||
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
|
||||
deprecation_message = (
|
||||
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
||||
|
||||
@@ -98,7 +98,7 @@ class MultilingualStableDiffusion(DiffusionPipeline, StableDiffusionMixin):
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
|
||||
if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
|
||||
deprecation_message = (
|
||||
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
|
||||
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
|
||||
|
||||
@@ -188,7 +188,7 @@ class AnimateDiffControlNetPipeline(
|
||||
feature_extractor=feature_extractor,
|
||||
image_encoder=image_encoder,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
self.control_image_processor = VaeImageProcessor(
|
||||
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
|
||||
|
||||
@@ -308,7 +308,7 @@ class AnimateDiffImgToVideoPipeline(
|
||||
feature_extractor=feature_extractor,
|
||||
image_encoder=image_encoder,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt
|
||||
|
||||
@@ -162,7 +162,7 @@ class AnimateDiffPipelineIpex(
|
||||
feature_extractor=feature_extractor,
|
||||
image_encoder=image_encoder,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
||||
self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt
|
||||
|
||||
@@ -166,9 +166,13 @@ class DemoFusionSDXLPipeline(
|
||||
scheduler=scheduler,
|
||||
)
|
||||
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
self.default_sample_size = self.unet.config.sample_size
|
||||
self.default_sample_size = (
|
||||
self.unet.config.sample_size
|
||||
if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size")
|
||||
else 128
|
||||
)
|
||||
|
||||
add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
|
||||
|
||||
@@ -290,7 +294,9 @@ class DemoFusionSDXLPipeline(
|
||||
)
|
||||
|
||||
# We are only ALWAYS interested in the pooled output of the final text encoder
|
||||
pooled_prompt_embeds = prompt_embeds[0]
|
||||
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
|
||||
pooled_prompt_embeds = prompt_embeds[0]
|
||||
|
||||
prompt_embeds = prompt_embeds.hidden_states[-2]
|
||||
|
||||
prompt_embeds_list.append(prompt_embeds)
|
||||
@@ -342,7 +348,8 @@ class DemoFusionSDXLPipeline(
|
||||
output_hidden_states=True,
|
||||
)
|
||||
# We are only ALWAYS interested in the pooled output of the final text encoder
|
||||
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
|
||||
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
|
||||
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
|
||||
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
|
||||
|
||||
negative_prompt_embeds_list.append(negative_prompt_embeds)
|
||||
|
||||
@@ -150,10 +150,14 @@ class FabricPipeline(DiffusionPipeline):
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
|
||||
version.parse(unet.config._diffusers_version).base_version
|
||||
) < version.parse("0.9.0.dev0")
|
||||
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
||||
is_unet_version_less_0_9_0 = (
|
||||
unet is not None
|
||||
and hasattr(unet.config, "_diffusers_version")
|
||||
and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
|
||||
)
|
||||
is_unet_sample_size_less_64 = (
|
||||
unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
||||
)
|
||||
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
|
||||
deprecation_message = (
|
||||
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
||||
@@ -179,7 +183,7 @@ class FabricPipeline(DiffusionPipeline):
|
||||
tokenizer=tokenizer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
|
||||
|
||||
@@ -221,13 +221,12 @@ class FluxDifferentialImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
self.vae_scale_factor = (
|
||||
2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels)) if getattr(self, "vae", None) else 16
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
|
||||
self.mask_processor = VaeImageProcessor(
|
||||
vae_scale_factor=self.vae_scale_factor,
|
||||
vae_latent_channels=self.vae.config.latent_channels,
|
||||
vae_latent_channels=latent_channels,
|
||||
do_normalize=False,
|
||||
do_binarize=False,
|
||||
do_convert_grayscale=True,
|
||||
@@ -876,10 +875,10 @@ class FluxDifferentialImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
||||
image_seq_len = (int(height) // self.vae_scale_factor) * (int(width) // self.vae_scale_factor)
|
||||
mu = calculate_shift(
|
||||
image_seq_len,
|
||||
self.scheduler.config.base_image_seq_len,
|
||||
self.scheduler.config.max_image_seq_len,
|
||||
self.scheduler.config.base_shift,
|
||||
self.scheduler.config.max_shift,
|
||||
self.scheduler.config.get("base_image_seq_len", 256),
|
||||
self.scheduler.config.get("max_image_seq_len", 4096),
|
||||
self.scheduler.config.get("base_shift", 0.5),
|
||||
self.scheduler.config.get("max_shift", 1.16),
|
||||
)
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler,
|
||||
|
||||
@@ -219,9 +219,7 @@ class RFInversionFluxPipeline(
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
self.vae_scale_factor = (
|
||||
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
self.tokenizer_max_length = (
|
||||
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
|
||||
@@ -419,7 +417,7 @@ class RFInversionFluxPipeline(
|
||||
)
|
||||
image = image.to(dtype)
|
||||
|
||||
x0 = self.vae.encode(image.to(self.device)).latent_dist.sample()
|
||||
x0 = self.vae.encode(image.to(self._execution_device)).latent_dist.sample()
|
||||
x0 = (x0 - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
||||
x0 = x0.to(dtype)
|
||||
return x0, resized
|
||||
@@ -822,10 +820,10 @@ class RFInversionFluxPipeline(
|
||||
image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2)
|
||||
mu = calculate_shift(
|
||||
image_seq_len,
|
||||
self.scheduler.config.base_image_seq_len,
|
||||
self.scheduler.config.max_image_seq_len,
|
||||
self.scheduler.config.base_shift,
|
||||
self.scheduler.config.max_shift,
|
||||
self.scheduler.config.get("base_image_seq_len", 256),
|
||||
self.scheduler.config.get("max_image_seq_len", 4096),
|
||||
self.scheduler.config.get("base_shift", 0.5),
|
||||
self.scheduler.config.get("max_shift", 1.16),
|
||||
)
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler,
|
||||
@@ -992,10 +990,10 @@ class RFInversionFluxPipeline(
|
||||
image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2)
|
||||
mu = calculate_shift(
|
||||
image_seq_len,
|
||||
self.scheduler.config.base_image_seq_len,
|
||||
self.scheduler.config.max_image_seq_len,
|
||||
self.scheduler.config.base_shift,
|
||||
self.scheduler.config.max_shift,
|
||||
self.scheduler.config.get("base_image_seq_len", 256),
|
||||
self.scheduler.config.get("max_image_seq_len", 4096),
|
||||
self.scheduler.config.get("base_shift", 0.5),
|
||||
self.scheduler.config.get("max_shift", 1.16),
|
||||
)
|
||||
timesteps, num_inversion_steps = retrieve_timesteps(
|
||||
self.scheduler,
|
||||
|
||||
@@ -64,6 +64,7 @@ EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
|
||||
def calculate_shift(
|
||||
image_seq_len,
|
||||
base_seq_len: int = 256,
|
||||
@@ -189,9 +190,7 @@ class FluxCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixi
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
self.vae_scale_factor = (
|
||||
2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels)) if getattr(self, "vae", None) else 16
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
self.tokenizer_max_length = (
|
||||
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
|
||||
@@ -757,10 +756,10 @@ class FluxCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixi
|
||||
image_seq_len = latents.shape[1]
|
||||
mu = calculate_shift(
|
||||
image_seq_len,
|
||||
self.scheduler.config.base_image_seq_len,
|
||||
self.scheduler.config.max_image_seq_len,
|
||||
self.scheduler.config.base_shift,
|
||||
self.scheduler.config.max_shift,
|
||||
self.scheduler.config.get("base_image_seq_len", 256),
|
||||
self.scheduler.config.get("max_image_seq_len", 4096),
|
||||
self.scheduler.config.get("base_shift", 0.5),
|
||||
self.scheduler.config.get("max_shift", 1.16),
|
||||
)
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler,
|
||||
|
||||
@@ -327,9 +327,7 @@ class HunyuanDiTDifferentialImg2ImgPipeline(DiffusionPipeline):
|
||||
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
||||
)
|
||||
|
||||
self.vae_scale_factor = (
|
||||
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
self.mask_processor = VaeImageProcessor(
|
||||
vae_scale_factor=self.vae_scale_factor,
|
||||
|
||||
@@ -209,16 +209,18 @@ class KolorsDifferentialImg2ImgPipeline(
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
|
||||
self.vae_scale_factor = (
|
||||
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
|
||||
self.mask_processor = VaeImageProcessor(
|
||||
vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_convert_grayscale=True
|
||||
)
|
||||
|
||||
self.default_sample_size = self.unet.config.sample_size
|
||||
self.default_sample_size = (
|
||||
self.unet.config.sample_size
|
||||
if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size")
|
||||
else 128
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.kolors.pipeline_kolors.KolorsPipeline.encode_prompt
|
||||
def encode_prompt(
|
||||
|
||||
@@ -131,7 +131,7 @@ class Prompt2PromptPipeline(
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
|
||||
if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
|
||||
deprecation_message = (
|
||||
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
|
||||
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
|
||||
@@ -145,7 +145,7 @@ class Prompt2PromptPipeline(
|
||||
new_config["steps_offset"] = 1
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
|
||||
if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
|
||||
deprecation_message = (
|
||||
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
|
||||
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
|
||||
@@ -174,10 +174,14 @@ class Prompt2PromptPipeline(
|
||||
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
||||
)
|
||||
|
||||
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
|
||||
version.parse(unet.config._diffusers_version).base_version
|
||||
) < version.parse("0.9.0.dev0")
|
||||
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
||||
is_unet_version_less_0_9_0 = (
|
||||
unet is not None
|
||||
and hasattr(unet.config, "_diffusers_version")
|
||||
and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
|
||||
)
|
||||
is_unet_sample_size_less_64 = (
|
||||
unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
||||
)
|
||||
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
|
||||
deprecation_message = (
|
||||
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
||||
@@ -205,7 +209,7 @@ class Prompt2PromptPipeline(
|
||||
feature_extractor=feature_extractor,
|
||||
image_encoder=image_encoder,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
|
||||
@@ -488,13 +488,17 @@ class StyleAlignedSDXLPipeline(
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
self.mask_processor = VaeImageProcessor(
|
||||
vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
|
||||
)
|
||||
|
||||
self.default_sample_size = self.unet.config.sample_size
|
||||
self.default_sample_size = (
|
||||
self.unet.config.sample_size
|
||||
if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size")
|
||||
else 128
|
||||
)
|
||||
|
||||
add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
|
||||
|
||||
@@ -628,7 +632,9 @@ class StyleAlignedSDXLPipeline(
|
||||
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
|
||||
|
||||
# We are only ALWAYS interested in the pooled output of the final text encoder
|
||||
pooled_prompt_embeds = prompt_embeds[0]
|
||||
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
|
||||
pooled_prompt_embeds = prompt_embeds[0]
|
||||
|
||||
if clip_skip is None:
|
||||
prompt_embeds = prompt_embeds.hidden_states[-2]
|
||||
else:
|
||||
@@ -688,7 +694,8 @@ class StyleAlignedSDXLPipeline(
|
||||
output_hidden_states=True,
|
||||
)
|
||||
# We are only ALWAYS interested in the pooled output of the final text encoder
|
||||
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
|
||||
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
|
||||
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
|
||||
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
|
||||
|
||||
negative_prompt_embeds_list.append(negative_prompt_embeds)
|
||||
|
||||
@@ -207,7 +207,7 @@ class StableDiffusion3DifferentialImg2ImgPipeline(DiffusionPipeline):
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
||||
self.image_processor = VaeImageProcessor(
|
||||
vae_scale_factor=self.vae_scale_factor, vae_latent_channels=self.vae.config.latent_channels
|
||||
)
|
||||
|
||||
@@ -417,7 +417,7 @@ class StableDiffusionBoxDiffPipeline(
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
|
||||
if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
|
||||
deprecation_message = (
|
||||
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
|
||||
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
|
||||
@@ -431,7 +431,7 @@ class StableDiffusionBoxDiffPipeline(
|
||||
new_config["steps_offset"] = 1
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
|
||||
if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
|
||||
deprecation_message = (
|
||||
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
|
||||
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
|
||||
@@ -460,10 +460,14 @@ class StableDiffusionBoxDiffPipeline(
|
||||
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
||||
)
|
||||
|
||||
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
|
||||
version.parse(unet.config._diffusers_version).base_version
|
||||
) < version.parse("0.9.0.dev0")
|
||||
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
||||
is_unet_version_less_0_9_0 = (
|
||||
unet is not None
|
||||
and hasattr(unet.config, "_diffusers_version")
|
||||
and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
|
||||
)
|
||||
is_unet_sample_size_less_64 = (
|
||||
unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
||||
)
|
||||
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
|
||||
deprecation_message = (
|
||||
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
||||
@@ -491,7 +495,7 @@ class StableDiffusionBoxDiffPipeline(
|
||||
feature_extractor=feature_extractor,
|
||||
image_encoder=image_encoder,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
|
||||
@@ -384,7 +384,7 @@ class StableDiffusionPAGPipeline(
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
|
||||
if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
|
||||
deprecation_message = (
|
||||
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
|
||||
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
|
||||
@@ -398,7 +398,7 @@ class StableDiffusionPAGPipeline(
|
||||
new_config["steps_offset"] = 1
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
|
||||
if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
|
||||
deprecation_message = (
|
||||
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
|
||||
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
|
||||
@@ -427,10 +427,14 @@ class StableDiffusionPAGPipeline(
|
||||
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
||||
)
|
||||
|
||||
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
|
||||
version.parse(unet.config._diffusers_version).base_version
|
||||
) < version.parse("0.9.0.dev0")
|
||||
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
||||
is_unet_version_less_0_9_0 = (
|
||||
unet is not None
|
||||
and hasattr(unet.config, "_diffusers_version")
|
||||
and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
|
||||
)
|
||||
is_unet_sample_size_less_64 = (
|
||||
unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
||||
)
|
||||
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
|
||||
deprecation_message = (
|
||||
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
||||
@@ -458,7 +462,7 @@ class StableDiffusionPAGPipeline(
|
||||
feature_extractor=feature_extractor,
|
||||
image_encoder=image_encoder,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
|
||||
@@ -151,7 +151,7 @@ class StableDiffusionUpscaleLDM3DPipeline(
|
||||
watermarker=watermarker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
||||
self.image_processor = VaeImageProcessorLDM3D(vae_scale_factor=self.vae_scale_factor, resample="bilinear")
|
||||
# self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
self.register_to_config(max_noise_level=max_noise_level)
|
||||
|
||||
@@ -226,12 +226,16 @@ class StableDiffusionXLControlNetAdapterPipeline(
|
||||
scheduler=scheduler,
|
||||
)
|
||||
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
self.control_image_processor = VaeImageProcessor(
|
||||
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
|
||||
)
|
||||
self.default_sample_size = self.unet.config.sample_size
|
||||
self.default_sample_size = (
|
||||
self.unet.config.sample_size
|
||||
if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size")
|
||||
else 128
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
|
||||
def encode_prompt(
|
||||
@@ -359,7 +363,9 @@ class StableDiffusionXLControlNetAdapterPipeline(
|
||||
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
|
||||
|
||||
# We are only ALWAYS interested in the pooled output of the final text encoder
|
||||
pooled_prompt_embeds = prompt_embeds[0]
|
||||
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
|
||||
pooled_prompt_embeds = prompt_embeds[0]
|
||||
|
||||
if clip_skip is None:
|
||||
prompt_embeds = prompt_embeds.hidden_states[-2]
|
||||
else:
|
||||
@@ -419,7 +425,8 @@ class StableDiffusionXLControlNetAdapterPipeline(
|
||||
output_hidden_states=True,
|
||||
)
|
||||
# We are only ALWAYS interested in the pooled output of the final text encoder
|
||||
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
|
||||
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
|
||||
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
|
||||
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
|
||||
|
||||
negative_prompt_embeds_list.append(negative_prompt_embeds)
|
||||
|
||||
@@ -374,12 +374,16 @@ class StableDiffusionXLControlNetAdapterInpaintPipeline(
|
||||
)
|
||||
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
|
||||
self.register_to_config(requires_aesthetics_score=requires_aesthetics_score)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
self.control_image_processor = VaeImageProcessor(
|
||||
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
|
||||
)
|
||||
self.default_sample_size = self.unet.config.sample_size
|
||||
self.default_sample_size = (
|
||||
self.unet.config.sample_size
|
||||
if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size")
|
||||
else 128
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
|
||||
def encode_prompt(
|
||||
@@ -507,7 +511,9 @@ class StableDiffusionXLControlNetAdapterInpaintPipeline(
|
||||
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
|
||||
|
||||
# We are only ALWAYS interested in the pooled output of the final text encoder
|
||||
pooled_prompt_embeds = prompt_embeds[0]
|
||||
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
|
||||
pooled_prompt_embeds = prompt_embeds[0]
|
||||
|
||||
if clip_skip is None:
|
||||
prompt_embeds = prompt_embeds.hidden_states[-2]
|
||||
else:
|
||||
@@ -567,7 +573,8 @@ class StableDiffusionXLControlNetAdapterInpaintPipeline(
|
||||
output_hidden_states=True,
|
||||
)
|
||||
# We are only ALWAYS interested in the pooled output of the final text encoder
|
||||
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
|
||||
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
|
||||
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
|
||||
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
|
||||
|
||||
negative_prompt_embeds_list.append(negative_prompt_embeds)
|
||||
|
||||
@@ -258,7 +258,7 @@ class StableDiffusionXLDifferentialImg2ImgPipeline(
|
||||
)
|
||||
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
|
||||
self.register_to_config(requires_aesthetics_score=requires_aesthetics_score)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
|
||||
add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
|
||||
@@ -394,7 +394,9 @@ class StableDiffusionXLDifferentialImg2ImgPipeline(
|
||||
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
|
||||
|
||||
# We are only ALWAYS interested in the pooled output of the final text encoder
|
||||
pooled_prompt_embeds = prompt_embeds[0]
|
||||
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
|
||||
pooled_prompt_embeds = prompt_embeds[0]
|
||||
|
||||
if clip_skip is None:
|
||||
prompt_embeds = prompt_embeds.hidden_states[-2]
|
||||
else:
|
||||
@@ -454,7 +456,8 @@ class StableDiffusionXLDifferentialImg2ImgPipeline(
|
||||
output_hidden_states=True,
|
||||
)
|
||||
# We are only ALWAYS interested in the pooled output of the final text encoder
|
||||
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
|
||||
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
|
||||
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
|
||||
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
|
||||
|
||||
negative_prompt_embeds_list.append(negative_prompt_embeds)
|
||||
|
||||
@@ -253,10 +253,14 @@ class StableDiffusionXLPipelineIpex(
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
|
||||
self.default_sample_size = self.unet.config.sample_size
|
||||
self.default_sample_size = (
|
||||
self.unet.config.sample_size
|
||||
if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size")
|
||||
else 128
|
||||
)
|
||||
|
||||
add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
|
||||
|
||||
@@ -390,7 +394,9 @@ class StableDiffusionXLPipelineIpex(
|
||||
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
|
||||
|
||||
# We are only ALWAYS interested in the pooled output of the final text encoder
|
||||
pooled_prompt_embeds = prompt_embeds[0]
|
||||
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
|
||||
pooled_prompt_embeds = prompt_embeds[0]
|
||||
|
||||
if clip_skip is None:
|
||||
prompt_embeds = prompt_embeds.hidden_states[-2]
|
||||
else:
|
||||
@@ -450,7 +456,8 @@ class StableDiffusionXLPipelineIpex(
|
||||
output_hidden_states=True,
|
||||
)
|
||||
# We are only ALWAYS interested in the pooled output of the final text encoder
|
||||
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
|
||||
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
|
||||
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
|
||||
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
|
||||
|
||||
negative_prompt_embeds_list.append(negative_prompt_embeds)
|
||||
|
||||
@@ -108,7 +108,7 @@ class Zero1to3StableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
|
||||
if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
|
||||
deprecation_message = (
|
||||
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
|
||||
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
|
||||
@@ -122,7 +122,7 @@ class Zero1to3StableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
|
||||
new_config["steps_offset"] = 1
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
|
||||
if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
|
||||
deprecation_message = (
|
||||
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
|
||||
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
|
||||
@@ -151,10 +151,14 @@ class Zero1to3StableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
|
||||
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
||||
)
|
||||
|
||||
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
|
||||
version.parse(unet.config._diffusers_version).base_version
|
||||
) < version.parse("0.9.0.dev0")
|
||||
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
||||
is_unet_version_less_0_9_0 = (
|
||||
unet is not None
|
||||
and hasattr(unet.config, "_diffusers_version")
|
||||
and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
|
||||
)
|
||||
is_unet_sample_size_less_64 = (
|
||||
unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
||||
)
|
||||
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
|
||||
deprecation_message = (
|
||||
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
||||
@@ -181,7 +185,7 @@ class Zero1to3StableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
|
||||
feature_extractor=feature_extractor,
|
||||
cc_projection=cc_projection,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
# self.model_mode = None
|
||||
|
||||
|
||||
@@ -352,7 +352,7 @@ class RerenderAVideoPipeline(StableDiffusionControlNetImg2ImgPipeline):
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
|
||||
self.control_image_processor = VaeImageProcessor(
|
||||
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
|
||||
@@ -632,7 +632,7 @@ class RerenderAVideoPipeline(StableDiffusionControlNetImg2ImgPipeline):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
frames (`List[np.ndarray]` or `torch.Tensor`): The input images to be used as the starting point for the image generation process.
|
||||
control_frames (`List[np.ndarray]` or `torch.Tensor`): The ControlNet input images condition to provide guidance to the `unet` for generation.
|
||||
control_frames (`List[np.ndarray]` or `torch.Tensor` or `Callable`): The ControlNet input images condition to provide guidance to the `unet` for generation or any callable object to convert frame to control_frame.
|
||||
strength ('float'): SDEdit strength.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
@@ -789,7 +789,7 @@ class RerenderAVideoPipeline(StableDiffusionControlNetImg2ImgPipeline):
|
||||
# Currently we only support single control
|
||||
if isinstance(controlnet, ControlNetModel):
|
||||
control_image = self.prepare_control_image(
|
||||
image=control_frames[0],
|
||||
image=control_frames(frames[0]) if callable(control_frames) else control_frames[0],
|
||||
width=width,
|
||||
height=height,
|
||||
batch_size=batch_size,
|
||||
@@ -908,6 +908,9 @@ class RerenderAVideoPipeline(StableDiffusionControlNetImg2ImgPipeline):
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
if not output_type == "latent":
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
||||
else:
|
||||
@@ -924,7 +927,7 @@ class RerenderAVideoPipeline(StableDiffusionControlNetImg2ImgPipeline):
|
||||
for idx in range(1, len(frames)):
|
||||
image = frames[idx]
|
||||
prev_image = frames[idx - 1]
|
||||
control_image = control_frames[idx]
|
||||
control_image = control_frames(image) if callable(control_frames) else control_frames[idx]
|
||||
# 5.1 prepare frames
|
||||
image = self.image_processor.preprocess(image).to(dtype=self.dtype)
|
||||
prev_image = self.image_processor.preprocess(prev_image).to(dtype=self.dtype)
|
||||
|
||||
@@ -179,7 +179,7 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, StableDiffusio
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
def _encode_prompt(
|
||||
|
||||
@@ -278,7 +278,7 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, StableDiffusio
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
def _encode_prompt(
|
||||
|
||||
@@ -263,7 +263,7 @@ class StableDiffusionControlNetInpaintImg2ImgPipeline(DiffusionPipeline, StableD
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
def _encode_prompt(
|
||||
|
||||
@@ -105,7 +105,7 @@ class StableDiffusionIPEXPipeline(
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
|
||||
if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
|
||||
deprecation_message = (
|
||||
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
|
||||
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
|
||||
@@ -119,7 +119,7 @@ class StableDiffusionIPEXPipeline(
|
||||
new_config["steps_offset"] = 1
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
|
||||
if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
|
||||
deprecation_message = (
|
||||
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
|
||||
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
|
||||
@@ -148,10 +148,14 @@ class StableDiffusionIPEXPipeline(
|
||||
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
||||
)
|
||||
|
||||
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
|
||||
version.parse(unet.config._diffusers_version).base_version
|
||||
) < version.parse("0.9.0.dev0")
|
||||
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
||||
is_unet_version_less_0_9_0 = (
|
||||
unet is not None
|
||||
and hasattr(unet.config, "_diffusers_version")
|
||||
and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
|
||||
)
|
||||
is_unet_sample_size_less_64 = (
|
||||
unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
||||
)
|
||||
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
|
||||
deprecation_message = (
|
||||
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
||||
@@ -178,7 +182,7 @@ class StableDiffusionIPEXPipeline(
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
def get_input_example(self, prompt, height=None, width=None, guidance_scale=7.5, num_images_per_prompt=1):
|
||||
|
||||
@@ -66,7 +66,7 @@ class StableDiffusionMegaPipeline(DiffusionPipeline, StableDiffusionMixin):
|
||||
requires_safety_checker: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
|
||||
if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
|
||||
deprecation_message = (
|
||||
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
|
||||
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
|
||||
|
||||
@@ -132,7 +132,7 @@ class StableDiffusionReferencePipeline(
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
|
||||
if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
|
||||
deprecation_message = (
|
||||
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
|
||||
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
|
||||
@@ -146,7 +146,7 @@ class StableDiffusionReferencePipeline(
|
||||
new_config["steps_offset"] = 1
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if hasattr(scheduler.config, "skip_prk_steps") and scheduler.config.skip_prk_steps is False:
|
||||
if scheduler is not None and getattr(scheduler.config, "skip_prk_steps", True) is False:
|
||||
deprecation_message = (
|
||||
f"The configuration file of this scheduler: {scheduler} has not set the configuration"
|
||||
" `skip_prk_steps`. `skip_prk_steps` should be set to True in the configuration file. Please make"
|
||||
@@ -181,10 +181,14 @@ class StableDiffusionReferencePipeline(
|
||||
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
||||
)
|
||||
|
||||
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
|
||||
version.parse(unet.config._diffusers_version).base_version
|
||||
) < version.parse("0.9.0.dev0")
|
||||
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
||||
is_unet_version_less_0_9_0 = (
|
||||
unet is not None
|
||||
and hasattr(unet.config, "_diffusers_version")
|
||||
and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
|
||||
)
|
||||
is_unet_sample_size_less_64 = (
|
||||
unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
||||
)
|
||||
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
|
||||
deprecation_message = (
|
||||
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
||||
@@ -202,7 +206,7 @@ class StableDiffusionReferencePipeline(
|
||||
new_config["sample_size"] = 64
|
||||
unet._internal_dict = FrozenDict(new_config)
|
||||
# Check shapes, assume num_channels_latents == 4, num_channels_mask == 1, num_channels_masked == 4
|
||||
if unet.config.in_channels != 4:
|
||||
if unet is not None and unet.config.in_channels != 4:
|
||||
logger.warning(
|
||||
f"You have loaded a UNet with {unet.config.in_channels} input channels, whereas by default,"
|
||||
f" {self.__class__} assumes that `pipeline.unet` has 4 input channels: 4 for `num_channels_latents`,"
|
||||
@@ -219,7 +223,7 @@ class StableDiffusionReferencePipeline(
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
|
||||
@@ -187,7 +187,7 @@ class StableDiffusionRepaintPipeline(
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
|
||||
if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
|
||||
deprecation_message = (
|
||||
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
|
||||
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
|
||||
@@ -201,7 +201,7 @@ class StableDiffusionRepaintPipeline(
|
||||
new_config["steps_offset"] = 1
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if hasattr(scheduler.config, "skip_prk_steps") and scheduler.config.skip_prk_steps is False:
|
||||
if scheduler is not None and getattr(scheduler.config, "skip_prk_steps", True) is False:
|
||||
deprecation_message = (
|
||||
f"The configuration file of this scheduler: {scheduler} has not set the configuration"
|
||||
" `skip_prk_steps`. `skip_prk_steps` should be set to True in the configuration file. Please make"
|
||||
@@ -236,10 +236,14 @@ class StableDiffusionRepaintPipeline(
|
||||
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
||||
)
|
||||
|
||||
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
|
||||
version.parse(unet.config._diffusers_version).base_version
|
||||
) < version.parse("0.9.0.dev0")
|
||||
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
||||
is_unet_version_less_0_9_0 = (
|
||||
unet is not None
|
||||
and hasattr(unet.config, "_diffusers_version")
|
||||
and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
|
||||
)
|
||||
is_unet_sample_size_less_64 = (
|
||||
unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
||||
)
|
||||
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
|
||||
deprecation_message = (
|
||||
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
||||
@@ -257,7 +261,7 @@ class StableDiffusionRepaintPipeline(
|
||||
new_config["sample_size"] = 64
|
||||
unet._internal_dict = FrozenDict(new_config)
|
||||
# Check shapes, assume num_channels_latents == 4, num_channels_mask == 1, num_channels_masked == 4
|
||||
if unet.config.in_channels != 4:
|
||||
if unet is not None and unet.config.in_channels != 4:
|
||||
logger.warning(
|
||||
f"You have loaded a UNet with {unet.config.in_channels} input channels, whereas by default,"
|
||||
f" {self.__class__} assumes that `pipeline.unet` has 4 input channels: 4 for `num_channels_latents`,"
|
||||
@@ -274,7 +278,7 @@ class StableDiffusionRepaintPipeline(
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
|
||||
|
||||
@@ -710,7 +710,7 @@ class TensorRTStableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
|
||||
if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
|
||||
deprecation_message = (
|
||||
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
|
||||
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
|
||||
@@ -724,7 +724,7 @@ class TensorRTStableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
new_config["steps_offset"] = 1
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
|
||||
if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
|
||||
deprecation_message = (
|
||||
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
|
||||
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
|
||||
@@ -753,10 +753,14 @@ class TensorRTStableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
||||
)
|
||||
|
||||
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
|
||||
version.parse(unet.config._diffusers_version).base_version
|
||||
) < version.parse("0.9.0.dev0")
|
||||
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
||||
is_unet_version_less_0_9_0 = (
|
||||
unet is not None
|
||||
and hasattr(unet.config, "_diffusers_version")
|
||||
and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
|
||||
)
|
||||
is_unet_sample_size_less_64 = (
|
||||
unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
||||
)
|
||||
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
|
||||
deprecation_message = (
|
||||
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
||||
@@ -806,7 +810,7 @@ class TensorRTStableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
self.engine = {} # loaded in build_engines()
|
||||
|
||||
self.vae.forward = self.vae.decode
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
|
||||
@@ -714,7 +714,7 @@ class TensorRTStableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
|
||||
if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
|
||||
deprecation_message = (
|
||||
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
|
||||
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
|
||||
@@ -728,7 +728,7 @@ class TensorRTStableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
new_config["steps_offset"] = 1
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
|
||||
if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
|
||||
deprecation_message = (
|
||||
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
|
||||
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
|
||||
@@ -757,10 +757,14 @@ class TensorRTStableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
||||
)
|
||||
|
||||
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
|
||||
version.parse(unet.config._diffusers_version).base_version
|
||||
) < version.parse("0.9.0.dev0")
|
||||
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
||||
is_unet_version_less_0_9_0 = (
|
||||
unet is not None
|
||||
and hasattr(unet.config, "_diffusers_version")
|
||||
and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
|
||||
)
|
||||
is_unet_sample_size_less_64 = (
|
||||
unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
||||
)
|
||||
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
|
||||
deprecation_message = (
|
||||
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
||||
@@ -810,7 +814,7 @@ class TensorRTStableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
self.engine = {} # loaded in build_engines()
|
||||
|
||||
self.vae.forward = self.vae.decode
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
|
||||
@@ -626,7 +626,7 @@ class TensorRTStableDiffusionPipeline(DiffusionPipeline):
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
|
||||
if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
|
||||
deprecation_message = (
|
||||
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
|
||||
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
|
||||
@@ -640,7 +640,7 @@ class TensorRTStableDiffusionPipeline(DiffusionPipeline):
|
||||
new_config["steps_offset"] = 1
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
|
||||
if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
|
||||
deprecation_message = (
|
||||
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
|
||||
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
|
||||
@@ -669,10 +669,14 @@ class TensorRTStableDiffusionPipeline(DiffusionPipeline):
|
||||
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
||||
)
|
||||
|
||||
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
|
||||
version.parse(unet.config._diffusers_version).base_version
|
||||
) < version.parse("0.9.0.dev0")
|
||||
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
||||
is_unet_version_less_0_9_0 = (
|
||||
unet is not None
|
||||
and hasattr(unet.config, "_diffusers_version")
|
||||
and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
|
||||
)
|
||||
is_unet_sample_size_less_64 = (
|
||||
unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
||||
)
|
||||
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
|
||||
deprecation_message = (
|
||||
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
||||
@@ -722,7 +726,7 @@ class TensorRTStableDiffusionPipeline(DiffusionPipeline):
|
||||
self.engine = {} # loaded in build_engines()
|
||||
|
||||
self.vae.forward = self.vae.decode
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
|
||||
@@ -71,7 +71,7 @@ class TextInpainting(DiffusionPipeline, StableDiffusionMixin):
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
|
||||
if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
|
||||
deprecation_message = (
|
||||
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
|
||||
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
|
||||
@@ -85,7 +85,7 @@ class TextInpainting(DiffusionPipeline, StableDiffusionMixin):
|
||||
new_config["steps_offset"] = 1
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if hasattr(scheduler.config, "skip_prk_steps") and scheduler.config.skip_prk_steps is False:
|
||||
if scheduler is not None and getattr(scheduler.config, "skip_prk_steps", True) is False:
|
||||
deprecation_message = (
|
||||
f"The configuration file of this scheduler: {scheduler} has not set the configuration"
|
||||
" `skip_prk_steps`. `skip_prk_steps` should be set to True in the configuration file. Please make"
|
||||
|
||||
@@ -120,7 +120,7 @@ class WildcardStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
|
||||
if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
|
||||
deprecation_message = (
|
||||
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
|
||||
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
|
||||
|
||||
@@ -121,7 +121,7 @@ prompt = "A couple, 4k photo, highly detailed"
|
||||
|
||||
gen_images = pipe(
|
||||
prompt=prompt,
|
||||
condition_image=image,
|
||||
control_image=image,
|
||||
num_inference_steps=50,
|
||||
joint_attention_kwargs={"scale": 0.9},
|
||||
guidance_scale=25.,
|
||||
@@ -190,7 +190,7 @@ prompt = "A couple, 4k photo, highly detailed"
|
||||
|
||||
gen_images = pipe(
|
||||
prompt=prompt,
|
||||
condition_image=image,
|
||||
control_image=image,
|
||||
num_inference_steps=50,
|
||||
guidance_scale=25.,
|
||||
).images[0]
|
||||
@@ -200,5 +200,5 @@ gen_images.save("output.png")
|
||||
## Things to note
|
||||
|
||||
* The scripts provided in this directory are experimental and educational. This means we may have to tweak things around to get good results on a given condition. We believe this is best done with the community 🤗
|
||||
* The scripts are not memory-optimized but we offload the VAE and the text encoders to CPU when they are not used.
|
||||
* The scripts are not memory-optimized but we offload the VAE and the text encoders to CPU when they are not used if `--offload` is specified.
|
||||
* We can extract LoRAs from the fully fine-tuned model. While we currently don't provide any utilities for that, users are welcome to refer to [this script](https://github.com/Stability-AI/stability-ComfyUI-nodes/blob/master/control_lora_create.py) that provides a similar functionality.
|
||||
@@ -122,7 +122,6 @@ def log_validation(flux_transformer, args, accelerator, weight_dtype, step, is_f
|
||||
|
||||
for _ in range(args.num_validation_images):
|
||||
with autocast_ctx:
|
||||
# need to fix in pipeline_flux_controlnet
|
||||
image = pipeline(
|
||||
prompt=validation_prompt,
|
||||
control_image=validation_image,
|
||||
@@ -159,7 +158,7 @@ def log_validation(flux_transformer, args, accelerator, weight_dtype, step, is_f
|
||||
images = log["images"]
|
||||
validation_prompt = log["validation_prompt"]
|
||||
validation_image = log["validation_image"]
|
||||
formatted_images.append(wandb.Image(validation_image, caption="Controlnet conditioning"))
|
||||
formatted_images.append(wandb.Image(validation_image, caption="Conditioning"))
|
||||
for image in images:
|
||||
image = wandb.Image(image, caption=validation_prompt)
|
||||
formatted_images.append(image)
|
||||
@@ -188,7 +187,7 @@ def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=N
|
||||
img_str += f"\n"
|
||||
|
||||
model_description = f"""
|
||||
# control-lora-{repo_id}
|
||||
# flux-control-{repo_id}
|
||||
|
||||
These are Control weights trained on {base_model} with new type of conditioning.
|
||||
{img_str}
|
||||
@@ -434,7 +433,7 @@ def parse_args(input_args=None):
|
||||
"--conditioning_image_column",
|
||||
type=str,
|
||||
default="conditioning_image",
|
||||
help="The column of the dataset containing the controlnet conditioning image.",
|
||||
help="The column of the dataset containing the control conditioning image.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--caption_column",
|
||||
@@ -442,6 +441,7 @@ def parse_args(input_args=None):
|
||||
default="text",
|
||||
help="The column of the dataset containing a caption or a list of captions.",
|
||||
)
|
||||
parser.add_argument("--log_dataset_samples", action="store_true", help="Whether to log somple dataset samples.")
|
||||
parser.add_argument(
|
||||
"--max_train_samples",
|
||||
type=int,
|
||||
@@ -468,7 +468,7 @@ def parse_args(input_args=None):
|
||||
default=None,
|
||||
nargs="+",
|
||||
help=(
|
||||
"A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`"
|
||||
"A set of paths to the control conditioning image be evaluated every `--validation_steps`"
|
||||
" and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a"
|
||||
" a single `--validation_prompt` to be used with all `--validation_image`s, or a single"
|
||||
" `--validation_image` that will be used with all `--validation_prompt`s."
|
||||
@@ -505,7 +505,11 @@ def parse_args(input_args=None):
|
||||
default=None,
|
||||
help="Path to the jsonl file containing the training data.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--only_target_transformer_blocks",
|
||||
action="store_true",
|
||||
help="If we should only target the transformer blocks to train along with the input layer (`x_embedder`).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--guidance_scale",
|
||||
type=float,
|
||||
@@ -581,7 +585,7 @@ def parse_args(input_args=None):
|
||||
|
||||
if args.resolution % 8 != 0:
|
||||
raise ValueError(
|
||||
"`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder."
|
||||
"`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the Flux transformer."
|
||||
)
|
||||
|
||||
return args
|
||||
@@ -665,7 +669,12 @@ def prepare_train_dataset(dataset, accelerator):
|
||||
conditioning_images = [image_transforms(image) for image in conditioning_images]
|
||||
examples["pixel_values"] = images
|
||||
examples["conditioning_pixel_values"] = conditioning_images
|
||||
examples["captions"] = list(examples[args.caption_column])
|
||||
|
||||
is_caption_list = isinstance(examples[args.caption_column][0], list)
|
||||
if is_caption_list:
|
||||
examples["captions"] = [max(example, key=len) for example in examples[args.caption_column]]
|
||||
else:
|
||||
examples["captions"] = list(examples[args.caption_column])
|
||||
|
||||
return examples
|
||||
|
||||
@@ -765,7 +774,8 @@ def main(args):
|
||||
subfolder="scheduler",
|
||||
)
|
||||
noise_scheduler_copy = copy.deepcopy(noise_scheduler)
|
||||
flux_transformer.requires_grad_(True)
|
||||
if not args.only_target_transformer_blocks:
|
||||
flux_transformer.requires_grad_(True)
|
||||
vae.requires_grad_(False)
|
||||
|
||||
# cast down and move to the CPU
|
||||
@@ -797,6 +807,12 @@ def main(args):
|
||||
assert torch.all(flux_transformer.x_embedder.weight[:, initial_input_channels:].data == 0)
|
||||
flux_transformer.register_to_config(in_channels=initial_input_channels * 2, out_channels=initial_input_channels)
|
||||
|
||||
if args.only_target_transformer_blocks:
|
||||
flux_transformer.x_embedder.requires_grad_(True)
|
||||
for name, module in flux_transformer.named_modules():
|
||||
if "transformer_blocks" in name:
|
||||
module.requires_grad_(True)
|
||||
|
||||
def unwrap_model(model):
|
||||
model = accelerator.unwrap_model(model)
|
||||
model = model._orig_mod if is_compiled_module(model) else model
|
||||
@@ -974,6 +990,32 @@ def main(args):
|
||||
else:
|
||||
initial_global_step = 0
|
||||
|
||||
if accelerator.is_main_process and args.report_to == "wandb" and args.log_dataset_samples:
|
||||
logger.info("Logging some dataset samples.")
|
||||
formatted_images = []
|
||||
formatted_control_images = []
|
||||
all_prompts = []
|
||||
for i, batch in enumerate(train_dataloader):
|
||||
images = (batch["pixel_values"] + 1) / 2
|
||||
control_images = (batch["conditioning_pixel_values"] + 1) / 2
|
||||
prompts = batch["captions"]
|
||||
|
||||
if len(formatted_images) > 10:
|
||||
break
|
||||
|
||||
for img, control_img, prompt in zip(images, control_images, prompts):
|
||||
formatted_images.append(img)
|
||||
formatted_control_images.append(control_img)
|
||||
all_prompts.append(prompt)
|
||||
|
||||
logged_artifacts = []
|
||||
for img, control_img, prompt in zip(formatted_images, formatted_control_images, all_prompts):
|
||||
logged_artifacts.append(wandb.Image(control_img, caption="Conditioning"))
|
||||
logged_artifacts.append(wandb.Image(img, caption=prompt))
|
||||
|
||||
wandb_tracker = [tracker for tracker in accelerator.trackers if tracker.name == "wandb"]
|
||||
wandb_tracker[0].log({"dataset_samples": logged_artifacts})
|
||||
|
||||
progress_bar = tqdm(
|
||||
range(0, args.max_train_steps),
|
||||
initial=initial_global_step,
|
||||
|
||||
@@ -132,7 +132,6 @@ def log_validation(flux_transformer, args, accelerator, weight_dtype, step, is_f
|
||||
|
||||
for _ in range(args.num_validation_images):
|
||||
with autocast_ctx:
|
||||
# need to fix in pipeline_flux_controlnet
|
||||
image = pipeline(
|
||||
prompt=validation_prompt,
|
||||
control_image=validation_image,
|
||||
@@ -169,7 +168,7 @@ def log_validation(flux_transformer, args, accelerator, weight_dtype, step, is_f
|
||||
images = log["images"]
|
||||
validation_prompt = log["validation_prompt"]
|
||||
validation_image = log["validation_image"]
|
||||
formatted_images.append(wandb.Image(validation_image, caption="Controlnet conditioning"))
|
||||
formatted_images.append(wandb.Image(validation_image, caption="Conditioning"))
|
||||
for image in images:
|
||||
image = wandb.Image(image, caption=validation_prompt)
|
||||
formatted_images.append(image)
|
||||
@@ -198,7 +197,7 @@ def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=N
|
||||
img_str += f"\n"
|
||||
|
||||
model_description = f"""
|
||||
# controlnet-lora-{repo_id}
|
||||
# control-lora-{repo_id}
|
||||
|
||||
These are Control LoRA weights trained on {base_model} with new type of conditioning.
|
||||
{img_str}
|
||||
@@ -256,7 +255,7 @@ def parse_args(input_args=None):
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default="controlnet-lora",
|
||||
default="control-lora",
|
||||
help="The output directory where the model predictions and checkpoints will be written.",
|
||||
)
|
||||
parser.add_argument(
|
||||
@@ -466,7 +465,7 @@ def parse_args(input_args=None):
|
||||
"--conditioning_image_column",
|
||||
type=str,
|
||||
default="conditioning_image",
|
||||
help="The column of the dataset containing the controlnet conditioning image.",
|
||||
help="The column of the dataset containing the control conditioning image.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--caption_column",
|
||||
@@ -474,6 +473,7 @@ def parse_args(input_args=None):
|
||||
default="text",
|
||||
help="The column of the dataset containing a caption or a list of captions.",
|
||||
)
|
||||
parser.add_argument("--log_dataset_samples", action="store_true", help="Whether to log somple dataset samples.")
|
||||
parser.add_argument(
|
||||
"--max_train_samples",
|
||||
type=int,
|
||||
@@ -500,7 +500,7 @@ def parse_args(input_args=None):
|
||||
default=None,
|
||||
nargs="+",
|
||||
help=(
|
||||
"A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`"
|
||||
"A set of paths to the control conditioning image be evaluated every `--validation_steps`"
|
||||
" and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a"
|
||||
" a single `--validation_prompt` to be used with all `--validation_image`s, or a single"
|
||||
" `--validation_image` that will be used with all `--validation_prompt`s."
|
||||
@@ -613,7 +613,7 @@ def parse_args(input_args=None):
|
||||
|
||||
if args.resolution % 8 != 0:
|
||||
raise ValueError(
|
||||
"`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder."
|
||||
"`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the Flux transformer."
|
||||
)
|
||||
|
||||
return args
|
||||
@@ -697,7 +697,12 @@ def prepare_train_dataset(dataset, accelerator):
|
||||
conditioning_images = [image_transforms(image) for image in conditioning_images]
|
||||
examples["pixel_values"] = images
|
||||
examples["conditioning_pixel_values"] = conditioning_images
|
||||
examples["captions"] = list(examples[args.caption_column])
|
||||
|
||||
is_caption_list = isinstance(examples[args.caption_column][0], list)
|
||||
if is_caption_list:
|
||||
examples["captions"] = [max(example, key=len) for example in examples[args.caption_column]]
|
||||
else:
|
||||
examples["captions"] = list(examples[args.caption_column])
|
||||
|
||||
return examples
|
||||
|
||||
@@ -1132,6 +1137,32 @@ def main(args):
|
||||
else:
|
||||
initial_global_step = 0
|
||||
|
||||
if accelerator.is_main_process and args.report_to == "wandb" and args.log_dataset_samples:
|
||||
logger.info("Logging some dataset samples.")
|
||||
formatted_images = []
|
||||
formatted_control_images = []
|
||||
all_prompts = []
|
||||
for i, batch in enumerate(train_dataloader):
|
||||
images = (batch["pixel_values"] + 1) / 2
|
||||
control_images = (batch["conditioning_pixel_values"] + 1) / 2
|
||||
prompts = batch["captions"]
|
||||
|
||||
if len(formatted_images) > 10:
|
||||
break
|
||||
|
||||
for img, control_img, prompt in zip(images, control_images, prompts):
|
||||
formatted_images.append(img)
|
||||
formatted_control_images.append(control_img)
|
||||
all_prompts.append(prompt)
|
||||
|
||||
logged_artifacts = []
|
||||
for img, control_img, prompt in zip(formatted_images, formatted_control_images, all_prompts):
|
||||
logged_artifacts.append(wandb.Image(control_img, caption="Conditioning"))
|
||||
logged_artifacts.append(wandb.Image(img, caption=prompt))
|
||||
|
||||
wandb_tracker = [tracker for tracker in accelerator.trackers if tracker.name == "wandb"]
|
||||
wandb_tracker[0].log({"dataset_samples": logged_artifacts})
|
||||
|
||||
progress_bar = tqdm(
|
||||
range(0, args.max_train_steps),
|
||||
initial=initial_global_step,
|
||||
|
||||
@@ -2,6 +2,34 @@
|
||||
This extended LoRA training script was authored by [Aiden-Frost](https://github.com/Aiden-Frost).
|
||||
This is an experimental LoRA extension of [this example](https://github.com/huggingface/diffusers/blob/main/examples/instruct_pix2pix/train_instruct_pix2pix.py). This script provides further support add LoRA layers for unet model.
|
||||
|
||||
## Running locally with PyTorch
|
||||
### Installing the dependencies
|
||||
|
||||
Before running the scripts, make sure to install the library's training dependencies:
|
||||
|
||||
**Important**
|
||||
|
||||
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
|
||||
```bash
|
||||
git clone https://github.com/huggingface/diffusers
|
||||
cd diffusers
|
||||
pip install .
|
||||
```
|
||||
|
||||
Then cd in the example folder and run
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
|
||||
|
||||
```bash
|
||||
accelerate config
|
||||
```
|
||||
|
||||
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
|
||||
|
||||
|
||||
## Training script example
|
||||
|
||||
```bash
|
||||
@@ -9,7 +37,7 @@ export MODEL_ID="timbrooks/instruct-pix2pix"
|
||||
export DATASET_ID="instruction-tuning-sd/cartoonization"
|
||||
export OUTPUT_DIR="instructPix2Pix-cartoonization"
|
||||
|
||||
accelerate launch finetune_instruct_pix2pix.py \
|
||||
accelerate launch train_instruct_pix2pix_lora.py \
|
||||
--pretrained_model_name_or_path=$MODEL_ID \
|
||||
--dataset_name=$DATASET_ID \
|
||||
--enable_xformers_memory_efficient_attention \
|
||||
@@ -24,7 +52,10 @@ accelerate launch finetune_instruct_pix2pix.py \
|
||||
--rank=4 \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--report_to=wandb \
|
||||
--push_to_hub
|
||||
--push_to_hub \
|
||||
--original_image_column="original_image" \
|
||||
--edited_image_column="cartoonized_image" \
|
||||
--edit_prompt_column="edit_prompt"
|
||||
```
|
||||
|
||||
## Inference
|
||||
|
||||
@@ -14,7 +14,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Script to fine-tune Stable Diffusion for InstructPix2Pix."""
|
||||
"""
|
||||
Script to fine-tune Stable Diffusion for LORA InstructPix2Pix.
|
||||
Base code referred from: https://github.com/huggingface/diffusers/blob/main/examples/instruct_pix2pix/train_instruct_pix2pix.py
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
@@ -30,6 +33,7 @@ import numpy as np
|
||||
import PIL
|
||||
import requests
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
import transformers
|
||||
@@ -39,21 +43,28 @@ from accelerate.utils import ProjectConfiguration, set_seed
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import create_repo, upload_folder
|
||||
from packaging import version
|
||||
from peft import LoraConfig
|
||||
from peft.utils import get_peft_model_state_dict
|
||||
from torchvision import transforms
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
import diffusers
|
||||
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionInstructPix2PixPipeline, UNet2DConditionModel
|
||||
from diffusers.models.lora import LoRALinearLayer
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import EMAModel
|
||||
from diffusers.utils import check_min_version, deprecate, is_wandb_available
|
||||
from diffusers.training_utils import EMAModel, cast_training_params
|
||||
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, deprecate, is_wandb_available
|
||||
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.torch_utils import is_compiled_module
|
||||
|
||||
|
||||
if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.26.0.dev0")
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
@@ -63,6 +74,92 @@ DATASET_NAME_MAPPING = {
|
||||
WANDB_TABLE_COL_NAMES = ["original_image", "edited_image", "edit_prompt"]
|
||||
|
||||
|
||||
def save_model_card(
|
||||
repo_id: str,
|
||||
images: list = None,
|
||||
base_model: str = None,
|
||||
dataset_name: str = None,
|
||||
repo_folder: str = None,
|
||||
):
|
||||
img_str = ""
|
||||
if images is not None:
|
||||
for i, image in enumerate(images):
|
||||
image.save(os.path.join(repo_folder, f"image_{i}.png"))
|
||||
img_str += f"\n"
|
||||
|
||||
model_description = f"""
|
||||
# LoRA text2image fine-tuning - {repo_id}
|
||||
These are LoRA adaption weights for {base_model}. The weights were fine-tuned on the {dataset_name} dataset. You can find some example images in the following. \n
|
||||
{img_str}
|
||||
"""
|
||||
|
||||
model_card = load_or_create_model_card(
|
||||
repo_id_or_path=repo_id,
|
||||
from_training=True,
|
||||
license="creativeml-openrail-m",
|
||||
base_model=base_model,
|
||||
model_description=model_description,
|
||||
inference=True,
|
||||
)
|
||||
|
||||
tags = [
|
||||
"stable-diffusion",
|
||||
"stable-diffusion-diffusers",
|
||||
"text-to-image",
|
||||
"instruct-pix2pix",
|
||||
"diffusers",
|
||||
"diffusers-training",
|
||||
"lora",
|
||||
]
|
||||
model_card = populate_model_card(model_card, tags=tags)
|
||||
|
||||
model_card.save(os.path.join(repo_folder, "README.md"))
|
||||
|
||||
|
||||
def log_validation(
|
||||
pipeline,
|
||||
args,
|
||||
accelerator,
|
||||
generator,
|
||||
):
|
||||
logger.info(
|
||||
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
|
||||
f" {args.validation_prompt}."
|
||||
)
|
||||
pipeline = pipeline.to(accelerator.device)
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
# run inference
|
||||
original_image = download_image(args.val_image_url)
|
||||
edited_images = []
|
||||
if torch.backends.mps.is_available():
|
||||
autocast_ctx = nullcontext()
|
||||
else:
|
||||
autocast_ctx = torch.autocast(accelerator.device.type)
|
||||
|
||||
with autocast_ctx:
|
||||
for _ in range(args.num_validation_images):
|
||||
edited_images.append(
|
||||
pipeline(
|
||||
args.validation_prompt,
|
||||
image=original_image,
|
||||
num_inference_steps=20,
|
||||
image_guidance_scale=1.5,
|
||||
guidance_scale=7,
|
||||
generator=generator,
|
||||
).images[0]
|
||||
)
|
||||
|
||||
for tracker in accelerator.trackers:
|
||||
if tracker.name == "wandb":
|
||||
wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES)
|
||||
for edited_image in edited_images:
|
||||
wandb_table.add_data(wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt)
|
||||
tracker.log({"validation": wandb_table})
|
||||
|
||||
return edited_images
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Simple example of a training script for InstructPix2Pix.")
|
||||
parser.add_argument(
|
||||
@@ -417,11 +514,6 @@ def main():
|
||||
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
|
||||
|
||||
if args.report_to == "wandb":
|
||||
if not is_wandb_available():
|
||||
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
|
||||
import wandb
|
||||
|
||||
# Make one log on every process with the configuration for debugging.
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
@@ -467,49 +559,58 @@ def main():
|
||||
args.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision
|
||||
)
|
||||
|
||||
# InstructPix2Pix uses an additional image for conditioning. To accommodate that,
|
||||
# it uses 8 channels (instead of 4) in the first (conv) layer of the UNet. This UNet is
|
||||
# then fine-tuned on the custom InstructPix2Pix dataset. This modified UNet is initialized
|
||||
# from the pre-trained checkpoints. For the extra channels added to the first layer, they are
|
||||
# initialized to zero.
|
||||
logger.info("Initializing the InstructPix2Pix UNet from the pretrained UNet.")
|
||||
in_channels = 8
|
||||
out_channels = unet.conv_in.out_channels
|
||||
unet.register_to_config(in_channels=in_channels)
|
||||
|
||||
with torch.no_grad():
|
||||
new_conv_in = nn.Conv2d(
|
||||
in_channels, out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding
|
||||
)
|
||||
new_conv_in.weight.zero_()
|
||||
new_conv_in.weight[:, :in_channels, :, :].copy_(unet.conv_in.weight)
|
||||
unet.conv_in = new_conv_in
|
||||
|
||||
# Freeze vae, text_encoder and unet
|
||||
vae.requires_grad_(False)
|
||||
text_encoder.requires_grad_(False)
|
||||
unet.requires_grad_(False)
|
||||
|
||||
# referred to https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora.py
|
||||
unet_lora_parameters = []
|
||||
for attn_processor_name, attn_processor in unet.attn_processors.items():
|
||||
# Parse the attention module.
|
||||
attn_module = unet
|
||||
for n in attn_processor_name.split(".")[:-1]:
|
||||
attn_module = getattr(attn_module, n)
|
||||
# For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
|
||||
# as these weights are only used for inference, keeping weights in full precision is not required.
|
||||
weight_dtype = torch.float32
|
||||
if accelerator.mixed_precision == "fp16":
|
||||
weight_dtype = torch.float16
|
||||
elif accelerator.mixed_precision == "bf16":
|
||||
weight_dtype = torch.bfloat16
|
||||
|
||||
# Set the `lora_layer` attribute of the attention-related matrices.
|
||||
attn_module.to_q.set_lora_layer(
|
||||
LoRALinearLayer(
|
||||
in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=args.rank
|
||||
)
|
||||
)
|
||||
attn_module.to_k.set_lora_layer(
|
||||
LoRALinearLayer(
|
||||
in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=args.rank
|
||||
)
|
||||
)
|
||||
# Freeze the unet parameters before adding adapters
|
||||
unet.requires_grad_(False)
|
||||
|
||||
attn_module.to_v.set_lora_layer(
|
||||
LoRALinearLayer(
|
||||
in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=args.rank
|
||||
)
|
||||
)
|
||||
attn_module.to_out[0].set_lora_layer(
|
||||
LoRALinearLayer(
|
||||
in_features=attn_module.to_out[0].in_features,
|
||||
out_features=attn_module.to_out[0].out_features,
|
||||
rank=args.rank,
|
||||
)
|
||||
)
|
||||
unet_lora_config = LoraConfig(
|
||||
r=args.rank,
|
||||
lora_alpha=args.rank,
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
|
||||
)
|
||||
|
||||
# Accumulate the LoRA params to optimize.
|
||||
unet_lora_parameters.extend(attn_module.to_q.lora_layer.parameters())
|
||||
unet_lora_parameters.extend(attn_module.to_k.lora_layer.parameters())
|
||||
unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters())
|
||||
unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters())
|
||||
# Move unet, vae and text_encoder to device and cast to weight_dtype
|
||||
unet.to(accelerator.device, dtype=weight_dtype)
|
||||
vae.to(accelerator.device, dtype=weight_dtype)
|
||||
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
# Add adapter and make sure the trainable params are in float32.
|
||||
unet.add_adapter(unet_lora_config)
|
||||
if args.mixed_precision == "fp16":
|
||||
# only upcast trainable parameters (LoRA) into fp32
|
||||
cast_training_params(unet, dtype=torch.float32)
|
||||
|
||||
# Create EMA for the unet.
|
||||
if args.use_ema:
|
||||
@@ -528,6 +629,13 @@ def main():
|
||||
else:
|
||||
raise ValueError("xformers is not available. Make sure it is installed correctly")
|
||||
|
||||
trainable_params = filter(lambda p: p.requires_grad, unet.parameters())
|
||||
|
||||
def unwrap_model(model):
|
||||
model = accelerator.unwrap_model(model)
|
||||
model = model._orig_mod if is_compiled_module(model) else model
|
||||
return model
|
||||
|
||||
# `accelerate` 0.16.0 will have better support for customized saving
|
||||
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
@@ -540,7 +648,8 @@ def main():
|
||||
model.save_pretrained(os.path.join(output_dir, "unet"))
|
||||
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
weights.pop()
|
||||
if weights:
|
||||
weights.pop()
|
||||
|
||||
def load_model_hook(models, input_dir):
|
||||
if args.use_ema:
|
||||
@@ -589,9 +698,9 @@ def main():
|
||||
else:
|
||||
optimizer_cls = torch.optim.AdamW
|
||||
|
||||
# train on only unet_lora_parameters
|
||||
# train on only lora_layers
|
||||
optimizer = optimizer_cls(
|
||||
unet_lora_parameters,
|
||||
trainable_params,
|
||||
lr=args.learning_rate,
|
||||
betas=(args.adam_beta1, args.adam_beta2),
|
||||
weight_decay=args.adam_weight_decay,
|
||||
@@ -730,22 +839,27 @@ def main():
|
||||
)
|
||||
|
||||
# Scheduler and math around the number of training steps.
|
||||
overrode_max_train_steps = False
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
|
||||
num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
|
||||
if args.max_train_steps is None:
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
overrode_max_train_steps = True
|
||||
len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
|
||||
num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
|
||||
num_training_steps_for_scheduler = (
|
||||
args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes
|
||||
)
|
||||
else:
|
||||
num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
|
||||
|
||||
lr_scheduler = get_scheduler(
|
||||
args.lr_scheduler,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
||||
num_training_steps=args.max_train_steps * accelerator.num_processes,
|
||||
num_warmup_steps=num_warmup_steps_for_scheduler,
|
||||
num_training_steps=num_training_steps_for_scheduler,
|
||||
)
|
||||
|
||||
# Prepare everything with our `accelerator`.
|
||||
unet, unet_lora_parameters, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
unet, unet_lora_parameters, optimizer, train_dataloader, lr_scheduler
|
||||
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
unet, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
|
||||
if args.use_ema:
|
||||
@@ -765,8 +879,14 @@ def main():
|
||||
|
||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
if overrode_max_train_steps:
|
||||
if args.max_train_steps is None:
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:
|
||||
logger.warning(
|
||||
f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
|
||||
f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
|
||||
f"This inconsistency may result in the learning rate scheduler not functioning properly."
|
||||
)
|
||||
# Afterwards we recalculate our number of training epochs
|
||||
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
|
||||
@@ -885,7 +1005,7 @@ def main():
|
||||
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
||||
|
||||
# Predict the noise residual and compute loss
|
||||
model_pred = unet(concatenated_noisy_latents, timesteps, encoder_hidden_states).sample
|
||||
model_pred = unet(concatenated_noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0]
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||
|
||||
# Gather the losses across all processes for logging (if we use distributed training).
|
||||
@@ -895,7 +1015,7 @@ def main():
|
||||
# Backpropagate
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients:
|
||||
accelerator.clip_grad_norm_(unet_lora_parameters, args.max_grad_norm)
|
||||
accelerator.clip_grad_norm_(trainable_params, args.max_grad_norm)
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
@@ -903,7 +1023,7 @@ def main():
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
if args.use_ema:
|
||||
ema_unet.step(unet_lora_parameters)
|
||||
ema_unet.step(trainable_params)
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
accelerator.log({"train_loss": train_loss}, step=global_step)
|
||||
@@ -933,6 +1053,16 @@ def main():
|
||||
|
||||
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
||||
accelerator.save_state(save_path)
|
||||
unwrapped_unet = unwrap_model(unet)
|
||||
unet_lora_state_dict = convert_state_dict_to_diffusers(
|
||||
get_peft_model_state_dict(unwrapped_unet)
|
||||
)
|
||||
|
||||
StableDiffusionInstructPix2PixPipeline.save_lora_weights(
|
||||
save_directory=save_path,
|
||||
unet_lora_layers=unet_lora_state_dict,
|
||||
safe_serialization=True,
|
||||
)
|
||||
logger.info(f"Saved state to {save_path}")
|
||||
|
||||
logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
||||
@@ -959,45 +1089,22 @@ def main():
|
||||
# The models need unwrapping because for compatibility in distributed training mode.
|
||||
pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
unet=accelerator.unwrap_model(unet),
|
||||
text_encoder=accelerator.unwrap_model(text_encoder),
|
||||
vae=accelerator.unwrap_model(vae),
|
||||
unet=unwrap_model(unet),
|
||||
text_encoder=unwrap_model(text_encoder),
|
||||
vae=unwrap_model(vae),
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
pipeline = pipeline.to(accelerator.device)
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
# run inference
|
||||
original_image = download_image(args.val_image_url)
|
||||
edited_images = []
|
||||
if torch.backends.mps.is_available():
|
||||
autocast_ctx = nullcontext()
|
||||
else:
|
||||
autocast_ctx = torch.autocast(accelerator.device.type)
|
||||
log_validation(
|
||||
pipeline,
|
||||
args,
|
||||
accelerator,
|
||||
generator,
|
||||
)
|
||||
|
||||
with autocast_ctx:
|
||||
for _ in range(args.num_validation_images):
|
||||
edited_images.append(
|
||||
pipeline(
|
||||
args.validation_prompt,
|
||||
image=original_image,
|
||||
num_inference_steps=20,
|
||||
image_guidance_scale=1.5,
|
||||
guidance_scale=7,
|
||||
generator=generator,
|
||||
).images[0]
|
||||
)
|
||||
|
||||
for tracker in accelerator.trackers:
|
||||
if tracker.name == "wandb":
|
||||
wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES)
|
||||
for edited_image in edited_images:
|
||||
wandb_table.add_data(
|
||||
wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt
|
||||
)
|
||||
tracker.log({"validation": wandb_table})
|
||||
if args.use_ema:
|
||||
# Switch back to the original UNet parameters.
|
||||
ema_unet.restore(unet.parameters())
|
||||
@@ -1008,22 +1115,47 @@ def main():
|
||||
# Create the pipeline using the trained modules and save it.
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
unet = accelerator.unwrap_model(unet)
|
||||
if args.use_ema:
|
||||
ema_unet.copy_to(unet.parameters())
|
||||
|
||||
# store only LORA layers
|
||||
unet = unet.to(torch.float32)
|
||||
|
||||
unwrapped_unet = unwrap_model(unet)
|
||||
unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unwrapped_unet))
|
||||
StableDiffusionInstructPix2PixPipeline.save_lora_weights(
|
||||
save_directory=args.output_dir,
|
||||
unet_lora_layers=unet_lora_state_dict,
|
||||
safe_serialization=True,
|
||||
)
|
||||
|
||||
pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
text_encoder=accelerator.unwrap_model(text_encoder),
|
||||
vae=accelerator.unwrap_model(vae),
|
||||
unet=unet,
|
||||
text_encoder=unwrap_model(text_encoder),
|
||||
vae=unwrap_model(vae),
|
||||
unet=unwrap_model(unet),
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
)
|
||||
# store only LORA layers
|
||||
unet.save_attn_procs(args.output_dir)
|
||||
pipeline.load_lora_weights(args.output_dir)
|
||||
|
||||
images = None
|
||||
if (args.val_image_url is not None) and (args.validation_prompt is not None):
|
||||
images = log_validation(
|
||||
pipeline,
|
||||
args,
|
||||
accelerator,
|
||||
generator,
|
||||
)
|
||||
|
||||
if args.push_to_hub:
|
||||
save_model_card(
|
||||
repo_id,
|
||||
images=images,
|
||||
base_model=args.pretrained_model_name_or_path,
|
||||
dataset_name=args.dataset_name,
|
||||
repo_folder=args.output_dir,
|
||||
)
|
||||
upload_folder(
|
||||
repo_id=repo_id,
|
||||
folder_path=args.output_dir,
|
||||
@@ -1031,31 +1163,6 @@ def main():
|
||||
ignore_patterns=["step_*", "epoch_*"],
|
||||
)
|
||||
|
||||
if args.validation_prompt is not None:
|
||||
edited_images = []
|
||||
pipeline = pipeline.to(accelerator.device)
|
||||
with torch.autocast(str(accelerator.device).replace(":0", "")):
|
||||
for _ in range(args.num_validation_images):
|
||||
edited_images.append(
|
||||
pipeline(
|
||||
args.validation_prompt,
|
||||
image=original_image,
|
||||
num_inference_steps=20,
|
||||
image_guidance_scale=1.5,
|
||||
guidance_scale=7,
|
||||
generator=generator,
|
||||
).images[0]
|
||||
)
|
||||
|
||||
for tracker in accelerator.trackers:
|
||||
if tracker.name == "wandb":
|
||||
wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES)
|
||||
for edited_image in edited_images:
|
||||
wandb_table.add_data(
|
||||
wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt
|
||||
)
|
||||
tracker.log({"test": wandb_table})
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
|
||||
@@ -310,7 +310,7 @@ class PixArtAlphaControlnetPipeline(DiffusionPipeline):
|
||||
controlnet=controlnet,
|
||||
)
|
||||
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
||||
self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
self.control_image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
|
||||
|
||||
@@ -233,7 +233,7 @@ class PromptDiffusionPipeline(
|
||||
feature_extractor=feature_extractor,
|
||||
image_encoder=image_encoder,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
|
||||
self.control_image_processor = VaeImageProcessor(
|
||||
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
|
||||
|
||||
@@ -78,7 +78,7 @@ class RDMPipeline(DiffusionPipeline, StableDiffusionMixin):
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
# Copy from statement here and all the methods we take from stable_diffusion_pipeline
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
self.retriever = retriever
|
||||
|
||||
|
||||
@@ -6,4 +6,4 @@ torch==2.2.0
|
||||
torchvision>=0.16
|
||||
ftfy==6.1.1
|
||||
tensorboard==2.14.0
|
||||
Jinja2==3.1.4
|
||||
Jinja2==3.1.5
|
||||
|
||||
@@ -765,7 +765,7 @@ def main(args):
|
||||
lora_state_dict = StableDiffusion3Pipeline.lora_state_dict(input_dir)
|
||||
|
||||
transformer_state_dict = {
|
||||
f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")
|
||||
f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
|
||||
}
|
||||
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
|
||||
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
|
||||
|
||||
@@ -25,6 +25,7 @@ from diffusers.utils.import_utils import is_accelerate_available
|
||||
CTX = init_empty_weights if is_accelerate_available else nullcontext
|
||||
|
||||
ckpt_ids = [
|
||||
"Efficient-Large-Model/Sana_1600M_4Kpx_BF16/checkpoints/Sana_1600M_4Kpx_BF16.pth",
|
||||
"Efficient-Large-Model/Sana_1600M_2Kpx_BF16/checkpoints/Sana_1600M_2Kpx_BF16.pth",
|
||||
"Efficient-Large-Model/Sana_1600M_1024px_MultiLing/checkpoints/Sana_1600M_1024px_MultiLing.pth",
|
||||
"Efficient-Large-Model/Sana_1600M_1024px_BF16/checkpoints/Sana_1600M_1024px_BF16.pth",
|
||||
@@ -89,7 +90,10 @@ def main(args):
|
||||
converted_state_dict["caption_norm.weight"] = state_dict.pop("attention_y_norm.weight")
|
||||
|
||||
# scheduler
|
||||
flow_shift = 3.0
|
||||
if args.image_size == 4096:
|
||||
flow_shift = 6.0
|
||||
else:
|
||||
flow_shift = 3.0
|
||||
|
||||
# model config
|
||||
if args.model_type == "SanaMS_1600M_P1_D20":
|
||||
@@ -99,7 +103,7 @@ def main(args):
|
||||
else:
|
||||
raise ValueError(f"{args.model_type} is not supported.")
|
||||
# Positional embedding interpolation scale.
|
||||
interpolation_scale = {512: None, 1024: None, 2048: 1.0}
|
||||
interpolation_scale = {512: None, 1024: None, 2048: 1.0, 4096: 2.0}
|
||||
|
||||
for depth in range(layer_num):
|
||||
# Transformer blocks.
|
||||
@@ -272,9 +276,9 @@ if __name__ == "__main__":
|
||||
"--image_size",
|
||||
default=1024,
|
||||
type=int,
|
||||
choices=[512, 1024, 2048],
|
||||
choices=[512, 1024, 2048, 4096],
|
||||
required=False,
|
||||
help="Image size of pretrained model, 512, 1024 or 2048.",
|
||||
help="Image size of pretrained model, 512, 1024, 2048 or 4096.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_type", default="SanaMS_1600M_P1_D20", type=str, choices=["SanaMS_1600M_P1_D20", "SanaMS_600M_P1_D28"]
|
||||
|
||||
@@ -135,6 +135,7 @@ _deps = [
|
||||
"transformers>=4.41.2",
|
||||
"urllib3<=2.0.0",
|
||||
"black",
|
||||
"phonemizer",
|
||||
]
|
||||
|
||||
# this is a lookup table with items like:
|
||||
@@ -227,6 +228,7 @@ extras["test"] = deps_list(
|
||||
"scipy",
|
||||
"torchvision",
|
||||
"transformers",
|
||||
"phonemizer",
|
||||
)
|
||||
extras["torch"] = deps_list("torch", "accelerate")
|
||||
|
||||
|
||||
@@ -43,4 +43,5 @@ deps = {
|
||||
"transformers": "transformers>=4.41.2",
|
||||
"urllib3": "urllib3<=2.0.0",
|
||||
"black": "black",
|
||||
"phonemizer": "phonemizer",
|
||||
}
|
||||
|
||||
@@ -28,13 +28,20 @@ from ..models.modeling_utils import ModelMixin, load_state_dict
|
||||
from ..utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
_get_model_file,
|
||||
convert_state_dict_to_diffusers,
|
||||
convert_state_dict_to_peft,
|
||||
delete_adapter_layers,
|
||||
deprecate,
|
||||
get_adapter_name,
|
||||
get_peft_kwargs,
|
||||
is_accelerate_available,
|
||||
is_peft_available,
|
||||
is_peft_version,
|
||||
is_transformers_available,
|
||||
is_transformers_version,
|
||||
logging,
|
||||
recurse_remove_peft_layers,
|
||||
scale_lora_layers,
|
||||
set_adapter_layers,
|
||||
set_weights_and_activate_adapters,
|
||||
)
|
||||
@@ -43,6 +50,8 @@ from ..utils import (
|
||||
if is_transformers_available():
|
||||
from transformers import PreTrainedModel
|
||||
|
||||
from ..models.lora import text_encoder_attn_modules, text_encoder_mlp_modules
|
||||
|
||||
if is_peft_available():
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
@@ -297,6 +306,152 @@ def _best_guess_weight_name(
|
||||
return weight_name
|
||||
|
||||
|
||||
def _load_lora_into_text_encoder(
|
||||
state_dict,
|
||||
network_alphas,
|
||||
text_encoder,
|
||||
prefix=None,
|
||||
lora_scale=1.0,
|
||||
text_encoder_name="text_encoder",
|
||||
adapter_name=None,
|
||||
_pipeline=None,
|
||||
low_cpu_mem_usage=False,
|
||||
):
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
peft_kwargs = {}
|
||||
if low_cpu_mem_usage:
|
||||
if not is_peft_version(">=", "0.13.1"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
if not is_transformers_version(">", "4.45.2"):
|
||||
# Note from sayakpaul: It's not in `transformers` stable yet.
|
||||
# https://github.com/huggingface/transformers/pull/33725/
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
|
||||
)
|
||||
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
||||
|
||||
from peft import LoraConfig
|
||||
|
||||
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
||||
# then the `state_dict` keys should have `unet_name` and/or `text_encoder_name` as
|
||||
# their prefixes.
|
||||
keys = list(state_dict.keys())
|
||||
prefix = text_encoder_name if prefix is None else prefix
|
||||
|
||||
# Safe prefix to check with.
|
||||
if any(text_encoder_name in key for key in keys):
|
||||
# Load the layers corresponding to text encoder and make necessary adjustments.
|
||||
text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix]
|
||||
text_encoder_lora_state_dict = {
|
||||
k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
|
||||
}
|
||||
|
||||
if len(text_encoder_lora_state_dict) > 0:
|
||||
logger.info(f"Loading {prefix}.")
|
||||
rank = {}
|
||||
text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
|
||||
|
||||
# convert state dict
|
||||
text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
|
||||
|
||||
for name, _ in text_encoder_attn_modules(text_encoder):
|
||||
for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
|
||||
rank_key = f"{name}.{module}.lora_B.weight"
|
||||
if rank_key not in text_encoder_lora_state_dict:
|
||||
continue
|
||||
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
|
||||
|
||||
for name, _ in text_encoder_mlp_modules(text_encoder):
|
||||
for module in ("fc1", "fc2"):
|
||||
rank_key = f"{name}.{module}.lora_B.weight"
|
||||
if rank_key not in text_encoder_lora_state_dict:
|
||||
continue
|
||||
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
|
||||
|
||||
if network_alphas is not None:
|
||||
alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
|
||||
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
|
||||
|
||||
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
|
||||
|
||||
if "use_dora" in lora_config_kwargs:
|
||||
if lora_config_kwargs["use_dora"]:
|
||||
if is_peft_version("<", "0.9.0"):
|
||||
raise ValueError(
|
||||
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
|
||||
)
|
||||
else:
|
||||
if is_peft_version("<", "0.9.0"):
|
||||
lora_config_kwargs.pop("use_dora")
|
||||
|
||||
if "lora_bias" in lora_config_kwargs:
|
||||
if lora_config_kwargs["lora_bias"]:
|
||||
if is_peft_version("<=", "0.13.2"):
|
||||
raise ValueError(
|
||||
"You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
|
||||
)
|
||||
else:
|
||||
if is_peft_version("<=", "0.13.2"):
|
||||
lora_config_kwargs.pop("lora_bias")
|
||||
|
||||
lora_config = LoraConfig(**lora_config_kwargs)
|
||||
|
||||
# adapter_name
|
||||
if adapter_name is None:
|
||||
adapter_name = get_adapter_name(text_encoder)
|
||||
|
||||
is_model_cpu_offload, is_sequential_cpu_offload = _func_optionally_disable_offloading(_pipeline)
|
||||
|
||||
# inject LoRA layers and load the state dict
|
||||
# in transformers we automatically check whether the adapter name is already in use or not
|
||||
text_encoder.load_adapter(
|
||||
adapter_name=adapter_name,
|
||||
adapter_state_dict=text_encoder_lora_state_dict,
|
||||
peft_config=lora_config,
|
||||
**peft_kwargs,
|
||||
)
|
||||
|
||||
# scale LoRA layers with `lora_scale`
|
||||
scale_lora_layers(text_encoder, weight=lora_scale)
|
||||
|
||||
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
|
||||
|
||||
# Offload back.
|
||||
if is_model_cpu_offload:
|
||||
_pipeline.enable_model_cpu_offload()
|
||||
elif is_sequential_cpu_offload:
|
||||
_pipeline.enable_sequential_cpu_offload()
|
||||
# Unsafe code />
|
||||
|
||||
|
||||
def _func_optionally_disable_offloading(_pipeline):
|
||||
is_model_cpu_offload = False
|
||||
is_sequential_cpu_offload = False
|
||||
|
||||
if _pipeline is not None and _pipeline.hf_device_map is None:
|
||||
for _, component in _pipeline.components.items():
|
||||
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
|
||||
if not is_model_cpu_offload:
|
||||
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
|
||||
if not is_sequential_cpu_offload:
|
||||
is_sequential_cpu_offload = (
|
||||
isinstance(component._hf_hook, AlignDevicesHook)
|
||||
or hasattr(component._hf_hook, "hooks")
|
||||
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
|
||||
)
|
||||
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
|
||||
|
||||
return (is_model_cpu_offload, is_sequential_cpu_offload)
|
||||
|
||||
|
||||
class LoraBaseMixin:
|
||||
"""Utility class for handling LoRAs."""
|
||||
|
||||
@@ -327,27 +482,7 @@ class LoraBaseMixin:
|
||||
tuple:
|
||||
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
|
||||
"""
|
||||
is_model_cpu_offload = False
|
||||
is_sequential_cpu_offload = False
|
||||
|
||||
if _pipeline is not None and _pipeline.hf_device_map is None:
|
||||
for _, component in _pipeline.components.items():
|
||||
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
|
||||
if not is_model_cpu_offload:
|
||||
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
|
||||
if not is_sequential_cpu_offload:
|
||||
is_sequential_cpu_offload = (
|
||||
isinstance(component._hf_hook, AlignDevicesHook)
|
||||
or hasattr(component._hf_hook, "hooks")
|
||||
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
|
||||
)
|
||||
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
|
||||
|
||||
return (is_model_cpu_offload, is_sequential_cpu_offload)
|
||||
return _func_optionally_disable_offloading(_pipeline=_pipeline)
|
||||
|
||||
@classmethod
|
||||
def _fetch_state_dict(cls, *args, **kwargs):
|
||||
|
||||
@@ -973,3 +973,178 @@ def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict):
|
||||
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def _convert_hunyuan_video_lora_to_diffusers(original_state_dict):
|
||||
converted_state_dict = {k: original_state_dict.pop(k) for k in list(original_state_dict.keys())}
|
||||
|
||||
def remap_norm_scale_shift_(key, state_dict):
|
||||
weight = state_dict.pop(key)
|
||||
shift, scale = weight.chunk(2, dim=0)
|
||||
new_weight = torch.cat([scale, shift], dim=0)
|
||||
state_dict[key.replace("final_layer.adaLN_modulation.1", "norm_out.linear")] = new_weight
|
||||
|
||||
def remap_txt_in_(key, state_dict):
|
||||
def rename_key(key):
|
||||
new_key = key.replace("individual_token_refiner.blocks", "token_refiner.refiner_blocks")
|
||||
new_key = new_key.replace("adaLN_modulation.1", "norm_out.linear")
|
||||
new_key = new_key.replace("txt_in", "context_embedder")
|
||||
new_key = new_key.replace("t_embedder.mlp.0", "time_text_embed.timestep_embedder.linear_1")
|
||||
new_key = new_key.replace("t_embedder.mlp.2", "time_text_embed.timestep_embedder.linear_2")
|
||||
new_key = new_key.replace("c_embedder", "time_text_embed.text_embedder")
|
||||
new_key = new_key.replace("mlp", "ff")
|
||||
return new_key
|
||||
|
||||
if "self_attn_qkv" in key:
|
||||
weight = state_dict.pop(key)
|
||||
to_q, to_k, to_v = weight.chunk(3, dim=0)
|
||||
state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_q"))] = to_q
|
||||
state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_k"))] = to_k
|
||||
state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_v"))] = to_v
|
||||
else:
|
||||
state_dict[rename_key(key)] = state_dict.pop(key)
|
||||
|
||||
def remap_img_attn_qkv_(key, state_dict):
|
||||
weight = state_dict.pop(key)
|
||||
if "lora_A" in key:
|
||||
state_dict[key.replace("img_attn_qkv", "attn.to_q")] = weight
|
||||
state_dict[key.replace("img_attn_qkv", "attn.to_k")] = weight
|
||||
state_dict[key.replace("img_attn_qkv", "attn.to_v")] = weight
|
||||
else:
|
||||
to_q, to_k, to_v = weight.chunk(3, dim=0)
|
||||
state_dict[key.replace("img_attn_qkv", "attn.to_q")] = to_q
|
||||
state_dict[key.replace("img_attn_qkv", "attn.to_k")] = to_k
|
||||
state_dict[key.replace("img_attn_qkv", "attn.to_v")] = to_v
|
||||
|
||||
def remap_txt_attn_qkv_(key, state_dict):
|
||||
weight = state_dict.pop(key)
|
||||
if "lora_A" in key:
|
||||
state_dict[key.replace("txt_attn_qkv", "attn.add_q_proj")] = weight
|
||||
state_dict[key.replace("txt_attn_qkv", "attn.add_k_proj")] = weight
|
||||
state_dict[key.replace("txt_attn_qkv", "attn.add_v_proj")] = weight
|
||||
else:
|
||||
to_q, to_k, to_v = weight.chunk(3, dim=0)
|
||||
state_dict[key.replace("txt_attn_qkv", "attn.add_q_proj")] = to_q
|
||||
state_dict[key.replace("txt_attn_qkv", "attn.add_k_proj")] = to_k
|
||||
state_dict[key.replace("txt_attn_qkv", "attn.add_v_proj")] = to_v
|
||||
|
||||
def remap_single_transformer_blocks_(key, state_dict):
|
||||
hidden_size = 3072
|
||||
|
||||
if "linear1.lora_A.weight" in key or "linear1.lora_B.weight" in key:
|
||||
linear1_weight = state_dict.pop(key)
|
||||
if "lora_A" in key:
|
||||
new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(
|
||||
".linear1.lora_A.weight"
|
||||
)
|
||||
state_dict[f"{new_key}.attn.to_q.lora_A.weight"] = linear1_weight
|
||||
state_dict[f"{new_key}.attn.to_k.lora_A.weight"] = linear1_weight
|
||||
state_dict[f"{new_key}.attn.to_v.lora_A.weight"] = linear1_weight
|
||||
state_dict[f"{new_key}.proj_mlp.lora_A.weight"] = linear1_weight
|
||||
else:
|
||||
split_size = (hidden_size, hidden_size, hidden_size, linear1_weight.size(0) - 3 * hidden_size)
|
||||
q, k, v, mlp = torch.split(linear1_weight, split_size, dim=0)
|
||||
new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(
|
||||
".linear1.lora_B.weight"
|
||||
)
|
||||
state_dict[f"{new_key}.attn.to_q.lora_B.weight"] = q
|
||||
state_dict[f"{new_key}.attn.to_k.lora_B.weight"] = k
|
||||
state_dict[f"{new_key}.attn.to_v.lora_B.weight"] = v
|
||||
state_dict[f"{new_key}.proj_mlp.lora_B.weight"] = mlp
|
||||
|
||||
elif "linear1.lora_A.bias" in key or "linear1.lora_B.bias" in key:
|
||||
linear1_bias = state_dict.pop(key)
|
||||
if "lora_A" in key:
|
||||
new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(
|
||||
".linear1.lora_A.bias"
|
||||
)
|
||||
state_dict[f"{new_key}.attn.to_q.lora_A.bias"] = linear1_bias
|
||||
state_dict[f"{new_key}.attn.to_k.lora_A.bias"] = linear1_bias
|
||||
state_dict[f"{new_key}.attn.to_v.lora_A.bias"] = linear1_bias
|
||||
state_dict[f"{new_key}.proj_mlp.lora_A.bias"] = linear1_bias
|
||||
else:
|
||||
split_size = (hidden_size, hidden_size, hidden_size, linear1_bias.size(0) - 3 * hidden_size)
|
||||
q_bias, k_bias, v_bias, mlp_bias = torch.split(linear1_bias, split_size, dim=0)
|
||||
new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(
|
||||
".linear1.lora_B.bias"
|
||||
)
|
||||
state_dict[f"{new_key}.attn.to_q.lora_B.bias"] = q_bias
|
||||
state_dict[f"{new_key}.attn.to_k.lora_B.bias"] = k_bias
|
||||
state_dict[f"{new_key}.attn.to_v.lora_B.bias"] = v_bias
|
||||
state_dict[f"{new_key}.proj_mlp.lora_B.bias"] = mlp_bias
|
||||
|
||||
else:
|
||||
new_key = key.replace("single_blocks", "single_transformer_blocks")
|
||||
new_key = new_key.replace("linear2", "proj_out")
|
||||
new_key = new_key.replace("q_norm", "attn.norm_q")
|
||||
new_key = new_key.replace("k_norm", "attn.norm_k")
|
||||
state_dict[new_key] = state_dict.pop(key)
|
||||
|
||||
TRANSFORMER_KEYS_RENAME_DICT = {
|
||||
"img_in": "x_embedder",
|
||||
"time_in.mlp.0": "time_text_embed.timestep_embedder.linear_1",
|
||||
"time_in.mlp.2": "time_text_embed.timestep_embedder.linear_2",
|
||||
"guidance_in.mlp.0": "time_text_embed.guidance_embedder.linear_1",
|
||||
"guidance_in.mlp.2": "time_text_embed.guidance_embedder.linear_2",
|
||||
"vector_in.in_layer": "time_text_embed.text_embedder.linear_1",
|
||||
"vector_in.out_layer": "time_text_embed.text_embedder.linear_2",
|
||||
"double_blocks": "transformer_blocks",
|
||||
"img_attn_q_norm": "attn.norm_q",
|
||||
"img_attn_k_norm": "attn.norm_k",
|
||||
"img_attn_proj": "attn.to_out.0",
|
||||
"txt_attn_q_norm": "attn.norm_added_q",
|
||||
"txt_attn_k_norm": "attn.norm_added_k",
|
||||
"txt_attn_proj": "attn.to_add_out",
|
||||
"img_mod.linear": "norm1.linear",
|
||||
"img_norm1": "norm1.norm",
|
||||
"img_norm2": "norm2",
|
||||
"img_mlp": "ff",
|
||||
"txt_mod.linear": "norm1_context.linear",
|
||||
"txt_norm1": "norm1.norm",
|
||||
"txt_norm2": "norm2_context",
|
||||
"txt_mlp": "ff_context",
|
||||
"self_attn_proj": "attn.to_out.0",
|
||||
"modulation.linear": "norm.linear",
|
||||
"pre_norm": "norm.norm",
|
||||
"final_layer.norm_final": "norm_out.norm",
|
||||
"final_layer.linear": "proj_out",
|
||||
"fc1": "net.0.proj",
|
||||
"fc2": "net.2",
|
||||
"input_embedder": "proj_in",
|
||||
}
|
||||
|
||||
TRANSFORMER_SPECIAL_KEYS_REMAP = {
|
||||
"txt_in": remap_txt_in_,
|
||||
"img_attn_qkv": remap_img_attn_qkv_,
|
||||
"txt_attn_qkv": remap_txt_attn_qkv_,
|
||||
"single_blocks": remap_single_transformer_blocks_,
|
||||
"final_layer.adaLN_modulation.1": remap_norm_scale_shift_,
|
||||
}
|
||||
|
||||
# Some folks attempt to make their state dict compatible with diffusers by adding "transformer." prefix to all keys
|
||||
# and use their custom code. To make sure both "original" and "attempted diffusers" loras work as expected, we make
|
||||
# sure that both follow the same initial format by stripping off the "transformer." prefix.
|
||||
for key in list(converted_state_dict.keys()):
|
||||
if key.startswith("transformer."):
|
||||
converted_state_dict[key[len("transformer.") :]] = converted_state_dict.pop(key)
|
||||
if key.startswith("diffusion_model."):
|
||||
converted_state_dict[key[len("diffusion_model.") :]] = converted_state_dict.pop(key)
|
||||
|
||||
# Rename and remap the state dict keys
|
||||
for key in list(converted_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
converted_state_dict[new_key] = converted_state_dict.pop(key)
|
||||
|
||||
for key in list(converted_state_dict.keys()):
|
||||
for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
|
||||
if special_key not in key:
|
||||
continue
|
||||
handler_fn_inplace(key, converted_state_dict)
|
||||
|
||||
# Add back the "transformer." prefix
|
||||
for key in list(converted_state_dict.keys()):
|
||||
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
@@ -20,22 +20,24 @@ from huggingface_hub.utils import validate_hf_hub_args
|
||||
|
||||
from ..utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
convert_state_dict_to_diffusers,
|
||||
convert_state_dict_to_peft,
|
||||
deprecate,
|
||||
get_adapter_name,
|
||||
get_peft_kwargs,
|
||||
is_peft_available,
|
||||
is_peft_version,
|
||||
is_torch_version,
|
||||
is_transformers_available,
|
||||
is_transformers_version,
|
||||
logging,
|
||||
scale_lora_layers,
|
||||
)
|
||||
from .lora_base import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, LoraBaseMixin, _fetch_state_dict # noqa
|
||||
from .lora_base import ( # noqa
|
||||
LORA_WEIGHT_NAME,
|
||||
LORA_WEIGHT_NAME_SAFE,
|
||||
LoraBaseMixin,
|
||||
_fetch_state_dict,
|
||||
_load_lora_into_text_encoder,
|
||||
)
|
||||
from .lora_conversion_utils import (
|
||||
_convert_bfl_flux_control_lora_to_diffusers,
|
||||
_convert_hunyuan_video_lora_to_diffusers,
|
||||
_convert_kohya_flux_lora_to_diffusers,
|
||||
_convert_non_diffusers_lora_to_diffusers,
|
||||
_convert_xlabs_flux_lora_to_diffusers,
|
||||
@@ -54,9 +56,6 @@ if is_torch_version(">=", "1.9.0"):
|
||||
_LOW_CPU_MEM_USAGE_DEFAULT_LORA = True
|
||||
|
||||
|
||||
if is_transformers_available():
|
||||
from ..models.lora import text_encoder_attn_modules, text_encoder_mlp_modules
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
TEXT_ENCODER_NAME = "text_encoder"
|
||||
@@ -348,119 +347,17 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
peft_kwargs = {}
|
||||
if low_cpu_mem_usage:
|
||||
if not is_peft_version(">=", "0.13.1"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
if not is_transformers_version(">", "4.45.2"):
|
||||
# Note from sayakpaul: It's not in `transformers` stable yet.
|
||||
# https://github.com/huggingface/transformers/pull/33725/
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
|
||||
)
|
||||
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
||||
|
||||
from peft import LoraConfig
|
||||
|
||||
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
||||
# then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
|
||||
# their prefixes.
|
||||
keys = list(state_dict.keys())
|
||||
prefix = cls.text_encoder_name if prefix is None else prefix
|
||||
|
||||
# Safe prefix to check with.
|
||||
if any(cls.text_encoder_name in key for key in keys):
|
||||
# Load the layers corresponding to text encoder and make necessary adjustments.
|
||||
text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix]
|
||||
text_encoder_lora_state_dict = {
|
||||
k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
|
||||
}
|
||||
|
||||
if len(text_encoder_lora_state_dict) > 0:
|
||||
logger.info(f"Loading {prefix}.")
|
||||
rank = {}
|
||||
text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
|
||||
|
||||
# convert state dict
|
||||
text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
|
||||
|
||||
for name, _ in text_encoder_attn_modules(text_encoder):
|
||||
for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
|
||||
rank_key = f"{name}.{module}.lora_B.weight"
|
||||
if rank_key not in text_encoder_lora_state_dict:
|
||||
continue
|
||||
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
|
||||
|
||||
for name, _ in text_encoder_mlp_modules(text_encoder):
|
||||
for module in ("fc1", "fc2"):
|
||||
rank_key = f"{name}.{module}.lora_B.weight"
|
||||
if rank_key not in text_encoder_lora_state_dict:
|
||||
continue
|
||||
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
|
||||
|
||||
if network_alphas is not None:
|
||||
alpha_keys = [
|
||||
k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix
|
||||
]
|
||||
network_alphas = {
|
||||
k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
|
||||
}
|
||||
|
||||
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
|
||||
|
||||
if "use_dora" in lora_config_kwargs:
|
||||
if lora_config_kwargs["use_dora"]:
|
||||
if is_peft_version("<", "0.9.0"):
|
||||
raise ValueError(
|
||||
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
|
||||
)
|
||||
else:
|
||||
if is_peft_version("<", "0.9.0"):
|
||||
lora_config_kwargs.pop("use_dora")
|
||||
|
||||
if "lora_bias" in lora_config_kwargs:
|
||||
if lora_config_kwargs["lora_bias"]:
|
||||
if is_peft_version("<=", "0.13.2"):
|
||||
raise ValueError(
|
||||
"You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
|
||||
)
|
||||
else:
|
||||
if is_peft_version("<=", "0.13.2"):
|
||||
lora_config_kwargs.pop("lora_bias")
|
||||
|
||||
lora_config = LoraConfig(**lora_config_kwargs)
|
||||
|
||||
# adapter_name
|
||||
if adapter_name is None:
|
||||
adapter_name = get_adapter_name(text_encoder)
|
||||
|
||||
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
|
||||
|
||||
# inject LoRA layers and load the state dict
|
||||
# in transformers we automatically check whether the adapter name is already in use or not
|
||||
text_encoder.load_adapter(
|
||||
adapter_name=adapter_name,
|
||||
adapter_state_dict=text_encoder_lora_state_dict,
|
||||
peft_config=lora_config,
|
||||
**peft_kwargs,
|
||||
)
|
||||
|
||||
# scale LoRA layers with `lora_scale`
|
||||
scale_lora_layers(text_encoder, weight=lora_scale)
|
||||
|
||||
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
|
||||
|
||||
# Offload back.
|
||||
if is_model_cpu_offload:
|
||||
_pipeline.enable_model_cpu_offload()
|
||||
elif is_sequential_cpu_offload:
|
||||
_pipeline.enable_sequential_cpu_offload()
|
||||
# Unsafe code />
|
||||
_load_lora_into_text_encoder(
|
||||
state_dict=state_dict,
|
||||
network_alphas=network_alphas,
|
||||
lora_scale=lora_scale,
|
||||
text_encoder=text_encoder,
|
||||
prefix=prefix,
|
||||
text_encoder_name=cls.text_encoder_name,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=_pipeline,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def save_lora_weights(
|
||||
@@ -891,119 +788,17 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
peft_kwargs = {}
|
||||
if low_cpu_mem_usage:
|
||||
if not is_peft_version(">=", "0.13.1"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
if not is_transformers_version(">", "4.45.2"):
|
||||
# Note from sayakpaul: It's not in `transformers` stable yet.
|
||||
# https://github.com/huggingface/transformers/pull/33725/
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
|
||||
)
|
||||
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
||||
|
||||
from peft import LoraConfig
|
||||
|
||||
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
||||
# then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
|
||||
# their prefixes.
|
||||
keys = list(state_dict.keys())
|
||||
prefix = cls.text_encoder_name if prefix is None else prefix
|
||||
|
||||
# Safe prefix to check with.
|
||||
if any(cls.text_encoder_name in key for key in keys):
|
||||
# Load the layers corresponding to text encoder and make necessary adjustments.
|
||||
text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix]
|
||||
text_encoder_lora_state_dict = {
|
||||
k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
|
||||
}
|
||||
|
||||
if len(text_encoder_lora_state_dict) > 0:
|
||||
logger.info(f"Loading {prefix}.")
|
||||
rank = {}
|
||||
text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
|
||||
|
||||
# convert state dict
|
||||
text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
|
||||
|
||||
for name, _ in text_encoder_attn_modules(text_encoder):
|
||||
for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
|
||||
rank_key = f"{name}.{module}.lora_B.weight"
|
||||
if rank_key not in text_encoder_lora_state_dict:
|
||||
continue
|
||||
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
|
||||
|
||||
for name, _ in text_encoder_mlp_modules(text_encoder):
|
||||
for module in ("fc1", "fc2"):
|
||||
rank_key = f"{name}.{module}.lora_B.weight"
|
||||
if rank_key not in text_encoder_lora_state_dict:
|
||||
continue
|
||||
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
|
||||
|
||||
if network_alphas is not None:
|
||||
alpha_keys = [
|
||||
k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix
|
||||
]
|
||||
network_alphas = {
|
||||
k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
|
||||
}
|
||||
|
||||
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
|
||||
|
||||
if "use_dora" in lora_config_kwargs:
|
||||
if lora_config_kwargs["use_dora"]:
|
||||
if is_peft_version("<", "0.9.0"):
|
||||
raise ValueError(
|
||||
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
|
||||
)
|
||||
else:
|
||||
if is_peft_version("<", "0.9.0"):
|
||||
lora_config_kwargs.pop("use_dora")
|
||||
|
||||
if "lora_bias" in lora_config_kwargs:
|
||||
if lora_config_kwargs["lora_bias"]:
|
||||
if is_peft_version("<=", "0.13.2"):
|
||||
raise ValueError(
|
||||
"You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
|
||||
)
|
||||
else:
|
||||
if is_peft_version("<=", "0.13.2"):
|
||||
lora_config_kwargs.pop("lora_bias")
|
||||
|
||||
lora_config = LoraConfig(**lora_config_kwargs)
|
||||
|
||||
# adapter_name
|
||||
if adapter_name is None:
|
||||
adapter_name = get_adapter_name(text_encoder)
|
||||
|
||||
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
|
||||
|
||||
# inject LoRA layers and load the state dict
|
||||
# in transformers we automatically check whether the adapter name is already in use or not
|
||||
text_encoder.load_adapter(
|
||||
adapter_name=adapter_name,
|
||||
adapter_state_dict=text_encoder_lora_state_dict,
|
||||
peft_config=lora_config,
|
||||
**peft_kwargs,
|
||||
)
|
||||
|
||||
# scale LoRA layers with `lora_scale`
|
||||
scale_lora_layers(text_encoder, weight=lora_scale)
|
||||
|
||||
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
|
||||
|
||||
# Offload back.
|
||||
if is_model_cpu_offload:
|
||||
_pipeline.enable_model_cpu_offload()
|
||||
elif is_sequential_cpu_offload:
|
||||
_pipeline.enable_sequential_cpu_offload()
|
||||
# Unsafe code />
|
||||
_load_lora_into_text_encoder(
|
||||
state_dict=state_dict,
|
||||
network_alphas=network_alphas,
|
||||
lora_scale=lora_scale,
|
||||
text_encoder=text_encoder,
|
||||
prefix=prefix,
|
||||
text_encoder_name=cls.text_encoder_name,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=_pipeline,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def save_lora_weights(
|
||||
@@ -1400,119 +1195,17 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
peft_kwargs = {}
|
||||
if low_cpu_mem_usage:
|
||||
if not is_peft_version(">=", "0.13.1"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
if not is_transformers_version(">", "4.45.2"):
|
||||
# Note from sayakpaul: It's not in `transformers` stable yet.
|
||||
# https://github.com/huggingface/transformers/pull/33725/
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
|
||||
)
|
||||
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
||||
|
||||
from peft import LoraConfig
|
||||
|
||||
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
||||
# then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
|
||||
# their prefixes.
|
||||
keys = list(state_dict.keys())
|
||||
prefix = cls.text_encoder_name if prefix is None else prefix
|
||||
|
||||
# Safe prefix to check with.
|
||||
if any(cls.text_encoder_name in key for key in keys):
|
||||
# Load the layers corresponding to text encoder and make necessary adjustments.
|
||||
text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix]
|
||||
text_encoder_lora_state_dict = {
|
||||
k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
|
||||
}
|
||||
|
||||
if len(text_encoder_lora_state_dict) > 0:
|
||||
logger.info(f"Loading {prefix}.")
|
||||
rank = {}
|
||||
text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
|
||||
|
||||
# convert state dict
|
||||
text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
|
||||
|
||||
for name, _ in text_encoder_attn_modules(text_encoder):
|
||||
for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
|
||||
rank_key = f"{name}.{module}.lora_B.weight"
|
||||
if rank_key not in text_encoder_lora_state_dict:
|
||||
continue
|
||||
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
|
||||
|
||||
for name, _ in text_encoder_mlp_modules(text_encoder):
|
||||
for module in ("fc1", "fc2"):
|
||||
rank_key = f"{name}.{module}.lora_B.weight"
|
||||
if rank_key not in text_encoder_lora_state_dict:
|
||||
continue
|
||||
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
|
||||
|
||||
if network_alphas is not None:
|
||||
alpha_keys = [
|
||||
k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix
|
||||
]
|
||||
network_alphas = {
|
||||
k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
|
||||
}
|
||||
|
||||
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
|
||||
|
||||
if "use_dora" in lora_config_kwargs:
|
||||
if lora_config_kwargs["use_dora"]:
|
||||
if is_peft_version("<", "0.9.0"):
|
||||
raise ValueError(
|
||||
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
|
||||
)
|
||||
else:
|
||||
if is_peft_version("<", "0.9.0"):
|
||||
lora_config_kwargs.pop("use_dora")
|
||||
|
||||
if "lora_bias" in lora_config_kwargs:
|
||||
if lora_config_kwargs["lora_bias"]:
|
||||
if is_peft_version("<=", "0.13.2"):
|
||||
raise ValueError(
|
||||
"You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
|
||||
)
|
||||
else:
|
||||
if is_peft_version("<=", "0.13.2"):
|
||||
lora_config_kwargs.pop("lora_bias")
|
||||
|
||||
lora_config = LoraConfig(**lora_config_kwargs)
|
||||
|
||||
# adapter_name
|
||||
if adapter_name is None:
|
||||
adapter_name = get_adapter_name(text_encoder)
|
||||
|
||||
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
|
||||
|
||||
# inject LoRA layers and load the state dict
|
||||
# in transformers we automatically check whether the adapter name is already in use or not
|
||||
text_encoder.load_adapter(
|
||||
adapter_name=adapter_name,
|
||||
adapter_state_dict=text_encoder_lora_state_dict,
|
||||
peft_config=lora_config,
|
||||
**peft_kwargs,
|
||||
)
|
||||
|
||||
# scale LoRA layers with `lora_scale`
|
||||
scale_lora_layers(text_encoder, weight=lora_scale)
|
||||
|
||||
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
|
||||
|
||||
# Offload back.
|
||||
if is_model_cpu_offload:
|
||||
_pipeline.enable_model_cpu_offload()
|
||||
elif is_sequential_cpu_offload:
|
||||
_pipeline.enable_sequential_cpu_offload()
|
||||
# Unsafe code />
|
||||
_load_lora_into_text_encoder(
|
||||
state_dict=state_dict,
|
||||
network_alphas=network_alphas,
|
||||
lora_scale=lora_scale,
|
||||
text_encoder=text_encoder,
|
||||
prefix=prefix,
|
||||
text_encoder_name=cls.text_encoder_name,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=_pipeline,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def save_lora_weights(
|
||||
@@ -2032,119 +1725,17 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
peft_kwargs = {}
|
||||
if low_cpu_mem_usage:
|
||||
if not is_peft_version(">=", "0.13.1"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
if not is_transformers_version(">", "4.45.2"):
|
||||
# Note from sayakpaul: It's not in `transformers` stable yet.
|
||||
# https://github.com/huggingface/transformers/pull/33725/
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
|
||||
)
|
||||
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
||||
|
||||
from peft import LoraConfig
|
||||
|
||||
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
||||
# then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
|
||||
# their prefixes.
|
||||
keys = list(state_dict.keys())
|
||||
prefix = cls.text_encoder_name if prefix is None else prefix
|
||||
|
||||
# Safe prefix to check with.
|
||||
if any(cls.text_encoder_name in key for key in keys):
|
||||
# Load the layers corresponding to text encoder and make necessary adjustments.
|
||||
text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix]
|
||||
text_encoder_lora_state_dict = {
|
||||
k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
|
||||
}
|
||||
|
||||
if len(text_encoder_lora_state_dict) > 0:
|
||||
logger.info(f"Loading {prefix}.")
|
||||
rank = {}
|
||||
text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
|
||||
|
||||
# convert state dict
|
||||
text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
|
||||
|
||||
for name, _ in text_encoder_attn_modules(text_encoder):
|
||||
for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
|
||||
rank_key = f"{name}.{module}.lora_B.weight"
|
||||
if rank_key not in text_encoder_lora_state_dict:
|
||||
continue
|
||||
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
|
||||
|
||||
for name, _ in text_encoder_mlp_modules(text_encoder):
|
||||
for module in ("fc1", "fc2"):
|
||||
rank_key = f"{name}.{module}.lora_B.weight"
|
||||
if rank_key not in text_encoder_lora_state_dict:
|
||||
continue
|
||||
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
|
||||
|
||||
if network_alphas is not None:
|
||||
alpha_keys = [
|
||||
k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix
|
||||
]
|
||||
network_alphas = {
|
||||
k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
|
||||
}
|
||||
|
||||
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
|
||||
|
||||
if "use_dora" in lora_config_kwargs:
|
||||
if lora_config_kwargs["use_dora"]:
|
||||
if is_peft_version("<", "0.9.0"):
|
||||
raise ValueError(
|
||||
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
|
||||
)
|
||||
else:
|
||||
if is_peft_version("<", "0.9.0"):
|
||||
lora_config_kwargs.pop("use_dora")
|
||||
|
||||
if "lora_bias" in lora_config_kwargs:
|
||||
if lora_config_kwargs["lora_bias"]:
|
||||
if is_peft_version("<=", "0.13.2"):
|
||||
raise ValueError(
|
||||
"You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
|
||||
)
|
||||
else:
|
||||
if is_peft_version("<=", "0.13.2"):
|
||||
lora_config_kwargs.pop("lora_bias")
|
||||
|
||||
lora_config = LoraConfig(**lora_config_kwargs)
|
||||
|
||||
# adapter_name
|
||||
if adapter_name is None:
|
||||
adapter_name = get_adapter_name(text_encoder)
|
||||
|
||||
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
|
||||
|
||||
# inject LoRA layers and load the state dict
|
||||
# in transformers we automatically check whether the adapter name is already in use or not
|
||||
text_encoder.load_adapter(
|
||||
adapter_name=adapter_name,
|
||||
adapter_state_dict=text_encoder_lora_state_dict,
|
||||
peft_config=lora_config,
|
||||
**peft_kwargs,
|
||||
)
|
||||
|
||||
# scale LoRA layers with `lora_scale`
|
||||
scale_lora_layers(text_encoder, weight=lora_scale)
|
||||
|
||||
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
|
||||
|
||||
# Offload back.
|
||||
if is_model_cpu_offload:
|
||||
_pipeline.enable_model_cpu_offload()
|
||||
elif is_sequential_cpu_offload:
|
||||
_pipeline.enable_sequential_cpu_offload()
|
||||
# Unsafe code />
|
||||
_load_lora_into_text_encoder(
|
||||
state_dict=state_dict,
|
||||
network_alphas=network_alphas,
|
||||
lora_scale=lora_scale,
|
||||
text_encoder=text_encoder,
|
||||
prefix=prefix,
|
||||
text_encoder_name=cls.text_encoder_name,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=_pipeline,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.save_lora_weights with unet->transformer
|
||||
@@ -2203,7 +1794,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
def fuse_lora(
|
||||
self,
|
||||
components: List[str] = ["transformer", "text_encoder"],
|
||||
components: List[str] = ["transformer"],
|
||||
lora_scale: float = 1.0,
|
||||
safe_fusing: bool = False,
|
||||
adapter_names: Optional[List[str]] = None,
|
||||
@@ -2597,119 +2188,17 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
peft_kwargs = {}
|
||||
if low_cpu_mem_usage:
|
||||
if not is_peft_version(">=", "0.13.1"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
if not is_transformers_version(">", "4.45.2"):
|
||||
# Note from sayakpaul: It's not in `transformers` stable yet.
|
||||
# https://github.com/huggingface/transformers/pull/33725/
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
|
||||
)
|
||||
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
||||
|
||||
from peft import LoraConfig
|
||||
|
||||
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
||||
# then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
|
||||
# their prefixes.
|
||||
keys = list(state_dict.keys())
|
||||
prefix = cls.text_encoder_name if prefix is None else prefix
|
||||
|
||||
# Safe prefix to check with.
|
||||
if any(cls.text_encoder_name in key for key in keys):
|
||||
# Load the layers corresponding to text encoder and make necessary adjustments.
|
||||
text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix]
|
||||
text_encoder_lora_state_dict = {
|
||||
k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
|
||||
}
|
||||
|
||||
if len(text_encoder_lora_state_dict) > 0:
|
||||
logger.info(f"Loading {prefix}.")
|
||||
rank = {}
|
||||
text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
|
||||
|
||||
# convert state dict
|
||||
text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
|
||||
|
||||
for name, _ in text_encoder_attn_modules(text_encoder):
|
||||
for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
|
||||
rank_key = f"{name}.{module}.lora_B.weight"
|
||||
if rank_key not in text_encoder_lora_state_dict:
|
||||
continue
|
||||
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
|
||||
|
||||
for name, _ in text_encoder_mlp_modules(text_encoder):
|
||||
for module in ("fc1", "fc2"):
|
||||
rank_key = f"{name}.{module}.lora_B.weight"
|
||||
if rank_key not in text_encoder_lora_state_dict:
|
||||
continue
|
||||
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
|
||||
|
||||
if network_alphas is not None:
|
||||
alpha_keys = [
|
||||
k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix
|
||||
]
|
||||
network_alphas = {
|
||||
k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
|
||||
}
|
||||
|
||||
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
|
||||
|
||||
if "use_dora" in lora_config_kwargs:
|
||||
if lora_config_kwargs["use_dora"]:
|
||||
if is_peft_version("<", "0.9.0"):
|
||||
raise ValueError(
|
||||
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
|
||||
)
|
||||
else:
|
||||
if is_peft_version("<", "0.9.0"):
|
||||
lora_config_kwargs.pop("use_dora")
|
||||
|
||||
if "lora_bias" in lora_config_kwargs:
|
||||
if lora_config_kwargs["lora_bias"]:
|
||||
if is_peft_version("<=", "0.13.2"):
|
||||
raise ValueError(
|
||||
"You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
|
||||
)
|
||||
else:
|
||||
if is_peft_version("<=", "0.13.2"):
|
||||
lora_config_kwargs.pop("lora_bias")
|
||||
|
||||
lora_config = LoraConfig(**lora_config_kwargs)
|
||||
|
||||
# adapter_name
|
||||
if adapter_name is None:
|
||||
adapter_name = get_adapter_name(text_encoder)
|
||||
|
||||
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
|
||||
|
||||
# inject LoRA layers and load the state dict
|
||||
# in transformers we automatically check whether the adapter name is already in use or not
|
||||
text_encoder.load_adapter(
|
||||
adapter_name=adapter_name,
|
||||
adapter_state_dict=text_encoder_lora_state_dict,
|
||||
peft_config=lora_config,
|
||||
**peft_kwargs,
|
||||
)
|
||||
|
||||
# scale LoRA layers with `lora_scale`
|
||||
scale_lora_layers(text_encoder, weight=lora_scale)
|
||||
|
||||
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
|
||||
|
||||
# Offload back.
|
||||
if is_model_cpu_offload:
|
||||
_pipeline.enable_model_cpu_offload()
|
||||
elif is_sequential_cpu_offload:
|
||||
_pipeline.enable_sequential_cpu_offload()
|
||||
# Unsafe code />
|
||||
_load_lora_into_text_encoder(
|
||||
state_dict=state_dict,
|
||||
network_alphas=network_alphas,
|
||||
lora_scale=lora_scale,
|
||||
text_encoder=text_encoder,
|
||||
prefix=prefix,
|
||||
text_encoder_name=cls.text_encoder_name,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=_pipeline,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def save_lora_weights(
|
||||
@@ -3007,10 +2496,9 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
|
||||
safe_serialization=safe_serialization,
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer
|
||||
def fuse_lora(
|
||||
self,
|
||||
components: List[str] = ["transformer", "text_encoder"],
|
||||
components: List[str] = ["transformer"],
|
||||
lora_scale: float = 1.0,
|
||||
safe_fusing: bool = False,
|
||||
adapter_names: Optional[List[str]] = None,
|
||||
@@ -3051,8 +2539,7 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
|
||||
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.unfuse_lora with unet->transformer
|
||||
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
|
||||
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
|
||||
r"""
|
||||
Reverses the effect of
|
||||
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
|
||||
@@ -3066,9 +2553,6 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
|
||||
Args:
|
||||
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
||||
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
||||
unfuse_text_encoder (`bool`, defaults to `True`):
|
||||
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
|
||||
LoRA parameters then it won't have any effect.
|
||||
"""
|
||||
super().unfuse_lora(components=components)
|
||||
|
||||
@@ -3315,10 +2799,9 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
|
||||
safe_serialization=safe_serialization,
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer
|
||||
def fuse_lora(
|
||||
self,
|
||||
components: List[str] = ["transformer", "text_encoder"],
|
||||
components: List[str] = ["transformer"],
|
||||
lora_scale: float = 1.0,
|
||||
safe_fusing: bool = False,
|
||||
adapter_names: Optional[List[str]] = None,
|
||||
@@ -3359,8 +2842,7 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
|
||||
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.unfuse_lora with unet->transformer
|
||||
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
|
||||
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
|
||||
r"""
|
||||
Reverses the effect of
|
||||
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
|
||||
@@ -3374,9 +2856,6 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
|
||||
Args:
|
||||
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
||||
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
||||
unfuse_text_encoder (`bool`, defaults to `True`):
|
||||
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
|
||||
LoRA parameters then it won't have any effect.
|
||||
"""
|
||||
super().unfuse_lora(components=components)
|
||||
|
||||
@@ -3623,10 +3102,9 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
|
||||
safe_serialization=safe_serialization,
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer
|
||||
def fuse_lora(
|
||||
self,
|
||||
components: List[str] = ["transformer", "text_encoder"],
|
||||
components: List[str] = ["transformer"],
|
||||
lora_scale: float = 1.0,
|
||||
safe_fusing: bool = False,
|
||||
adapter_names: Optional[List[str]] = None,
|
||||
@@ -3667,8 +3145,7 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
|
||||
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.unfuse_lora with unet->transformer
|
||||
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
|
||||
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
|
||||
r"""
|
||||
Reverses the effect of
|
||||
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
|
||||
@@ -3682,9 +3159,6 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
|
||||
Args:
|
||||
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
||||
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
||||
unfuse_text_encoder (`bool`, defaults to `True`):
|
||||
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
|
||||
LoRA parameters then it won't have any effect.
|
||||
"""
|
||||
super().unfuse_lora(components=components)
|
||||
|
||||
@@ -3931,10 +3405,9 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
|
||||
safe_serialization=safe_serialization,
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer
|
||||
def fuse_lora(
|
||||
self,
|
||||
components: List[str] = ["transformer", "text_encoder"],
|
||||
components: List[str] = ["transformer"],
|
||||
lora_scale: float = 1.0,
|
||||
safe_fusing: bool = False,
|
||||
adapter_names: Optional[List[str]] = None,
|
||||
@@ -3975,8 +3448,7 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
|
||||
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.unfuse_lora with unet->transformer
|
||||
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
|
||||
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
|
||||
r"""
|
||||
Reverses the effect of
|
||||
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
|
||||
@@ -3990,9 +3462,6 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
|
||||
Args:
|
||||
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
||||
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
||||
unfuse_text_encoder (`bool`, defaults to `True`):
|
||||
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
|
||||
LoRA parameters then it won't have any effect.
|
||||
"""
|
||||
super().unfuse_lora(components=components)
|
||||
|
||||
@@ -4007,7 +3476,6 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
|
||||
def lora_state_dict(
|
||||
cls,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
@@ -4018,7 +3486,7 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
We support loading A1111 formatted LoRA checkpoints in a limited capacity.
|
||||
We support loading original format HunyuanVideo LoRA checkpoints.
|
||||
|
||||
This function is experimental and might change in the future.
|
||||
|
||||
@@ -4101,6 +3569,10 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
|
||||
logger.warning(warn_msg)
|
||||
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
||||
|
||||
is_original_hunyuan_video = any("img_attn_qkv" in k for k in state_dict)
|
||||
if is_original_hunyuan_video:
|
||||
state_dict = _convert_hunyuan_video_lora_to_diffusers(state_dict)
|
||||
|
||||
return state_dict
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
|
||||
@@ -4239,10 +3711,9 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
|
||||
safe_serialization=safe_serialization,
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer
|
||||
def fuse_lora(
|
||||
self,
|
||||
components: List[str] = ["transformer", "text_encoder"],
|
||||
components: List[str] = ["transformer"],
|
||||
lora_scale: float = 1.0,
|
||||
safe_fusing: bool = False,
|
||||
adapter_names: Optional[List[str]] = None,
|
||||
@@ -4283,8 +3754,7 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
|
||||
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.unfuse_lora with unet->transformer
|
||||
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
|
||||
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
|
||||
r"""
|
||||
Reverses the effect of
|
||||
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
|
||||
@@ -4298,9 +3768,6 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
|
||||
Args:
|
||||
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
||||
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
||||
unfuse_text_encoder (`bool`, defaults to `True`):
|
||||
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
|
||||
LoRA parameters then it won't have any effect.
|
||||
"""
|
||||
super().unfuse_lora(components=components)
|
||||
|
||||
|
||||
@@ -20,7 +20,6 @@ from typing import Dict, List, Optional, Union
|
||||
|
||||
import safetensors
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..utils import (
|
||||
MIN_PEFT_VERSION,
|
||||
@@ -30,20 +29,16 @@ from ..utils import (
|
||||
delete_adapter_layers,
|
||||
get_adapter_name,
|
||||
get_peft_kwargs,
|
||||
is_accelerate_available,
|
||||
is_peft_available,
|
||||
is_peft_version,
|
||||
logging,
|
||||
set_adapter_layers,
|
||||
set_weights_and_activate_adapters,
|
||||
)
|
||||
from .lora_base import _fetch_state_dict
|
||||
from .lora_base import _fetch_state_dict, _func_optionally_disable_offloading
|
||||
from .unet_loader_utils import _maybe_expand_lora_scales
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_SET_ADAPTER_SCALE_FN_MAPPING = {
|
||||
@@ -140,27 +135,7 @@ class PeftAdapterMixin:
|
||||
tuple:
|
||||
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
|
||||
"""
|
||||
is_model_cpu_offload = False
|
||||
is_sequential_cpu_offload = False
|
||||
|
||||
if _pipeline is not None and _pipeline.hf_device_map is None:
|
||||
for _, component in _pipeline.components.items():
|
||||
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
|
||||
if not is_model_cpu_offload:
|
||||
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
|
||||
if not is_sequential_cpu_offload:
|
||||
is_sequential_cpu_offload = (
|
||||
isinstance(component._hf_hook, AlignDevicesHook)
|
||||
or hasattr(component._hf_hook, "hooks")
|
||||
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
|
||||
)
|
||||
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
|
||||
|
||||
return (is_model_cpu_offload, is_sequential_cpu_offload)
|
||||
return _func_optionally_disable_offloading(_pipeline=_pipeline)
|
||||
|
||||
def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="transformer", **kwargs):
|
||||
r"""
|
||||
@@ -325,15 +300,17 @@ class PeftAdapterMixin:
|
||||
try:
|
||||
inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs)
|
||||
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs)
|
||||
except RuntimeError as e:
|
||||
for module in self.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
active_adapters = module.active_adapters
|
||||
for active_adapter in active_adapters:
|
||||
if adapter_name in active_adapter:
|
||||
module.delete_adapter(adapter_name)
|
||||
except Exception as e:
|
||||
# In case `inject_adapter_in_model()` was unsuccessful even before injecting the `peft_config`.
|
||||
if hasattr(self, "peft_config"):
|
||||
for module in self.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
active_adapters = module.active_adapters
|
||||
for active_adapter in active_adapters:
|
||||
if adapter_name in active_adapter:
|
||||
module.delete_adapter(adapter_name)
|
||||
|
||||
self.peft_config.pop(adapter_name)
|
||||
self.peft_config.pop(adapter_name)
|
||||
logger.error(f"Loading {adapter_name} was unsucessful with the following error: \n{e}")
|
||||
raise
|
||||
|
||||
|
||||
@@ -60,6 +60,7 @@ def load_single_file_sub_model(
|
||||
local_files_only=False,
|
||||
torch_dtype=None,
|
||||
is_legacy_loading=False,
|
||||
disable_mmap=False,
|
||||
**kwargs,
|
||||
):
|
||||
if is_pipeline_module:
|
||||
@@ -106,6 +107,7 @@ def load_single_file_sub_model(
|
||||
subfolder=name,
|
||||
torch_dtype=torch_dtype,
|
||||
local_files_only=local_files_only,
|
||||
disable_mmap=disable_mmap,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -308,6 +310,9 @@ class FromSingleFileMixin:
|
||||
hosted on the Hub.
|
||||
- A path to a *directory* (for example `./my_pipeline_directory/`) containing the pipeline
|
||||
component configs in Diffusers format.
|
||||
disable_mmap ('bool', *optional*, defaults to 'False'):
|
||||
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
|
||||
is on a network mount or hard drive.
|
||||
kwargs (remaining dictionary of keyword arguments, *optional*):
|
||||
Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline
|
||||
class). The overwritten components are passed directly to the pipelines `__init__` method. See example
|
||||
@@ -355,6 +360,7 @@ class FromSingleFileMixin:
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
revision = kwargs.pop("revision", None)
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
disable_mmap = kwargs.pop("disable_mmap", False)
|
||||
|
||||
is_legacy_loading = False
|
||||
|
||||
@@ -383,6 +389,7 @@ class FromSingleFileMixin:
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=local_files_only,
|
||||
revision=revision,
|
||||
disable_mmap=disable_mmap,
|
||||
)
|
||||
|
||||
if config is None:
|
||||
@@ -504,6 +511,7 @@ class FromSingleFileMixin:
|
||||
original_config=original_config,
|
||||
local_files_only=local_files_only,
|
||||
is_legacy_loading=is_legacy_loading,
|
||||
disable_mmap=disable_mmap,
|
||||
**kwargs,
|
||||
)
|
||||
except SingleFileComponentError as e:
|
||||
|
||||
@@ -25,6 +25,7 @@ from ..utils import deprecate, is_accelerate_available, logging
|
||||
from .single_file_utils import (
|
||||
SingleFileComponentError,
|
||||
convert_animatediff_checkpoint_to_diffusers,
|
||||
convert_auraflow_transformer_checkpoint_to_diffusers,
|
||||
convert_autoencoder_dc_checkpoint_to_diffusers,
|
||||
convert_controlnet_checkpoint,
|
||||
convert_flux_transformer_checkpoint_to_diffusers,
|
||||
@@ -106,6 +107,10 @@ SINGLE_FILE_LOADABLE_CLASSES = {
|
||||
"checkpoint_mapping_fn": convert_hunyuan_video_transformer_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
"AuraFlowTransformer2DModel": {
|
||||
"checkpoint_mapping_fn": convert_auraflow_transformer_checkpoint_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -182,6 +187,9 @@ class FromOriginalModelMixin:
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
||||
allowed by Git.
|
||||
disable_mmap ('bool', *optional*, defaults to 'False'):
|
||||
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
|
||||
is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
|
||||
kwargs (remaining dictionary of keyword arguments, *optional*):
|
||||
Can be used to overwrite load and saveable variables (for example the pipeline components of the
|
||||
specific pipeline class). The overwritten components are directly passed to the pipelines `__init__`
|
||||
@@ -229,6 +237,7 @@ class FromOriginalModelMixin:
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
quantization_config = kwargs.pop("quantization_config", None)
|
||||
device = kwargs.pop("device", None)
|
||||
disable_mmap = kwargs.pop("disable_mmap", False)
|
||||
|
||||
if isinstance(pretrained_model_link_or_path_or_dict, dict):
|
||||
checkpoint = pretrained_model_link_or_path_or_dict
|
||||
@@ -241,6 +250,7 @@ class FromOriginalModelMixin:
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=local_files_only,
|
||||
revision=revision,
|
||||
disable_mmap=disable_mmap,
|
||||
)
|
||||
if quantization_config is not None:
|
||||
hf_quantizer = DiffusersAutoQuantizer.from_config(quantization_config)
|
||||
|
||||
@@ -94,6 +94,12 @@ CHECKPOINT_KEY_NAMES = {
|
||||
"animatediff_sdxl_beta": "up_blocks.2.motion_modules.0.temporal_transformer.norm.weight",
|
||||
"animatediff_scribble": "controlnet_cond_embedding.conv_in.weight",
|
||||
"animatediff_rgb": "controlnet_cond_embedding.weight",
|
||||
"auraflow": [
|
||||
"double_layers.0.attn.w2q.weight",
|
||||
"double_layers.0.attn.w1q.weight",
|
||||
"cond_seq_linear.weight",
|
||||
"t_embedder.mlp.0.weight",
|
||||
],
|
||||
"flux": [
|
||||
"double_blocks.0.img_attn.norm.key_norm.scale",
|
||||
"model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale",
|
||||
@@ -154,6 +160,7 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
||||
"animatediff_sdxl_beta": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-sdxl-beta"},
|
||||
"animatediff_scribble": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-scribble"},
|
||||
"animatediff_rgb": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-rgb"},
|
||||
"auraflow": {"pretrained_model_name_or_path": "fal/AuraFlow-v0.3"},
|
||||
"flux-dev": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-dev"},
|
||||
"flux-fill": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Fill-dev"},
|
||||
"flux-depth": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Depth-dev"},
|
||||
@@ -179,6 +186,7 @@ DIFFUSERS_TO_LDM_DEFAULT_IMAGE_SIZE_MAP = {
|
||||
"inpainting": 512,
|
||||
"inpainting_v2": 512,
|
||||
"controlnet": 512,
|
||||
"instruct-pix2pix": 512,
|
||||
"v2": 768,
|
||||
"v1": 512,
|
||||
}
|
||||
@@ -380,6 +388,7 @@ def load_single_file_checkpoint(
|
||||
cache_dir=None,
|
||||
local_files_only=None,
|
||||
revision=None,
|
||||
disable_mmap=False,
|
||||
):
|
||||
if os.path.isfile(pretrained_model_link_or_path):
|
||||
pretrained_model_link_or_path = pretrained_model_link_or_path
|
||||
@@ -397,7 +406,7 @@ def load_single_file_checkpoint(
|
||||
revision=revision,
|
||||
)
|
||||
|
||||
checkpoint = load_state_dict(pretrained_model_link_or_path)
|
||||
checkpoint = load_state_dict(pretrained_model_link_or_path, disable_mmap=disable_mmap)
|
||||
|
||||
# some checkpoints contain the model state dict under a "state_dict" key
|
||||
while "state_dict" in checkpoint:
|
||||
@@ -597,10 +606,14 @@ def infer_diffusers_model_type(checkpoint):
|
||||
if any(
|
||||
g in checkpoint for g in ["guidance_in.in_layer.bias", "model.diffusion_model.guidance_in.in_layer.bias"]
|
||||
):
|
||||
if checkpoint["img_in.weight"].shape[1] == 384:
|
||||
model_type = "flux-fill"
|
||||
if "model.diffusion_model.img_in.weight" in checkpoint:
|
||||
key = "model.diffusion_model.img_in.weight"
|
||||
else:
|
||||
key = "img_in.weight"
|
||||
|
||||
elif checkpoint["img_in.weight"].shape[1] == 128:
|
||||
if checkpoint[key].shape[1] == 384:
|
||||
model_type = "flux-fill"
|
||||
elif checkpoint[key].shape[1] == 128:
|
||||
model_type = "flux-depth"
|
||||
else:
|
||||
model_type = "flux-dev"
|
||||
@@ -635,6 +648,9 @@ def infer_diffusers_model_type(checkpoint):
|
||||
elif CHECKPOINT_KEY_NAMES["hunyuan-video"] in checkpoint:
|
||||
model_type = "hunyuan-video"
|
||||
|
||||
elif all(key in checkpoint for key in CHECKPOINT_KEY_NAMES["auraflow"]):
|
||||
model_type = "auraflow"
|
||||
|
||||
elif (
|
||||
CHECKPOINT_KEY_NAMES["instruct-pix2pix"] in checkpoint
|
||||
and checkpoint[CHECKPOINT_KEY_NAMES["instruct-pix2pix"]].shape[1] == 8
|
||||
@@ -2090,6 +2106,7 @@ def convert_animatediff_checkpoint_to_diffusers(checkpoint, **kwargs):
|
||||
def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
|
||||
converted_state_dict = {}
|
||||
keys = list(checkpoint.keys())
|
||||
|
||||
for k in keys:
|
||||
if "model.diffusion_model." in k:
|
||||
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
|
||||
@@ -2689,3 +2706,95 @@ def convert_hunyuan_video_transformer_to_diffusers(checkpoint, **kwargs):
|
||||
handler_fn_inplace(key, checkpoint)
|
||||
|
||||
return checkpoint
|
||||
|
||||
|
||||
def convert_auraflow_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
|
||||
converted_state_dict = {}
|
||||
state_dict_keys = list(checkpoint.keys())
|
||||
|
||||
# Handle register tokens and positional embeddings
|
||||
converted_state_dict["register_tokens"] = checkpoint.pop("register_tokens", None)
|
||||
|
||||
# Handle time step projection
|
||||
converted_state_dict["time_step_proj.linear_1.weight"] = checkpoint.pop("t_embedder.mlp.0.weight", None)
|
||||
converted_state_dict["time_step_proj.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias", None)
|
||||
converted_state_dict["time_step_proj.linear_2.weight"] = checkpoint.pop("t_embedder.mlp.2.weight", None)
|
||||
converted_state_dict["time_step_proj.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias", None)
|
||||
|
||||
# Handle context embedder
|
||||
converted_state_dict["context_embedder.weight"] = checkpoint.pop("cond_seq_linear.weight", None)
|
||||
|
||||
# Calculate the number of layers
|
||||
def calculate_layers(keys, key_prefix):
|
||||
layers = set()
|
||||
for k in keys:
|
||||
if key_prefix in k:
|
||||
layer_num = int(k.split(".")[1]) # get the layer number
|
||||
layers.add(layer_num)
|
||||
return len(layers)
|
||||
|
||||
mmdit_layers = calculate_layers(state_dict_keys, key_prefix="double_layers")
|
||||
single_dit_layers = calculate_layers(state_dict_keys, key_prefix="single_layers")
|
||||
|
||||
# MMDiT blocks
|
||||
for i in range(mmdit_layers):
|
||||
# Feed-forward
|
||||
path_mapping = {"mlpX": "ff", "mlpC": "ff_context"}
|
||||
weight_mapping = {"c_fc1": "linear_1", "c_fc2": "linear_2", "c_proj": "out_projection"}
|
||||
for orig_k, diffuser_k in path_mapping.items():
|
||||
for k, v in weight_mapping.items():
|
||||
converted_state_dict[f"joint_transformer_blocks.{i}.{diffuser_k}.{v}.weight"] = checkpoint.pop(
|
||||
f"double_layers.{i}.{orig_k}.{k}.weight", None
|
||||
)
|
||||
|
||||
# Norms
|
||||
path_mapping = {"modX": "norm1", "modC": "norm1_context"}
|
||||
for orig_k, diffuser_k in path_mapping.items():
|
||||
converted_state_dict[f"joint_transformer_blocks.{i}.{diffuser_k}.linear.weight"] = checkpoint.pop(
|
||||
f"double_layers.{i}.{orig_k}.1.weight", None
|
||||
)
|
||||
|
||||
# Attentions
|
||||
x_attn_mapping = {"w2q": "to_q", "w2k": "to_k", "w2v": "to_v", "w2o": "to_out.0"}
|
||||
context_attn_mapping = {"w1q": "add_q_proj", "w1k": "add_k_proj", "w1v": "add_v_proj", "w1o": "to_add_out"}
|
||||
for attn_mapping in [x_attn_mapping, context_attn_mapping]:
|
||||
for k, v in attn_mapping.items():
|
||||
converted_state_dict[f"joint_transformer_blocks.{i}.attn.{v}.weight"] = checkpoint.pop(
|
||||
f"double_layers.{i}.attn.{k}.weight", None
|
||||
)
|
||||
|
||||
# Single-DiT blocks
|
||||
for i in range(single_dit_layers):
|
||||
# Feed-forward
|
||||
mapping = {"c_fc1": "linear_1", "c_fc2": "linear_2", "c_proj": "out_projection"}
|
||||
for k, v in mapping.items():
|
||||
converted_state_dict[f"single_transformer_blocks.{i}.ff.{v}.weight"] = checkpoint.pop(
|
||||
f"single_layers.{i}.mlp.{k}.weight", None
|
||||
)
|
||||
|
||||
# Norms
|
||||
converted_state_dict[f"single_transformer_blocks.{i}.norm1.linear.weight"] = checkpoint.pop(
|
||||
f"single_layers.{i}.modCX.1.weight", None
|
||||
)
|
||||
|
||||
# Attentions
|
||||
x_attn_mapping = {"w1q": "to_q", "w1k": "to_k", "w1v": "to_v", "w1o": "to_out.0"}
|
||||
for k, v in x_attn_mapping.items():
|
||||
converted_state_dict[f"single_transformer_blocks.{i}.attn.{v}.weight"] = checkpoint.pop(
|
||||
f"single_layers.{i}.attn.{k}.weight", None
|
||||
)
|
||||
# Final blocks
|
||||
converted_state_dict["proj_out.weight"] = checkpoint.pop("final_linear.weight", None)
|
||||
|
||||
# Handle the final norm layer
|
||||
norm_weight = checkpoint.pop("modF.1.weight", None)
|
||||
if norm_weight is not None:
|
||||
converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(norm_weight, dim=None)
|
||||
else:
|
||||
converted_state_dict["norm_out.linear.weight"] = None
|
||||
|
||||
converted_state_dict["pos_embed.pos_embed"] = checkpoint.pop("positional_encoding")
|
||||
converted_state_dict["pos_embed.proj.weight"] = checkpoint.pop("init_x_linear.weight")
|
||||
converted_state_dict["pos_embed.proj.bias"] = checkpoint.pop("init_x_linear.bias")
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
@@ -21,7 +21,6 @@ import safetensors
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
from torch import nn
|
||||
|
||||
from ..models.embeddings import (
|
||||
ImageProjection,
|
||||
@@ -44,13 +43,11 @@ from ..utils import (
|
||||
is_torch_version,
|
||||
logging,
|
||||
)
|
||||
from .lora_base import _func_optionally_disable_offloading
|
||||
from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, TEXT_ENCODER_NAME, UNET_NAME
|
||||
from .utils import AttnProcsLayers
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@@ -411,27 +408,7 @@ class UNet2DConditionLoadersMixin:
|
||||
tuple:
|
||||
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
|
||||
"""
|
||||
is_model_cpu_offload = False
|
||||
is_sequential_cpu_offload = False
|
||||
|
||||
if _pipeline is not None and _pipeline.hf_device_map is None:
|
||||
for _, component in _pipeline.components.items():
|
||||
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
|
||||
if not is_model_cpu_offload:
|
||||
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
|
||||
if not is_sequential_cpu_offload:
|
||||
is_sequential_cpu_offload = (
|
||||
isinstance(component._hf_hook, AlignDevicesHook)
|
||||
or hasattr(component._hf_hook, "hooks")
|
||||
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
|
||||
)
|
||||
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
|
||||
|
||||
return (is_model_cpu_offload, is_sequential_cpu_offload)
|
||||
return _func_optionally_disable_offloading(_pipeline=_pipeline)
|
||||
|
||||
def save_attn_procs(
|
||||
self,
|
||||
|
||||
@@ -486,6 +486,9 @@ class AutoencoderDC(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
self.tile_sample_stride_height = 448
|
||||
self.tile_sample_stride_width = 448
|
||||
|
||||
self.tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
|
||||
self.tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
|
||||
|
||||
def enable_tiling(
|
||||
self,
|
||||
tile_sample_min_height: Optional[int] = None,
|
||||
@@ -515,6 +518,8 @@ class AutoencoderDC(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
|
||||
self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
|
||||
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
|
||||
self.tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
|
||||
self.tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
|
||||
|
||||
def disable_tiling(self) -> None:
|
||||
r"""
|
||||
@@ -606,11 +611,106 @@ class AutoencoderDC(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
return (decoded,)
|
||||
return DecoderOutput(sample=decoded)
|
||||
|
||||
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
||||
blend_extent = min(a.shape[2], b.shape[2], blend_extent)
|
||||
for y in range(blend_extent):
|
||||
b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
|
||||
return b
|
||||
|
||||
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
||||
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
|
||||
for x in range(blend_extent):
|
||||
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
|
||||
return b
|
||||
|
||||
def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> torch.Tensor:
|
||||
raise NotImplementedError("`tiled_encode` has not been implemented for AutoencoderDC.")
|
||||
batch_size, num_channels, height, width = x.shape
|
||||
latent_height = height // self.spatial_compression_ratio
|
||||
latent_width = width // self.spatial_compression_ratio
|
||||
|
||||
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
|
||||
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
|
||||
tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
|
||||
tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
|
||||
blend_height = tile_latent_min_height - tile_latent_stride_height
|
||||
blend_width = tile_latent_min_width - tile_latent_stride_width
|
||||
|
||||
# Split x into overlapping tiles and encode them separately.
|
||||
# The tiles have an overlap to avoid seams between tiles.
|
||||
rows = []
|
||||
for i in range(0, x.shape[2], self.tile_sample_stride_height):
|
||||
row = []
|
||||
for j in range(0, x.shape[3], self.tile_sample_stride_width):
|
||||
tile = x[:, :, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width]
|
||||
if (
|
||||
tile.shape[2] % self.spatial_compression_ratio != 0
|
||||
or tile.shape[3] % self.spatial_compression_ratio != 0
|
||||
):
|
||||
pad_h = (self.spatial_compression_ratio - tile.shape[2]) % self.spatial_compression_ratio
|
||||
pad_w = (self.spatial_compression_ratio - tile.shape[3]) % self.spatial_compression_ratio
|
||||
tile = F.pad(tile, (0, pad_w, 0, pad_h))
|
||||
tile = self.encoder(tile)
|
||||
row.append(tile)
|
||||
rows.append(row)
|
||||
result_rows = []
|
||||
for i, row in enumerate(rows):
|
||||
result_row = []
|
||||
for j, tile in enumerate(row):
|
||||
# blend the above tile and the left tile
|
||||
# to the current tile and add the current tile to the result row
|
||||
if i > 0:
|
||||
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
|
||||
if j > 0:
|
||||
tile = self.blend_h(row[j - 1], tile, blend_width)
|
||||
result_row.append(tile[:, :, :tile_latent_stride_height, :tile_latent_stride_width])
|
||||
result_rows.append(torch.cat(result_row, dim=3))
|
||||
|
||||
encoded = torch.cat(result_rows, dim=2)[:, :, :latent_height, :latent_width]
|
||||
|
||||
if not return_dict:
|
||||
return (encoded,)
|
||||
return EncoderOutput(latent=encoded)
|
||||
|
||||
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
raise NotImplementedError("`tiled_decode` has not been implemented for AutoencoderDC.")
|
||||
batch_size, num_channels, height, width = z.shape
|
||||
|
||||
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
|
||||
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
|
||||
tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
|
||||
tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
|
||||
|
||||
blend_height = self.tile_sample_min_height - self.tile_sample_stride_height
|
||||
blend_width = self.tile_sample_min_width - self.tile_sample_stride_width
|
||||
|
||||
# Split z into overlapping tiles and decode them separately.
|
||||
# The tiles have an overlap to avoid seams between tiles.
|
||||
rows = []
|
||||
for i in range(0, height, tile_latent_stride_height):
|
||||
row = []
|
||||
for j in range(0, width, tile_latent_stride_width):
|
||||
tile = z[:, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width]
|
||||
decoded = self.decoder(tile)
|
||||
row.append(decoded)
|
||||
rows.append(row)
|
||||
|
||||
result_rows = []
|
||||
for i, row in enumerate(rows):
|
||||
result_row = []
|
||||
for j, tile in enumerate(row):
|
||||
# blend the above tile and the left tile
|
||||
# to the current tile and add the current tile to the result row
|
||||
if i > 0:
|
||||
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
|
||||
if j > 0:
|
||||
tile = self.blend_h(row[j - 1], tile, blend_width)
|
||||
result_row.append(tile[:, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width])
|
||||
result_rows.append(torch.cat(result_row, dim=3))
|
||||
|
||||
decoded = torch.cat(result_rows, dim=2)
|
||||
|
||||
if not return_dict:
|
||||
return (decoded,)
|
||||
return DecoderOutput(sample=decoded)
|
||||
|
||||
def forward(self, sample: torch.Tensor, return_dict: bool = True) -> torch.Tensor:
|
||||
encoded = self.encode(sample, return_dict=False)[0]
|
||||
|
||||
@@ -1010,10 +1010,12 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
# The minimal tile height and width for spatial tiling to be used
|
||||
self.tile_sample_min_height = 512
|
||||
self.tile_sample_min_width = 512
|
||||
self.tile_sample_min_num_frames = 16
|
||||
|
||||
# The minimal distance between two spatial tiles
|
||||
self.tile_sample_stride_height = 448
|
||||
self.tile_sample_stride_width = 448
|
||||
self.tile_sample_stride_num_frames = 8
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if isinstance(module, (LTXVideoEncoder3d, LTXVideoDecoder3d)):
|
||||
@@ -1023,8 +1025,10 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
self,
|
||||
tile_sample_min_height: Optional[int] = None,
|
||||
tile_sample_min_width: Optional[int] = None,
|
||||
tile_sample_min_num_frames: Optional[int] = None,
|
||||
tile_sample_stride_height: Optional[float] = None,
|
||||
tile_sample_stride_width: Optional[float] = None,
|
||||
tile_sample_stride_num_frames: Optional[float] = None,
|
||||
) -> None:
|
||||
r"""
|
||||
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
||||
@@ -1046,8 +1050,10 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
self.use_tiling = True
|
||||
self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
|
||||
self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
|
||||
self.tile_sample_min_num_frames = tile_sample_min_num_frames or self.tile_sample_min_num_frames
|
||||
self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
|
||||
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
|
||||
self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames
|
||||
|
||||
def disable_tiling(self) -> None:
|
||||
r"""
|
||||
@@ -1073,18 +1079,13 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
def _encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, num_channels, num_frames, height, width = x.shape
|
||||
|
||||
if self.use_framewise_decoding and num_frames > self.tile_sample_min_num_frames:
|
||||
return self._temporal_tiled_encode(x)
|
||||
|
||||
if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
|
||||
return self.tiled_encode(x)
|
||||
|
||||
if self.use_framewise_encoding:
|
||||
# TODO(aryan): requires investigation
|
||||
raise NotImplementedError(
|
||||
"Frame-wise encoding has not been implemented for AutoencoderKLLTXVideo, at the moment, due to "
|
||||
"quality issues caused by splitting inference across frame dimension. If you believe this "
|
||||
"should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls."
|
||||
)
|
||||
else:
|
||||
enc = self.encoder(x)
|
||||
enc = self.encoder(x)
|
||||
|
||||
return enc
|
||||
|
||||
@@ -1121,19 +1122,15 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
batch_size, num_channels, num_frames, height, width = z.shape
|
||||
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
|
||||
tile_latent_min_width = self.tile_sample_stride_width // self.spatial_compression_ratio
|
||||
tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio
|
||||
|
||||
if self.use_framewise_decoding and num_frames > tile_latent_min_num_frames:
|
||||
return self._temporal_tiled_decode(z, temb, return_dict=return_dict)
|
||||
|
||||
if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height):
|
||||
return self.tiled_decode(z, temb, return_dict=return_dict)
|
||||
|
||||
if self.use_framewise_decoding:
|
||||
# TODO(aryan): requires investigation
|
||||
raise NotImplementedError(
|
||||
"Frame-wise decoding has not been implemented for AutoencoderKLLTXVideo, at the moment, due to "
|
||||
"quality issues caused by splitting inference across frame dimension. If you believe this "
|
||||
"should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls."
|
||||
)
|
||||
else:
|
||||
dec = self.decoder(z, temb)
|
||||
dec = self.decoder(z, temb)
|
||||
|
||||
if not return_dict:
|
||||
return (dec,)
|
||||
@@ -1189,6 +1186,14 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
)
|
||||
return b
|
||||
|
||||
def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
||||
blend_extent = min(a.shape[-3], b.shape[-3], blend_extent)
|
||||
for x in range(blend_extent):
|
||||
b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * (1 - x / blend_extent) + b[:, :, x, :, :] * (
|
||||
x / blend_extent
|
||||
)
|
||||
return b
|
||||
|
||||
def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
r"""Encode a batch of images using a tiled encoder.
|
||||
|
||||
@@ -1217,17 +1222,9 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
for i in range(0, height, self.tile_sample_stride_height):
|
||||
row = []
|
||||
for j in range(0, width, self.tile_sample_stride_width):
|
||||
if self.use_framewise_encoding:
|
||||
# TODO(aryan): requires investigation
|
||||
raise NotImplementedError(
|
||||
"Frame-wise encoding has not been implemented for AutoencoderKLLTXVideo, at the moment, due to "
|
||||
"quality issues caused by splitting inference across frame dimension. If you believe this "
|
||||
"should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls."
|
||||
)
|
||||
else:
|
||||
time = self.encoder(
|
||||
x[:, :, :, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width]
|
||||
)
|
||||
time = self.encoder(
|
||||
x[:, :, :, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width]
|
||||
)
|
||||
|
||||
row.append(time)
|
||||
rows.append(row)
|
||||
@@ -1283,17 +1280,7 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
for i in range(0, height, tile_latent_stride_height):
|
||||
row = []
|
||||
for j in range(0, width, tile_latent_stride_width):
|
||||
if self.use_framewise_decoding:
|
||||
# TODO(aryan): requires investigation
|
||||
raise NotImplementedError(
|
||||
"Frame-wise decoding has not been implemented for AutoencoderKLLTXVideo, at the moment, due to "
|
||||
"quality issues caused by splitting inference across frame dimension. If you believe this "
|
||||
"should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls."
|
||||
)
|
||||
else:
|
||||
time = self.decoder(
|
||||
z[:, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width], temb
|
||||
)
|
||||
time = self.decoder(z[:, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width], temb)
|
||||
|
||||
row.append(time)
|
||||
rows.append(row)
|
||||
@@ -1318,6 +1305,74 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
|
||||
return DecoderOutput(sample=dec)
|
||||
|
||||
def _temporal_tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput:
|
||||
batch_size, num_channels, num_frames, height, width = x.shape
|
||||
latent_num_frames = (num_frames - 1) // self.temporal_compression_ratio + 1
|
||||
|
||||
tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio
|
||||
tile_latent_stride_num_frames = self.tile_sample_stride_num_frames // self.temporal_compression_ratio
|
||||
blend_num_frames = tile_latent_min_num_frames - tile_latent_stride_num_frames
|
||||
|
||||
row = []
|
||||
for i in range(0, num_frames, self.tile_sample_stride_num_frames):
|
||||
tile = x[:, :, i : i + self.tile_sample_min_num_frames + 1, :, :]
|
||||
if self.use_tiling and (height > self.tile_sample_min_height or width > self.tile_sample_min_width):
|
||||
tile = self.tiled_encode(tile)
|
||||
else:
|
||||
tile = self.encoder(tile)
|
||||
if i > 0:
|
||||
tile = tile[:, :, 1:, :, :]
|
||||
row.append(tile)
|
||||
|
||||
result_row = []
|
||||
for i, tile in enumerate(row):
|
||||
if i > 0:
|
||||
tile = self.blend_t(row[i - 1], tile, blend_num_frames)
|
||||
result_row.append(tile[:, :, :tile_latent_stride_num_frames, :, :])
|
||||
else:
|
||||
result_row.append(tile[:, :, : tile_latent_stride_num_frames + 1, :, :])
|
||||
|
||||
enc = torch.cat(result_row, dim=2)[:, :, :latent_num_frames]
|
||||
return enc
|
||||
|
||||
def _temporal_tiled_decode(
|
||||
self, z: torch.Tensor, temb: Optional[torch.Tensor], return_dict: bool = True
|
||||
) -> Union[DecoderOutput, torch.Tensor]:
|
||||
batch_size, num_channels, num_frames, height, width = z.shape
|
||||
num_sample_frames = (num_frames - 1) * self.temporal_compression_ratio + 1
|
||||
|
||||
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
|
||||
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
|
||||
tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio
|
||||
tile_latent_stride_num_frames = self.tile_sample_stride_num_frames // self.temporal_compression_ratio
|
||||
blend_num_frames = self.tile_sample_min_num_frames - self.tile_sample_stride_num_frames
|
||||
|
||||
row = []
|
||||
for i in range(0, num_frames, tile_latent_stride_num_frames):
|
||||
tile = z[:, :, i : i + tile_latent_min_num_frames + 1, :, :]
|
||||
if self.use_tiling and (tile.shape[-1] > tile_latent_min_width or tile.shape[-2] > tile_latent_min_height):
|
||||
decoded = self.tiled_decode(tile, temb, return_dict=True).sample
|
||||
else:
|
||||
decoded = self.decoder(tile, temb)
|
||||
if i > 0:
|
||||
decoded = decoded[:, :, :-1, :, :]
|
||||
row.append(decoded)
|
||||
|
||||
result_row = []
|
||||
for i, tile in enumerate(row):
|
||||
if i > 0:
|
||||
tile = self.blend_t(row[i - 1], tile, blend_num_frames)
|
||||
tile = tile[:, :, : self.tile_sample_stride_num_frames, :, :]
|
||||
result_row.append(tile)
|
||||
else:
|
||||
result_row.append(tile[:, :, : self.tile_sample_stride_num_frames + 1, :, :])
|
||||
|
||||
dec = torch.cat(result_row, dim=2)[:, :, :num_sample_frames]
|
||||
|
||||
if not return_dict:
|
||||
return (dec,)
|
||||
return DecoderOutput(sample=dec)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
@@ -1334,5 +1389,5 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
z = posterior.mode()
|
||||
dec = self.decode(z, temb)
|
||||
if not return_dict:
|
||||
return (dec,)
|
||||
return (dec.sample,)
|
||||
return dec
|
||||
|
||||
@@ -131,7 +131,9 @@ def _fetch_remapped_cls_from_config(config, old_class):
|
||||
return old_class
|
||||
|
||||
|
||||
def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None):
|
||||
def load_state_dict(
|
||||
checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None, disable_mmap: bool = False
|
||||
):
|
||||
"""
|
||||
Reads a checkpoint file, returning properly formatted errors if they arise.
|
||||
"""
|
||||
@@ -142,7 +144,10 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[
|
||||
try:
|
||||
file_extension = os.path.basename(checkpoint_file).split(".")[-1]
|
||||
if file_extension == SAFETENSORS_FILE_EXTENSION:
|
||||
return safetensors.torch.load_file(checkpoint_file, device="cpu")
|
||||
if disable_mmap:
|
||||
return safetensors.torch.load(open(checkpoint_file, "rb").read())
|
||||
else:
|
||||
return safetensors.torch.load_file(checkpoint_file, device="cpu")
|
||||
elif file_extension == GGUF_FILE_EXTENSION:
|
||||
return load_gguf_checkpoint(checkpoint_file)
|
||||
else:
|
||||
|
||||
@@ -559,6 +559,9 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the
|
||||
`safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors`
|
||||
weights. If set to `False`, `safetensors` weights are not loaded.
|
||||
disable_mmap ('bool', *optional*, defaults to 'False'):
|
||||
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
|
||||
is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
|
||||
|
||||
<Tip>
|
||||
|
||||
@@ -604,6 +607,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
variant = kwargs.pop("variant", None)
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
quantization_config = kwargs.pop("quantization_config", None)
|
||||
disable_mmap = kwargs.pop("disable_mmap", False)
|
||||
|
||||
allow_pickle = False
|
||||
if use_safetensors is None:
|
||||
@@ -883,7 +887,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
# TODO (sayakpaul, SunMarc): remove this after model loading refactor
|
||||
else:
|
||||
param_device = torch.device(torch.cuda.current_device())
|
||||
state_dict = load_state_dict(model_file, variant=variant)
|
||||
state_dict = load_state_dict(model_file, variant=variant, disable_mmap=disable_mmap)
|
||||
model._convert_deprecated_attention_blocks(state_dict)
|
||||
|
||||
# move the params from meta device to cpu
|
||||
@@ -920,14 +924,12 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
else: # else let accelerate handle loading and dispatching.
|
||||
# Load weights and dispatch according to the device_map
|
||||
# by default the device_map is None and the weights are loaded on the CPU
|
||||
force_hook = True
|
||||
device_map = _determine_device_map(
|
||||
model, device_map, max_memory, torch_dtype, keep_in_fp32_modules, hf_quantizer
|
||||
)
|
||||
if device_map is None and is_sharded:
|
||||
# we load the parameters on the cpu
|
||||
device_map = {"": "cpu"}
|
||||
force_hook = False
|
||||
try:
|
||||
accelerate.load_checkpoint_and_dispatch(
|
||||
model,
|
||||
@@ -937,7 +939,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
offload_folder=offload_folder,
|
||||
offload_state_dict=offload_state_dict,
|
||||
dtype=torch_dtype,
|
||||
force_hooks=force_hook,
|
||||
strict=True,
|
||||
)
|
||||
except AttributeError as e:
|
||||
@@ -967,7 +968,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
offload_folder=offload_folder,
|
||||
offload_state_dict=offload_state_dict,
|
||||
dtype=torch_dtype,
|
||||
force_hooks=force_hook,
|
||||
strict=True,
|
||||
)
|
||||
model._undo_temp_convert_self_to_deprecated_attention_blocks()
|
||||
@@ -983,7 +983,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
else:
|
||||
model = cls.from_config(config, **unused_kwargs)
|
||||
|
||||
state_dict = load_state_dict(model_file, variant=variant)
|
||||
state_dict = load_state_dict(model_file, variant=variant, disable_mmap=disable_mmap)
|
||||
model._convert_deprecated_attention_blocks(state_dict)
|
||||
|
||||
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
|
||||
@@ -1214,7 +1214,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
# Adapted from `transformers` modeling_utils.py
|
||||
def _get_no_split_modules(self, device_map: str):
|
||||
"""
|
||||
Get the modules of the model that should not be spit when using device_map. We iterate through the modules to
|
||||
Get the modules of the model that should not be split when using device_map. We iterate through the modules to
|
||||
get the underlying `_no_split_modules`.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -20,6 +20,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin
|
||||
from ...utils import is_torch_version, logging
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention_processor import (
|
||||
@@ -253,7 +254,7 @@ class AuraFlowJointTransformerBlock(nn.Module):
|
||||
return encoder_hidden_states, hidden_states
|
||||
|
||||
|
||||
class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin):
|
||||
class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
r"""
|
||||
A 2D Transformer model as introduced in AuraFlow (https://blog.fal.ai/auraflow/).
|
||||
|
||||
|
||||
@@ -120,8 +120,10 @@ class CogVideoXBlock(nn.Module):
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> torch.Tensor:
|
||||
text_seq_length = encoder_hidden_states.size(1)
|
||||
attention_kwargs = attention_kwargs or {}
|
||||
|
||||
# norm & modulate
|
||||
norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
|
||||
@@ -133,6 +135,7 @@ class CogVideoXBlock(nn.Module):
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=norm_encoder_hidden_states,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
**attention_kwargs,
|
||||
)
|
||||
|
||||
hidden_states = hidden_states + gate_msa * attn_hidden_states
|
||||
@@ -210,6 +213,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["CogVideoXBlock", "CogVideoXPatchEmbed"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
@@ -497,6 +501,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
encoder_hidden_states,
|
||||
emb,
|
||||
image_rotary_emb,
|
||||
attention_kwargs,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
else:
|
||||
@@ -505,6 +510,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
temb=emb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
attention_kwargs=attention_kwargs,
|
||||
)
|
||||
|
||||
if not self.config.use_rotary_positional_embeddings:
|
||||
|
||||
@@ -221,6 +221,8 @@ class AllegroTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
Scaling factor to apply in 3D positional embeddings across time dimension.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -166,6 +166,7 @@ class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["CogView3PlusTransformerBlock", "CogView3PlusPatchEmbed"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
|
||||
@@ -542,6 +542,12 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
_no_split_modules = [
|
||||
"HunyuanVideoTransformerBlock",
|
||||
"HunyuanVideoSingleTransformerBlock",
|
||||
"HunyuanVideoPatchEmbed",
|
||||
"HunyuanVideoTokenRefiner",
|
||||
]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
@@ -713,15 +719,15 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
|
||||
condition_sequence_length = encoder_hidden_states.shape[1]
|
||||
sequence_length = latent_sequence_length + condition_sequence_length
|
||||
attention_mask = torch.zeros(
|
||||
batch_size, sequence_length, sequence_length, device=hidden_states.device, dtype=torch.bool
|
||||
) # [B, N, N]
|
||||
batch_size, sequence_length, device=hidden_states.device, dtype=torch.bool
|
||||
) # [B, N]
|
||||
|
||||
effective_condition_sequence_length = encoder_attention_mask.sum(dim=1, dtype=torch.int) # [B,]
|
||||
effective_sequence_length = latent_sequence_length + effective_condition_sequence_length
|
||||
|
||||
for i in range(batch_size):
|
||||
attention_mask[i, : effective_sequence_length[i], : effective_sequence_length[i]] = True
|
||||
attention_mask = attention_mask.unsqueeze(1) # [B, 1, N, N], for broadcasting across attention heads
|
||||
attention_mask[i, : effective_sequence_length[i]] = True
|
||||
attention_mask = attention_mask.unsqueeze(1) # [B, 1, N], for broadcasting across attention heads
|
||||
|
||||
# 4. Transformer blocks
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
@@ -58,7 +58,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
|
||||
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`):
|
||||
Tuple of downsample block types.
|
||||
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2D"`):
|
||||
Block type for middle of UNet, it can be either `UNetMidBlock2D` or `UnCLIPUNetMidBlock2D`.
|
||||
Block type for middle of UNet, it can be either `UNetMidBlock2D` or `None`.
|
||||
up_block_types (`Tuple[str]`, *optional*, defaults to `("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`):
|
||||
Tuple of upsample block types.
|
||||
block_out_channels (`Tuple[int]`, *optional*, defaults to `(224, 448, 672, 896)`):
|
||||
@@ -103,6 +103,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
|
||||
freq_shift: int = 0,
|
||||
flip_sin_to_cos: bool = True,
|
||||
down_block_types: Tuple[str, ...] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
|
||||
mid_block_type: Optional[str] = "UNetMidBlock2D",
|
||||
up_block_types: Tuple[str, ...] = ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"),
|
||||
block_out_channels: Tuple[int, ...] = (224, 448, 672, 896),
|
||||
layers_per_block: int = 2,
|
||||
@@ -194,19 +195,22 @@ class UNet2DModel(ModelMixin, ConfigMixin):
|
||||
self.down_blocks.append(down_block)
|
||||
|
||||
# mid
|
||||
self.mid_block = UNetMidBlock2D(
|
||||
in_channels=block_out_channels[-1],
|
||||
temb_channels=time_embed_dim,
|
||||
dropout=dropout,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
output_scale_factor=mid_block_scale_factor,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
attention_head_dim=attention_head_dim if attention_head_dim is not None else block_out_channels[-1],
|
||||
resnet_groups=norm_num_groups,
|
||||
attn_groups=attn_norm_num_groups,
|
||||
add_attention=add_attention,
|
||||
)
|
||||
if mid_block_type is None:
|
||||
self.mid_block = None
|
||||
else:
|
||||
self.mid_block = UNetMidBlock2D(
|
||||
in_channels=block_out_channels[-1],
|
||||
temb_channels=time_embed_dim,
|
||||
dropout=dropout,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
output_scale_factor=mid_block_scale_factor,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
attention_head_dim=attention_head_dim if attention_head_dim is not None else block_out_channels[-1],
|
||||
resnet_groups=norm_num_groups,
|
||||
attn_groups=attn_norm_num_groups,
|
||||
add_attention=add_attention,
|
||||
)
|
||||
|
||||
# up
|
||||
reversed_block_out_channels = list(reversed(block_out_channels))
|
||||
@@ -322,7 +326,8 @@ class UNet2DModel(ModelMixin, ConfigMixin):
|
||||
down_block_res_samples += res_samples
|
||||
|
||||
# 4. mid
|
||||
sample = self.mid_block(sample, emb)
|
||||
if self.mid_block is not None:
|
||||
sample = self.mid_block(sample, emb)
|
||||
|
||||
# 5. up
|
||||
skip_sample = None
|
||||
|
||||
@@ -33,6 +33,7 @@ from ...utils import (
|
||||
deprecate,
|
||||
is_bs4_available,
|
||||
is_ftfy_available,
|
||||
is_torch_xla_available,
|
||||
logging,
|
||||
replace_example_docstring,
|
||||
)
|
||||
@@ -41,6 +42,14 @@ from ...video_processor import VideoProcessor
|
||||
from .pipeline_output import AllegroPipelineOutput
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
if is_bs4_available():
|
||||
@@ -194,10 +203,10 @@ class AllegroPipeline(DiffusionPipeline):
|
||||
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
|
||||
)
|
||||
self.vae_scale_factor_spatial = (
|
||||
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
|
||||
2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
||||
)
|
||||
self.vae_scale_factor_temporal = (
|
||||
self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
|
||||
self.vae.config.temporal_compression_ratio if getattr(self, "vae", None) else 4
|
||||
)
|
||||
|
||||
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
||||
@@ -921,6 +930,9 @@ class AllegroPipeline(DiffusionPipeline):
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
if not output_type == "latent":
|
||||
latents = latents.to(self.vae.dtype)
|
||||
video = self.decode_latents(latents)
|
||||
|
||||
@@ -20,10 +20,18 @@ from transformers import CLIPTextModelWithProjection, CLIPTokenizer
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...models import UVit2DModel, VQModel
|
||||
from ...schedulers import AmusedScheduler
|
||||
from ...utils import replace_example_docstring
|
||||
from ...utils import is_torch_xla_available, replace_example_docstring
|
||||
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
@@ -66,7 +74,9 @@ class AmusedPipeline(DiffusionPipeline):
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vqvae.config.block_out_channels) - 1)
|
||||
self.vae_scale_factor = (
|
||||
2 ** (len(self.vqvae.config.block_out_channels) - 1) if getattr(self, "vqvae", None) else 8
|
||||
)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False)
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -297,6 +307,9 @@ class AmusedPipeline(DiffusionPipeline):
|
||||
step_idx = i // getattr(self.scheduler, "order", 1)
|
||||
callback(step_idx, timestep, latents)
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
if output_type == "latent":
|
||||
output = latents
|
||||
else:
|
||||
|
||||
@@ -20,10 +20,18 @@ from transformers import CLIPTextModelWithProjection, CLIPTokenizer
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...models import UVit2DModel, VQModel
|
||||
from ...schedulers import AmusedScheduler
|
||||
from ...utils import replace_example_docstring
|
||||
from ...utils import is_torch_xla_available, replace_example_docstring
|
||||
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
@@ -81,7 +89,9 @@ class AmusedImg2ImgPipeline(DiffusionPipeline):
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vqvae.config.block_out_channels) - 1)
|
||||
self.vae_scale_factor = (
|
||||
2 ** (len(self.vqvae.config.block_out_channels) - 1) if getattr(self, "vqvae", None) else 8
|
||||
)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False)
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -323,6 +333,9 @@ class AmusedImg2ImgPipeline(DiffusionPipeline):
|
||||
step_idx = i // getattr(self.scheduler, "order", 1)
|
||||
callback(step_idx, timestep, latents)
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
if output_type == "latent":
|
||||
output = latents
|
||||
else:
|
||||
|
||||
@@ -21,10 +21,18 @@ from transformers import CLIPTextModelWithProjection, CLIPTokenizer
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...models import UVit2DModel, VQModel
|
||||
from ...schedulers import AmusedScheduler
|
||||
from ...utils import replace_example_docstring
|
||||
from ...utils import is_torch_xla_available, replace_example_docstring
|
||||
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
@@ -89,7 +97,9 @@ class AmusedInpaintPipeline(DiffusionPipeline):
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vqvae.config.block_out_channels) - 1)
|
||||
self.vae_scale_factor = (
|
||||
2 ** (len(self.vqvae.config.block_out_channels) - 1) if getattr(self, "vqvae", None) else 8
|
||||
)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False)
|
||||
self.mask_processor = VaeImageProcessor(
|
||||
vae_scale_factor=self.vae_scale_factor,
|
||||
@@ -354,6 +364,9 @@ class AmusedInpaintPipeline(DiffusionPipeline):
|
||||
step_idx = i // getattr(self.scheduler, "order", 1)
|
||||
callback(step_idx, timestep, latents)
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
if output_type == "latent":
|
||||
output = latents
|
||||
else:
|
||||
|
||||
@@ -34,6 +34,7 @@ from ...schedulers import (
|
||||
from ...utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
deprecate,
|
||||
is_torch_xla_available,
|
||||
logging,
|
||||
replace_example_docstring,
|
||||
scale_lora_layers,
|
||||
@@ -47,8 +48,16 @@ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
from .pipeline_output import AnimateDiffPipelineOutput
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
@@ -139,7 +148,7 @@ class AnimateDiffPipeline(
|
||||
feature_extractor=feature_extractor,
|
||||
image_encoder=image_encoder,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
||||
self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt
|
||||
@@ -844,6 +853,9 @@ class AnimateDiffPipeline(
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
# 9. Post processing
|
||||
if output_type == "latent":
|
||||
video = latents
|
||||
|
||||
@@ -32,7 +32,7 @@ from ...models import (
|
||||
from ...models.lora import adjust_lora_scale_text_encoder
|
||||
from ...models.unets.unet_motion_model import MotionAdapter
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils import USE_PEFT_BACKEND, is_torch_xla_available, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import is_compiled_module, randn_tensor
|
||||
from ...video_processor import VideoProcessor
|
||||
from ..free_init_utils import FreeInitMixin
|
||||
@@ -41,8 +41,16 @@ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
from .pipeline_output import AnimateDiffPipelineOutput
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
@@ -180,7 +188,7 @@ class AnimateDiffControlNetPipeline(
|
||||
feature_extractor=feature_extractor,
|
||||
image_encoder=image_encoder,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
||||
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
self.control_video_processor = VideoProcessor(
|
||||
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
|
||||
@@ -1090,6 +1098,9 @@ class AnimateDiffControlNetPipeline(
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
# 9. Post processing
|
||||
if output_type == "latent":
|
||||
video = latents
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user