Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| a12d8d90e2 | |||
| 5a2909734d | |||
| 5ce8e040aa |
+1
-1
@@ -175,4 +175,4 @@ tags
|
||||
.ruff_cache
|
||||
|
||||
# wandb
|
||||
wandb
|
||||
wandb
|
||||
|
||||
@@ -1,72 +0,0 @@
|
||||
import torch
|
||||
from fa3_processor import FA3AttnProcessor
|
||||
from diffusers import DiffusionPipeline
|
||||
import argparse
|
||||
import torch.utils.benchmark as benchmark
|
||||
import gc
|
||||
import json
|
||||
|
||||
def flush():
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_max_memory_allocated()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
def bytes_to_giga_bytes(bytes):
|
||||
return f"{(bytes / 1024 / 1024 / 1024):.3f}"
|
||||
|
||||
def benchmark_fn(f, *args, **kwargs):
|
||||
t0 = benchmark.Timer(
|
||||
stmt="f(*args, **kwargs)",
|
||||
globals={"args": args, "kwargs": kwargs, "f": f},
|
||||
num_threads=torch.get_num_threads(),
|
||||
)
|
||||
return f"{(t0.blocked_autorange().mean):.3f}"
|
||||
|
||||
def load_pipeline(args):
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"PixArt-alpha/PixArt-Sigma-XL-2-1024-MS", torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
if args.fa3:
|
||||
pipeline.transformer.set_attn_processor(FA3AttnProcessor())
|
||||
pipeline.vae.set_attn_processor(FA3AttnProcessor())
|
||||
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
return pipeline
|
||||
|
||||
def run_pipeline(pipeline, args):
|
||||
_ = pipeline(
|
||||
prompt="a cat with tiger-like looks",
|
||||
num_images_per_prompt=args.batch_size,
|
||||
guidance_scale=7.5
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--fa3", default=0, type=int)
|
||||
parser.add_argument("--batch_size", default=1, type=int)
|
||||
args = parser.parse_args()
|
||||
|
||||
flush()
|
||||
|
||||
pipeline = load_pipeline(args)
|
||||
|
||||
for _ in range(3):
|
||||
run_pipeline(pipeline, args)
|
||||
|
||||
time = benchmark_fn(run_pipeline, pipeline, args)
|
||||
memory = bytes_to_giga_bytes(torch.cuda.max_memory_allocated())
|
||||
data_dict = dict(time=time, memory=memory)
|
||||
print(f"FA3: {bool(args.fa3)} Time: {time} seconds Memory: {memory} GB")
|
||||
|
||||
filename_prefix = f"fa3@{args.fa3}-bs@{args.batch_size}"
|
||||
with open(f"{filename_prefix}.json", "w") as f:
|
||||
json.dump(data_dict, f)
|
||||
|
||||
image = pipeline(
|
||||
prompt="a cat with tiger-like looks",
|
||||
num_images_per_prompt=args.batch_size,
|
||||
num_inference_steps=25,
|
||||
guidance_scale=7.5
|
||||
).images[0]
|
||||
image.save(f"{filename_prefix}.png")
|
||||
@@ -249,12 +249,6 @@
|
||||
title: DiTTransformer2DModel
|
||||
- local: api/models/hunyuan_transformer2d
|
||||
title: HunyuanDiT2DModel
|
||||
- local: api/models/aura_flow_transformer2d
|
||||
title: AuraFlowTransformer2DModel
|
||||
- local: api/models/latte_transformer3d
|
||||
title: LatteTransformer3DModel
|
||||
- local: api/models/lumina_nextdit2d
|
||||
title: LuminaNextDiT2DModel
|
||||
- local: api/models/transformer_temporal
|
||||
title: TransformerTemporalModel
|
||||
- local: api/models/sd3_transformer2d
|
||||
@@ -282,8 +276,6 @@
|
||||
title: AudioLDM
|
||||
- local: api/pipelines/audioldm2
|
||||
title: AudioLDM 2
|
||||
- local: api/pipelines/aura_flow
|
||||
title: AuraFlow
|
||||
- local: api/pipelines/auto_pipeline
|
||||
title: AutoPipeline
|
||||
- local: api/pipelines/blip_diffusion
|
||||
@@ -326,16 +318,12 @@
|
||||
title: Kandinsky 2.2
|
||||
- local: api/pipelines/kandinsky3
|
||||
title: Kandinsky 3
|
||||
- local: api/pipelines/kolors
|
||||
title: Kolors
|
||||
- local: api/pipelines/latent_consistency_models
|
||||
title: Latent Consistency Models
|
||||
- local: api/pipelines/latent_diffusion
|
||||
title: Latent Diffusion
|
||||
- local: api/pipelines/ledits_pp
|
||||
title: LEDITS++
|
||||
- local: api/pipelines/lumina
|
||||
title: Lumina-T2X
|
||||
- local: api/pipelines/marigold
|
||||
title: Marigold
|
||||
- local: api/pipelines/panorama
|
||||
@@ -447,8 +435,6 @@
|
||||
title: EulerDiscreteScheduler
|
||||
- local: api/schedulers/flow_match_euler_discrete
|
||||
title: FlowMatchEulerDiscreteScheduler
|
||||
- local: api/schedulers/flow_match_heun_discrete
|
||||
title: FlowMatchHeunDiscreteScheduler
|
||||
- local: api/schedulers/heun
|
||||
title: HeunDiscreteScheduler
|
||||
- local: api/schedulers/ipndm
|
||||
|
||||
@@ -1,19 +0,0 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# AuraFlowTransformer2DModel
|
||||
|
||||
A Transformer model for image-like data from [AuraFlow](https://blog.fal.ai/auraflow/).
|
||||
|
||||
## AuraFlowTransformer2DModel
|
||||
|
||||
[[autodoc]] AuraFlowTransformer2DModel
|
||||
@@ -1,19 +0,0 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
## LatteTransformer3DModel
|
||||
|
||||
A Diffusion Transformer model for 3D data from [Latte](https://github.com/Vchitect/Latte).
|
||||
|
||||
## LatteTransformer3DModel
|
||||
|
||||
[[autodoc]] LatteTransformer3DModel
|
||||
@@ -1,20 +0,0 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# LuminaNextDiT2DModel
|
||||
|
||||
A Next Version of Diffusion Transformer model for 2D data from [Lumina-T2X](https://github.com/Alpha-VLLM/Lumina-T2X).
|
||||
|
||||
## LuminaNextDiT2DModel
|
||||
|
||||
[[autodoc]] LuminaNextDiT2DModel
|
||||
|
||||
@@ -560,20 +560,6 @@ export_to_gif(frames, "animatelcm-motion-lora.gif")
|
||||
</table>
|
||||
|
||||
|
||||
## Using `from_single_file` with the MotionAdapter
|
||||
|
||||
`diffusers>=0.30.0` supports loading the AnimateDiff checkpoints into the `MotionAdapter` in their original format via `from_single_file`
|
||||
|
||||
```python
|
||||
from diffusers import MotionAdapter
|
||||
|
||||
ckpt_path = "https://huggingface.co/Lightricks/LongAnimateDiff/blob/main/lt_long_mm_32_frames.ckpt"
|
||||
|
||||
adapter = MotionAdapter.from_single_file(ckpt_path, torch_dtype=torch.float16)
|
||||
pipe = AnimateDiffPipeline.from_pretrained("emilianJR/epiCRealism", motion_adapter=adapter)
|
||||
|
||||
```
|
||||
|
||||
## AnimateDiffPipeline
|
||||
|
||||
[[autodoc]] AnimateDiffPipeline
|
||||
|
||||
@@ -1,29 +0,0 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# AuraFlow
|
||||
|
||||
AuraFlow is inspired by [Stable Diffusion 3](../pipelines/stable_diffusion/stable_diffusion_3.md) and is by far the largest text-to-image generation model that comes with an Apache 2.0 license. This model achieves state-of-the-art results on the [GenEval](https://github.com/djghosh13/geneval) benchmark.
|
||||
|
||||
It was developed by the Fal team and more details about it can be found in [this blog post](https://blog.fal.ai/auraflow/).
|
||||
|
||||
<Tip>
|
||||
|
||||
AuraFlow can be quite expensive to run on consumer hardware devices. However, you can perform a suite of optimizations to run it faster and in a more memory-friendly manner. Check out [this section](https://huggingface.co/blog/sd3#memory-optimizations-for-sd3) for more details.
|
||||
|
||||
</Tip>
|
||||
|
||||
## AuraFlowPipeline
|
||||
|
||||
[[autodoc]] AuraFlowPipeline
|
||||
- all
|
||||
- __call__
|
||||
@@ -1,49 +0,0 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Kolors: Effective Training of Diffusion Model for Photorealistic Text-to-Image Synthesis
|
||||
|
||||

|
||||
|
||||
Kolors is a large-scale text-to-image generation model based on latent diffusion, developed by [the Kuaishou Kolors team](kwai-kolors@kuaishou.com). Trained on billions of text-image pairs, Kolors exhibits significant advantages over both open-source and closed-source models in visual quality, complex semantic accuracy, and text rendering for both Chinese and English characters. Furthermore, Kolors supports both Chinese and English inputs, demonstrating strong performance in understanding and generating Chinese-specific content. For more details, please refer to this [technical report](https://github.com/Kwai-Kolors/Kolors/blob/master/imgs/Kolors_paper.pdf).
|
||||
|
||||
The abstract from the technical report is:
|
||||
|
||||
*We present Kolors, a latent diffusion model for text-to-image synthesis, characterized by its profound understanding of both English and Chinese, as well as an impressive degree of photorealism. There are three key insights contributing to the development of Kolors. Firstly, unlike large language model T5 used in Imagen and Stable Diffusion 3, Kolors is built upon the General Language Model (GLM), which enhances its comprehension capabilities in both English and Chinese. Moreover, we employ a multimodal large language model to recaption the extensive training dataset for fine-grained text understanding. These strategies significantly improve Kolors’ ability to comprehend intricate semantics, particularly those involving multiple entities, and enable its advanced text rendering capabilities. Secondly, we divide the training of Kolors into two phases: the concept learning phase with broad knowledge and the quality improvement phase with specifically curated high-aesthetic data. Furthermore, we investigate the critical role of the noise schedule and introduce a novel schedule to optimize high-resolution image generation. These strategies collectively enhance the visual appeal of the generated high-resolution images. Lastly, we propose a category-balanced benchmark KolorsPrompts, which serves as a guide for the training and evaluation of Kolors. Consequently, even when employing the commonly used U-Net backbone, Kolors has demonstrated remarkable performance in human evaluations, surpassing the existing open-source models and achieving Midjourney-v6 level performance, especially in terms of visual appeal. We will release the code and weights of Kolors at <https://github.com/Kwai-Kolors/Kolors>, and hope that it will benefit future research and applications in the visual generation community.*
|
||||
|
||||
## Usage Example
|
||||
|
||||
```python
|
||||
import torch
|
||||
|
||||
from diffusers import DPMSolverMultistepScheduler, KolorsPipeline
|
||||
|
||||
pipe = KolorsPipeline.from_pretrained("Kwai-Kolors/Kolors-diffusers", torch_dtype=torch.float16, variant="fp16")
|
||||
pipe.to("cuda")
|
||||
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config, use_karras_sigmas=True)
|
||||
|
||||
image = pipe(
|
||||
prompt='一张瓢虫的照片,微距,变焦,高质量,电影,拿着一个牌子,写着"可图"',
|
||||
negative_prompt="",
|
||||
guidance_scale=6.5,
|
||||
num_inference_steps=25,
|
||||
).images[0]
|
||||
|
||||
image.save("kolors_sample.png")
|
||||
```
|
||||
|
||||
## KolorsPipeline
|
||||
|
||||
[[autodoc]] KolorsPipeline
|
||||
|
||||
- all
|
||||
- __call__
|
||||
@@ -1,88 +0,0 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Lumina-T2X
|
||||

|
||||
|
||||
[Lumina-Next : Making Lumina-T2X Stronger and Faster with Next-DiT](https://github.com/Alpha-VLLM/Lumina-T2X/blob/main/assets/lumina-next.pdf) from Alpha-VLLM, OpenGVLab, Shanghai AI Laboratory.
|
||||
|
||||
The abstract from the paper is:
|
||||
|
||||
*Lumina-T2X is a nascent family of Flow-based Large Diffusion Transformers (Flag-DiT) that establishes a unified framework for transforming noise into various modalities, such as images and videos, conditioned on text instructions. Despite its promising capabilities, Lumina-T2X still encounters challenges including training instability, slow inference, and extrapolation artifacts. In this paper, we present Lumina-Next, an improved version of Lumina-T2X, showcasing stronger generation performance with increased training and inference efficiency. We begin with a comprehensive analysis of the Flag-DiT architecture and identify several suboptimal components, which we address by introducing the Next-DiT architecture with 3D RoPE and sandwich normalizations. To enable better resolution extrapolation, we thoroughly compare different context extrapolation methods applied to text-to-image generation with 3D RoPE, and propose Frequency- and Time-Aware Scaled RoPE tailored for diffusion transformers. Additionally, we introduce a sigmoid time discretization schedule to reduce sampling steps in solving the Flow ODE and the Context Drop method to merge redundant visual tokens for faster network evaluation, effectively boosting the overall sampling speed. Thanks to these improvements, Lumina-Next not only improves the quality and efficiency of basic text-to-image generation but also demonstrates superior resolution extrapolation capabilities and multilingual generation using decoder-based LLMs as the text encoder, all in a zero-shot manner. To further validate Lumina-Next as a versatile generative framework, we instantiate it on diverse tasks including visual recognition, multi-view, audio, music, and point cloud generation, showcasing strong performance across these domains. By releasing all codes and model weights at https://github.com/Alpha-VLLM/Lumina-T2X, we aim to advance the development of next-generation generative AI capable of universal modeling.*
|
||||
|
||||
**Highlights**: Lumina-Next is a next-generation Diffusion Transformer that significantly enhances text-to-image generation, multilingual generation, and multitask performance by introducing the Next-DiT architecture, 3D RoPE, and frequency- and time-aware RoPE, among other improvements.
|
||||
|
||||
Lumina-Next has the following components:
|
||||
* It improves sampling efficiency with fewer and faster Steps.
|
||||
* It uses a Next-DiT as a transformer backbone with Sandwichnorm 3D RoPE, and Grouped-Query Attention.
|
||||
* It uses a Frequency- and Time-Aware Scaled RoPE.
|
||||
|
||||
---
|
||||
|
||||
[Lumina-T2X: Transforming Text into Any Modality, Resolution, and Duration via Flow-based Large Diffusion Transformers](https://arxiv.org/abs/2405.05945) from Alpha-VLLM, OpenGVLab, Shanghai AI Laboratory.
|
||||
|
||||
The abstract from the paper is:
|
||||
|
||||
*Sora unveils the potential of scaling Diffusion Transformer for generating photorealistic images and videos at arbitrary resolutions, aspect ratios, and durations, yet it still lacks sufficient implementation details. In this technical report, we introduce the Lumina-T2X family - a series of Flow-based Large Diffusion Transformers (Flag-DiT) equipped with zero-initialized attention, as a unified framework designed to transform noise into images, videos, multi-view 3D objects, and audio clips conditioned on text instructions. By tokenizing the latent spatial-temporal space and incorporating learnable placeholders such as [nextline] and [nextframe] tokens, Lumina-T2X seamlessly unifies the representations of different modalities across various spatial-temporal resolutions. This unified approach enables training within a single framework for different modalities and allows for flexible generation of multimodal data at any resolution, aspect ratio, and length during inference. Advanced techniques like RoPE, RMSNorm, and flow matching enhance the stability, flexibility, and scalability of Flag-DiT, enabling models of Lumina-T2X to scale up to 7 billion parameters and extend the context window to 128K tokens. This is particularly beneficial for creating ultra-high-definition images with our Lumina-T2I model and long 720p videos with our Lumina-T2V model. Remarkably, Lumina-T2I, powered by a 5-billion-parameter Flag-DiT, requires only 35% of the training computational costs of a 600-million-parameter naive DiT. Our further comprehensive analysis underscores Lumina-T2X's preliminary capability in resolution extrapolation, high-resolution editing, generating consistent 3D views, and synthesizing videos with seamless transitions. We expect that the open-sourcing of Lumina-T2X will further foster creativity, transparency, and diversity in the generative AI community.*
|
||||
|
||||
|
||||
You can find the original codebase at [Alpha-VLLM](https://github.com/Alpha-VLLM/Lumina-T2X) and all the available checkpoints at [Alpha-VLLM Lumina Family](https://huggingface.co/collections/Alpha-VLLM/lumina-family-66423205bedb81171fd0644b).
|
||||
|
||||
**Highlights**: Lumina-T2X supports Any Modality, Resolution, and Duration.
|
||||
|
||||
Lumina-T2X has the following components:
|
||||
* It uses a Flow-based Large Diffusion Transformer as the backbone
|
||||
* It supports different any modalities with one backbone and corresponding encoder, decoder.
|
||||
|
||||
<Tip>
|
||||
|
||||
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
|
||||
</Tip>
|
||||
|
||||
### Inference (Text-to-Image)
|
||||
|
||||
Use [`torch.compile`](https://huggingface.co/docs/diffusers/main/en/tutorials/fast_diffusion#torchcompile) to reduce the inference latency.
|
||||
|
||||
First, load the pipeline:
|
||||
|
||||
```python
|
||||
from diffusers import LuminaText2ImgPipeline
|
||||
import torch
|
||||
|
||||
pipeline = LuminaText2ImgPipeline.from_pretrained(
|
||||
"Alpha-VLLM/Lumina-Next-SFT-diffusers", torch_dtype=torch.bfloat16
|
||||
).to("cuda")
|
||||
```
|
||||
|
||||
Then change the memory layout of the pipelines `transformer` and `vae` components to `torch.channels-last`:
|
||||
|
||||
```python
|
||||
pipeline.transformer.to(memory_format=torch.channels_last)
|
||||
pipeline.vae.to(memory_format=torch.channels_last)
|
||||
```
|
||||
|
||||
Finally, compile the components and run inference:
|
||||
|
||||
```python
|
||||
pipeline.transformer = torch.compile(pipeline.transformer, mode="max-autotune", fullgraph=True)
|
||||
pipeline.vae.decode = torch.compile(pipeline.vae.decode, mode="max-autotune", fullgraph=True)
|
||||
|
||||
image = pipeline(prompt="Upper body of a young woman in a Victorian-era outfit with brass goggles and leather straps. Background shows an industrial revolution cityscape with smoky skies and tall, metal structures").images[0]
|
||||
```
|
||||
|
||||
## LuminaText2ImgPipeline
|
||||
|
||||
[[autodoc]] LuminaText2ImgPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
@@ -25,11 +25,6 @@ The abstract from the paper is:
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## StableDiffusionControlNetPAGPipeline
|
||||
[[autodoc]] StableDiffusionControlNetPAGPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## StableDiffusionXLPAGPipeline
|
||||
[[autodoc]] StableDiffusionXLPAGPipeline
|
||||
- all
|
||||
|
||||
@@ -1,18 +0,0 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# FlowMatchHeunDiscreteScheduler
|
||||
|
||||
`FlowMatchHeunDiscreteScheduler` is based on the flow-matching sampling introduced in [EDM](https://arxiv.org/abs/2403.03206).
|
||||
|
||||
## FlowMatchHeunDiscreteScheduler
|
||||
[[autodoc]] FlowMatchHeunDiscreteScheduler
|
||||
@@ -418,7 +418,7 @@ my_local_checkpoint_path = hf_hub_download(
|
||||
|
||||
my_local_config_path = snapshot_download(
|
||||
repo_id="segmind/SSD-1B",
|
||||
allow_patterns=["*.json", "**/*.json", "*.txt", "**/*.txt"]
|
||||
allowed_patterns=["*.json", "**/*.json", "*.txt", "**/*.txt"]
|
||||
)
|
||||
|
||||
pipeline = StableDiffusionXLPipeline.from_single_file(my_local_checkpoint_path, config=my_local_config_path, local_files_only=True)
|
||||
@@ -438,7 +438,7 @@ my_local_checkpoint_path = hf_hub_download(
|
||||
|
||||
my_local_config_path = snapshot_download(
|
||||
repo_id="segmind/SSD-1B",
|
||||
allow_patterns=["*.json", "**/*.json", "*.txt", "**/*.txt"]
|
||||
allowed_patterns=["*.json", "**/*.json", "*.txt", "**/*.txt"]
|
||||
local_dir="my_local_config"
|
||||
)
|
||||
|
||||
@@ -468,7 +468,7 @@ print("My local checkpoint: ", my_local_checkpoint_path)
|
||||
|
||||
my_local_config_path = snapshot_download(
|
||||
repo_id="segmind/SSD-1B",
|
||||
allow_patterns=["*.json", "**/*.json", "*.txt", "**/*.txt"]
|
||||
allowed_patterns=["*.json", "**/*.json", "*.txt", "**/*.txt"]
|
||||
local_dir_use_symlinks=False,
|
||||
)
|
||||
print("My local config: ", my_local_config_path)
|
||||
|
||||
@@ -10,30 +10,30 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# 철학 [[philosophy]]
|
||||
# 철학
|
||||
|
||||
🧨 Diffusers는 다양한 모달리티에서 **최신의** 사전 훈련된 diffusion 모델을 제공합니다.
|
||||
그 목적은 추론과 훈련을 위한 **모듈식 툴박스**로 사용되는 것입니다.
|
||||
|
||||
저희는 시간이 지나도 변치 않는 라이브러리를 구축하는 것을 목표로 하기에 API 설계를 매우 중요하게 생각합니다.
|
||||
우리는 오랜 시간에 견딜 수 있는 라이브러리를 구축하는 것을 목표로 하고, 따라서 API 설계를 매우 중요시합니다.
|
||||
|
||||
간단히 말해서, Diffusers는 PyTorch를 자연스럽게 확장할 수 있도록 만들어졌습니다. 따라서 대부분의 설계 선택은 [PyTorch의 설계 원칙](https://pytorch.org/docs/stable/community/design.html#pytorch-design-philosophy)에 기반합니다. 이제 가장 중요한 것들을 살펴보겠습니다:
|
||||
간단히 말해서, Diffusers는 PyTorch의 자연스러운 확장이 되도록 구축되었습니다. 따라서 대부분의 설계 선택은 [PyTorch의 설계 원칙](https://pytorch.org/docs/stable/community/design.html#pytorch-design-philosophy)에 기반합니다. 이제 가장 중요한 것들을 살펴보겠습니다:
|
||||
|
||||
## 성능보다는 사용성을 [[usability-over-performance]]
|
||||
## 성능보다는 사용성을
|
||||
|
||||
- Diffusers는 다양한 성능 향상 기능이 내장되어 있지만 (자세한 내용은 [메모리와 속도](https://huggingface.co/docs/diffusers/optimization/fp16) 참조), 모델은 항상 가장 높은 정밀도와 최소한의 최적화로 로드됩니다. 따라서 사용자가 별도로 정의하지 않는 한 기본적으로 diffusion 파이프라인은 항상 float32 정밀도로 CPU에 인스턴스화됩니다. 이는 다양한 플랫폼과 가속기에서의 사용성을 보장하며, 라이브러리를 실행하기 위해 복잡한 설치가 필요하지 않다는 것을 의미합니다.
|
||||
- Diffusers는 많은 내장 성능 향상 기능을 갖고 있지만 (자세한 내용은 [메모리와 속도](https://huggingface.co/docs/diffusers/optimization/fp16) 참조), 모델은 항상 가장 높은 정밀도와 최소한의 최적화로 로드됩니다. 따라서 기본적인 diffusion 파이프라인은 따로 정의하지 않는다면 CPU에서 float32 정밀도로 인스턴스화됩니다. 이는 다양한 플랫폼과 가속기에서의 사용성을 보장하며, 라이브러리를 실행하기 위해 복잡한 설치가 필요하지 않음을 의미합니다.
|
||||
- Diffusers는 **가벼운** 패키지를 지향하기 때문에 필수 종속성은 거의 없지만 성능을 향상시킬 수 있는 많은 선택적 종속성이 있습니다 (`accelerate`, `safetensors`, `onnx` 등). 저희는 라이브러리를 가능한 한 가볍게 유지하여 다른 패키지에 대한 종속성 걱정이 없도록 노력하고 있습니다.
|
||||
- Diffusers는 간결하고 이해하기 쉬운 코드를 선호합니다. 이는 람다 함수나 고급 PyTorch 연산자와 같은 압축된 코드 구문을 자주 사용하지 않는 것을 의미합니다.
|
||||
|
||||
## 쉬움보다는 간단함을 [[simple-over-easy]]
|
||||
## 쉬움보다는 간단함을
|
||||
|
||||
PyTorch에서는 **명시적인 것이 암시적인 것보다 낫다**와 **단순한 것이 복잡한 것보다 낫다**라고 말합니다. 이 설계 철학은 라이브러리의 여러 부분에 반영되어 있습니다:
|
||||
- [`DiffusionPipeline.to`](https://huggingface.co/docs/diffusers/main/en/api/diffusion_pipeline#diffusers.DiffusionPipeline.to)와 같은 메소드를 사용하여 사용자가 장치 관리를 할 수 있도록 PyTorch의 API를 따릅니다.
|
||||
- [`DiffusionPipeline.to`](https://huggingface.co/docs/diffusers/main/en/api/diffusion_pipeline#diffusers.DiffusionPipeline.to)와 같은 메서드를 사용하여 사용자가 장치 관리를 할 수 있도록 PyTorch의 API를 따릅니다.
|
||||
- 잘못된 입력을 조용히 수정하는 대신 간결한 오류 메시지를 발생시키는 것이 우선입니다. Diffusers는 라이브러리를 가능한 한 쉽게 사용할 수 있도록 하는 것보다 사용자를 가르치는 것을 목표로 합니다.
|
||||
- 복잡한 모델과 스케줄러 로직이 내부에서 마법처럼 처리하는 대신 노출됩니다. 스케줄러/샘플러는 서로에게 최소한의 종속성을 가지고 분리되어 있습니다. 이로써 사용자는 언롤된 노이즈 제거 루프를 작성해야 합니다. 그러나 이 분리는 디버깅을 더 쉽게하고 노이즈 제거 과정을 조정하거나 diffusers 모델이나 스케줄러를 교체하는 데 사용자에게 더 많은 제어권을 제공합니다.
|
||||
- diffusers 파이프라인의 따로 훈련된 구성 요소인 text encoder, unet 및 variational autoencoder는 각각 자체 모델 클래스를 갖습니다. 이로써 사용자는 서로 다른 모델의 구성 요소 간의 상호 작용을 처리해야 하며, 직렬화 형식은 모델 구성 요소를 다른 파일로 분리합니다. 그러나 이는 디버깅과 커스터마이징을 더 쉽게합니다. DreamBooth나 Textual Inversion 훈련은 Diffusers의 'diffusion 파이프라인의 단일 구성 요소들을 분리할 수 있는 능력' 덕분에 매우 간단합니다.
|
||||
|
||||
## 추상화보다는 수정 가능하고 기여하기 쉬움을 [[tweakable-contributor-friendly-over-abstraction]]
|
||||
## 추상화보다는 수정 가능하고 기여하기 쉬움을
|
||||
|
||||
라이브러리의 대부분에 대해 Diffusers는 [Transformers 라이브러리](https://github.com/huggingface/transformers)의 중요한 설계 원칙을 채택합니다, 바로 성급한 추상화보다는 copy-pasted 코드를 선호한다는 것입니다. 이 설계 원칙은 [Don't repeat yourself (DRY)](https://en.wikipedia.org/wiki/Don%27t_repeat_yourself)와 같은 인기 있는 설계 원칙과는 대조적으로 매우 의견이 분분한데요.
|
||||
간단히 말해서, Transformers가 모델링 파일에 대해 수행하는 것처럼, Diffusers는 매우 낮은 수준의 추상화와 매우 독립적인 코드를 유지하는 것을 선호합니다. 함수, 긴 코드 블록, 심지어 클래스도 여러 파일에 복사할 수 있으며, 이는 처음에는 라이브러리를 유지할 수 없게 만드는 나쁜, 서투른 설계 선택으로 보일 수 있습니다. 하지만 이러한 설계는 매우 성공적이며, 커뮤니티 기반의 오픈 소스 기계 학습 라이브러리에 매우 적합합니다. 그 이유는 다음과 같습니다:
|
||||
@@ -48,11 +48,11 @@ Diffusers에서는 이러한 철학을 파이프라인과 스케줄러에 모두
|
||||
좋아요, 이제 🧨 Diffusers가 설계된 방식을 대략적으로 이해했을 것입니다 🤗.
|
||||
우리는 이러한 설계 원칙을 일관되게 라이브러리 전체에 적용하려고 노력하고 있습니다. 그럼에도 불구하고 철학에 대한 일부 예외 사항이나 불행한 설계 선택이 있을 수 있습니다. 디자인에 대한 피드백이 있다면 [GitHub에서 직접](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=&template=feedback.md&title=) 알려주시면 감사하겠습니다.
|
||||
|
||||
## 디자인 철학 자세히 알아보기 [[design-philosophy-in-details]]
|
||||
## 디자인 철학 자세히 알아보기
|
||||
|
||||
이제 디자인 철학의 세부 사항을 좀 더 자세히 살펴보겠습니다. Diffusers는 주로 세 가지 주요 클래스로 구성됩니다: [파이프라인](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines), [모델](https://github.com/huggingface/diffusers/tree/main/src/diffusers/models), 그리고 [스케줄러](https://github.com/huggingface/diffusers/tree/main/src/diffusers/schedulers). 각 클래스에 대한 더 자세한 설계 결정 사항을 살펴보겠습니다.
|
||||
|
||||
### 파이프라인 [[pipelines]]
|
||||
### 파이프라인
|
||||
|
||||
파이프라인은 사용하기 쉽도록 설계되었으며 (따라서 [*쉬움보다는 간단함을*](#쉬움보다는-간단함을)을 100% 따르지는 않음), feature-complete하지 않으며, 추론을 위한 [모델](#모델)과 [스케줄러](#스케줄러)를 사용하는 방법의 예시로 간주될 수 있습니다.
|
||||
|
||||
@@ -65,11 +65,11 @@ Diffusers에서는 이러한 철학을 파이프라인과 스케줄러에 모두
|
||||
- 파이프라인은 매우 가독성이 좋고, 이해하기 쉽고, 쉽게 조정할 수 있도록 설계되어야 합니다.
|
||||
- 파이프라인은 서로 상호작용하고, 상위 수준 API에 쉽게 통합할 수 있도록 설계되어야 합니다.
|
||||
- 파이프라인은 사용자 인터페이스가 feature-complete하지 않게 하는 것을 목표로 합니다. future-complete한 사용자 인터페이스를 원한다면 [InvokeAI](https://github.com/invoke-ai/InvokeAI), [Diffuzers](https://github.com/abhishekkrthakur/diffuzers), [lama-cleaner](https://github.com/Sanster/lama-cleaner)를 참조해야 합니다.
|
||||
- 모든 파이프라인은 오로지 `__call__` 메소드를 통해 실행할 수 있어야 합니다. `__call__` 인자의 이름은 모든 파이프라인에서 공유되어야 합니다.
|
||||
- 모든 파이프라인은 오로지 `__call__` 메서드를 통해 실행할 수 있어야 합니다. `__call__` 인자의 이름은 모든 파이프라인에서 공유되어야 합니다.
|
||||
- 파이프라인은 해결하고자 하는 작업의 이름으로 지정되어야 합니다.
|
||||
- 대부분의 경우에 새로운 diffusion 파이프라인은 새로운 파이프라인 폴더/파일에 구현되어야 합니다.
|
||||
|
||||
### 모델 [[models]]
|
||||
### 모델
|
||||
|
||||
모델은 [PyTorch의 Module 클래스](https://pytorch.org/docs/stable/generated/torch.nn.Module.html)의 자연스러운 확장이 되도록, 구성 가능한 툴박스로 설계되었습니다. 그리고 모델은 **단일 파일 정책**을 일부만 따릅니다.
|
||||
|
||||
@@ -85,7 +85,7 @@ Diffusers에서는 이러한 철학을 파이프라인과 스케줄러에 모두
|
||||
- 모델은 미래의 변경 사항을 쉽게 확장할 수 있도록 설계되어야 합니다. 이는 공개 함수 인수들과 구성 인수들을 제한하고,미래의 변경 사항을 "예상"하는 것을 통해 달성할 수 있습니다. 예를 들어, 불리언 `is_..._type` 인수보다는 새로운 미래 유형에 쉽게 확장할 수 있는 문자열 "...type" 인수를 추가하는 것이 일반적으로 더 좋습니다. 새로운 모델 체크포인트가 작동하도록 하기 위해 기존 아키텍처에 최소한의 변경만을 가해야 합니다.
|
||||
- 모델 디자인은 코드의 가독성과 간결성을 유지하는 것과 많은 모델 체크포인트를 지원하는 것 사이의 어려운 균형 조절입니다. 모델링 코드의 대부분은 새로운 모델 체크포인트를 위해 클래스를 수정하는 것이 좋지만, [UNet 블록](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py) 및 [Attention 프로세서](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py)와 같이 코드를 장기적으로 간결하고 읽기 쉽게 유지하기 위해 새로운 클래스를 추가하는 예외도 있습니다.
|
||||
|
||||
### 스케줄러 [[schedulers]]
|
||||
### 스케줄러
|
||||
|
||||
스케줄러는 추론을 위한 노이즈 제거 과정을 안내하고 훈련을 위한 노이즈 스케줄을 정의하는 역할을 합니다. 스케줄러는 개별 클래스로 설계되어 있으며, 로드 가능한 구성 파일과 **단일 파일 정책**을 엄격히 따릅니다.
|
||||
|
||||
@@ -95,7 +95,7 @@ Diffusers에서는 이러한 철학을 파이프라인과 스케줄러에 모두
|
||||
- 하나의 스케줄러 Python 파일은 하나의 스케줄러 알고리즘(논문에서 정의된 것과 같은)에 해당합니다.
|
||||
- 스케줄러가 유사한 기능을 공유하는 경우, `# Copied from` 메커니즘을 사용할 수 있습니다.
|
||||
- 모든 스케줄러는 `SchedulerMixin`과 `ConfigMixin`을 상속합니다.
|
||||
- [`ConfigMixin.from_config`](https://huggingface.co/docs/diffusers/main/en/api/configuration#diffusers.ConfigMixin.from_config) 메소드를 사용하여 스케줄러를 쉽게 교체할 수 있습니다. 자세한 내용은 [여기](../using-diffusers/schedulers.md)에서 설명합니다.
|
||||
- [`ConfigMixin.from_config`](https://huggingface.co/docs/diffusers/main/en/api/configuration#diffusers.ConfigMixin.from_config) 메서드를 사용하여 스케줄러를 쉽게 교체할 수 있습니다. 자세한 내용은 [여기](../using-diffusers/schedulers.md)에서 설명합니다.
|
||||
- 모든 스케줄러는 `set_num_inference_steps`와 `step` 함수를 가져야 합니다. `set_num_inference_steps(...)`는 각 노이즈 제거 과정(즉, `step(...)`이 호출되기 전) 이전에 호출되어야 합니다.
|
||||
- 각 스케줄러는 모델이 호출될 타임스텝의 배열인 `timesteps` 속성을 통해 루프를 돌 수 있는 타임스텝을 노출합니다.
|
||||
- `step(...)` 함수는 예측된 모델 출력과 "현재" 샘플(x_t)을 입력으로 받고, "이전" 약간 더 노이즈가 제거된 샘플(x_t-1)을 반환합니다.
|
||||
|
||||
@@ -1290,7 +1290,6 @@ def main(args):
|
||||
text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers(
|
||||
get_peft_model_state_dict(model)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
@@ -1857,10 +1856,10 @@ def main(args):
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
||||
pipeline_args = {"prompt": args.validation_prompt}
|
||||
|
||||
if torch.backends.mps.is_available():
|
||||
autocast_ctx = nullcontext()
|
||||
else:
|
||||
autocast_ctx = torch.autocast(accelerator.device.type)
|
||||
if torch.backends.mps.is_available():
|
||||
autocast_ctx = nullcontext()
|
||||
else:
|
||||
autocast_ctx = torch.autocast(accelerator.device.type)
|
||||
|
||||
with autocast_ctx:
|
||||
images = [
|
||||
@@ -1881,6 +1880,7 @@ def main(args):
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
del pipeline
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@@ -1982,7 +1982,7 @@ def main(args):
|
||||
lora_state_dict = load_file(f"{args.output_dir}/pytorch_lora_weights.safetensors")
|
||||
peft_state_dict = convert_all_state_dict_to_peft(lora_state_dict)
|
||||
kohya_state_dict = convert_state_dict_to_kohya(peft_state_dict)
|
||||
save_file(kohya_state_dict, f"{args.output_dir}/{Path(args.output_dir).name}.safetensors")
|
||||
save_file(kohya_state_dict, f"{args.output_dir}/{args.output_dir}.safetensors")
|
||||
|
||||
save_model_card(
|
||||
model_id if not args.push_to_hub else repo_id,
|
||||
|
||||
@@ -573,13 +573,6 @@ def parse_args(input_args=None):
|
||||
default=1e-4,
|
||||
help="Initial learning rate (after the potential warmup period) to use.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--clip_skip",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that "
|
||||
"the output of the pre-final layer will be used for computing the prompt embeddings.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--text_encoder_lr",
|
||||
@@ -1243,7 +1236,7 @@ def tokenize_prompt(tokenizer, prompt, add_special_tokens=False):
|
||||
|
||||
|
||||
# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt
|
||||
def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None, clip_skip=None):
|
||||
def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None):
|
||||
prompt_embeds_list = []
|
||||
|
||||
for i, text_encoder in enumerate(text_encoders):
|
||||
@@ -1260,11 +1253,7 @@ def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None, c
|
||||
|
||||
# We are only ALWAYS interested in the pooled output of the final text encoder
|
||||
pooled_prompt_embeds = prompt_embeds[0]
|
||||
if clip_skip is None:
|
||||
prompt_embeds = prompt_embeds[-1][-2]
|
||||
else:
|
||||
# "2" because SDXL always indexes from the penultimate layer.
|
||||
prompt_embeds = prompt_embeds[-1][-(clip_skip + 2)]
|
||||
prompt_embeds = prompt_embeds[-1][-2]
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
|
||||
prompt_embeds_list.append(prompt_embeds)
|
||||
@@ -1841,9 +1830,9 @@ def main(args):
|
||||
tokenizers = [tokenizer_one, tokenizer_two]
|
||||
text_encoders = [text_encoder_one, text_encoder_two]
|
||||
|
||||
def compute_text_embeddings(prompt, text_encoders, tokenizers, clip_skip):
|
||||
def compute_text_embeddings(prompt, text_encoders, tokenizers):
|
||||
with torch.no_grad():
|
||||
prompt_embeds, pooled_prompt_embeds = encode_prompt(text_encoders, tokenizers, prompt, clip_skip)
|
||||
prompt_embeds, pooled_prompt_embeds = encode_prompt(text_encoders, tokenizers, prompt)
|
||||
prompt_embeds = prompt_embeds.to(accelerator.device)
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)
|
||||
return prompt_embeds, pooled_prompt_embeds
|
||||
@@ -1853,7 +1842,7 @@ def main(args):
|
||||
# the redundant encoding.
|
||||
if freeze_text_encoder and not train_dataset.custom_instance_prompts:
|
||||
instance_prompt_hidden_states, instance_pooled_prompt_embeds = compute_text_embeddings(
|
||||
args.instance_prompt, text_encoders, tokenizers, args.clip_skip
|
||||
args.instance_prompt, text_encoders, tokenizers
|
||||
)
|
||||
|
||||
# Handle class prompt for prior-preservation.
|
||||
@@ -2063,7 +2052,7 @@ def main(args):
|
||||
if train_dataset.custom_instance_prompts:
|
||||
if freeze_text_encoder:
|
||||
prompt_embeds, unet_add_text_embeds = compute_text_embeddings(
|
||||
prompts, text_encoders, tokenizers, args.clip_skip
|
||||
prompts, text_encoders, tokenizers
|
||||
)
|
||||
|
||||
else:
|
||||
@@ -2158,7 +2147,6 @@ def main(args):
|
||||
tokenizers=None,
|
||||
prompt=None,
|
||||
text_input_ids_list=[tokens_one, tokens_two],
|
||||
clip_skip=args.clip_skip,
|
||||
)
|
||||
unet_added_conditions.update(
|
||||
{"text_embeds": pooled_prompt_embeds.repeat(elems_to_repeat_text_embeds, 1)}
|
||||
@@ -2425,7 +2413,7 @@ def main(args):
|
||||
lora_state_dict = load_file(f"{args.output_dir}/pytorch_lora_weights.safetensors")
|
||||
peft_state_dict = convert_all_state_dict_to_peft(lora_state_dict)
|
||||
kohya_state_dict = convert_state_dict_to_kohya(peft_state_dict)
|
||||
save_file(kohya_state_dict, f"{args.output_dir}/{Path(args.output_dir).name}.safetensors")
|
||||
save_file(kohya_state_dict, f"{args.output_dir}/{args.output_dir}.safetensors")
|
||||
|
||||
save_model_card(
|
||||
model_id if not args.push_to_hub else repo_id,
|
||||
|
||||
@@ -70,7 +70,6 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif
|
||||
| Stable Diffusion XL IPEX Pipeline | Accelerate Stable Diffusion XL inference pipeline with BF16/FP32 precision on Intel Xeon CPUs with [IPEX](https://github.com/intel/intel-extension-for-pytorch) | [Stable Diffusion XL on IPEX](#stable-diffusion-xl-on-ipex) | - | [Dan Li](https://github.com/ustcuna/) |
|
||||
| Stable Diffusion BoxDiff Pipeline | Training-free controlled generation with bounding boxes using [BoxDiff](https://github.com/showlab/BoxDiff) | [Stable Diffusion BoxDiff Pipeline](#stable-diffusion-boxdiff) | - | [Jingyang Zhang](https://github.com/zjysteven/) |
|
||||
| FRESCO V2V Pipeline | Implementation of [[CVPR 2024] FRESCO: Spatial-Temporal Correspondence for Zero-Shot Video Translation](https://arxiv.org/abs/2403.12962) | [FRESCO V2V Pipeline](#fresco) | - | [Yifan Zhou](https://github.com/SingleZombie) |
|
||||
| AnimateDiff IPEX Pipeline | Accelerate AnimateDiff inference pipeline with BF16/FP32 precision on Intel Xeon CPUs with [IPEX](https://github.com/intel/intel-extension-for-pytorch) | [AnimateDiff on IPEX](#animatediff-on-ipex) | - | [Dan Li](https://github.com/ustcuna/) |
|
||||
|
||||
To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly.
|
||||
|
||||
@@ -4100,117 +4099,6 @@ output_frames[0].save(output_video_path, save_all=True,
|
||||
append_images=output_frames[1:], duration=100, loop=0)
|
||||
```
|
||||
|
||||
### AnimateDiff on IPEX
|
||||
|
||||
This diffusion pipeline aims to accelerate the inference of AnimateDiff on Intel Xeon CPUs with BF16/FP32 precision using [IPEX](https://github.com/intel/intel-extension-for-pytorch).
|
||||
|
||||
To use this pipeline, you need to:
|
||||
1. Install [IPEX](https://github.com/intel/intel-extension-for-pytorch)
|
||||
|
||||
**Note:** For each PyTorch release, there is a corresponding release of IPEX. Here is the mapping relationship. It is recommended to install Pytorch/IPEX2.3 to get the best performance.
|
||||
|
||||
|PyTorch Version|IPEX Version|
|
||||
|--|--|
|
||||
|[v2.3.\*](https://github.com/pytorch/pytorch/tree/v2.3.0 "v2.3.0")|[v2.3.\*](https://github.com/intel/intel-extension-for-pytorch/tree/v2.3.0+cpu)|
|
||||
|[v1.13.\*](https://github.com/pytorch/pytorch/tree/v1.13.0 "v1.13.0")|[v1.13.\*](https://github.com/intel/intel-extension-for-pytorch/tree/v1.13.100+cpu)|
|
||||
|
||||
You can simply use pip to install IPEX with the latest version.
|
||||
```python
|
||||
python -m pip install intel_extension_for_pytorch
|
||||
```
|
||||
**Note:** To install a specific version, run with the following command:
|
||||
```
|
||||
python -m pip install intel_extension_for_pytorch==<version_name> -f https://developer.intel.com/ipex-whl-stable-cpu
|
||||
```
|
||||
2. After pipeline initialization, `prepare_for_ipex()` should be called to enable IPEX accelaration. Supported inference datatypes are Float32 and BFloat16.
|
||||
|
||||
```python
|
||||
pipe = AnimateDiffPipelineIpex.from_pretrained(base, motion_adapter=adapter, torch_dtype=dtype).to(device)
|
||||
# For Float32
|
||||
pipe.prepare_for_ipex(torch.float32, prompt="A girl smiling")
|
||||
# For BFloat16
|
||||
pipe.prepare_for_ipex(torch.bfloat16, prompt="A girl smiling")
|
||||
```
|
||||
|
||||
Then you can use the ipex pipeline in a similar way to the default animatediff pipeline.
|
||||
```python
|
||||
# For Float32
|
||||
output = pipe(prompt="A girl smiling", guidance_scale=1.0, num_inference_steps=step)
|
||||
# For BFloat16
|
||||
with torch.cpu.amp.autocast(enabled = True, dtype = torch.bfloat16):
|
||||
output = pipe(prompt="A girl smiling", guidance_scale=1.0, num_inference_steps=step)
|
||||
```
|
||||
|
||||
The following code compares the performance of the original animatediff pipeline with the ipex-optimized pipeline.
|
||||
By using this optimized pipeline, we can get about 1.5-2.2 times performance boost with BFloat16 on the fifth generation of Intel Xeon CPUs, code-named Emerald Rapids.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import MotionAdapter, AnimateDiffPipeline, EulerDiscreteScheduler
|
||||
from safetensors.torch import load_file
|
||||
from pipeline_animatediff_ipex import AnimateDiffPipelineIpex
|
||||
import time
|
||||
|
||||
device = "cpu"
|
||||
dtype = torch.float32
|
||||
|
||||
prompt = "A girl smiling"
|
||||
step = 8 # Options: [1,2,4,8]
|
||||
repo = "ByteDance/AnimateDiff-Lightning"
|
||||
ckpt = f"animatediff_lightning_{step}step_diffusers.safetensors"
|
||||
base = "emilianJR/epiCRealism" # Choose to your favorite base model.
|
||||
|
||||
adapter = MotionAdapter().to(device, dtype)
|
||||
adapter.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device))
|
||||
|
||||
# Helper function for time evaluation
|
||||
def elapsed_time(pipeline, nb_pass=3, num_inference_steps=1):
|
||||
# warmup
|
||||
for _ in range(2):
|
||||
output = pipeline(prompt = prompt, guidance_scale=1.0, num_inference_steps = num_inference_steps)
|
||||
#time evaluation
|
||||
start = time.time()
|
||||
for _ in range(nb_pass):
|
||||
pipeline(prompt = prompt, guidance_scale=1.0, num_inference_steps = num_inference_steps)
|
||||
end = time.time()
|
||||
return (end - start) / nb_pass
|
||||
|
||||
############## bf16 inference performance ###############
|
||||
|
||||
# 1. IPEX Pipeline initialization
|
||||
pipe = AnimateDiffPipelineIpex.from_pretrained(base, motion_adapter=adapter, torch_dtype=dtype).to(device)
|
||||
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", beta_schedule="linear")
|
||||
pipe.prepare_for_ipex(torch.bfloat16, prompt = prompt)
|
||||
|
||||
# 2. Original Pipeline initialization
|
||||
pipe2 = AnimateDiffPipeline.from_pretrained(base, motion_adapter=adapter, torch_dtype=dtype).to(device)
|
||||
pipe2.scheduler = EulerDiscreteScheduler.from_config(pipe2.scheduler.config, timestep_spacing="trailing", beta_schedule="linear")
|
||||
|
||||
# 3. Compare performance between Original Pipeline and IPEX Pipeline
|
||||
with torch.cpu.amp.autocast(enabled=True, dtype=torch.bfloat16):
|
||||
latency = elapsed_time(pipe, num_inference_steps=step)
|
||||
print("Latency of AnimateDiffPipelineIpex--bf16", latency, "s for total", step, "steps")
|
||||
latency = elapsed_time(pipe2, num_inference_steps=step)
|
||||
print("Latency of AnimateDiffPipeline--bf16", latency, "s for total", step, "steps")
|
||||
|
||||
############## fp32 inference performance ###############
|
||||
|
||||
# 1. IPEX Pipeline initialization
|
||||
pipe3 = AnimateDiffPipelineIpex.from_pretrained(base, motion_adapter=adapter, torch_dtype=dtype).to(device)
|
||||
pipe3.scheduler = EulerDiscreteScheduler.from_config(pipe3.scheduler.config, timestep_spacing="trailing", beta_schedule="linear")
|
||||
pipe3.prepare_for_ipex(torch.float32, prompt = prompt)
|
||||
|
||||
# 2. Original Pipeline initialization
|
||||
pipe4 = AnimateDiffPipeline.from_pretrained(base, motion_adapter=adapter, torch_dtype=dtype).to(device)
|
||||
pipe4.scheduler = EulerDiscreteScheduler.from_config(pipe4.scheduler.config, timestep_spacing="trailing", beta_schedule="linear")
|
||||
|
||||
# 3. Compare performance between Original Pipeline and IPEX Pipeline
|
||||
latency = elapsed_time(pipe3, num_inference_steps=step)
|
||||
print("Latency of AnimateDiffPipelineIpex--fp32", latency, "s for total", step, "steps")
|
||||
latency = elapsed_time(pipe4, num_inference_steps=step)
|
||||
print("Latency of AnimateDiffPipeline--fp32",latency, "s for total", step, "steps")
|
||||
```
|
||||
|
||||
# Perturbed-Attention Guidance
|
||||
|
||||
[Project](https://ku-cvlab.github.io/Perturbed-Attention-Guidance/) / [arXiv](https://arxiv.org/abs/2403.17377) / [GitHub](https://github.com/KU-CVLAB/Perturbed-Attention-Guidance)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -467,6 +467,8 @@ def make_emblist(self, prompts):
|
||||
|
||||
|
||||
def split_dims(xs, height, width):
|
||||
xs = xs
|
||||
|
||||
def repeat_div(x, y):
|
||||
while y > 0:
|
||||
x = math.ceil(x / 2)
|
||||
|
||||
@@ -1,165 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
import safetensors
|
||||
|
||||
|
||||
sys.path.append("..")
|
||||
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
logger = logging.getLogger()
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
|
||||
class DreamBoothLoRASD3(ExamplesTestsAccelerate):
|
||||
instance_data_dir = "docs/source/en/imgs"
|
||||
instance_prompt = "photo"
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-sd3-pipe"
|
||||
script_path = "examples/dreambooth/train_dreambooth_lora_sd3.py"
|
||||
|
||||
def test_dreambooth_lora_sd3(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
|
||||
--instance_data_dir {self.instance_data_dir}
|
||||
--instance_prompt {self.instance_prompt}
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
|
||||
|
||||
# make sure the state_dict has the correct naming in the parameters.
|
||||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
|
||||
is_lora = all("lora" in k for k in lora_state_dict.keys())
|
||||
self.assertTrue(is_lora)
|
||||
|
||||
# when not training the text encoder, all the parameters in the state dict should start
|
||||
# with `"transformer"` in their names.
|
||||
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
|
||||
self.assertTrue(starts_with_transformer)
|
||||
|
||||
def test_dreambooth_lora_text_encoder_sd3(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
|
||||
--instance_data_dir {self.instance_data_dir}
|
||||
--instance_prompt {self.instance_prompt}
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--train_text_encoder
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
|
||||
|
||||
# make sure the state_dict has the correct naming in the parameters.
|
||||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
|
||||
is_lora = all("lora" in k for k in lora_state_dict.keys())
|
||||
self.assertTrue(is_lora)
|
||||
|
||||
starts_with_expected_prefix = all(
|
||||
(key.startswith("transformer") or key.startswith("text_encoder")) for key in lora_state_dict.keys()
|
||||
)
|
||||
self.assertTrue(starts_with_expected_prefix)
|
||||
|
||||
def test_dreambooth_lora_sd3_checkpointing_checkpoints_total_limit(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
|
||||
--instance_data_dir={self.instance_data_dir}
|
||||
--output_dir={tmpdir}
|
||||
--instance_prompt={self.instance_prompt}
|
||||
--resolution=64
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--max_train_steps=6
|
||||
--checkpoints_total_limit=2
|
||||
--checkpointing_steps=2
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-4", "checkpoint-6"},
|
||||
)
|
||||
|
||||
def test_dreambooth_lora_sd3_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
|
||||
--instance_data_dir={self.instance_data_dir}
|
||||
--output_dir={tmpdir}
|
||||
--instance_prompt={self.instance_prompt}
|
||||
--resolution=64
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--max_train_steps=4
|
||||
--checkpointing_steps=2
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
|
||||
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-2", "checkpoint-4"})
|
||||
|
||||
resume_run_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
|
||||
--instance_data_dir={self.instance_data_dir}
|
||||
--output_dir={tmpdir}
|
||||
--instance_prompt={self.instance_prompt}
|
||||
--resolution=64
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--max_train_steps=8
|
||||
--checkpointing_steps=2
|
||||
--resume_from_checkpoint=checkpoint-4
|
||||
--checkpoints_total_limit=2
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + resume_run_args)
|
||||
|
||||
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
|
||||
@@ -1,203 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
from diffusers import DiffusionPipeline, SD3Transformer2DModel
|
||||
|
||||
|
||||
sys.path.append("..")
|
||||
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
logger = logging.getLogger()
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
|
||||
class DreamBoothSD3(ExamplesTestsAccelerate):
|
||||
instance_data_dir = "docs/source/en/imgs"
|
||||
instance_prompt = "photo"
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-sd3-pipe"
|
||||
script_path = "examples/dreambooth/train_dreambooth_sd3.py"
|
||||
|
||||
def test_dreambooth(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
|
||||
--instance_data_dir {self.instance_data_dir}
|
||||
--instance_prompt {self.instance_prompt}
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "transformer", "diffusion_pytorch_model.safetensors")))
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))
|
||||
|
||||
def test_dreambooth_checkpointing(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Run training script with checkpointing
|
||||
# max_train_steps == 4, checkpointing_steps == 2
|
||||
# Should create checkpoints at steps 2, 4
|
||||
|
||||
initial_run_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
|
||||
--instance_data_dir {self.instance_data_dir}
|
||||
--instance_prompt {self.instance_prompt}
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 4
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
--checkpointing_steps=2
|
||||
--seed=0
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + initial_run_args)
|
||||
|
||||
# check can run the original fully trained output pipeline
|
||||
pipe = DiffusionPipeline.from_pretrained(tmpdir)
|
||||
pipe(self.instance_prompt, num_inference_steps=1)
|
||||
|
||||
# check checkpoint directories exist
|
||||
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-2")))
|
||||
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-4")))
|
||||
|
||||
# check can run an intermediate checkpoint
|
||||
transformer = SD3Transformer2DModel.from_pretrained(tmpdir, subfolder="checkpoint-2/transformer")
|
||||
pipe = DiffusionPipeline.from_pretrained(self.pretrained_model_name_or_path, transformer=transformer)
|
||||
pipe(self.instance_prompt, num_inference_steps=1)
|
||||
|
||||
# Remove checkpoint 2 so that we can check only later checkpoints exist after resuming
|
||||
shutil.rmtree(os.path.join(tmpdir, "checkpoint-2"))
|
||||
|
||||
# Run training script for 7 total steps resuming from checkpoint 4
|
||||
|
||||
resume_run_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
|
||||
--instance_data_dir {self.instance_data_dir}
|
||||
--instance_prompt {self.instance_prompt}
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 6
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
--checkpointing_steps=2
|
||||
--resume_from_checkpoint=checkpoint-4
|
||||
--seed=0
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + resume_run_args)
|
||||
|
||||
# check can run new fully trained pipeline
|
||||
pipe = DiffusionPipeline.from_pretrained(tmpdir)
|
||||
pipe(self.instance_prompt, num_inference_steps=1)
|
||||
|
||||
# check old checkpoints do not exist
|
||||
self.assertFalse(os.path.isdir(os.path.join(tmpdir, "checkpoint-2")))
|
||||
|
||||
# check new checkpoints exist
|
||||
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-4")))
|
||||
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-6")))
|
||||
|
||||
def test_dreambooth_checkpointing_checkpoints_total_limit(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
|
||||
--instance_data_dir={self.instance_data_dir}
|
||||
--output_dir={tmpdir}
|
||||
--instance_prompt={self.instance_prompt}
|
||||
--resolution=64
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--max_train_steps=6
|
||||
--checkpoints_total_limit=2
|
||||
--checkpointing_steps=2
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-4", "checkpoint-6"},
|
||||
)
|
||||
|
||||
def test_dreambooth_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
|
||||
--instance_data_dir={self.instance_data_dir}
|
||||
--output_dir={tmpdir}
|
||||
--instance_prompt={self.instance_prompt}
|
||||
--resolution=64
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--max_train_steps=4
|
||||
--checkpointing_steps=2
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-2", "checkpoint-4"},
|
||||
)
|
||||
|
||||
resume_run_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
|
||||
--instance_data_dir={self.instance_data_dir}
|
||||
--output_dir={tmpdir}
|
||||
--instance_prompt={self.instance_prompt}
|
||||
--resolution=64
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--max_train_steps=8
|
||||
--checkpointing_steps=2
|
||||
--resume_from_checkpoint=checkpoint-4
|
||||
--checkpoints_total_limit=2
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + resume_run_args)
|
||||
|
||||
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
|
||||
@@ -101,37 +101,19 @@ def save_model_card(
|
||||
|
||||
## Model description
|
||||
|
||||
These are {repo_id} DreamBooth LoRA weights for {base_model}.
|
||||
These are {repo_id} DreamBooth weights for {base_model}.
|
||||
|
||||
The weights were trained using [DreamBooth](https://dreambooth.github.io/) with the [SD3 diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_sd3.md).
|
||||
The weights were trained using [DreamBooth](https://dreambooth.github.io/).
|
||||
|
||||
Was LoRA for the text encoder enabled? {train_text_encoder}.
|
||||
LoRA for the text encoder was enabled: {train_text_encoder}.
|
||||
|
||||
## Trigger words
|
||||
|
||||
You should use `{instance_prompt}` to trigger the image generation.
|
||||
You should use {instance_prompt} to trigger the image generation.
|
||||
|
||||
## Download model
|
||||
|
||||
[Download the *.safetensors LoRA]({repo_id}/tree/main) in the Files & versions tab.
|
||||
|
||||
## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers)
|
||||
|
||||
```py
|
||||
from diffusers import AutoPipelineForText2Image
|
||||
import torch
|
||||
pipeline = AutoPipelineForText2Image.from_pretrained('stabilityai/stable-diffusion-3-medium-diffusers', torch_dtype=torch.float16).to('cuda')
|
||||
pipeline.load_lora_weights('{repo_id}', weight_name='pytorch_lora_weights.safetensors')
|
||||
image = pipeline('{validation_prompt if validation_prompt else instance_prompt}').images[0]
|
||||
```
|
||||
|
||||
### Use it with UIs such as AUTOMATIC1111, Comfy UI, SD.Next, Invoke
|
||||
|
||||
- **LoRA**: download **[`diffusers_lora_weights.safetensors` here 💾](/{repo_id}/blob/main/diffusers_lora_weights.safetensors)**.
|
||||
- Rename it and place it on your `models/Lora` folder.
|
||||
- On AUTOMATIC1111, load the LoRA by adding `<lora:your_new_name:1>` to your prompt. On ComfyUI just [load it as a regular LoRA](https://comfyanonymous.github.io/ComfyUI_examples/lora/).
|
||||
|
||||
For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters)
|
||||
[Download]({repo_id}/tree/main) them in the Files & versions tab.
|
||||
|
||||
## License
|
||||
|
||||
@@ -980,7 +962,7 @@ def encode_prompt(
|
||||
prompt=prompt,
|
||||
device=device if device is not None else text_encoder.device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
text_input_ids=text_input_ids_list[i] if text_input_ids_list else None,
|
||||
text_input_ids=text_input_ids_list[i],
|
||||
)
|
||||
clip_prompt_embeds_list.append(prompt_embeds)
|
||||
clip_pooled_prompt_embeds_list.append(pooled_prompt_embeds)
|
||||
@@ -994,7 +976,7 @@ def encode_prompt(
|
||||
max_sequence_length,
|
||||
prompt=prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
text_input_ids=text_input_ids_list[-1] if text_input_ids_list else None,
|
||||
text_input_ids=text_input_ids_list[:-1],
|
||||
device=device if device is not None else text_encoders[-1].device,
|
||||
)
|
||||
|
||||
@@ -1509,9 +1491,6 @@ def main(args):
|
||||
) = accelerator.prepare(
|
||||
transformer, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
assert text_encoder_one is not None
|
||||
assert text_encoder_two is not None
|
||||
assert text_encoder_three is not None
|
||||
else:
|
||||
transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
transformer, optimizer, train_dataloader, lr_scheduler
|
||||
@@ -1619,7 +1598,7 @@ def main(args):
|
||||
tokens_three = tokenize_prompt(tokenizer_three, prompts)
|
||||
prompt_embeds, pooled_prompt_embeds = encode_prompt(
|
||||
text_encoders=[text_encoder_one, text_encoder_two, text_encoder_three],
|
||||
tokenizers=[None, None, None],
|
||||
tokenizers=[None, None, tokenizer_three],
|
||||
prompt=prompts,
|
||||
max_sequence_length=args.max_sequence_length,
|
||||
text_input_ids_list=[tokens_one, tokens_two, tokens_three],
|
||||
@@ -1629,7 +1608,7 @@ def main(args):
|
||||
prompt_embeds, pooled_prompt_embeds = encode_prompt(
|
||||
text_encoders=[text_encoder_one, text_encoder_two, text_encoder_three],
|
||||
tokenizers=[None, None, tokenizer_three],
|
||||
prompt=args.instance_prompt,
|
||||
prompt=prompts,
|
||||
max_sequence_length=args.max_sequence_length,
|
||||
text_input_ids_list=[tokens_one, tokens_two, tokens_three],
|
||||
)
|
||||
@@ -1706,12 +1685,10 @@ def main(args):
|
||||
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients:
|
||||
params_to_clip = (
|
||||
itertools.chain(
|
||||
transformer_lora_parameters, text_lora_parameters_one, text_lora_parameters_two
|
||||
)
|
||||
if args.train_text_encoder
|
||||
else transformer_lora_parameters
|
||||
params_to_clip = itertools.chain(
|
||||
transformer_lora_parameters,
|
||||
text_lora_parameters_one,
|
||||
text_lora_parameters_two if args.train_text_encoder else transformer_lora_parameters,
|
||||
)
|
||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||
|
||||
@@ -1764,6 +1741,13 @@ def main(args):
|
||||
text_encoder_one, text_encoder_two, text_encoder_three = load_text_encoders(
|
||||
text_encoder_cls_one, text_encoder_cls_two, text_encoder_cls_three
|
||||
)
|
||||
else:
|
||||
text_encoder_three = text_encoder_cls_three.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
subfolder="text_encoder_3",
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
)
|
||||
pipeline = StableDiffusion3Pipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
vae=vae,
|
||||
@@ -1783,9 +1767,7 @@ def main(args):
|
||||
pipeline_args=pipeline_args,
|
||||
epoch=epoch,
|
||||
)
|
||||
if not args.train_text_encoder:
|
||||
del text_encoder_one, text_encoder_two, text_encoder_three
|
||||
|
||||
del text_encoder_one, text_encoder_two, text_encoder_three
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
|
||||
@@ -95,22 +95,17 @@ def save_model_card(
|
||||
|
||||
These are {repo_id} DreamBooth weights for {base_model}.
|
||||
|
||||
The weights were trained using [DreamBooth](https://dreambooth.github.io/) with the [SD3 diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_sd3.md).
|
||||
The weights were trained using [DreamBooth](https://dreambooth.github.io/).
|
||||
|
||||
Was the text encoder fine-tuned? {train_text_encoder}.
|
||||
Text encoder was fine-tuned: {train_text_encoder}.
|
||||
|
||||
## Trigger words
|
||||
|
||||
You should use `{instance_prompt}` to trigger the image generation.
|
||||
You should use {instance_prompt} to trigger the image generation.
|
||||
|
||||
## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers)
|
||||
## Download model
|
||||
|
||||
```py
|
||||
from diffusers import AutoPipelineForText2Image
|
||||
import torch
|
||||
pipeline = AutoPipelineForText2Image.from_pretrained('{repo_id}', torch_dtype=torch.float16).to('cuda')
|
||||
image = pipeline('{validation_prompt if validation_prompt else instance_prompt}').images[0]
|
||||
```
|
||||
[Download]({repo_id}/tree/main) them in the Files & versions tab.
|
||||
|
||||
## License
|
||||
|
||||
|
||||
@@ -1,38 +0,0 @@
|
||||
# Running Stable Diffusion 3 DreamBooth LoRA training under 16GB
|
||||
|
||||
This is an **EDUCATIONAL** project that provides utilities for DreamBooth LoRA training for [Stable Diffusion 3 (SD3)](ttps://huggingface.co/papers/2403.03206) under 16GB GPU VRAM. This means you can successfully try out this project using a [free-tier Colab Notebook](https://colab.research.google.com/github/huggingface/diffusers/blob/main/examples/research_projects/sd3_lora_colab/sd3_dreambooth_lora_16gb.ipynb) instance. 🤗
|
||||
|
||||
> [!NOTE]
|
||||
> SD3 is gated, so you need to make sure you agree to [share your contact info](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers) to access the model before using it with Diffusers. Once you have access, you need to log in so your system knows you’re authorized. Use the command below to log in:
|
||||
|
||||
```bash
|
||||
huggingface-cli login
|
||||
```
|
||||
|
||||
This will also allow us to push the trained model parameters to the Hugging Face Hub platform.
|
||||
|
||||
For setup, inference code, and details on how to run the code, please follow the Colab Notebook provided above.
|
||||
|
||||
## How
|
||||
|
||||
We make use of several techniques to make this possible:
|
||||
|
||||
* Compute the embeddings from the instance prompt and serialize them for later reuse. This is implemented in the [`compute_embeddings.py`](./compute_embeddings.py) script. We use an 8bit (as introduced in [`LLM.int8()`](https://arxiv.org/abs/2208.07339)) T5 to reduce memory requirements to ~10.5GB.
|
||||
* In the `train_dreambooth_sd3_lora_miniature.py` script, we make use of:
|
||||
* 8bit Adam for optimization through the `bitsandbytes` library.
|
||||
* Gradient checkpointing and gradient accumulation.
|
||||
* FP16 precision.
|
||||
* Flash attention through `F.scaled_dot_product_attention()`.
|
||||
|
||||
Computing the text embeddings is arguably the most memory-intensive part in the pipeline as SD3 employs three text encoders. If we run them in FP32, it will take about 20GB of VRAM. With FP16, we are down to 12GB.
|
||||
|
||||
|
||||
## Gotchas
|
||||
|
||||
This project is educational. It exists to showcase the possibility of fine-tuning a big diffusion system on consumer GPUs. But additional components might have to be added to obtain state-of-the-art performance. Below are some commonly known gotchas that users should be aware of:
|
||||
|
||||
* Training of text encoders is purposefully disabled.
|
||||
* Techniques such as prior-preservation is unsupported.
|
||||
* Custom instance captions for instance images are unsupported, but this should be relatively easy to integrate.
|
||||
|
||||
Hopefully, this project gives you a template to extend it further to suit your needs.
|
||||
@@ -1,123 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
import hashlib
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
from transformers import T5EncoderModel
|
||||
|
||||
from diffusers import StableDiffusion3Pipeline
|
||||
|
||||
|
||||
PROMPT = "a photo of sks dog"
|
||||
MAX_SEQ_LENGTH = 77
|
||||
LOCAL_DATA_DIR = "dog"
|
||||
OUTPUT_PATH = "sample_embeddings.parquet"
|
||||
|
||||
|
||||
def bytes_to_giga_bytes(bytes):
|
||||
return bytes / 1024 / 1024 / 1024
|
||||
|
||||
|
||||
def generate_image_hash(image_path):
|
||||
with open(image_path, "rb") as f:
|
||||
img_data = f.read()
|
||||
return hashlib.sha256(img_data).hexdigest()
|
||||
|
||||
|
||||
def load_sd3_pipeline():
|
||||
id = "stabilityai/stable-diffusion-3-medium-diffusers"
|
||||
text_encoder = T5EncoderModel.from_pretrained(id, subfolder="text_encoder_3", load_in_8bit=True, device_map="auto")
|
||||
pipeline = StableDiffusion3Pipeline.from_pretrained(
|
||||
id, text_encoder_3=text_encoder, transformer=None, vae=None, device_map="balanced"
|
||||
)
|
||||
return pipeline
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def compute_embeddings(pipeline, prompt, max_sequence_length):
|
||||
(
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
) = pipeline.encode_prompt(prompt=prompt, prompt_2=None, prompt_3=None, max_sequence_length=max_sequence_length)
|
||||
|
||||
print(
|
||||
f"{prompt_embeds.shape=}, {negative_prompt_embeds.shape=}, {pooled_prompt_embeds.shape=}, {negative_pooled_prompt_embeds.shape}"
|
||||
)
|
||||
|
||||
max_memory = bytes_to_giga_bytes(torch.cuda.max_memory_allocated())
|
||||
print(f"Max memory allocated: {max_memory:.3f} GB")
|
||||
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
|
||||
|
||||
|
||||
def run(args):
|
||||
pipeline = load_sd3_pipeline()
|
||||
prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = compute_embeddings(
|
||||
pipeline, args.prompt, args.max_sequence_length
|
||||
)
|
||||
|
||||
# Assumes that the images within `args.local_image_dir` have a JPEG extension. Change
|
||||
# as needed.
|
||||
image_paths = glob.glob(f"{args.local_data_dir}/*.jpeg")
|
||||
data = []
|
||||
for image_path in image_paths:
|
||||
img_hash = generate_image_hash(image_path)
|
||||
data.append(
|
||||
(img_hash, prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds)
|
||||
)
|
||||
|
||||
# Create a DataFrame
|
||||
embedding_cols = [
|
||||
"prompt_embeds",
|
||||
"negative_prompt_embeds",
|
||||
"pooled_prompt_embeds",
|
||||
"negative_pooled_prompt_embeds",
|
||||
]
|
||||
df = pd.DataFrame(
|
||||
data,
|
||||
columns=["image_hash"] + embedding_cols,
|
||||
)
|
||||
|
||||
# Convert embedding lists to arrays (for proper storage in parquet)
|
||||
for col in embedding_cols:
|
||||
df[col] = df[col].apply(lambda x: x.cpu().numpy().flatten().tolist())
|
||||
|
||||
# Save the dataframe to a parquet file
|
||||
df.to_parquet(args.output_path)
|
||||
print(f"Data successfully serialized to {args.output_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--prompt", type=str, default=PROMPT, help="The instance prompt.")
|
||||
parser.add_argument(
|
||||
"--max_sequence_length",
|
||||
type=int,
|
||||
default=MAX_SEQ_LENGTH,
|
||||
help="Maximum sequence length to use for computing the embeddings. The more the higher computational costs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--local_data_dir", type=str, default=LOCAL_DATA_DIR, help="Path to the directory containing instance images."
|
||||
)
|
||||
parser.add_argument("--output_path", type=str, default=OUTPUT_PATH, help="Path to serialize the parquet file.")
|
||||
args = parser.parse_args()
|
||||
|
||||
run(args)
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,11 +0,0 @@
|
||||
# VAE
|
||||
|
||||
`vae_roundtrip.py` Demonstrates the use of a VAE by roundtripping an image through the encoder and decoder. Original and reconstructed images are displayed side by side.
|
||||
|
||||
```
|
||||
cd examples/research_projects/vae
|
||||
python vae_roundtrip.py \
|
||||
--pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5" \
|
||||
--subfolder="vae" \
|
||||
--input_image="/path/to/your/input.png"
|
||||
```
|
||||
@@ -1,282 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
|
||||
import argparse
|
||||
import typing
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torchvision import transforms # type: ignore
|
||||
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.models.autoencoders.autoencoder_kl import (
|
||||
AutoencoderKL,
|
||||
AutoencoderKLOutput,
|
||||
)
|
||||
from diffusers.models.autoencoders.autoencoder_tiny import (
|
||||
AutoencoderTiny,
|
||||
AutoencoderTinyOutput,
|
||||
)
|
||||
from diffusers.models.autoencoders.vae import DecoderOutput
|
||||
|
||||
|
||||
SupportedAutoencoder = Union[AutoencoderKL, AutoencoderTiny]
|
||||
|
||||
|
||||
def load_vae_model(
|
||||
*,
|
||||
device: torch.device,
|
||||
model_name_or_path: str,
|
||||
revision: Optional[str],
|
||||
variant: Optional[str],
|
||||
# NOTE: use subfolder="vae" if the pointed model is for stable diffusion as a whole instead of just the VAE
|
||||
subfolder: Optional[str],
|
||||
use_tiny_nn: bool,
|
||||
) -> SupportedAutoencoder:
|
||||
if use_tiny_nn:
|
||||
# NOTE: These scaling factors don't have to be the same as each other.
|
||||
down_scale = 2
|
||||
up_scale = 2
|
||||
vae = AutoencoderTiny.from_pretrained( # type: ignore
|
||||
model_name_or_path,
|
||||
subfolder=subfolder,
|
||||
revision=revision,
|
||||
variant=variant,
|
||||
downscaling_scaling_factor=down_scale,
|
||||
upsampling_scaling_factor=up_scale,
|
||||
)
|
||||
assert isinstance(vae, AutoencoderTiny)
|
||||
else:
|
||||
vae = AutoencoderKL.from_pretrained( # type: ignore
|
||||
model_name_or_path,
|
||||
subfolder=subfolder,
|
||||
revision=revision,
|
||||
variant=variant,
|
||||
)
|
||||
assert isinstance(vae, AutoencoderKL)
|
||||
vae = vae.to(device)
|
||||
vae.eval() # Set the model to inference mode
|
||||
return vae
|
||||
|
||||
|
||||
def pil_to_nhwc(
|
||||
*,
|
||||
device: torch.device,
|
||||
image: Image.Image,
|
||||
) -> torch.Tensor:
|
||||
assert image.mode == "RGB"
|
||||
transform = transforms.ToTensor()
|
||||
nhwc = transform(image).unsqueeze(0).to(device) # type: ignore
|
||||
assert isinstance(nhwc, torch.Tensor)
|
||||
return nhwc
|
||||
|
||||
|
||||
def nhwc_to_pil(
|
||||
*,
|
||||
nhwc: torch.Tensor,
|
||||
) -> Image.Image:
|
||||
assert nhwc.shape[0] == 1
|
||||
hwc = nhwc.squeeze(0).cpu()
|
||||
return transforms.ToPILImage()(hwc) # type: ignore
|
||||
|
||||
|
||||
def concatenate_images(
|
||||
*,
|
||||
left: Image.Image,
|
||||
right: Image.Image,
|
||||
vertical: bool = False,
|
||||
) -> Image.Image:
|
||||
width1, height1 = left.size
|
||||
width2, height2 = right.size
|
||||
if vertical:
|
||||
total_height = height1 + height2
|
||||
max_width = max(width1, width2)
|
||||
new_image = Image.new("RGB", (max_width, total_height))
|
||||
new_image.paste(left, (0, 0))
|
||||
new_image.paste(right, (0, height1))
|
||||
else:
|
||||
total_width = width1 + width2
|
||||
max_height = max(height1, height2)
|
||||
new_image = Image.new("RGB", (total_width, max_height))
|
||||
new_image.paste(left, (0, 0))
|
||||
new_image.paste(right, (width1, 0))
|
||||
return new_image
|
||||
|
||||
|
||||
def to_latent(
|
||||
*,
|
||||
rgb_nchw: torch.Tensor,
|
||||
vae: SupportedAutoencoder,
|
||||
) -> torch.Tensor:
|
||||
rgb_nchw = VaeImageProcessor.normalize(rgb_nchw) # type: ignore
|
||||
encoding_nchw = vae.encode(typing.cast(torch.FloatTensor, rgb_nchw))
|
||||
if isinstance(encoding_nchw, AutoencoderKLOutput):
|
||||
latent = encoding_nchw.latent_dist.sample() # type: ignore
|
||||
assert isinstance(latent, torch.Tensor)
|
||||
elif isinstance(encoding_nchw, AutoencoderTinyOutput):
|
||||
latent = encoding_nchw.latents
|
||||
do_internal_vae_scaling = False # Is this needed?
|
||||
if do_internal_vae_scaling:
|
||||
latent = vae.scale_latents(latent).mul(255).round().byte() # type: ignore
|
||||
latent = vae.unscale_latents(latent / 255.0) # type: ignore
|
||||
assert isinstance(latent, torch.Tensor)
|
||||
else:
|
||||
assert False, f"Unknown encoding type: {type(encoding_nchw)}"
|
||||
return latent
|
||||
|
||||
|
||||
def from_latent(
|
||||
*,
|
||||
latent_nchw: torch.Tensor,
|
||||
vae: SupportedAutoencoder,
|
||||
) -> torch.Tensor:
|
||||
decoding_nchw = vae.decode(latent_nchw) # type: ignore
|
||||
assert isinstance(decoding_nchw, DecoderOutput)
|
||||
rgb_nchw = VaeImageProcessor.denormalize(decoding_nchw.sample) # type: ignore
|
||||
assert isinstance(rgb_nchw, torch.Tensor)
|
||||
return rgb_nchw
|
||||
|
||||
|
||||
def main_kwargs(
|
||||
*,
|
||||
device: torch.device,
|
||||
input_image_path: str,
|
||||
pretrained_model_name_or_path: str,
|
||||
revision: Optional[str],
|
||||
variant: Optional[str],
|
||||
subfolder: Optional[str],
|
||||
use_tiny_nn: bool,
|
||||
) -> None:
|
||||
vae = load_vae_model(
|
||||
device=device,
|
||||
model_name_or_path=pretrained_model_name_or_path,
|
||||
revision=revision,
|
||||
variant=variant,
|
||||
subfolder=subfolder,
|
||||
use_tiny_nn=use_tiny_nn,
|
||||
)
|
||||
original_pil = Image.open(input_image_path).convert("RGB")
|
||||
original_image = pil_to_nhwc(
|
||||
device=device,
|
||||
image=original_pil,
|
||||
)
|
||||
print(f"Original image shape: {original_image.shape}")
|
||||
reconstructed_image: Optional[torch.Tensor] = None
|
||||
|
||||
with torch.no_grad():
|
||||
latent_image = to_latent(rgb_nchw=original_image, vae=vae)
|
||||
print(f"Latent shape: {latent_image.shape}")
|
||||
reconstructed_image = from_latent(latent_nchw=latent_image, vae=vae)
|
||||
reconstructed_pil = nhwc_to_pil(nhwc=reconstructed_image)
|
||||
combined_image = concatenate_images(
|
||||
left=original_pil,
|
||||
right=reconstructed_pil,
|
||||
vertical=False,
|
||||
)
|
||||
combined_image.show("Original | Reconstruction")
|
||||
print(f"Reconstructed image shape: {reconstructed_image.shape}")
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description="Inference with VAE")
|
||||
parser.add_argument(
|
||||
"--input_image",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the input image for inference.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pretrained_model_name_or_path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to pretrained VAE model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--revision",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Model version.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--variant",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Model file variant, e.g., 'fp16'.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--subfolder",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Subfolder in the model file.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_cuda",
|
||||
action="store_true",
|
||||
help="Use CUDA if available.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_tiny_nn",
|
||||
action="store_true",
|
||||
help="Use tiny neural network.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
# EXAMPLE USAGE:
|
||||
#
|
||||
# python vae_roundtrip.py --use_cuda --pretrained_model_name_or_path "runwayml/stable-diffusion-v1-5" --subfolder "vae" --input_image "foo.png"
|
||||
#
|
||||
# python vae_roundtrip.py --use_cuda --pretrained_model_name_or_path "madebyollin/taesd" --use_tiny_nn --input_image "foo.png"
|
||||
#
|
||||
def main_cli() -> None:
|
||||
args = parse_args()
|
||||
|
||||
input_image_path = args.input_image
|
||||
assert isinstance(input_image_path, str)
|
||||
|
||||
pretrained_model_name_or_path = args.pretrained_model_name_or_path
|
||||
assert isinstance(pretrained_model_name_or_path, str)
|
||||
|
||||
revision = args.revision
|
||||
assert isinstance(revision, (str, type(None)))
|
||||
|
||||
variant = args.variant
|
||||
assert isinstance(variant, (str, type(None)))
|
||||
|
||||
subfolder = args.subfolder
|
||||
assert isinstance(subfolder, (str, type(None)))
|
||||
|
||||
use_cuda = args.use_cuda
|
||||
assert isinstance(use_cuda, bool)
|
||||
|
||||
use_tiny_nn = args.use_tiny_nn
|
||||
assert isinstance(use_tiny_nn, bool)
|
||||
|
||||
device = torch.device("cuda" if use_cuda else "cpu")
|
||||
|
||||
main_kwargs(
|
||||
device=device,
|
||||
input_image_path=input_image_path,
|
||||
pretrained_model_name_or_path=pretrained_model_name_or_path,
|
||||
revision=revision,
|
||||
variant=variant,
|
||||
subfolder=subfolder,
|
||||
use_tiny_nn=use_tiny_nn,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main_cli()
|
||||
@@ -1,95 +0,0 @@
|
||||
import torch
|
||||
from flash_attn_interface import flash_attn_func
|
||||
|
||||
class FA3AttnProcessor:
|
||||
r"""
|
||||
Processor for using Flash Attention 3 (FA3) via `flash-attn`.
|
||||
|
||||
To install `flash-attn` that supports FA3, follow:
|
||||
https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#flashattention-3-beta-release
|
||||
|
||||
Reference: https://tridao.me/blog/2024/flash3/
|
||||
"""
|
||||
def __call__(
|
||||
self,
|
||||
attn,
|
||||
hidden_states,
|
||||
encoder_hidden_states=None,
|
||||
attention_mask=None,
|
||||
temb=None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
|
||||
if attn.spatial_norm is not None:
|
||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
batch_size, key_tokens, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, key_tokens, batch_size)
|
||||
if attention_mask is not None:
|
||||
# expand our mask's singleton query_tokens dimension:
|
||||
# [batch*heads, 1, key_tokens] ->
|
||||
# [batch*heads, query_tokens, key_tokens]
|
||||
# so that it can be added as a bias onto the attention scores that xformers computes:
|
||||
# [batch*heads, query_tokens, key_tokens]
|
||||
# we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
|
||||
_, query_tokens, _ = hidden_states.shape
|
||||
attention_mask = attention_mask.expand(-1, query_tokens, -1)
|
||||
|
||||
if attn.group_norm is not None:
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).contiguous()
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).contiguous()
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).contiguous()
|
||||
|
||||
# nasty hack to make the head number and head dim compatible with FA3.
|
||||
# if attn.heads ==1 and head_dim == 512:
|
||||
# factor = 8
|
||||
# new_head_dim = head_dim // factor
|
||||
# query = query.view(batch_size, -1, factor, new_head_dim)
|
||||
# key = key.view(batch_size, -1, factor, new_head_dim)
|
||||
# value = value.view(batch_size, -1, factor, new_head_dim)
|
||||
hidden_states, _ = flash_attn_func(
|
||||
query, key, value, softmax_scale=attn.scale, causal=False
|
||||
)
|
||||
hidden_states = hidden_states.reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
if attn.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states = hidden_states / attn.rescale_output_factor
|
||||
|
||||
return hidden_states
|
||||
@@ -1,131 +0,0 @@
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from diffusers.models.transformers.auraflow_transformer_2d import AuraFlowTransformer2DModel
|
||||
|
||||
|
||||
def load_original_state_dict(args):
|
||||
model_pt = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename="aura_diffusion_pytorch_model.bin")
|
||||
state_dict = torch.load(model_pt, map_location="cpu")
|
||||
return state_dict
|
||||
|
||||
|
||||
def calculate_layers(state_dict_keys, key_prefix):
|
||||
dit_layers = set()
|
||||
for k in state_dict_keys:
|
||||
if key_prefix in k:
|
||||
dit_layers.add(int(k.split(".")[2]))
|
||||
print(f"{key_prefix}: {len(dit_layers)}")
|
||||
return len(dit_layers)
|
||||
|
||||
|
||||
# similar to SD3 but only for the last norm layer
|
||||
def swap_scale_shift(weight, dim):
|
||||
shift, scale = weight.chunk(2, dim=0)
|
||||
new_weight = torch.cat([scale, shift], dim=0)
|
||||
return new_weight
|
||||
|
||||
|
||||
def convert_transformer(state_dict):
|
||||
converted_state_dict = {}
|
||||
state_dict_keys = list(state_dict.keys())
|
||||
|
||||
converted_state_dict["register_tokens"] = state_dict.pop("model.register_tokens")
|
||||
converted_state_dict["pos_embed.pos_embed"] = state_dict.pop("model.positional_encoding")
|
||||
converted_state_dict["pos_embed.proj.weight"] = state_dict.pop("model.init_x_linear.weight")
|
||||
converted_state_dict["pos_embed.proj.bias"] = state_dict.pop("model.init_x_linear.bias")
|
||||
|
||||
converted_state_dict["time_step_proj.linear_1.weight"] = state_dict.pop("model.t_embedder.mlp.0.weight")
|
||||
converted_state_dict["time_step_proj.linear_1.bias"] = state_dict.pop("model.t_embedder.mlp.0.bias")
|
||||
converted_state_dict["time_step_proj.linear_2.weight"] = state_dict.pop("model.t_embedder.mlp.2.weight")
|
||||
converted_state_dict["time_step_proj.linear_2.bias"] = state_dict.pop("model.t_embedder.mlp.2.bias")
|
||||
|
||||
converted_state_dict["context_embedder.weight"] = state_dict.pop("model.cond_seq_linear.weight")
|
||||
|
||||
mmdit_layers = calculate_layers(state_dict_keys, key_prefix="double_layers")
|
||||
single_dit_layers = calculate_layers(state_dict_keys, key_prefix="single_layers")
|
||||
|
||||
# MMDiT blocks 🎸.
|
||||
for i in range(mmdit_layers):
|
||||
# feed-forward
|
||||
path_mapping = {"mlpX": "ff", "mlpC": "ff_context"}
|
||||
weight_mapping = {"c_fc1": "linear_1", "c_fc2": "linear_2", "c_proj": "out_projection"}
|
||||
for orig_k, diffuser_k in path_mapping.items():
|
||||
for k, v in weight_mapping.items():
|
||||
converted_state_dict[f"joint_transformer_blocks.{i}.{diffuser_k}.{v}.weight"] = state_dict.pop(
|
||||
f"model.double_layers.{i}.{orig_k}.{k}.weight"
|
||||
)
|
||||
|
||||
# norms
|
||||
path_mapping = {"modX": "norm1", "modC": "norm1_context"}
|
||||
for orig_k, diffuser_k in path_mapping.items():
|
||||
converted_state_dict[f"joint_transformer_blocks.{i}.{diffuser_k}.linear.weight"] = state_dict.pop(
|
||||
f"model.double_layers.{i}.{orig_k}.1.weight"
|
||||
)
|
||||
|
||||
# attns
|
||||
x_attn_mapping = {"w2q": "to_q", "w2k": "to_k", "w2v": "to_v", "w2o": "to_out.0"}
|
||||
context_attn_mapping = {"w1q": "add_q_proj", "w1k": "add_k_proj", "w1v": "add_v_proj", "w1o": "to_add_out"}
|
||||
for attn_mapping in [x_attn_mapping, context_attn_mapping]:
|
||||
for k, v in attn_mapping.items():
|
||||
converted_state_dict[f"joint_transformer_blocks.{i}.attn.{v}.weight"] = state_dict.pop(
|
||||
f"model.double_layers.{i}.attn.{k}.weight"
|
||||
)
|
||||
|
||||
# Single-DiT blocks.
|
||||
for i in range(single_dit_layers):
|
||||
# feed-forward
|
||||
mapping = {"c_fc1": "linear_1", "c_fc2": "linear_2", "c_proj": "out_projection"}
|
||||
for k, v in mapping.items():
|
||||
converted_state_dict[f"single_transformer_blocks.{i}.ff.{v}.weight"] = state_dict.pop(
|
||||
f"model.single_layers.{i}.mlp.{k}.weight"
|
||||
)
|
||||
|
||||
# norms
|
||||
converted_state_dict[f"single_transformer_blocks.{i}.norm1.linear.weight"] = state_dict.pop(
|
||||
f"model.single_layers.{i}.modCX.1.weight"
|
||||
)
|
||||
|
||||
# attns
|
||||
x_attn_mapping = {"w1q": "to_q", "w1k": "to_k", "w1v": "to_v", "w1o": "to_out.0"}
|
||||
for k, v in x_attn_mapping.items():
|
||||
converted_state_dict[f"single_transformer_blocks.{i}.attn.{v}.weight"] = state_dict.pop(
|
||||
f"model.single_layers.{i}.attn.{k}.weight"
|
||||
)
|
||||
|
||||
# Final blocks.
|
||||
converted_state_dict["proj_out.weight"] = state_dict.pop("model.final_linear.weight")
|
||||
converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(state_dict.pop("model.modF.1.weight"), dim=None)
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def populate_state_dict(args):
|
||||
original_state_dict = load_original_state_dict(args)
|
||||
state_dict_keys = list(original_state_dict.keys())
|
||||
mmdit_layers = calculate_layers(state_dict_keys, key_prefix="double_layers")
|
||||
single_dit_layers = calculate_layers(state_dict_keys, key_prefix="single_layers")
|
||||
|
||||
converted_state_dict = convert_transformer(original_state_dict)
|
||||
model_diffusers = AuraFlowTransformer2DModel(
|
||||
num_mmdit_layers=mmdit_layers, num_single_dit_layers=single_dit_layers
|
||||
)
|
||||
model_diffusers.load_state_dict(converted_state_dict, strict=True)
|
||||
|
||||
return model_diffusers
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--original_state_dict_repo_id", default="AuraDiffusion/auradiffusion-v0.1a0", type=str)
|
||||
parser.add_argument("--dump_path", default="aura-flow", type=str)
|
||||
parser.add_argument("--hub_id", default=None, type=str)
|
||||
args = parser.parse_args()
|
||||
|
||||
model_diffusers = populate_state_dict(args)
|
||||
model_diffusers.save_pretrained(args.dump_path)
|
||||
if args.hub_id is not None:
|
||||
model_diffusers.push_to_hub(args.hub_id)
|
||||
@@ -1,241 +0,0 @@
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import HunyuanDiT2DControlNetModel
|
||||
|
||||
|
||||
def main(args):
|
||||
state_dict = torch.load(args.pt_checkpoint_path, map_location="cpu")
|
||||
|
||||
if args.load_key != "none":
|
||||
try:
|
||||
state_dict = state_dict[args.load_key]
|
||||
except KeyError:
|
||||
raise KeyError(
|
||||
f"{args.load_key} not found in the checkpoint."
|
||||
"Please load from the following keys:{state_dict.keys()}"
|
||||
)
|
||||
device = "cuda"
|
||||
|
||||
model_config = HunyuanDiT2DControlNetModel.load_config(
|
||||
"Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers", subfolder="transformer"
|
||||
)
|
||||
model_config[
|
||||
"use_style_cond_and_image_meta_size"
|
||||
] = args.use_style_cond_and_image_meta_size ### version <= v1.1: True; version >= v1.2: False
|
||||
print(model_config)
|
||||
|
||||
for key in state_dict:
|
||||
print("local:", key)
|
||||
|
||||
model = HunyuanDiT2DControlNetModel.from_config(model_config).to(device)
|
||||
|
||||
for key in model.state_dict():
|
||||
print("diffusers:", key)
|
||||
|
||||
num_layers = 19
|
||||
for i in range(num_layers):
|
||||
# attn1
|
||||
# Wkqv -> to_q, to_k, to_v
|
||||
q, k, v = torch.chunk(state_dict[f"blocks.{i}.attn1.Wqkv.weight"], 3, dim=0)
|
||||
q_bias, k_bias, v_bias = torch.chunk(state_dict[f"blocks.{i}.attn1.Wqkv.bias"], 3, dim=0)
|
||||
state_dict[f"blocks.{i}.attn1.to_q.weight"] = q
|
||||
state_dict[f"blocks.{i}.attn1.to_q.bias"] = q_bias
|
||||
state_dict[f"blocks.{i}.attn1.to_k.weight"] = k
|
||||
state_dict[f"blocks.{i}.attn1.to_k.bias"] = k_bias
|
||||
state_dict[f"blocks.{i}.attn1.to_v.weight"] = v
|
||||
state_dict[f"blocks.{i}.attn1.to_v.bias"] = v_bias
|
||||
state_dict.pop(f"blocks.{i}.attn1.Wqkv.weight")
|
||||
state_dict.pop(f"blocks.{i}.attn1.Wqkv.bias")
|
||||
|
||||
# q_norm, k_norm -> norm_q, norm_k
|
||||
state_dict[f"blocks.{i}.attn1.norm_q.weight"] = state_dict[f"blocks.{i}.attn1.q_norm.weight"]
|
||||
state_dict[f"blocks.{i}.attn1.norm_q.bias"] = state_dict[f"blocks.{i}.attn1.q_norm.bias"]
|
||||
state_dict[f"blocks.{i}.attn1.norm_k.weight"] = state_dict[f"blocks.{i}.attn1.k_norm.weight"]
|
||||
state_dict[f"blocks.{i}.attn1.norm_k.bias"] = state_dict[f"blocks.{i}.attn1.k_norm.bias"]
|
||||
|
||||
state_dict.pop(f"blocks.{i}.attn1.q_norm.weight")
|
||||
state_dict.pop(f"blocks.{i}.attn1.q_norm.bias")
|
||||
state_dict.pop(f"blocks.{i}.attn1.k_norm.weight")
|
||||
state_dict.pop(f"blocks.{i}.attn1.k_norm.bias")
|
||||
|
||||
# out_proj -> to_out
|
||||
state_dict[f"blocks.{i}.attn1.to_out.0.weight"] = state_dict[f"blocks.{i}.attn1.out_proj.weight"]
|
||||
state_dict[f"blocks.{i}.attn1.to_out.0.bias"] = state_dict[f"blocks.{i}.attn1.out_proj.bias"]
|
||||
state_dict.pop(f"blocks.{i}.attn1.out_proj.weight")
|
||||
state_dict.pop(f"blocks.{i}.attn1.out_proj.bias")
|
||||
|
||||
# attn2
|
||||
# kq_proj -> to_k, to_v
|
||||
k, v = torch.chunk(state_dict[f"blocks.{i}.attn2.kv_proj.weight"], 2, dim=0)
|
||||
k_bias, v_bias = torch.chunk(state_dict[f"blocks.{i}.attn2.kv_proj.bias"], 2, dim=0)
|
||||
state_dict[f"blocks.{i}.attn2.to_k.weight"] = k
|
||||
state_dict[f"blocks.{i}.attn2.to_k.bias"] = k_bias
|
||||
state_dict[f"blocks.{i}.attn2.to_v.weight"] = v
|
||||
state_dict[f"blocks.{i}.attn2.to_v.bias"] = v_bias
|
||||
state_dict.pop(f"blocks.{i}.attn2.kv_proj.weight")
|
||||
state_dict.pop(f"blocks.{i}.attn2.kv_proj.bias")
|
||||
|
||||
# q_proj -> to_q
|
||||
state_dict[f"blocks.{i}.attn2.to_q.weight"] = state_dict[f"blocks.{i}.attn2.q_proj.weight"]
|
||||
state_dict[f"blocks.{i}.attn2.to_q.bias"] = state_dict[f"blocks.{i}.attn2.q_proj.bias"]
|
||||
state_dict.pop(f"blocks.{i}.attn2.q_proj.weight")
|
||||
state_dict.pop(f"blocks.{i}.attn2.q_proj.bias")
|
||||
|
||||
# q_norm, k_norm -> norm_q, norm_k
|
||||
state_dict[f"blocks.{i}.attn2.norm_q.weight"] = state_dict[f"blocks.{i}.attn2.q_norm.weight"]
|
||||
state_dict[f"blocks.{i}.attn2.norm_q.bias"] = state_dict[f"blocks.{i}.attn2.q_norm.bias"]
|
||||
state_dict[f"blocks.{i}.attn2.norm_k.weight"] = state_dict[f"blocks.{i}.attn2.k_norm.weight"]
|
||||
state_dict[f"blocks.{i}.attn2.norm_k.bias"] = state_dict[f"blocks.{i}.attn2.k_norm.bias"]
|
||||
|
||||
state_dict.pop(f"blocks.{i}.attn2.q_norm.weight")
|
||||
state_dict.pop(f"blocks.{i}.attn2.q_norm.bias")
|
||||
state_dict.pop(f"blocks.{i}.attn2.k_norm.weight")
|
||||
state_dict.pop(f"blocks.{i}.attn2.k_norm.bias")
|
||||
|
||||
# out_proj -> to_out
|
||||
state_dict[f"blocks.{i}.attn2.to_out.0.weight"] = state_dict[f"blocks.{i}.attn2.out_proj.weight"]
|
||||
state_dict[f"blocks.{i}.attn2.to_out.0.bias"] = state_dict[f"blocks.{i}.attn2.out_proj.bias"]
|
||||
state_dict.pop(f"blocks.{i}.attn2.out_proj.weight")
|
||||
state_dict.pop(f"blocks.{i}.attn2.out_proj.bias")
|
||||
|
||||
# switch norm 2 and norm 3
|
||||
norm2_weight = state_dict[f"blocks.{i}.norm2.weight"]
|
||||
norm2_bias = state_dict[f"blocks.{i}.norm2.bias"]
|
||||
state_dict[f"blocks.{i}.norm2.weight"] = state_dict[f"blocks.{i}.norm3.weight"]
|
||||
state_dict[f"blocks.{i}.norm2.bias"] = state_dict[f"blocks.{i}.norm3.bias"]
|
||||
state_dict[f"blocks.{i}.norm3.weight"] = norm2_weight
|
||||
state_dict[f"blocks.{i}.norm3.bias"] = norm2_bias
|
||||
|
||||
# norm1 -> norm1.norm
|
||||
# default_modulation.1 -> norm1.linear
|
||||
state_dict[f"blocks.{i}.norm1.norm.weight"] = state_dict[f"blocks.{i}.norm1.weight"]
|
||||
state_dict[f"blocks.{i}.norm1.norm.bias"] = state_dict[f"blocks.{i}.norm1.bias"]
|
||||
state_dict[f"blocks.{i}.norm1.linear.weight"] = state_dict[f"blocks.{i}.default_modulation.1.weight"]
|
||||
state_dict[f"blocks.{i}.norm1.linear.bias"] = state_dict[f"blocks.{i}.default_modulation.1.bias"]
|
||||
state_dict.pop(f"blocks.{i}.norm1.weight")
|
||||
state_dict.pop(f"blocks.{i}.norm1.bias")
|
||||
state_dict.pop(f"blocks.{i}.default_modulation.1.weight")
|
||||
state_dict.pop(f"blocks.{i}.default_modulation.1.bias")
|
||||
|
||||
# mlp.fc1 -> ff.net.0, mlp.fc2 -> ff.net.2
|
||||
state_dict[f"blocks.{i}.ff.net.0.proj.weight"] = state_dict[f"blocks.{i}.mlp.fc1.weight"]
|
||||
state_dict[f"blocks.{i}.ff.net.0.proj.bias"] = state_dict[f"blocks.{i}.mlp.fc1.bias"]
|
||||
state_dict[f"blocks.{i}.ff.net.2.weight"] = state_dict[f"blocks.{i}.mlp.fc2.weight"]
|
||||
state_dict[f"blocks.{i}.ff.net.2.bias"] = state_dict[f"blocks.{i}.mlp.fc2.bias"]
|
||||
state_dict.pop(f"blocks.{i}.mlp.fc1.weight")
|
||||
state_dict.pop(f"blocks.{i}.mlp.fc1.bias")
|
||||
state_dict.pop(f"blocks.{i}.mlp.fc2.weight")
|
||||
state_dict.pop(f"blocks.{i}.mlp.fc2.bias")
|
||||
|
||||
# after_proj_list -> controlnet_blocks
|
||||
state_dict[f"controlnet_blocks.{i}.weight"] = state_dict[f"after_proj_list.{i}.weight"]
|
||||
state_dict[f"controlnet_blocks.{i}.bias"] = state_dict[f"after_proj_list.{i}.bias"]
|
||||
state_dict.pop(f"after_proj_list.{i}.weight")
|
||||
state_dict.pop(f"after_proj_list.{i}.bias")
|
||||
|
||||
# before_proj -> input_block
|
||||
state_dict["input_block.weight"] = state_dict["before_proj.weight"]
|
||||
state_dict["input_block.bias"] = state_dict["before_proj.bias"]
|
||||
state_dict.pop("before_proj.weight")
|
||||
state_dict.pop("before_proj.bias")
|
||||
|
||||
# pooler -> time_extra_emb
|
||||
state_dict["time_extra_emb.pooler.positional_embedding"] = state_dict["pooler.positional_embedding"]
|
||||
state_dict["time_extra_emb.pooler.k_proj.weight"] = state_dict["pooler.k_proj.weight"]
|
||||
state_dict["time_extra_emb.pooler.k_proj.bias"] = state_dict["pooler.k_proj.bias"]
|
||||
state_dict["time_extra_emb.pooler.q_proj.weight"] = state_dict["pooler.q_proj.weight"]
|
||||
state_dict["time_extra_emb.pooler.q_proj.bias"] = state_dict["pooler.q_proj.bias"]
|
||||
state_dict["time_extra_emb.pooler.v_proj.weight"] = state_dict["pooler.v_proj.weight"]
|
||||
state_dict["time_extra_emb.pooler.v_proj.bias"] = state_dict["pooler.v_proj.bias"]
|
||||
state_dict["time_extra_emb.pooler.c_proj.weight"] = state_dict["pooler.c_proj.weight"]
|
||||
state_dict["time_extra_emb.pooler.c_proj.bias"] = state_dict["pooler.c_proj.bias"]
|
||||
state_dict.pop("pooler.k_proj.weight")
|
||||
state_dict.pop("pooler.k_proj.bias")
|
||||
state_dict.pop("pooler.q_proj.weight")
|
||||
state_dict.pop("pooler.q_proj.bias")
|
||||
state_dict.pop("pooler.v_proj.weight")
|
||||
state_dict.pop("pooler.v_proj.bias")
|
||||
state_dict.pop("pooler.c_proj.weight")
|
||||
state_dict.pop("pooler.c_proj.bias")
|
||||
state_dict.pop("pooler.positional_embedding")
|
||||
|
||||
# t_embedder -> time_embedding (`TimestepEmbedding`)
|
||||
state_dict["time_extra_emb.timestep_embedder.linear_1.bias"] = state_dict["t_embedder.mlp.0.bias"]
|
||||
state_dict["time_extra_emb.timestep_embedder.linear_1.weight"] = state_dict["t_embedder.mlp.0.weight"]
|
||||
state_dict["time_extra_emb.timestep_embedder.linear_2.bias"] = state_dict["t_embedder.mlp.2.bias"]
|
||||
state_dict["time_extra_emb.timestep_embedder.linear_2.weight"] = state_dict["t_embedder.mlp.2.weight"]
|
||||
|
||||
state_dict.pop("t_embedder.mlp.0.bias")
|
||||
state_dict.pop("t_embedder.mlp.0.weight")
|
||||
state_dict.pop("t_embedder.mlp.2.bias")
|
||||
state_dict.pop("t_embedder.mlp.2.weight")
|
||||
|
||||
# x_embedder -> pos_embd (`PatchEmbed`)
|
||||
state_dict["pos_embed.proj.weight"] = state_dict["x_embedder.proj.weight"]
|
||||
state_dict["pos_embed.proj.bias"] = state_dict["x_embedder.proj.bias"]
|
||||
state_dict.pop("x_embedder.proj.weight")
|
||||
state_dict.pop("x_embedder.proj.bias")
|
||||
|
||||
# mlp_t5 -> text_embedder
|
||||
state_dict["text_embedder.linear_1.bias"] = state_dict["mlp_t5.0.bias"]
|
||||
state_dict["text_embedder.linear_1.weight"] = state_dict["mlp_t5.0.weight"]
|
||||
state_dict["text_embedder.linear_2.bias"] = state_dict["mlp_t5.2.bias"]
|
||||
state_dict["text_embedder.linear_2.weight"] = state_dict["mlp_t5.2.weight"]
|
||||
state_dict.pop("mlp_t5.0.bias")
|
||||
state_dict.pop("mlp_t5.0.weight")
|
||||
state_dict.pop("mlp_t5.2.bias")
|
||||
state_dict.pop("mlp_t5.2.weight")
|
||||
|
||||
# extra_embedder -> extra_embedder
|
||||
state_dict["time_extra_emb.extra_embedder.linear_1.bias"] = state_dict["extra_embedder.0.bias"]
|
||||
state_dict["time_extra_emb.extra_embedder.linear_1.weight"] = state_dict["extra_embedder.0.weight"]
|
||||
state_dict["time_extra_emb.extra_embedder.linear_2.bias"] = state_dict["extra_embedder.2.bias"]
|
||||
state_dict["time_extra_emb.extra_embedder.linear_2.weight"] = state_dict["extra_embedder.2.weight"]
|
||||
state_dict.pop("extra_embedder.0.bias")
|
||||
state_dict.pop("extra_embedder.0.weight")
|
||||
state_dict.pop("extra_embedder.2.bias")
|
||||
state_dict.pop("extra_embedder.2.weight")
|
||||
|
||||
# style_embedder
|
||||
if model_config["use_style_cond_and_image_meta_size"]:
|
||||
print(state_dict["style_embedder.weight"])
|
||||
print(state_dict["style_embedder.weight"].shape)
|
||||
state_dict["time_extra_emb.style_embedder.weight"] = state_dict["style_embedder.weight"][0:1]
|
||||
state_dict.pop("style_embedder.weight")
|
||||
|
||||
model.load_state_dict(state_dict)
|
||||
|
||||
if args.save:
|
||||
model.save_pretrained(args.output_checkpoint_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--save", default=True, type=bool, required=False, help="Whether to save the converted pipeline or not."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pt_checkpoint_path", default=None, type=str, required=True, help="Path to the .pt pretrained model."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_checkpoint_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=False,
|
||||
help="Path to the output converted diffusers pipeline.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--load_key", default="none", type=str, required=False, help="The key to load from the pretrained .pt file"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_style_cond_and_image_meta_size",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="version <= v1.1: True; version >= v1.2: False",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
@@ -1,267 +0,0 @@
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import HunyuanDiT2DModel
|
||||
|
||||
|
||||
def main(args):
|
||||
state_dict = torch.load(args.pt_checkpoint_path, map_location="cpu")
|
||||
|
||||
if args.load_key != "none":
|
||||
try:
|
||||
state_dict = state_dict[args.load_key]
|
||||
except KeyError:
|
||||
raise KeyError(
|
||||
f"{args.load_key} not found in the checkpoint."
|
||||
f"Please load from the following keys:{state_dict.keys()}"
|
||||
)
|
||||
|
||||
device = "cuda"
|
||||
model_config = HunyuanDiT2DModel.load_config("Tencent-Hunyuan/HunyuanDiT-Diffusers", subfolder="transformer")
|
||||
model_config[
|
||||
"use_style_cond_and_image_meta_size"
|
||||
] = args.use_style_cond_and_image_meta_size ### version <= v1.1: True; version >= v1.2: False
|
||||
|
||||
# input_size -> sample_size, text_dim -> cross_attention_dim
|
||||
for key in state_dict:
|
||||
print("local:", key)
|
||||
|
||||
model = HunyuanDiT2DModel.from_config(model_config).to(device)
|
||||
|
||||
for key in model.state_dict():
|
||||
print("diffusers:", key)
|
||||
|
||||
num_layers = 40
|
||||
for i in range(num_layers):
|
||||
# attn1
|
||||
# Wkqv -> to_q, to_k, to_v
|
||||
q, k, v = torch.chunk(state_dict[f"blocks.{i}.attn1.Wqkv.weight"], 3, dim=0)
|
||||
q_bias, k_bias, v_bias = torch.chunk(state_dict[f"blocks.{i}.attn1.Wqkv.bias"], 3, dim=0)
|
||||
state_dict[f"blocks.{i}.attn1.to_q.weight"] = q
|
||||
state_dict[f"blocks.{i}.attn1.to_q.bias"] = q_bias
|
||||
state_dict[f"blocks.{i}.attn1.to_k.weight"] = k
|
||||
state_dict[f"blocks.{i}.attn1.to_k.bias"] = k_bias
|
||||
state_dict[f"blocks.{i}.attn1.to_v.weight"] = v
|
||||
state_dict[f"blocks.{i}.attn1.to_v.bias"] = v_bias
|
||||
state_dict.pop(f"blocks.{i}.attn1.Wqkv.weight")
|
||||
state_dict.pop(f"blocks.{i}.attn1.Wqkv.bias")
|
||||
|
||||
# q_norm, k_norm -> norm_q, norm_k
|
||||
state_dict[f"blocks.{i}.attn1.norm_q.weight"] = state_dict[f"blocks.{i}.attn1.q_norm.weight"]
|
||||
state_dict[f"blocks.{i}.attn1.norm_q.bias"] = state_dict[f"blocks.{i}.attn1.q_norm.bias"]
|
||||
state_dict[f"blocks.{i}.attn1.norm_k.weight"] = state_dict[f"blocks.{i}.attn1.k_norm.weight"]
|
||||
state_dict[f"blocks.{i}.attn1.norm_k.bias"] = state_dict[f"blocks.{i}.attn1.k_norm.bias"]
|
||||
|
||||
state_dict.pop(f"blocks.{i}.attn1.q_norm.weight")
|
||||
state_dict.pop(f"blocks.{i}.attn1.q_norm.bias")
|
||||
state_dict.pop(f"blocks.{i}.attn1.k_norm.weight")
|
||||
state_dict.pop(f"blocks.{i}.attn1.k_norm.bias")
|
||||
|
||||
# out_proj -> to_out
|
||||
state_dict[f"blocks.{i}.attn1.to_out.0.weight"] = state_dict[f"blocks.{i}.attn1.out_proj.weight"]
|
||||
state_dict[f"blocks.{i}.attn1.to_out.0.bias"] = state_dict[f"blocks.{i}.attn1.out_proj.bias"]
|
||||
state_dict.pop(f"blocks.{i}.attn1.out_proj.weight")
|
||||
state_dict.pop(f"blocks.{i}.attn1.out_proj.bias")
|
||||
|
||||
# attn2
|
||||
# kq_proj -> to_k, to_v
|
||||
k, v = torch.chunk(state_dict[f"blocks.{i}.attn2.kv_proj.weight"], 2, dim=0)
|
||||
k_bias, v_bias = torch.chunk(state_dict[f"blocks.{i}.attn2.kv_proj.bias"], 2, dim=0)
|
||||
state_dict[f"blocks.{i}.attn2.to_k.weight"] = k
|
||||
state_dict[f"blocks.{i}.attn2.to_k.bias"] = k_bias
|
||||
state_dict[f"blocks.{i}.attn2.to_v.weight"] = v
|
||||
state_dict[f"blocks.{i}.attn2.to_v.bias"] = v_bias
|
||||
state_dict.pop(f"blocks.{i}.attn2.kv_proj.weight")
|
||||
state_dict.pop(f"blocks.{i}.attn2.kv_proj.bias")
|
||||
|
||||
# q_proj -> to_q
|
||||
state_dict[f"blocks.{i}.attn2.to_q.weight"] = state_dict[f"blocks.{i}.attn2.q_proj.weight"]
|
||||
state_dict[f"blocks.{i}.attn2.to_q.bias"] = state_dict[f"blocks.{i}.attn2.q_proj.bias"]
|
||||
state_dict.pop(f"blocks.{i}.attn2.q_proj.weight")
|
||||
state_dict.pop(f"blocks.{i}.attn2.q_proj.bias")
|
||||
|
||||
# q_norm, k_norm -> norm_q, norm_k
|
||||
state_dict[f"blocks.{i}.attn2.norm_q.weight"] = state_dict[f"blocks.{i}.attn2.q_norm.weight"]
|
||||
state_dict[f"blocks.{i}.attn2.norm_q.bias"] = state_dict[f"blocks.{i}.attn2.q_norm.bias"]
|
||||
state_dict[f"blocks.{i}.attn2.norm_k.weight"] = state_dict[f"blocks.{i}.attn2.k_norm.weight"]
|
||||
state_dict[f"blocks.{i}.attn2.norm_k.bias"] = state_dict[f"blocks.{i}.attn2.k_norm.bias"]
|
||||
|
||||
state_dict.pop(f"blocks.{i}.attn2.q_norm.weight")
|
||||
state_dict.pop(f"blocks.{i}.attn2.q_norm.bias")
|
||||
state_dict.pop(f"blocks.{i}.attn2.k_norm.weight")
|
||||
state_dict.pop(f"blocks.{i}.attn2.k_norm.bias")
|
||||
|
||||
# out_proj -> to_out
|
||||
state_dict[f"blocks.{i}.attn2.to_out.0.weight"] = state_dict[f"blocks.{i}.attn2.out_proj.weight"]
|
||||
state_dict[f"blocks.{i}.attn2.to_out.0.bias"] = state_dict[f"blocks.{i}.attn2.out_proj.bias"]
|
||||
state_dict.pop(f"blocks.{i}.attn2.out_proj.weight")
|
||||
state_dict.pop(f"blocks.{i}.attn2.out_proj.bias")
|
||||
|
||||
# switch norm 2 and norm 3
|
||||
norm2_weight = state_dict[f"blocks.{i}.norm2.weight"]
|
||||
norm2_bias = state_dict[f"blocks.{i}.norm2.bias"]
|
||||
state_dict[f"blocks.{i}.norm2.weight"] = state_dict[f"blocks.{i}.norm3.weight"]
|
||||
state_dict[f"blocks.{i}.norm2.bias"] = state_dict[f"blocks.{i}.norm3.bias"]
|
||||
state_dict[f"blocks.{i}.norm3.weight"] = norm2_weight
|
||||
state_dict[f"blocks.{i}.norm3.bias"] = norm2_bias
|
||||
|
||||
# norm1 -> norm1.norm
|
||||
# default_modulation.1 -> norm1.linear
|
||||
state_dict[f"blocks.{i}.norm1.norm.weight"] = state_dict[f"blocks.{i}.norm1.weight"]
|
||||
state_dict[f"blocks.{i}.norm1.norm.bias"] = state_dict[f"blocks.{i}.norm1.bias"]
|
||||
state_dict[f"blocks.{i}.norm1.linear.weight"] = state_dict[f"blocks.{i}.default_modulation.1.weight"]
|
||||
state_dict[f"blocks.{i}.norm1.linear.bias"] = state_dict[f"blocks.{i}.default_modulation.1.bias"]
|
||||
state_dict.pop(f"blocks.{i}.norm1.weight")
|
||||
state_dict.pop(f"blocks.{i}.norm1.bias")
|
||||
state_dict.pop(f"blocks.{i}.default_modulation.1.weight")
|
||||
state_dict.pop(f"blocks.{i}.default_modulation.1.bias")
|
||||
|
||||
# mlp.fc1 -> ff.net.0, mlp.fc2 -> ff.net.2
|
||||
state_dict[f"blocks.{i}.ff.net.0.proj.weight"] = state_dict[f"blocks.{i}.mlp.fc1.weight"]
|
||||
state_dict[f"blocks.{i}.ff.net.0.proj.bias"] = state_dict[f"blocks.{i}.mlp.fc1.bias"]
|
||||
state_dict[f"blocks.{i}.ff.net.2.weight"] = state_dict[f"blocks.{i}.mlp.fc2.weight"]
|
||||
state_dict[f"blocks.{i}.ff.net.2.bias"] = state_dict[f"blocks.{i}.mlp.fc2.bias"]
|
||||
state_dict.pop(f"blocks.{i}.mlp.fc1.weight")
|
||||
state_dict.pop(f"blocks.{i}.mlp.fc1.bias")
|
||||
state_dict.pop(f"blocks.{i}.mlp.fc2.weight")
|
||||
state_dict.pop(f"blocks.{i}.mlp.fc2.bias")
|
||||
|
||||
# pooler -> time_extra_emb
|
||||
state_dict["time_extra_emb.pooler.positional_embedding"] = state_dict["pooler.positional_embedding"]
|
||||
state_dict["time_extra_emb.pooler.k_proj.weight"] = state_dict["pooler.k_proj.weight"]
|
||||
state_dict["time_extra_emb.pooler.k_proj.bias"] = state_dict["pooler.k_proj.bias"]
|
||||
state_dict["time_extra_emb.pooler.q_proj.weight"] = state_dict["pooler.q_proj.weight"]
|
||||
state_dict["time_extra_emb.pooler.q_proj.bias"] = state_dict["pooler.q_proj.bias"]
|
||||
state_dict["time_extra_emb.pooler.v_proj.weight"] = state_dict["pooler.v_proj.weight"]
|
||||
state_dict["time_extra_emb.pooler.v_proj.bias"] = state_dict["pooler.v_proj.bias"]
|
||||
state_dict["time_extra_emb.pooler.c_proj.weight"] = state_dict["pooler.c_proj.weight"]
|
||||
state_dict["time_extra_emb.pooler.c_proj.bias"] = state_dict["pooler.c_proj.bias"]
|
||||
state_dict.pop("pooler.k_proj.weight")
|
||||
state_dict.pop("pooler.k_proj.bias")
|
||||
state_dict.pop("pooler.q_proj.weight")
|
||||
state_dict.pop("pooler.q_proj.bias")
|
||||
state_dict.pop("pooler.v_proj.weight")
|
||||
state_dict.pop("pooler.v_proj.bias")
|
||||
state_dict.pop("pooler.c_proj.weight")
|
||||
state_dict.pop("pooler.c_proj.bias")
|
||||
state_dict.pop("pooler.positional_embedding")
|
||||
|
||||
# t_embedder -> time_embedding (`TimestepEmbedding`)
|
||||
state_dict["time_extra_emb.timestep_embedder.linear_1.bias"] = state_dict["t_embedder.mlp.0.bias"]
|
||||
state_dict["time_extra_emb.timestep_embedder.linear_1.weight"] = state_dict["t_embedder.mlp.0.weight"]
|
||||
state_dict["time_extra_emb.timestep_embedder.linear_2.bias"] = state_dict["t_embedder.mlp.2.bias"]
|
||||
state_dict["time_extra_emb.timestep_embedder.linear_2.weight"] = state_dict["t_embedder.mlp.2.weight"]
|
||||
|
||||
state_dict.pop("t_embedder.mlp.0.bias")
|
||||
state_dict.pop("t_embedder.mlp.0.weight")
|
||||
state_dict.pop("t_embedder.mlp.2.bias")
|
||||
state_dict.pop("t_embedder.mlp.2.weight")
|
||||
|
||||
# x_embedder -> pos_embd (`PatchEmbed`)
|
||||
state_dict["pos_embed.proj.weight"] = state_dict["x_embedder.proj.weight"]
|
||||
state_dict["pos_embed.proj.bias"] = state_dict["x_embedder.proj.bias"]
|
||||
state_dict.pop("x_embedder.proj.weight")
|
||||
state_dict.pop("x_embedder.proj.bias")
|
||||
|
||||
# mlp_t5 -> text_embedder
|
||||
state_dict["text_embedder.linear_1.bias"] = state_dict["mlp_t5.0.bias"]
|
||||
state_dict["text_embedder.linear_1.weight"] = state_dict["mlp_t5.0.weight"]
|
||||
state_dict["text_embedder.linear_2.bias"] = state_dict["mlp_t5.2.bias"]
|
||||
state_dict["text_embedder.linear_2.weight"] = state_dict["mlp_t5.2.weight"]
|
||||
state_dict.pop("mlp_t5.0.bias")
|
||||
state_dict.pop("mlp_t5.0.weight")
|
||||
state_dict.pop("mlp_t5.2.bias")
|
||||
state_dict.pop("mlp_t5.2.weight")
|
||||
|
||||
# extra_embedder -> extra_embedder
|
||||
state_dict["time_extra_emb.extra_embedder.linear_1.bias"] = state_dict["extra_embedder.0.bias"]
|
||||
state_dict["time_extra_emb.extra_embedder.linear_1.weight"] = state_dict["extra_embedder.0.weight"]
|
||||
state_dict["time_extra_emb.extra_embedder.linear_2.bias"] = state_dict["extra_embedder.2.bias"]
|
||||
state_dict["time_extra_emb.extra_embedder.linear_2.weight"] = state_dict["extra_embedder.2.weight"]
|
||||
state_dict.pop("extra_embedder.0.bias")
|
||||
state_dict.pop("extra_embedder.0.weight")
|
||||
state_dict.pop("extra_embedder.2.bias")
|
||||
state_dict.pop("extra_embedder.2.weight")
|
||||
|
||||
# model.final_adaLN_modulation.1 -> norm_out.linear
|
||||
def swap_scale_shift(weight):
|
||||
shift, scale = weight.chunk(2, dim=0)
|
||||
new_weight = torch.cat([scale, shift], dim=0)
|
||||
return new_weight
|
||||
|
||||
state_dict["norm_out.linear.weight"] = swap_scale_shift(state_dict["final_layer.adaLN_modulation.1.weight"])
|
||||
state_dict["norm_out.linear.bias"] = swap_scale_shift(state_dict["final_layer.adaLN_modulation.1.bias"])
|
||||
state_dict.pop("final_layer.adaLN_modulation.1.weight")
|
||||
state_dict.pop("final_layer.adaLN_modulation.1.bias")
|
||||
|
||||
# final_linear -> proj_out
|
||||
state_dict["proj_out.weight"] = state_dict["final_layer.linear.weight"]
|
||||
state_dict["proj_out.bias"] = state_dict["final_layer.linear.bias"]
|
||||
state_dict.pop("final_layer.linear.weight")
|
||||
state_dict.pop("final_layer.linear.bias")
|
||||
|
||||
# style_embedder
|
||||
if model_config["use_style_cond_and_image_meta_size"]:
|
||||
print(state_dict["style_embedder.weight"])
|
||||
print(state_dict["style_embedder.weight"].shape)
|
||||
state_dict["time_extra_emb.style_embedder.weight"] = state_dict["style_embedder.weight"][0:1]
|
||||
state_dict.pop("style_embedder.weight")
|
||||
|
||||
model.load_state_dict(state_dict)
|
||||
|
||||
from diffusers import HunyuanDiTPipeline
|
||||
|
||||
if args.use_style_cond_and_image_meta_size:
|
||||
pipe = HunyuanDiTPipeline.from_pretrained(
|
||||
"Tencent-Hunyuan/HunyuanDiT-Diffusers", transformer=model, torch_dtype=torch.float32
|
||||
)
|
||||
else:
|
||||
pipe = HunyuanDiTPipeline.from_pretrained(
|
||||
"Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers", transformer=model, torch_dtype=torch.float32
|
||||
)
|
||||
pipe.to("cuda")
|
||||
pipe.to(dtype=torch.float32)
|
||||
|
||||
if args.save:
|
||||
pipe.save_pretrained(args.output_checkpoint_path)
|
||||
|
||||
# ### NOTE: HunyuanDiT supports both Chinese and English inputs
|
||||
prompt = "一个宇航员在骑马"
|
||||
# prompt = "An astronaut riding a horse"
|
||||
generator = torch.Generator(device="cuda").manual_seed(0)
|
||||
image = pipe(
|
||||
height=1024, width=1024, prompt=prompt, generator=generator, num_inference_steps=25, guidance_scale=5.0
|
||||
).images[0]
|
||||
|
||||
image.save("img.png")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--save", default=True, type=bool, required=False, help="Whether to save the converted pipeline or not."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pt_checkpoint_path", default=None, type=str, required=True, help="Path to the .pt pretrained model."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_checkpoint_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=False,
|
||||
help="Path to the output converted diffusers pipeline.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--load_key", default="none", type=str, required=False, help="The key to load from the pretrained .pt file"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_style_cond_and_image_meta_size",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="version <= v1.1: True; version >= v1.2: False",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
@@ -1,142 +0,0 @@
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
|
||||
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, LuminaNextDiT2DModel, LuminaText2ImgPipeline
|
||||
|
||||
|
||||
def main(args):
|
||||
# checkpoint from https://huggingface.co/Alpha-VLLM/Lumina-Next-SFT or https://huggingface.co/Alpha-VLLM/Lumina-Next-T2I
|
||||
all_sd = load_file(args.origin_ckpt_path, device="cpu")
|
||||
converted_state_dict = {}
|
||||
# pad token
|
||||
converted_state_dict["pad_token"] = all_sd["pad_token"]
|
||||
|
||||
# patch embed
|
||||
converted_state_dict["patch_embedder.weight"] = all_sd["x_embedder.weight"]
|
||||
converted_state_dict["patch_embedder.bias"] = all_sd["x_embedder.bias"]
|
||||
|
||||
# time and caption embed
|
||||
converted_state_dict["time_caption_embed.timestep_embedder.linear_1.weight"] = all_sd["t_embedder.mlp.0.weight"]
|
||||
converted_state_dict["time_caption_embed.timestep_embedder.linear_1.bias"] = all_sd["t_embedder.mlp.0.bias"]
|
||||
converted_state_dict["time_caption_embed.timestep_embedder.linear_2.weight"] = all_sd["t_embedder.mlp.2.weight"]
|
||||
converted_state_dict["time_caption_embed.timestep_embedder.linear_2.bias"] = all_sd["t_embedder.mlp.2.bias"]
|
||||
converted_state_dict["time_caption_embed.caption_embedder.0.weight"] = all_sd["cap_embedder.0.weight"]
|
||||
converted_state_dict["time_caption_embed.caption_embedder.0.bias"] = all_sd["cap_embedder.0.bias"]
|
||||
converted_state_dict["time_caption_embed.caption_embedder.1.weight"] = all_sd["cap_embedder.1.weight"]
|
||||
converted_state_dict["time_caption_embed.caption_embedder.1.bias"] = all_sd["cap_embedder.1.bias"]
|
||||
|
||||
for i in range(24):
|
||||
# adaln
|
||||
converted_state_dict[f"layers.{i}.gate"] = all_sd[f"layers.{i}.attention.gate"]
|
||||
converted_state_dict[f"layers.{i}.adaLN_modulation.1.weight"] = all_sd[f"layers.{i}.adaLN_modulation.1.weight"]
|
||||
converted_state_dict[f"layers.{i}.adaLN_modulation.1.bias"] = all_sd[f"layers.{i}.adaLN_modulation.1.bias"]
|
||||
|
||||
# qkv
|
||||
converted_state_dict[f"layers.{i}.attn1.to_q.weight"] = all_sd[f"layers.{i}.attention.wq.weight"]
|
||||
converted_state_dict[f"layers.{i}.attn1.to_k.weight"] = all_sd[f"layers.{i}.attention.wk.weight"]
|
||||
converted_state_dict[f"layers.{i}.attn1.to_v.weight"] = all_sd[f"layers.{i}.attention.wv.weight"]
|
||||
|
||||
# cap
|
||||
converted_state_dict[f"layers.{i}.attn2.to_q.weight"] = all_sd[f"layers.{i}.attention.wq.weight"]
|
||||
converted_state_dict[f"layers.{i}.attn2.to_k.weight"] = all_sd[f"layers.{i}.attention.wk_y.weight"]
|
||||
converted_state_dict[f"layers.{i}.attn2.to_v.weight"] = all_sd[f"layers.{i}.attention.wv_y.weight"]
|
||||
|
||||
# output
|
||||
converted_state_dict[f"layers.{i}.attn2.to_out.0.weight"] = all_sd[f"layers.{i}.attention.wo.weight"]
|
||||
|
||||
# attention
|
||||
# qk norm
|
||||
converted_state_dict[f"layers.{i}.attn1.norm_q.weight"] = all_sd[f"layers.{i}.attention.q_norm.weight"]
|
||||
converted_state_dict[f"layers.{i}.attn1.norm_q.bias"] = all_sd[f"layers.{i}.attention.q_norm.bias"]
|
||||
|
||||
converted_state_dict[f"layers.{i}.attn1.norm_k.weight"] = all_sd[f"layers.{i}.attention.k_norm.weight"]
|
||||
converted_state_dict[f"layers.{i}.attn1.norm_k.bias"] = all_sd[f"layers.{i}.attention.k_norm.bias"]
|
||||
|
||||
converted_state_dict[f"layers.{i}.attn2.norm_q.weight"] = all_sd[f"layers.{i}.attention.q_norm.weight"]
|
||||
converted_state_dict[f"layers.{i}.attn2.norm_q.bias"] = all_sd[f"layers.{i}.attention.q_norm.bias"]
|
||||
|
||||
converted_state_dict[f"layers.{i}.attn2.norm_k.weight"] = all_sd[f"layers.{i}.attention.ky_norm.weight"]
|
||||
converted_state_dict[f"layers.{i}.attn2.norm_k.bias"] = all_sd[f"layers.{i}.attention.ky_norm.bias"]
|
||||
|
||||
# attention norm
|
||||
converted_state_dict[f"layers.{i}.attn_norm1.weight"] = all_sd[f"layers.{i}.attention_norm1.weight"]
|
||||
converted_state_dict[f"layers.{i}.attn_norm2.weight"] = all_sd[f"layers.{i}.attention_norm2.weight"]
|
||||
converted_state_dict[f"layers.{i}.norm1_context.weight"] = all_sd[f"layers.{i}.attention_y_norm.weight"]
|
||||
|
||||
# feed forward
|
||||
converted_state_dict[f"layers.{i}.feed_forward.linear_1.weight"] = all_sd[f"layers.{i}.feed_forward.w1.weight"]
|
||||
converted_state_dict[f"layers.{i}.feed_forward.linear_2.weight"] = all_sd[f"layers.{i}.feed_forward.w2.weight"]
|
||||
converted_state_dict[f"layers.{i}.feed_forward.linear_3.weight"] = all_sd[f"layers.{i}.feed_forward.w3.weight"]
|
||||
|
||||
# feed forward norm
|
||||
converted_state_dict[f"layers.{i}.ffn_norm1.weight"] = all_sd[f"layers.{i}.ffn_norm1.weight"]
|
||||
converted_state_dict[f"layers.{i}.ffn_norm2.weight"] = all_sd[f"layers.{i}.ffn_norm2.weight"]
|
||||
|
||||
# final layer
|
||||
converted_state_dict["final_layer.linear.weight"] = all_sd["final_layer.linear.weight"]
|
||||
converted_state_dict["final_layer.linear.bias"] = all_sd["final_layer.linear.bias"]
|
||||
|
||||
converted_state_dict["final_layer.adaLN_modulation.1.weight"] = all_sd["final_layer.adaLN_modulation.1.weight"]
|
||||
converted_state_dict["final_layer.adaLN_modulation.1.bias"] = all_sd["final_layer.adaLN_modulation.1.bias"]
|
||||
|
||||
# Lumina-Next-SFT 2B
|
||||
transformer = LuminaNextDiT2DModel(
|
||||
sample_size=128,
|
||||
patch_size=2,
|
||||
in_channels=4,
|
||||
hidden_size=2304,
|
||||
num_layers=24,
|
||||
num_attention_heads=32,
|
||||
num_kv_heads=8,
|
||||
multiple_of=256,
|
||||
ffn_dim_multiplier=None,
|
||||
norm_eps=1e-5,
|
||||
learn_sigma=True,
|
||||
qk_norm=True,
|
||||
cross_attention_dim=2048,
|
||||
scaling_factor=1.0,
|
||||
)
|
||||
transformer.load_state_dict(converted_state_dict, strict=True)
|
||||
|
||||
num_model_params = sum(p.numel() for p in transformer.parameters())
|
||||
print(f"Total number of transformer parameters: {num_model_params}")
|
||||
|
||||
if args.only_transformer:
|
||||
transformer.save_pretrained(os.path.join(args.dump_path, "transformer"))
|
||||
else:
|
||||
scheduler = FlowMatchEulerDiscreteScheduler()
|
||||
|
||||
vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae", torch_dtype=torch.float32)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
|
||||
text_encoder = AutoModel.from_pretrained("google/gemma-2b")
|
||||
|
||||
pipeline = LuminaText2ImgPipeline(
|
||||
tokenizer=tokenizer, text_encoder=text_encoder, transformer=transformer, vae=vae, scheduler=scheduler
|
||||
)
|
||||
pipeline.save_pretrained(args.dump_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--origin_ckpt_path", default=None, type=str, required=False, help="Path to the checkpoint to convert."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--image_size",
|
||||
default=1024,
|
||||
type=int,
|
||||
choices=[256, 512, 1024],
|
||||
required=False,
|
||||
help="Image size of pretrained model, either 512 or 1024.",
|
||||
)
|
||||
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.")
|
||||
parser.add_argument("--only_transformer", default=True, type=bool, required=True)
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
@@ -76,7 +76,6 @@ else:
|
||||
_import_structure["models"].extend(
|
||||
[
|
||||
"AsymmetricAutoencoderKL",
|
||||
"AuraFlowTransformer2DModel",
|
||||
"AutoencoderKL",
|
||||
"AutoencoderKLTemporalDecoder",
|
||||
"AutoencoderTiny",
|
||||
@@ -89,8 +88,6 @@ else:
|
||||
"HunyuanDiT2DMultiControlNetModel",
|
||||
"I2VGenXLUNet",
|
||||
"Kandinsky3UNet",
|
||||
"LatteTransformer3DModel",
|
||||
"LuminaNextDiT2DModel",
|
||||
"ModelMixin",
|
||||
"MotionAdapter",
|
||||
"MultiAdapter",
|
||||
@@ -165,7 +162,6 @@ else:
|
||||
"EulerAncestralDiscreteScheduler",
|
||||
"EulerDiscreteScheduler",
|
||||
"FlowMatchEulerDiscreteScheduler",
|
||||
"FlowMatchHeunDiscreteScheduler",
|
||||
"HeunDiscreteScheduler",
|
||||
"IPNDMScheduler",
|
||||
"KarrasVeScheduler",
|
||||
@@ -236,11 +232,8 @@ else:
|
||||
"AudioLDM2ProjectionModel",
|
||||
"AudioLDM2UNet2DConditionModel",
|
||||
"AudioLDMPipeline",
|
||||
"AuraFlowPipeline",
|
||||
"BlipDiffusionControlNetPipeline",
|
||||
"BlipDiffusionPipeline",
|
||||
"ChatGLMModel",
|
||||
"ChatGLMTokenizer",
|
||||
"CLIPImageProjection",
|
||||
"CycleDiffusionPipeline",
|
||||
"HunyuanDiTControlNetPipeline",
|
||||
@@ -272,15 +265,11 @@ else:
|
||||
"KandinskyV22Pipeline",
|
||||
"KandinskyV22PriorEmb2EmbPipeline",
|
||||
"KandinskyV22PriorPipeline",
|
||||
"KolorsImg2ImgPipeline",
|
||||
"KolorsPipeline",
|
||||
"LatentConsistencyModelImg2ImgPipeline",
|
||||
"LatentConsistencyModelPipeline",
|
||||
"LattePipeline",
|
||||
"LDMTextToImagePipeline",
|
||||
"LEditsPPPipelineStableDiffusion",
|
||||
"LEditsPPPipelineStableDiffusionXL",
|
||||
"LuminaText2ImgPipeline",
|
||||
"MarigoldDepthPipeline",
|
||||
"MarigoldNormalsPipeline",
|
||||
"MusicLDMPipeline",
|
||||
@@ -296,13 +285,11 @@ else:
|
||||
"StableCascadePriorPipeline",
|
||||
"StableDiffusion3ControlNetPipeline",
|
||||
"StableDiffusion3Img2ImgPipeline",
|
||||
"StableDiffusion3InpaintPipeline",
|
||||
"StableDiffusion3Pipeline",
|
||||
"StableDiffusionAdapterPipeline",
|
||||
"StableDiffusionAttendAndExcitePipeline",
|
||||
"StableDiffusionControlNetImg2ImgPipeline",
|
||||
"StableDiffusionControlNetInpaintPipeline",
|
||||
"StableDiffusionControlNetPAGPipeline",
|
||||
"StableDiffusionControlNetPipeline",
|
||||
"StableDiffusionControlNetXSPipeline",
|
||||
"StableDiffusionDepth2ImgPipeline",
|
||||
@@ -510,7 +497,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
else:
|
||||
from .models import (
|
||||
AsymmetricAutoencoderKL,
|
||||
AuraFlowTransformer2DModel,
|
||||
AutoencoderKL,
|
||||
AutoencoderKLTemporalDecoder,
|
||||
AutoencoderTiny,
|
||||
@@ -523,8 +509,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
HunyuanDiT2DMultiControlNetModel,
|
||||
I2VGenXLUNet,
|
||||
Kandinsky3UNet,
|
||||
LatteTransformer3DModel,
|
||||
LuminaNextDiT2DModel,
|
||||
ModelMixin,
|
||||
MotionAdapter,
|
||||
MultiAdapter,
|
||||
@@ -596,7 +580,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
EulerAncestralDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
FlowMatchHeunDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
IPNDMScheduler,
|
||||
KarrasVeScheduler,
|
||||
@@ -650,9 +633,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AudioLDM2ProjectionModel,
|
||||
AudioLDM2UNet2DConditionModel,
|
||||
AudioLDMPipeline,
|
||||
AuraFlowPipeline,
|
||||
ChatGLMModel,
|
||||
ChatGLMTokenizer,
|
||||
CLIPImageProjection,
|
||||
CycleDiffusionPipeline,
|
||||
HunyuanDiTControlNetPipeline,
|
||||
@@ -684,15 +664,11 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
KandinskyV22Pipeline,
|
||||
KandinskyV22PriorEmb2EmbPipeline,
|
||||
KandinskyV22PriorPipeline,
|
||||
KolorsImg2ImgPipeline,
|
||||
KolorsPipeline,
|
||||
LatentConsistencyModelImg2ImgPipeline,
|
||||
LatentConsistencyModelPipeline,
|
||||
LattePipeline,
|
||||
LDMTextToImagePipeline,
|
||||
LEditsPPPipelineStableDiffusion,
|
||||
LEditsPPPipelineStableDiffusionXL,
|
||||
LuminaText2ImgPipeline,
|
||||
MarigoldDepthPipeline,
|
||||
MarigoldNormalsPipeline,
|
||||
MusicLDMPipeline,
|
||||
@@ -708,13 +684,11 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
StableCascadePriorPipeline,
|
||||
StableDiffusion3ControlNetPipeline,
|
||||
StableDiffusion3Img2ImgPipeline,
|
||||
StableDiffusion3InpaintPipeline,
|
||||
StableDiffusion3Pipeline,
|
||||
StableDiffusionAdapterPipeline,
|
||||
StableDiffusionAttendAndExcitePipeline,
|
||||
StableDiffusionControlNetImg2ImgPipeline,
|
||||
StableDiffusionControlNetInpaintPipeline,
|
||||
StableDiffusionControlNetPAGPipeline,
|
||||
StableDiffusionControlNetPipeline,
|
||||
StableDiffusionControlNetXSPipeline,
|
||||
StableDiffusionDepth2ImgPipeline,
|
||||
|
||||
@@ -555,4 +555,7 @@ class FromSingleFileMixin:
|
||||
|
||||
pipe = pipeline_class(**init_kwargs)
|
||||
|
||||
if torch_dtype is not None:
|
||||
pipe.to(dtype=torch_dtype)
|
||||
|
||||
return pipe
|
||||
|
||||
@@ -22,7 +22,6 @@ from huggingface_hub.utils import validate_hf_hub_args
|
||||
from ..utils import deprecate, is_accelerate_available, logging
|
||||
from .single_file_utils import (
|
||||
SingleFileComponentError,
|
||||
convert_animatediff_checkpoint_to_diffusers,
|
||||
convert_controlnet_checkpoint,
|
||||
convert_ldm_unet_checkpoint,
|
||||
convert_ldm_vae_checkpoint,
|
||||
@@ -71,9 +70,6 @@ SINGLE_FILE_LOADABLE_CLASSES = {
|
||||
"checkpoint_mapping_fn": convert_sd3_transformer_checkpoint_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
"MotionAdapter": {
|
||||
"checkpoint_mapping_fn": convert_animatediff_checkpoint_to_diffusers,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -74,9 +74,6 @@ CHECKPOINT_KEY_NAMES = {
|
||||
"stable_cascade_stage_b": "down_blocks.1.0.channelwise.0.weight",
|
||||
"stable_cascade_stage_c": "clip_txt_mapper.weight",
|
||||
"sd3": "model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.bias",
|
||||
"animatediff": "down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.1.pos_encoder.pe",
|
||||
"animatediff_v2": "mid_block.motion_modules.0.temporal_transformer.norm.bias",
|
||||
"animatediff_sdxl_beta": "up_blocks.2.motion_modules.0.temporal_transformer.norm.weight",
|
||||
}
|
||||
|
||||
DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
||||
@@ -106,10 +103,6 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
||||
"sd3": {
|
||||
"pretrained_model_name_or_path": "stabilityai/stable-diffusion-3-medium-diffusers",
|
||||
},
|
||||
"animatediff_v1": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5"},
|
||||
"animatediff_v2": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5-2"},
|
||||
"animatediff_v3": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5-3"},
|
||||
"animatediff_sdxl_beta": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-sdxl-beta"},
|
||||
}
|
||||
|
||||
# Use to configure model sample size when original config is provided
|
||||
@@ -492,19 +485,6 @@ def infer_diffusers_model_type(checkpoint):
|
||||
elif CHECKPOINT_KEY_NAMES["sd3"] in checkpoint:
|
||||
model_type = "sd3"
|
||||
|
||||
elif CHECKPOINT_KEY_NAMES["animatediff"] in checkpoint:
|
||||
if CHECKPOINT_KEY_NAMES["animatediff_v2"] in checkpoint:
|
||||
model_type = "animatediff_v2"
|
||||
|
||||
elif checkpoint[CHECKPOINT_KEY_NAMES["animatediff_sdxl_beta"]].shape[-1] == 320:
|
||||
model_type = "animatediff_sdxl_beta"
|
||||
|
||||
elif checkpoint[CHECKPOINT_KEY_NAMES["animatediff"]].shape[1] == 24:
|
||||
model_type = "animatediff_v1"
|
||||
|
||||
else:
|
||||
model_type = "animatediff_v3"
|
||||
|
||||
else:
|
||||
model_type = "v1"
|
||||
|
||||
@@ -1828,36 +1808,4 @@ def create_diffusers_t5_model_from_checkpoint(
|
||||
|
||||
else:
|
||||
model.load_state_dict(diffusers_format_checkpoint)
|
||||
|
||||
use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (torch_dtype == torch.float16)
|
||||
if use_keep_in_fp32_modules:
|
||||
keep_in_fp32_modules = model._keep_in_fp32_modules
|
||||
else:
|
||||
keep_in_fp32_modules = []
|
||||
|
||||
if keep_in_fp32_modules is not None:
|
||||
for name, param in model.named_parameters():
|
||||
if any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules):
|
||||
# param = param.to(torch.float32) does not work here as only in the local scope.
|
||||
param.data = param.data.to(torch.float32)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def convert_animatediff_checkpoint_to_diffusers(checkpoint, **kwargs):
|
||||
converted_state_dict = {}
|
||||
for k, v in checkpoint.items():
|
||||
if "pos_encoder" in k:
|
||||
continue
|
||||
|
||||
else:
|
||||
converted_state_dict[
|
||||
k.replace(".norms.0", ".norm1")
|
||||
.replace(".norms.1", ".norm2")
|
||||
.replace(".ff_norm", ".norm3")
|
||||
.replace(".attention_blocks.0", ".attn1")
|
||||
.replace(".attention_blocks.1", ".attn2")
|
||||
.replace(".temporal_transformer", "")
|
||||
] = v
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
@@ -38,12 +38,9 @@ if is_torch_available():
|
||||
_import_structure["controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"]
|
||||
_import_structure["embeddings"] = ["ImageProjection"]
|
||||
_import_structure["modeling_utils"] = ["ModelMixin"]
|
||||
_import_structure["transformers.auraflow_transformer_2d"] = ["AuraFlowTransformer2DModel"]
|
||||
_import_structure["transformers.dit_transformer_2d"] = ["DiTTransformer2DModel"]
|
||||
_import_structure["transformers.dual_transformer_2d"] = ["DualTransformer2DModel"]
|
||||
_import_structure["transformers.hunyuan_transformer_2d"] = ["HunyuanDiT2DModel"]
|
||||
_import_structure["transformers.latte_transformer_3d"] = ["LatteTransformer3DModel"]
|
||||
_import_structure["transformers.lumina_nextdit2d"] = ["LuminaNextDiT2DModel"]
|
||||
_import_structure["transformers.pixart_transformer_2d"] = ["PixArtTransformer2DModel"]
|
||||
_import_structure["transformers.prior_transformer"] = ["PriorTransformer"]
|
||||
_import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"]
|
||||
@@ -85,12 +82,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .embeddings import ImageProjection
|
||||
from .modeling_utils import ModelMixin
|
||||
from .transformers import (
|
||||
AuraFlowTransformer2DModel,
|
||||
DiTTransformer2DModel,
|
||||
DualTransformer2DModel,
|
||||
HunyuanDiT2DModel,
|
||||
LatteTransformer3DModel,
|
||||
LuminaNextDiT2DModel,
|
||||
PixArtTransformer2DModel,
|
||||
PriorTransformer,
|
||||
SD3Transformer2DModel,
|
||||
|
||||
@@ -19,7 +19,7 @@ from torch import nn
|
||||
|
||||
from ..utils import deprecate, logging
|
||||
from ..utils.torch_utils import maybe_allow_in_graph
|
||||
from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU
|
||||
from .activations import GEGLU, GELU, ApproximateGELU
|
||||
from .attention_processor import Attention, JointAttnProcessor2_0
|
||||
from .embeddings import SinusoidalPositionalEmbedding
|
||||
from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm
|
||||
@@ -128,9 +128,9 @@ class JointTransformerBlock(nn.Module):
|
||||
query_dim=dim,
|
||||
cross_attention_dim=None,
|
||||
added_kv_proj_dim=dim,
|
||||
dim_head=attention_head_dim,
|
||||
dim_head=attention_head_dim // num_attention_heads,
|
||||
heads=num_attention_heads,
|
||||
out_dim=dim,
|
||||
out_dim=attention_head_dim,
|
||||
context_pre_only=context_pre_only,
|
||||
bias=True,
|
||||
processor=processor,
|
||||
@@ -359,10 +359,7 @@ class BasicTransformerBlock(nn.Module):
|
||||
out_bias=attention_out_bias,
|
||||
) # is self-attn if encoder_hidden_states is none
|
||||
else:
|
||||
if norm_type == "ada_norm_single": # For Latte
|
||||
self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
|
||||
else:
|
||||
self.norm2 = None
|
||||
self.norm2 = None
|
||||
self.attn2 = None
|
||||
|
||||
# 3. Feed-forward
|
||||
@@ -442,6 +439,7 @@ class BasicTransformerBlock(nn.Module):
|
||||
).chunk(6, dim=1)
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
|
||||
norm_hidden_states = norm_hidden_states.squeeze(1)
|
||||
else:
|
||||
raise ValueError("Incorrect norm used")
|
||||
|
||||
@@ -458,7 +456,6 @@ class BasicTransformerBlock(nn.Module):
|
||||
attention_mask=attention_mask,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
|
||||
if self.norm_type == "ada_norm_zero":
|
||||
attn_output = gate_msa.unsqueeze(1) * attn_output
|
||||
elif self.norm_type == "ada_norm_single":
|
||||
@@ -530,56 +527,6 @@ class BasicTransformerBlock(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LuminaFeedForward(nn.Module):
|
||||
r"""
|
||||
A feed-forward layer.
|
||||
|
||||
Parameters:
|
||||
hidden_size (`int`):
|
||||
The dimensionality of the hidden layers in the model. This parameter determines the width of the model's
|
||||
hidden representations.
|
||||
intermediate_size (`int`): The intermediate dimension of the feedforward layer.
|
||||
multiple_of (`int`, *optional*): Value to ensure hidden dimension is a multiple
|
||||
of this value.
|
||||
ffn_dim_multiplier (float, *optional*): Custom multiplier for hidden
|
||||
dimension. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
inner_dim: int,
|
||||
multiple_of: Optional[int] = 256,
|
||||
ffn_dim_multiplier: Optional[float] = None,
|
||||
):
|
||||
super().__init__()
|
||||
inner_dim = int(2 * inner_dim / 3)
|
||||
# custom hidden_size factor multiplier
|
||||
if ffn_dim_multiplier is not None:
|
||||
inner_dim = int(ffn_dim_multiplier * inner_dim)
|
||||
inner_dim = multiple_of * ((inner_dim + multiple_of - 1) // multiple_of)
|
||||
|
||||
self.linear_1 = nn.Linear(
|
||||
dim,
|
||||
inner_dim,
|
||||
bias=False,
|
||||
)
|
||||
self.linear_2 = nn.Linear(
|
||||
inner_dim,
|
||||
dim,
|
||||
bias=False,
|
||||
)
|
||||
self.linear_3 = nn.Linear(
|
||||
dim,
|
||||
inner_dim,
|
||||
bias=False,
|
||||
)
|
||||
self.silu = FP32SiLU()
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear_2(self.silu(self.linear_1(x)) * self.linear_3(x))
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class TemporalBasicTransformerBlock(nn.Module):
|
||||
r"""
|
||||
|
||||
@@ -22,7 +22,7 @@ from torch import nn
|
||||
from ..image_processor import IPAdapterMaskProcessor
|
||||
from ..utils import deprecate, logging
|
||||
from ..utils.import_utils import is_torch_npu_available, is_xformers_available
|
||||
from ..utils.torch_utils import is_torch_version, maybe_allow_in_graph
|
||||
from ..utils.torch_utils import maybe_allow_in_graph
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -94,7 +94,6 @@ class Attention(nn.Module):
|
||||
query_dim: int,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
heads: int = 8,
|
||||
kv_heads: Optional[int] = None,
|
||||
dim_head: int = 64,
|
||||
dropout: float = 0.0,
|
||||
bias: bool = False,
|
||||
@@ -104,7 +103,6 @@ class Attention(nn.Module):
|
||||
cross_attention_norm_num_groups: int = 32,
|
||||
qk_norm: Optional[str] = None,
|
||||
added_kv_proj_dim: Optional[int] = None,
|
||||
added_proj_bias: Optional[bool] = True,
|
||||
norm_num_groups: Optional[int] = None,
|
||||
spatial_norm_dim: Optional[int] = None,
|
||||
out_bias: bool = True,
|
||||
@@ -119,12 +117,7 @@ class Attention(nn.Module):
|
||||
context_pre_only=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# To prevent circular import.
|
||||
from .normalization import FP32LayerNorm
|
||||
|
||||
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
|
||||
self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads
|
||||
self.query_dim = query_dim
|
||||
self.use_bias = bias
|
||||
self.is_cross_attention = cross_attention_dim is not None
|
||||
@@ -175,13 +168,6 @@ class Attention(nn.Module):
|
||||
elif qk_norm == "layer_norm":
|
||||
self.norm_q = nn.LayerNorm(dim_head, eps=eps)
|
||||
self.norm_k = nn.LayerNorm(dim_head, eps=eps)
|
||||
elif qk_norm == "fp32_layer_norm":
|
||||
self.norm_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
|
||||
self.norm_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
|
||||
elif qk_norm == "layer_norm_across_heads":
|
||||
# Lumina applys qk norm across all heads
|
||||
self.norm_q = nn.LayerNorm(dim_head * heads, eps=eps)
|
||||
self.norm_k = nn.LayerNorm(dim_head * kv_heads, eps=eps)
|
||||
else:
|
||||
raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None or 'layer_norm'")
|
||||
|
||||
@@ -212,17 +198,17 @@ class Attention(nn.Module):
|
||||
|
||||
if not self.only_cross_attention:
|
||||
# only relevant for the `AddedKVProcessor` classes
|
||||
self.to_k = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
|
||||
self.to_v = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
|
||||
self.to_k = nn.Linear(self.cross_attention_dim, self.inner_dim, bias=bias)
|
||||
self.to_v = nn.Linear(self.cross_attention_dim, self.inner_dim, bias=bias)
|
||||
else:
|
||||
self.to_k = None
|
||||
self.to_v = None
|
||||
|
||||
if self.added_kv_proj_dim is not None:
|
||||
self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias)
|
||||
self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias)
|
||||
self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_dim)
|
||||
self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim)
|
||||
if self.context_pre_only is not None:
|
||||
self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
|
||||
self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim)
|
||||
|
||||
self.to_out = nn.ModuleList([])
|
||||
self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
|
||||
@@ -231,14 +217,6 @@ class Attention(nn.Module):
|
||||
if self.context_pre_only is not None and not self.context_pre_only:
|
||||
self.to_add_out = nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)
|
||||
|
||||
if qk_norm is not None and added_kv_proj_dim is not None:
|
||||
if qk_norm == "fp32_layer_norm":
|
||||
self.norm_added_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
|
||||
self.norm_added_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
|
||||
else:
|
||||
self.norm_added_q = None
|
||||
self.norm_added_k = None
|
||||
|
||||
# set attention processor
|
||||
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses
|
||||
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
|
||||
@@ -1128,7 +1106,9 @@ class FusedJointAttnProcessor2_0:
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
|
||||
hidden_states = hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
@@ -1153,100 +1133,6 @@ class FusedJointAttnProcessor2_0:
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class AuraFlowAttnProcessor2_0:
|
||||
"""Attention processor used typically in processing Aura Flow."""
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention") and is_torch_version("<", "2.1"):
|
||||
raise ImportError(
|
||||
"AuraFlowAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to at least 2.1 or above as we use `scale` in `F.scaled_dot_product_attention()`. "
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor = None,
|
||||
i=0,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
batch_size = hidden_states.shape[0]
|
||||
|
||||
# `sample` projections.
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
|
||||
# `context` projections.
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
|
||||
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
||||
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
||||
|
||||
# Reshape.
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim)
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim)
|
||||
|
||||
# Apply QK norm.
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# Concatenate the projections.
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
)
|
||||
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(batch_size, -1, attn.heads, head_dim)
|
||||
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
)
|
||||
|
||||
if attn.norm_added_q is not None:
|
||||
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
|
||||
if attn.norm_added_k is not None:
|
||||
encoder_hidden_states_key_proj = attn.norm_added_q(encoder_hidden_states_key_proj)
|
||||
|
||||
query = torch.cat([encoder_hidden_states_query_proj, query], dim=1)
|
||||
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
|
||||
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
|
||||
|
||||
query = query.transpose(1, 2)
|
||||
key = key.transpose(1, 2)
|
||||
value = value.transpose(1, 2)
|
||||
|
||||
# Attention.
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, dropout_p=0.0, scale=attn.scale, is_causal=False
|
||||
)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
# Split the attention outputs.
|
||||
if encoder_hidden_states is not None:
|
||||
hidden_states, encoder_hidden_states = (
|
||||
hidden_states[:, encoder_hidden_states.shape[1] :],
|
||||
hidden_states[:, : encoder_hidden_states.shape[1]],
|
||||
)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
return hidden_states, encoder_hidden_states
|
||||
else:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class XFormersAttnAddedKVProcessor:
|
||||
r"""
|
||||
Processor for implementing memory efficient attention using xFormers.
|
||||
@@ -1708,102 +1594,6 @@ class HunyuanAttnProcessor2_0:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LuminaAttnProcessor2_0:
|
||||
r"""
|
||||
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
|
||||
used in the LuminaNextDiT model. It applies a s normalization layer and rotary embedding on query and key vector.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
query_rotary_emb: Optional[torch.Tensor] = None,
|
||||
key_rotary_emb: Optional[torch.Tensor] = None,
|
||||
base_sequence_length: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
from .embeddings import apply_rotary_emb
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
batch_size, sequence_length, _ = hidden_states.shape
|
||||
|
||||
# Get Query-Key-Value Pair
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
query_dim = query.shape[-1]
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = query_dim // attn.heads
|
||||
dtype = query.dtype
|
||||
|
||||
# Get key-value heads
|
||||
kv_heads = inner_dim // head_dim
|
||||
|
||||
# Apply Query-Key Norm if needed
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim)
|
||||
|
||||
key = key.view(batch_size, -1, kv_heads, head_dim)
|
||||
value = value.view(batch_size, -1, kv_heads, head_dim)
|
||||
|
||||
# Apply RoPE if needed
|
||||
if query_rotary_emb is not None:
|
||||
query = apply_rotary_emb(query, query_rotary_emb, use_real=False)
|
||||
if key_rotary_emb is not None:
|
||||
key = apply_rotary_emb(key, key_rotary_emb, use_real=False)
|
||||
|
||||
query, key = query.to(dtype), key.to(dtype)
|
||||
|
||||
# Apply proportional attention if true
|
||||
if key_rotary_emb is None:
|
||||
softmax_scale = None
|
||||
else:
|
||||
if base_sequence_length is not None:
|
||||
softmax_scale = math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale
|
||||
else:
|
||||
softmax_scale = attn.scale
|
||||
|
||||
# perform Grouped-qurey Attention (GQA)
|
||||
n_rep = attn.heads // kv_heads
|
||||
if n_rep >= 1:
|
||||
key = key.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
||||
value = value.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
||||
|
||||
# scaled_dot_product_attention expects attention_mask shape to be
|
||||
# (batch, heads, source_length, target_length)
|
||||
attention_mask = attention_mask.bool().view(batch_size, 1, 1, -1)
|
||||
attention_mask = attention_mask.expand(-1, attn.heads, sequence_length, -1)
|
||||
|
||||
query = query.transpose(1, 2)
|
||||
key = key.transpose(1, 2)
|
||||
value = value.transpose(1, 2)
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, scale=softmax_scale
|
||||
)
|
||||
hidden_states = hidden_states.transpose(1, 2).to(dtype)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FusedAttnProcessor2_0:
|
||||
r"""
|
||||
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). It uses
|
||||
@@ -2985,26 +2775,6 @@ class PAGCFGIdentitySelfAttnProcessor2_0:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LoRAAttnProcessor:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
class LoRAAttnProcessor2_0:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
class LoRAXFormersAttnProcessor:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
class LoRAAttnAddedKVProcessor:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
ADDED_KV_ATTENTION_PROCESSORS = (
|
||||
AttnAddedKVProcessor,
|
||||
SlicedAttnAddedKVProcessor,
|
||||
|
||||
@@ -360,8 +360,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
for j in range(0, x.shape[3], overlap_size):
|
||||
tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
|
||||
tile = self.encoder(tile)
|
||||
if self.config.use_quant_conv:
|
||||
tile = self.quant_conv(tile)
|
||||
tile = self.quant_conv(tile)
|
||||
row.append(tile)
|
||||
rows.append(row)
|
||||
result_rows = []
|
||||
@@ -410,8 +409,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
row = []
|
||||
for j in range(0, z.shape[3], overlap_size):
|
||||
tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
|
||||
if self.config.use_post_quant_conv:
|
||||
tile = self.post_quant_conv(tile)
|
||||
tile = self.post_quant_conv(tile)
|
||||
decoded = self.decoder(tile)
|
||||
row.append(decoded)
|
||||
rows.append(row)
|
||||
|
||||
@@ -57,7 +57,6 @@ class HunyuanDiT2DControlNetModel(ModelMixin, ConfigMixin):
|
||||
pooled_projection_dim: int = 1024,
|
||||
text_len: int = 77,
|
||||
text_len_t5: int = 256,
|
||||
use_style_cond_and_image_meta_size: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = num_attention_heads
|
||||
@@ -88,7 +87,6 @@ class HunyuanDiT2DControlNetModel(ModelMixin, ConfigMixin):
|
||||
pooled_projection_dim=pooled_projection_dim,
|
||||
seq_len=text_len_t5,
|
||||
cross_attention_dim=cross_attention_dim_t5,
|
||||
use_style_cond_and_image_meta_size=use_style_cond_and_image_meta_size,
|
||||
)
|
||||
|
||||
# controlnet_blocks
|
||||
|
||||
@@ -81,7 +81,7 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
|
||||
JointTransformerBlock(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=self.config.attention_head_dim,
|
||||
attention_head_dim=self.inner_dim,
|
||||
context_pre_only=False,
|
||||
)
|
||||
for i in range(num_layers)
|
||||
@@ -239,16 +239,16 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
@classmethod
|
||||
def from_transformer(cls, transformer, num_layers=12, load_weights_from_transformer=True):
|
||||
def from_transformer(cls, transformer, num_layers=None, load_weights_from_transformer=True):
|
||||
config = transformer.config
|
||||
config["num_layers"] = num_layers or config.num_layers
|
||||
controlnet = cls(**config)
|
||||
|
||||
if load_weights_from_transformer:
|
||||
controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict())
|
||||
controlnet.time_text_embed.load_state_dict(transformer.time_text_embed.state_dict())
|
||||
controlnet.context_embedder.load_state_dict(transformer.context_embedder.state_dict())
|
||||
controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict(), strict=False)
|
||||
controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict(), strict=False)
|
||||
controlnet.time_text_embed.load_state_dict(transformer.time_text_embed.state_dict(), strict=False)
|
||||
controlnet.context_embedder.load_state_dict(transformer.context_embedder.state_dict(), strict=False)
|
||||
controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict())
|
||||
|
||||
controlnet.pos_embed_input = zero_module(controlnet.pos_embed_input)
|
||||
|
||||
@@ -308,6 +308,8 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
|
||||
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
height, width = hidden_states.shape[-2:]
|
||||
|
||||
hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too.
|
||||
temb = self.time_text_embed(timestep, pooled_projections)
|
||||
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
||||
|
||||
@@ -35,21 +35,10 @@ def get_timestep_embedding(
|
||||
"""
|
||||
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
|
||||
|
||||
Args
|
||||
timesteps (torch.Tensor):
|
||||
a 1-D Tensor of N indices, one per batch element. These may be fractional.
|
||||
embedding_dim (int):
|
||||
the dimension of the output.
|
||||
flip_sin_to_cos (bool):
|
||||
Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False)
|
||||
downscale_freq_shift (float):
|
||||
Controls the delta between frequencies between dimensions
|
||||
scale (float):
|
||||
Scaling factor applied to the embeddings.
|
||||
max_period (int):
|
||||
Controls the maximum frequency of the embeddings
|
||||
Returns
|
||||
torch.Tensor: an [N x dim] Tensor of positional embeddings.
|
||||
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
:param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
|
||||
embeddings. :return: an [N x dim] Tensor of positional embeddings.
|
||||
"""
|
||||
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
|
||||
|
||||
@@ -241,52 +230,6 @@ class PatchEmbed(nn.Module):
|
||||
return (latent + pos_embed).to(latent.dtype)
|
||||
|
||||
|
||||
class LuminaPatchEmbed(nn.Module):
|
||||
"""2D Image to Patch Embedding with support for Lumina-T2X"""
|
||||
|
||||
def __init__(self, patch_size=2, in_channels=4, embed_dim=768, bias=True):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
self.proj = nn.Linear(
|
||||
in_features=patch_size * patch_size * in_channels,
|
||||
out_features=embed_dim,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
def forward(self, x, freqs_cis):
|
||||
"""
|
||||
Patchifies and embeds the input tensor(s).
|
||||
|
||||
Args:
|
||||
x (List[torch.Tensor] | torch.Tensor): The input tensor(s) to be patchified and embedded.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], torch.Tensor]: A tuple containing the patchified
|
||||
and embedded tensor(s), the mask indicating the valid patches, the original image size(s), and the
|
||||
frequency tensor(s).
|
||||
"""
|
||||
freqs_cis = freqs_cis.to(x[0].device)
|
||||
patch_height = patch_width = self.patch_size
|
||||
batch_size, channel, height, width = x.size()
|
||||
height_tokens, width_tokens = height // patch_height, width // patch_width
|
||||
|
||||
x = x.view(batch_size, channel, height_tokens, patch_height, width_tokens, patch_width).permute(
|
||||
0, 2, 4, 1, 3, 5
|
||||
)
|
||||
x = x.flatten(3)
|
||||
x = self.proj(x)
|
||||
x = x.flatten(1, 2)
|
||||
|
||||
mask = torch.ones(x.shape[0], x.shape[1], dtype=torch.int32, device=x.device)
|
||||
|
||||
return (
|
||||
x,
|
||||
mask,
|
||||
[(height, width)] * batch_size,
|
||||
freqs_cis[:height_tokens, :width_tokens].flatten(0, 1).unsqueeze(0),
|
||||
)
|
||||
|
||||
|
||||
def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True):
|
||||
"""
|
||||
RoPE for image tokens with 2d structure.
|
||||
@@ -331,25 +274,7 @@ def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False):
|
||||
return emb
|
||||
|
||||
|
||||
def get_2d_rotary_pos_embed_lumina(embed_dim, len_h, len_w, linear_factor=1.0, ntk_factor=1.0):
|
||||
assert embed_dim % 4 == 0
|
||||
|
||||
emb_h = get_1d_rotary_pos_embed(
|
||||
embed_dim // 2, len_h, linear_factor=linear_factor, ntk_factor=ntk_factor
|
||||
) # (H, D/4)
|
||||
emb_w = get_1d_rotary_pos_embed(
|
||||
embed_dim // 2, len_w, linear_factor=linear_factor, ntk_factor=ntk_factor
|
||||
) # (W, D/4)
|
||||
emb_h = emb_h.view(len_h, 1, embed_dim // 4, 1).repeat(1, len_w, 1, 1) # (H, W, D/4, 1)
|
||||
emb_w = emb_w.view(1, len_w, embed_dim // 4, 1).repeat(len_h, 1, 1, 1) # (H, W, D/4, 1)
|
||||
|
||||
emb = torch.cat([emb_h, emb_w], dim=-1).flatten(2) # (H, W, D/2)
|
||||
return emb
|
||||
|
||||
|
||||
def get_1d_rotary_pos_embed(
|
||||
dim: int, pos: Union[np.ndarray, int], theta: float = 10000.0, use_real=False, linear_factor=1.0, ntk_factor=1.0
|
||||
):
|
||||
def get_1d_rotary_pos_embed(dim: int, pos: Union[np.ndarray, int], theta: float = 10000.0, use_real=False):
|
||||
"""
|
||||
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
|
||||
|
||||
@@ -364,17 +289,13 @@ def get_1d_rotary_pos_embed(
|
||||
Scaling factor for frequency computation. Defaults to 10000.0.
|
||||
use_real (`bool`, *optional*):
|
||||
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
|
||||
linear_factor (`float`, *optional*, defaults to 1.0):
|
||||
Scaling factor for the context extrapolation. Defaults to 1.0.
|
||||
ntk_factor (`float`, *optional*, defaults to 1.0):
|
||||
Scaling factor for the NTK-Aware RoPE. Defaults to 1.0.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
|
||||
"""
|
||||
if isinstance(pos, int):
|
||||
pos = np.arange(pos)
|
||||
theta = theta * ntk_factor
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) / linear_factor # [D/2]
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [D/2]
|
||||
t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S]
|
||||
freqs = torch.outer(t, freqs).float() # type: ignore # [S, D/2]
|
||||
if use_real:
|
||||
@@ -389,7 +310,6 @@ def get_1d_rotary_pos_embed(
|
||||
def apply_rotary_emb(
|
||||
x: torch.Tensor,
|
||||
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
|
||||
use_real: bool = True,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
|
||||
@@ -405,23 +325,16 @@ def apply_rotary_emb(
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
|
||||
"""
|
||||
if use_real:
|
||||
cos, sin = freqs_cis # [S, D]
|
||||
cos = cos[None, None]
|
||||
sin = sin[None, None]
|
||||
cos, sin = cos.to(x.device), sin.to(x.device)
|
||||
cos, sin = freqs_cis # [S, D]
|
||||
cos = cos[None, None]
|
||||
sin = sin[None, None]
|
||||
cos, sin = cos.to(x.device), sin.to(x.device)
|
||||
|
||||
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
|
||||
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
||||
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
|
||||
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
|
||||
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
||||
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
|
||||
|
||||
return out
|
||||
else:
|
||||
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
||||
freqs_cis = freqs_cis.unsqueeze(2)
|
||||
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
|
||||
|
||||
return x_out.type_as(x)
|
||||
return out
|
||||
|
||||
|
||||
class TimestepEmbedding(nn.Module):
|
||||
@@ -473,12 +386,11 @@ class TimestepEmbedding(nn.Module):
|
||||
|
||||
|
||||
class Timesteps(nn.Module):
|
||||
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1):
|
||||
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
|
||||
super().__init__()
|
||||
self.num_channels = num_channels
|
||||
self.flip_sin_to_cos = flip_sin_to_cos
|
||||
self.downscale_freq_shift = downscale_freq_shift
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, timesteps):
|
||||
t_emb = get_timestep_embedding(
|
||||
@@ -486,7 +398,6 @@ class Timesteps(nn.Module):
|
||||
self.num_channels,
|
||||
flip_sin_to_cos=self.flip_sin_to_cos,
|
||||
downscale_freq_shift=self.downscale_freq_shift,
|
||||
scale=self.scale,
|
||||
)
|
||||
return t_emb
|
||||
|
||||
@@ -504,10 +415,9 @@ class GaussianFourierProjection(nn.Module):
|
||||
|
||||
if set_W_to_weight:
|
||||
# to delete later
|
||||
del self.weight
|
||||
self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
|
||||
|
||||
self.weight = self.W
|
||||
del self.W
|
||||
|
||||
def forward(self, x):
|
||||
if self.log:
|
||||
@@ -820,8 +730,6 @@ class HunyuanCombinedTimestepTextSizeStyleEmbedding(nn.Module):
|
||||
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
||||
|
||||
self.size_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
|
||||
self.pooler = HunyuanDiTAttentionPool(
|
||||
seq_len, cross_attention_dim, num_heads=8, output_dim=pooled_projection_dim
|
||||
)
|
||||
@@ -850,7 +758,7 @@ class HunyuanCombinedTimestepTextSizeStyleEmbedding(nn.Module):
|
||||
|
||||
if self.use_style_cond_and_image_meta_size:
|
||||
# extra condition2: image meta size embdding
|
||||
image_meta_size = self.size_proj(image_meta_size.view(-1))
|
||||
image_meta_size = get_timestep_embedding(image_meta_size.view(-1), 256, True, 0)
|
||||
image_meta_size = image_meta_size.to(dtype=hidden_dtype)
|
||||
image_meta_size = image_meta_size.view(-1, 6 * 256) # (N, 1536)
|
||||
|
||||
@@ -867,40 +775,6 @@ class HunyuanCombinedTimestepTextSizeStyleEmbedding(nn.Module):
|
||||
return conditioning
|
||||
|
||||
|
||||
class LuminaCombinedTimestepCaptionEmbedding(nn.Module):
|
||||
def __init__(self, hidden_size=4096, cross_attention_dim=2048, frequency_embedding_size=256):
|
||||
super().__init__()
|
||||
self.time_proj = Timesteps(
|
||||
num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0.0
|
||||
)
|
||||
|
||||
self.timestep_embedder = TimestepEmbedding(in_channels=frequency_embedding_size, time_embed_dim=hidden_size)
|
||||
|
||||
self.caption_embedder = nn.Sequential(
|
||||
nn.LayerNorm(cross_attention_dim),
|
||||
nn.Linear(
|
||||
cross_attention_dim,
|
||||
hidden_size,
|
||||
bias=True,
|
||||
),
|
||||
)
|
||||
|
||||
def forward(self, timestep, caption_feat, caption_mask):
|
||||
# timestep embedding:
|
||||
time_freq = self.time_proj(timestep)
|
||||
time_embed = self.timestep_embedder(time_freq.to(dtype=self.timestep_embedder.linear_1.weight.dtype))
|
||||
|
||||
# caption condition embedding:
|
||||
caption_mask_float = caption_mask.float().unsqueeze(-1)
|
||||
caption_feats_pool = (caption_feat * caption_mask_float).sum(dim=1) / caption_mask_float.sum(dim=1)
|
||||
caption_feats_pool = caption_feats_pool.to(caption_feat)
|
||||
caption_embed = self.caption_embedder(caption_feats_pool)
|
||||
|
||||
conditioning = time_embed + caption_embed
|
||||
|
||||
return conditioning
|
||||
|
||||
|
||||
class TextTimeEmbedding(nn.Module):
|
||||
def __init__(self, encoder_dim: int, time_embed_dim: int, num_heads: int = 64):
|
||||
super().__init__()
|
||||
|
||||
@@ -221,7 +221,7 @@ def _fetch_index_file(
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=None,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
commit_hash=commit_hash,
|
||||
)
|
||||
|
||||
@@ -22,10 +22,7 @@ import torch.nn.functional as F
|
||||
|
||||
from ..utils import is_torch_version
|
||||
from .activations import get_activation
|
||||
from .embeddings import (
|
||||
CombinedTimestepLabelEmbeddings,
|
||||
PixArtAlphaCombinedTimestepSizeEmbeddings,
|
||||
)
|
||||
from .embeddings import CombinedTimestepLabelEmbeddings, PixArtAlphaCombinedTimestepSizeEmbeddings
|
||||
|
||||
|
||||
class AdaLayerNorm(nn.Module):
|
||||
@@ -51,18 +48,6 @@ class AdaLayerNorm(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class FP32LayerNorm(nn.LayerNorm):
|
||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
origin_dtype = inputs.dtype
|
||||
return F.layer_norm(
|
||||
inputs.float(),
|
||||
self.normalized_shape,
|
||||
self.weight.float() if self.weight is not None else None,
|
||||
self.bias.float() if self.bias is not None else None,
|
||||
self.eps,
|
||||
).to(origin_dtype)
|
||||
|
||||
|
||||
class AdaLayerNormZero(nn.Module):
|
||||
r"""
|
||||
Norm layer adaptive layer norm zero (adaLN-Zero).
|
||||
@@ -72,7 +57,7 @@ class AdaLayerNormZero(nn.Module):
|
||||
num_embeddings (`int`): The size of the embeddings dictionary.
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_dim: int, num_embeddings: Optional[int] = None, norm_type="layer_norm", bias=True):
|
||||
def __init__(self, embedding_dim: int, num_embeddings: Optional[int] = None):
|
||||
super().__init__()
|
||||
if num_embeddings is not None:
|
||||
self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim)
|
||||
@@ -80,15 +65,8 @@ class AdaLayerNormZero(nn.Module):
|
||||
self.emb = None
|
||||
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=bias)
|
||||
if norm_type == "layer_norm":
|
||||
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
|
||||
elif norm_type == "fp32_layer_norm":
|
||||
self.norm = FP32LayerNorm(embedding_dim, elementwise_affine=False, bias=False)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
|
||||
)
|
||||
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
|
||||
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -106,37 +84,6 @@ class AdaLayerNormZero(nn.Module):
|
||||
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
|
||||
|
||||
|
||||
class LuminaRMSNormZero(nn.Module):
|
||||
"""
|
||||
Norm layer adaptive RMS normalization zero.
|
||||
|
||||
Parameters:
|
||||
embedding_dim (`int`): The size of each embedding vector.
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_dim: int, norm_eps: float, norm_elementwise_affine: bool):
|
||||
super().__init__()
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = nn.Linear(
|
||||
min(embedding_dim, 1024),
|
||||
4 * embedding_dim,
|
||||
bias=True,
|
||||
)
|
||||
self.norm = RMSNorm(embedding_dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
emb: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
# emb = self.emb(timestep, encoder_hidden_states, encoder_mask)
|
||||
emb = self.linear(self.silu(emb))
|
||||
scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1)
|
||||
x = self.norm(x) * (1 + scale_msa[:, None])
|
||||
|
||||
return x, gate_msa, scale_mlp, gate_mlp
|
||||
|
||||
|
||||
class AdaLayerNormSingle(nn.Module):
|
||||
r"""
|
||||
Norm layer adaptive layer norm single (adaLN-single).
|
||||
@@ -241,54 +188,6 @@ class AdaLayerNormContinuous(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class LuminaLayerNormContinuous(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
conditioning_embedding_dim: int,
|
||||
# NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters
|
||||
# because the output is immediately scaled and shifted by the projected conditioning embeddings.
|
||||
# Note that AdaLayerNorm does not let the norm layer have scale and shift parameters.
|
||||
# However, this is how it was implemented in the original code, and it's rather likely you should
|
||||
# set `elementwise_affine` to False.
|
||||
elementwise_affine=True,
|
||||
eps=1e-5,
|
||||
bias=True,
|
||||
norm_type="layer_norm",
|
||||
out_dim: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
# AdaLN
|
||||
self.silu = nn.SiLU()
|
||||
self.linear_1 = nn.Linear(conditioning_embedding_dim, embedding_dim, bias=bias)
|
||||
if norm_type == "layer_norm":
|
||||
self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
|
||||
else:
|
||||
raise ValueError(f"unknown norm_type {norm_type}")
|
||||
# linear_2
|
||||
if out_dim is not None:
|
||||
self.linear_2 = nn.Linear(
|
||||
embedding_dim,
|
||||
out_dim,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
conditioning_embedding: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT)
|
||||
emb = self.linear_1(self.silu(conditioning_embedding).to(x.dtype))
|
||||
scale = emb
|
||||
x = self.norm(x) * (1 + scale)[:, None, :]
|
||||
|
||||
if self.linear_2 is not None:
|
||||
x = self.linear_2(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
if is_torch_version(">=", "2.1.0"):
|
||||
LayerNorm = nn.LayerNorm
|
||||
else:
|
||||
|
||||
@@ -2,12 +2,9 @@ from ...utils import is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from .auraflow_transformer_2d import AuraFlowTransformer2DModel
|
||||
from .dit_transformer_2d import DiTTransformer2DModel
|
||||
from .dual_transformer_2d import DualTransformer2DModel
|
||||
from .hunyuan_transformer_2d import HunyuanDiT2DModel
|
||||
from .latte_transformer_3d import LatteTransformer3DModel
|
||||
from .lumina_nextdit2d import LuminaNextDiT2DModel
|
||||
from .pixart_transformer_2d import PixArtTransformer2DModel
|
||||
from .prior_transformer import PriorTransformer
|
||||
from .t5_film_transformer import T5FilmDecoder
|
||||
|
||||
@@ -1,422 +0,0 @@
|
||||
# Copyright 2024 AuraFlow Authors, The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import is_torch_version, logging
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention_processor import Attention, AuraFlowAttnProcessor2_0
|
||||
from ..embeddings import TimestepEmbedding, Timesteps
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNormZero, FP32LayerNorm
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
# Taken from the original aura flow inference code.
|
||||
def find_multiple(n: int, k: int) -> int:
|
||||
if n % k == 0:
|
||||
return n
|
||||
return n + k - (n % k)
|
||||
|
||||
|
||||
# Aura Flow patch embed doesn't use convs for projections.
|
||||
# Additionally, it uses learned positional embeddings.
|
||||
class AuraFlowPatchEmbed(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
height=224,
|
||||
width=224,
|
||||
patch_size=16,
|
||||
in_channels=3,
|
||||
embed_dim=768,
|
||||
pos_embed_max_size=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.num_patches = (height // patch_size) * (width // patch_size)
|
||||
self.pos_embed_max_size = pos_embed_max_size
|
||||
|
||||
self.proj = nn.Linear(patch_size * patch_size * in_channels, embed_dim)
|
||||
self.pos_embed = nn.Parameter(torch.randn(1, pos_embed_max_size, embed_dim) * 0.1)
|
||||
|
||||
self.patch_size = patch_size
|
||||
self.height, self.width = height // patch_size, width // patch_size
|
||||
self.base_size = height // patch_size
|
||||
|
||||
def forward(self, latent):
|
||||
batch_size, num_channels, height, width = latent.size()
|
||||
latent = latent.view(
|
||||
batch_size,
|
||||
num_channels,
|
||||
height // self.patch_size,
|
||||
self.patch_size,
|
||||
width // self.patch_size,
|
||||
self.patch_size,
|
||||
)
|
||||
latent = latent.permute(0, 2, 4, 1, 3, 5).flatten(-3).flatten(1, 2)
|
||||
latent = self.proj(latent)
|
||||
return latent + self.pos_embed
|
||||
|
||||
|
||||
# Taken from the original Aura flow inference code.
|
||||
# Our feedforward only has GELU but Aura uses SiLU.
|
||||
class AuraFlowFeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim=None) -> None:
|
||||
super().__init__()
|
||||
if hidden_dim is None:
|
||||
hidden_dim = 4 * dim
|
||||
|
||||
final_hidden_dim = int(2 * hidden_dim / 3)
|
||||
final_hidden_dim = find_multiple(final_hidden_dim, 256)
|
||||
|
||||
self.linear_1 = nn.Linear(dim, final_hidden_dim, bias=False)
|
||||
self.linear_2 = nn.Linear(dim, final_hidden_dim, bias=False)
|
||||
self.out_projection = nn.Linear(final_hidden_dim, dim, bias=False)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = F.silu(self.linear_1(x)) * self.linear_2(x)
|
||||
x = self.out_projection(x)
|
||||
return x
|
||||
|
||||
|
||||
class AuraFlowPreFinalBlock(nn.Module):
|
||||
def __init__(self, embedding_dim: int, conditioning_embedding_dim: int):
|
||||
super().__init__()
|
||||
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=False)
|
||||
|
||||
def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
|
||||
emb = self.linear(self.silu(conditioning_embedding).to(x.dtype))
|
||||
scale, shift = torch.chunk(emb, 2, dim=1)
|
||||
x = x * (1 + scale)[:, None, :] + shift[:, None, :]
|
||||
return x
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class AuraFlowSingleTransformerBlock(nn.Module):
|
||||
"""Similar to `AuraFlowJointTransformerBlock` with a single DiT instead of an MMDiT."""
|
||||
|
||||
def __init__(self, dim, num_attention_heads, attention_head_dim):
|
||||
super().__init__()
|
||||
|
||||
self.norm1 = AdaLayerNormZero(dim, bias=False, norm_type="fp32_layer_norm")
|
||||
|
||||
processor = AuraFlowAttnProcessor2_0()
|
||||
self.attn = Attention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=None,
|
||||
dim_head=attention_head_dim,
|
||||
heads=num_attention_heads,
|
||||
qk_norm="fp32_layer_norm",
|
||||
out_dim=dim,
|
||||
bias=False,
|
||||
out_bias=False,
|
||||
processor=processor,
|
||||
)
|
||||
|
||||
self.norm2 = FP32LayerNorm(dim, elementwise_affine=False, bias=False)
|
||||
self.ff = AuraFlowFeedForward(dim, dim * 4)
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor, i=9999):
|
||||
residual = hidden_states
|
||||
|
||||
# Norm + Projection.
|
||||
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
|
||||
|
||||
# Attention.
|
||||
attn_output = self.attn(hidden_states=norm_hidden_states, i=i)
|
||||
|
||||
# Process attention outputs for the `hidden_states`.
|
||||
hidden_states = self.norm2(residual + gate_msa.unsqueeze(1) * attn_output)
|
||||
hidden_states = hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
||||
ff_output = self.ff(hidden_states)
|
||||
hidden_states = gate_mlp.unsqueeze(1) * ff_output
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class AuraFlowJointTransformerBlock(nn.Module):
|
||||
r"""
|
||||
Transformer block for Aura Flow. Similar to SD3 MMDiT. Differences (non-exhaustive):
|
||||
|
||||
* QK Norm in the attention blocks
|
||||
* No bias in the attention blocks
|
||||
* Most LayerNorms are in FP32
|
||||
|
||||
Parameters:
|
||||
dim (`int`): The number of channels in the input and output.
|
||||
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`): The number of channels in each head.
|
||||
is_last (`bool`): Boolean to determine if this is the last block in the model.
|
||||
"""
|
||||
|
||||
def __init__(self, dim, num_attention_heads, attention_head_dim):
|
||||
super().__init__()
|
||||
|
||||
self.norm1 = AdaLayerNormZero(dim, bias=False, norm_type="fp32_layer_norm")
|
||||
self.norm1_context = AdaLayerNormZero(dim, bias=False, norm_type="fp32_layer_norm")
|
||||
|
||||
processor = AuraFlowAttnProcessor2_0()
|
||||
self.attn = Attention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=None,
|
||||
added_kv_proj_dim=dim,
|
||||
added_proj_bias=False,
|
||||
dim_head=attention_head_dim,
|
||||
heads=num_attention_heads,
|
||||
qk_norm="fp32_layer_norm",
|
||||
out_dim=dim,
|
||||
bias=False,
|
||||
out_bias=False,
|
||||
processor=processor,
|
||||
context_pre_only=False,
|
||||
)
|
||||
|
||||
self.norm2 = FP32LayerNorm(dim, elementwise_affine=False, bias=False)
|
||||
self.ff = AuraFlowFeedForward(dim, dim * 4)
|
||||
self.norm2_context = FP32LayerNorm(dim, elementwise_affine=False, bias=False)
|
||||
self.ff_context = AuraFlowFeedForward(dim, dim * 4)
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor, i=0
|
||||
):
|
||||
residual = hidden_states
|
||||
residual_context = encoder_hidden_states
|
||||
|
||||
# Norm + Projection.
|
||||
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
|
||||
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
|
||||
encoder_hidden_states, emb=temb
|
||||
)
|
||||
|
||||
# Attention.
|
||||
attn_output, context_attn_output = self.attn(
|
||||
hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, i=i
|
||||
)
|
||||
|
||||
# Process attention outputs for the `hidden_states`.
|
||||
hidden_states = self.norm2(residual + gate_msa.unsqueeze(1) * attn_output)
|
||||
hidden_states = hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
||||
hidden_states = gate_mlp.unsqueeze(1) * self.ff(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Process attention outputs for the `encoder_hidden_states`.
|
||||
encoder_hidden_states = self.norm2_context(residual_context + c_gate_msa.unsqueeze(1) * context_attn_output)
|
||||
encoder_hidden_states = encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
|
||||
encoder_hidden_states = c_gate_mlp.unsqueeze(1) * self.ff_context(encoder_hidden_states)
|
||||
encoder_hidden_states = residual_context + encoder_hidden_states
|
||||
|
||||
return encoder_hidden_states, hidden_states
|
||||
|
||||
|
||||
class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin):
|
||||
r"""
|
||||
A 2D Transformer model as introduced in AuraFlow (https://blog.fal.ai/auraflow/).
|
||||
|
||||
Parameters:
|
||||
sample_size (`int`): The width of the latent images. This is fixed during training since
|
||||
it is used to learn a number of position embeddings.
|
||||
patch_size (`int`): Patch size to turn the input data into small patches.
|
||||
in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
|
||||
num_mmdit_layers (`int`, *optional*, defaults to 4): The number of layers of MMDiT Transformer blocks to use.
|
||||
num_single_dit_layers (`int`, *optional*, defaults to 4):
|
||||
The number of layers of Transformer blocks to use. These blocks use concatenated image and text
|
||||
representations.
|
||||
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
|
||||
num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
|
||||
joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
|
||||
caption_projection_dim (`int`): Number of dimensions to use when projecting the `encoder_hidden_states`.
|
||||
out_channels (`int`, defaults to 16): Number of output channels.
|
||||
pos_embed_max_size (`int`, defaults to 4096): Maximum positions to embed from the image latents.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
sample_size: int = 64,
|
||||
patch_size: int = 2,
|
||||
in_channels: int = 4,
|
||||
num_mmdit_layers: int = 4,
|
||||
num_single_dit_layers: int = 32,
|
||||
attention_head_dim: int = 256,
|
||||
num_attention_heads: int = 12,
|
||||
joint_attention_dim: int = 2048,
|
||||
caption_projection_dim: int = 3072,
|
||||
out_channels: int = 4,
|
||||
pos_embed_max_size: int = 1024,
|
||||
):
|
||||
super().__init__()
|
||||
default_out_channels = in_channels
|
||||
self.out_channels = out_channels if out_channels is not None else default_out_channels
|
||||
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
|
||||
|
||||
self.pos_embed = AuraFlowPatchEmbed(
|
||||
height=self.config.sample_size,
|
||||
width=self.config.sample_size,
|
||||
patch_size=self.config.patch_size,
|
||||
in_channels=self.config.in_channels,
|
||||
embed_dim=self.inner_dim,
|
||||
pos_embed_max_size=pos_embed_max_size,
|
||||
)
|
||||
|
||||
self.context_embedder = nn.Linear(
|
||||
self.config.joint_attention_dim, self.config.caption_projection_dim, bias=False
|
||||
)
|
||||
self.time_step_embed = Timesteps(num_channels=256, downscale_freq_shift=0, scale=1000, flip_sin_to_cos=True)
|
||||
self.time_step_proj = TimestepEmbedding(in_channels=256, time_embed_dim=self.inner_dim)
|
||||
|
||||
self.joint_transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
AuraFlowJointTransformerBlock(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=self.config.num_attention_heads,
|
||||
attention_head_dim=self.config.attention_head_dim,
|
||||
)
|
||||
for i in range(self.config.num_mmdit_layers)
|
||||
]
|
||||
)
|
||||
self.single_transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
AuraFlowSingleTransformerBlock(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=self.config.num_attention_heads,
|
||||
attention_head_dim=self.config.attention_head_dim,
|
||||
)
|
||||
for _ in range(self.config.num_single_dit_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.norm_out = AuraFlowPreFinalBlock(self.inner_dim, self.inner_dim)
|
||||
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=False)
|
||||
|
||||
# https://arxiv.org/abs/2309.16588
|
||||
# prevents artifacts in the attention maps
|
||||
self.register_tokens = nn.Parameter(torch.randn(1, 8, self.inner_dim) * 0.02)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor = None,
|
||||
timestep: torch.LongTensor = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
|
||||
height, width = hidden_states.shape[-2:]
|
||||
|
||||
# Apply patch embedding, timestep embedding, and project the caption embeddings.
|
||||
hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too.
|
||||
temb = self.time_step_embed(timestep).to(dtype=next(self.parameters()).dtype)
|
||||
temb = self.time_step_proj(temb)
|
||||
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
||||
encoder_hidden_states = torch.cat(
|
||||
[self.register_tokens.repeat(encoder_hidden_states.size(0), 1, 1), encoder_hidden_states], dim=1
|
||||
)
|
||||
|
||||
# MMDiT blocks.
|
||||
for index_block, block in enumerate(self.joint_transformer_blocks):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
|
||||
else:
|
||||
encoder_hidden_states, hidden_states = block(
|
||||
hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb, i=index_block
|
||||
)
|
||||
|
||||
# Single DiT blocks that combine the `hidden_states` (image) and `encoder_hidden_states` (text)
|
||||
if len(self.single_transformer_blocks) > 0:
|
||||
encoder_seq_len = encoder_hidden_states.size(1)
|
||||
combined_hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
|
||||
for index_block, block in enumerate(self.single_transformer_blocks):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
combined_hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
combined_hidden_states,
|
||||
temb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
|
||||
else:
|
||||
combined_hidden_states = block(hidden_states=combined_hidden_states, temb=temb)
|
||||
|
||||
hidden_states = combined_hidden_states[:, encoder_seq_len:]
|
||||
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
|
||||
# unpatchify
|
||||
patch_size = self.config.patch_size
|
||||
out_channels = self.config.out_channels
|
||||
height = height // patch_size
|
||||
width = width // patch_size
|
||||
|
||||
hidden_states = hidden_states.reshape(
|
||||
shape=(hidden_states.shape[0], height, width, patch_size, patch_size, out_channels)
|
||||
)
|
||||
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
|
||||
output = hidden_states.reshape(
|
||||
shape=(hidden_states.shape[0], out_channels, height * patch_size, width * patch_size)
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
@@ -14,6 +14,7 @@
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
@@ -28,12 +29,20 @@ from ..embeddings import (
|
||||
)
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNormContinuous, FP32LayerNorm
|
||||
from ..normalization import AdaLayerNormContinuous
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class FP32LayerNorm(nn.LayerNorm):
|
||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
origin_dtype = inputs.dtype
|
||||
return F.layer_norm(
|
||||
inputs.float(), self.normalized_shape, self.weight.float(), self.bias.float(), self.eps
|
||||
).to(origin_dtype)
|
||||
|
||||
|
||||
class AdaLayerNormShift(nn.Module):
|
||||
r"""
|
||||
Norm layer modified to incorporate timestep embeddings.
|
||||
|
||||
@@ -1,327 +0,0 @@
|
||||
# Copyright 2024 the Latte Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...models.embeddings import PixArtAlphaTextProjection, get_1d_sincos_pos_embed_from_grid
|
||||
from ..attention import BasicTransformerBlock
|
||||
from ..embeddings import PatchEmbed
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNormSingle
|
||||
|
||||
|
||||
class LatteTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
"""
|
||||
A 3D Transformer model for video-like data, paper: https://arxiv.org/abs/2401.03048, offical code:
|
||||
https://github.com/Vchitect/Latte
|
||||
|
||||
Parameters:
|
||||
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
|
||||
in_channels (`int`, *optional*):
|
||||
The number of channels in the input.
|
||||
out_channels (`int`, *optional*):
|
||||
The number of channels in the output.
|
||||
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
|
||||
attention_bias (`bool`, *optional*):
|
||||
Configure if the `TransformerBlocks` attention should contain a bias parameter.
|
||||
sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
|
||||
This is fixed during training since it is used to learn a number of position embeddings.
|
||||
patch_size (`int`, *optional*):
|
||||
The size of the patches to use in the patch embedding layer.
|
||||
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
|
||||
num_embeds_ada_norm ( `int`, *optional*):
|
||||
The number of diffusion steps used during training. Pass if at least one of the norm_layers is
|
||||
`AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
|
||||
added to the hidden states. During inference, you can denoise for up to but not more steps than
|
||||
`num_embeds_ada_norm`.
|
||||
norm_type (`str`, *optional*, defaults to `"layer_norm"`):
|
||||
The type of normalization to use. Options are `"layer_norm"` or `"ada_layer_norm"`.
|
||||
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to use elementwise affine in normalization layers.
|
||||
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use in normalization layers.
|
||||
caption_channels (`int`, *optional*):
|
||||
The number of channels in the caption embeddings.
|
||||
video_length (`int`, *optional*):
|
||||
The number of frames in the video-like data.
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int = 16,
|
||||
attention_head_dim: int = 88,
|
||||
in_channels: Optional[int] = None,
|
||||
out_channels: Optional[int] = None,
|
||||
num_layers: int = 1,
|
||||
dropout: float = 0.0,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
attention_bias: bool = False,
|
||||
sample_size: int = 64,
|
||||
patch_size: Optional[int] = None,
|
||||
activation_fn: str = "geglu",
|
||||
num_embeds_ada_norm: Optional[int] = None,
|
||||
norm_type: str = "layer_norm",
|
||||
norm_elementwise_affine: bool = True,
|
||||
norm_eps: float = 1e-5,
|
||||
caption_channels: int = None,
|
||||
video_length: int = 16,
|
||||
):
|
||||
super().__init__()
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
# 1. Define input layers
|
||||
self.height = sample_size
|
||||
self.width = sample_size
|
||||
|
||||
interpolation_scale = self.config.sample_size // 64
|
||||
interpolation_scale = max(interpolation_scale, 1)
|
||||
self.pos_embed = PatchEmbed(
|
||||
height=sample_size,
|
||||
width=sample_size,
|
||||
patch_size=patch_size,
|
||||
in_channels=in_channels,
|
||||
embed_dim=inner_dim,
|
||||
interpolation_scale=interpolation_scale,
|
||||
)
|
||||
|
||||
# 2. Define spatial transformers blocks
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
inner_dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
dropout=dropout,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
activation_fn=activation_fn,
|
||||
num_embeds_ada_norm=num_embeds_ada_norm,
|
||||
attention_bias=attention_bias,
|
||||
norm_type=norm_type,
|
||||
norm_elementwise_affine=norm_elementwise_affine,
|
||||
norm_eps=norm_eps,
|
||||
)
|
||||
for d in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# 3. Define temporal transformers blocks
|
||||
self.temporal_transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
inner_dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
dropout=dropout,
|
||||
cross_attention_dim=None,
|
||||
activation_fn=activation_fn,
|
||||
num_embeds_ada_norm=num_embeds_ada_norm,
|
||||
attention_bias=attention_bias,
|
||||
norm_type=norm_type,
|
||||
norm_elementwise_affine=norm_elementwise_affine,
|
||||
norm_eps=norm_eps,
|
||||
)
|
||||
for d in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# 4. Define output layers
|
||||
self.out_channels = in_channels if out_channels is None else out_channels
|
||||
self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
|
||||
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
|
||||
|
||||
# 5. Latte other blocks.
|
||||
self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=False)
|
||||
self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
|
||||
|
||||
# define temporal positional embedding
|
||||
temp_pos_embed = get_1d_sincos_pos_embed_from_grid(
|
||||
inner_dim, torch.arange(0, video_length).unsqueeze(1)
|
||||
) # 1152 hidden size
|
||||
self.register_buffer("temp_pos_embed", torch.from_numpy(temp_pos_embed).float().unsqueeze(0), persistent=False)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
self.gradient_checkpointing = value
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
timestep: Optional[torch.LongTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
enable_temporal_attentions: bool = True,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
"""
|
||||
The [`LatteTransformer3DModel`] forward method.
|
||||
|
||||
Args:
|
||||
hidden_states shape `(batch size, channel, num_frame, height, width)`:
|
||||
Input `hidden_states`.
|
||||
timestep ( `torch.LongTensor`, *optional*):
|
||||
Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
|
||||
encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
|
||||
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
||||
self-attention.
|
||||
encoder_attention_mask ( `torch.Tensor`, *optional*):
|
||||
Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
|
||||
|
||||
* Mask `(batcheight, sequence_length)` True = keep, False = discard.
|
||||
* Bias `(batcheight, 1, sequence_length)` 0 = keep, -10000 = discard.
|
||||
|
||||
If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
|
||||
above. This bias will be added to the cross-attention scores.
|
||||
enable_temporal_attentions:
|
||||
(`bool`, *optional*, defaults to `True`): Whether to enable temporal attentions.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
||||
tuple.
|
||||
|
||||
Returns:
|
||||
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
||||
`tuple` where the first element is the sample tensor.
|
||||
"""
|
||||
|
||||
# Reshape hidden states
|
||||
batch_size, channels, num_frame, height, width = hidden_states.shape
|
||||
# batch_size channels num_frame height width -> (batch_size * num_frame) channels height width
|
||||
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape(-1, channels, height, width)
|
||||
|
||||
# Input
|
||||
height, width = (
|
||||
hidden_states.shape[-2] // self.config.patch_size,
|
||||
hidden_states.shape[-1] // self.config.patch_size,
|
||||
)
|
||||
num_patches = height * width
|
||||
|
||||
hidden_states = self.pos_embed(hidden_states) # alrady add positional embeddings
|
||||
|
||||
added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
|
||||
timestep, embedded_timestep = self.adaln_single(
|
||||
timestep, added_cond_kwargs=added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
|
||||
)
|
||||
|
||||
# Prepare text embeddings for spatial block
|
||||
# batch_size num_tokens hidden_size -> (batch_size * num_frame) num_tokens hidden_size
|
||||
encoder_hidden_states = self.caption_projection(encoder_hidden_states) # 3 120 1152
|
||||
encoder_hidden_states_spatial = encoder_hidden_states.repeat_interleave(num_frame, dim=0).view(
|
||||
-1, encoder_hidden_states.shape[-2], encoder_hidden_states.shape[-1]
|
||||
)
|
||||
|
||||
# Prepare timesteps for spatial and temporal block
|
||||
timestep_spatial = timestep.repeat_interleave(num_frame, dim=0).view(-1, timestep.shape[-1])
|
||||
timestep_temp = timestep.repeat_interleave(num_patches, dim=0).view(-1, timestep.shape[-1])
|
||||
|
||||
# Spatial and temporal transformer blocks
|
||||
for i, (spatial_block, temp_block) in enumerate(
|
||||
zip(self.transformer_blocks, self.temporal_transformer_blocks)
|
||||
):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
spatial_block,
|
||||
hidden_states,
|
||||
None, # attention_mask
|
||||
encoder_hidden_states_spatial,
|
||||
encoder_attention_mask,
|
||||
timestep_spatial,
|
||||
None, # cross_attention_kwargs
|
||||
None, # class_labels
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
hidden_states = spatial_block(
|
||||
hidden_states,
|
||||
None, # attention_mask
|
||||
encoder_hidden_states_spatial,
|
||||
encoder_attention_mask,
|
||||
timestep_spatial,
|
||||
None, # cross_attention_kwargs
|
||||
None, # class_labels
|
||||
)
|
||||
|
||||
if enable_temporal_attentions:
|
||||
# (batch_size * num_frame) num_tokens hidden_size -> (batch_size * num_tokens) num_frame hidden_size
|
||||
hidden_states = hidden_states.reshape(
|
||||
batch_size, -1, hidden_states.shape[-2], hidden_states.shape[-1]
|
||||
).permute(0, 2, 1, 3)
|
||||
hidden_states = hidden_states.reshape(-1, hidden_states.shape[-2], hidden_states.shape[-1])
|
||||
|
||||
if i == 0 and num_frame > 1:
|
||||
hidden_states = hidden_states + self.temp_pos_embed
|
||||
|
||||
if self.training and self.gradient_checkpointing:
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
temp_block,
|
||||
hidden_states,
|
||||
None, # attention_mask
|
||||
None, # encoder_hidden_states
|
||||
None, # encoder_attention_mask
|
||||
timestep_temp,
|
||||
None, # cross_attention_kwargs
|
||||
None, # class_labels
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
hidden_states = temp_block(
|
||||
hidden_states,
|
||||
None, # attention_mask
|
||||
None, # encoder_hidden_states
|
||||
None, # encoder_attention_mask
|
||||
timestep_temp,
|
||||
None, # cross_attention_kwargs
|
||||
None, # class_labels
|
||||
)
|
||||
|
||||
# (batch_size * num_tokens) num_frame hidden_size -> (batch_size * num_frame) num_tokens hidden_size
|
||||
hidden_states = hidden_states.reshape(
|
||||
batch_size, -1, hidden_states.shape[-2], hidden_states.shape[-1]
|
||||
).permute(0, 2, 1, 3)
|
||||
hidden_states = hidden_states.reshape(-1, hidden_states.shape[-2], hidden_states.shape[-1])
|
||||
|
||||
embedded_timestep = embedded_timestep.repeat_interleave(num_frame, dim=0).view(-1, embedded_timestep.shape[-1])
|
||||
shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
# Modulation
|
||||
hidden_states = hidden_states * (1 + scale) + shift
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
|
||||
# unpatchify
|
||||
if self.adaln_single is None:
|
||||
height = width = int(hidden_states.shape[1] ** 0.5)
|
||||
hidden_states = hidden_states.reshape(
|
||||
shape=(-1, height, width, self.config.patch_size, self.config.patch_size, self.out_channels)
|
||||
)
|
||||
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
|
||||
output = hidden_states.reshape(
|
||||
shape=(-1, self.out_channels, height * self.config.patch_size, width * self.config.patch_size)
|
||||
)
|
||||
output = output.reshape(batch_size, -1, output.shape[-3], output.shape[-2], output.shape[-1]).permute(
|
||||
0, 2, 1, 3, 4
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
@@ -1,340 +0,0 @@
|
||||
# Copyright 2024 Alpha-VLLM Authors and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import logging
|
||||
from ..attention import LuminaFeedForward
|
||||
from ..attention_processor import Attention, LuminaAttnProcessor2_0
|
||||
from ..embeddings import (
|
||||
LuminaCombinedTimestepCaptionEmbedding,
|
||||
LuminaPatchEmbed,
|
||||
)
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import LuminaLayerNormContinuous, LuminaRMSNormZero, RMSNorm
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class LuminaNextDiTBlock(nn.Module):
|
||||
"""
|
||||
A LuminaNextDiTBlock for LuminaNextDiT2DModel.
|
||||
|
||||
Parameters:
|
||||
dim (`int`): Embedding dimension of the input features.
|
||||
num_attention_heads (`int`): Number of attention heads.
|
||||
num_kv_heads (`int`):
|
||||
Number of attention heads in key and value features (if using GQA), or set to None for the same as query.
|
||||
multiple_of (`int`): The number of multiple of ffn layer.
|
||||
ffn_dim_multiplier (`float`): The multipier factor of ffn layer dimension.
|
||||
norm_eps (`float`): The eps for norm layer.
|
||||
qk_norm (`bool`): normalization for query and key.
|
||||
cross_attention_dim (`int`): Cross attention embedding dimension of the input text prompt hidden_states.
|
||||
norm_elementwise_affine (`bool`, *optional*, defaults to True),
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
num_kv_heads: int,
|
||||
multiple_of: int,
|
||||
ffn_dim_multiplier: float,
|
||||
norm_eps: float,
|
||||
qk_norm: bool,
|
||||
cross_attention_dim: int,
|
||||
norm_elementwise_affine: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.head_dim = dim // num_attention_heads
|
||||
|
||||
self.gate = nn.Parameter(torch.zeros([num_attention_heads]))
|
||||
|
||||
# Self-attention
|
||||
self.attn1 = Attention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=None,
|
||||
dim_head=dim // num_attention_heads,
|
||||
qk_norm="layer_norm_across_heads" if qk_norm else None,
|
||||
heads=num_attention_heads,
|
||||
kv_heads=num_kv_heads,
|
||||
eps=1e-5,
|
||||
bias=False,
|
||||
out_bias=False,
|
||||
processor=LuminaAttnProcessor2_0(),
|
||||
)
|
||||
self.attn1.to_out = nn.Identity()
|
||||
|
||||
# Cross-attention
|
||||
self.attn2 = Attention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
dim_head=dim // num_attention_heads,
|
||||
qk_norm="layer_norm_across_heads" if qk_norm else None,
|
||||
heads=num_attention_heads,
|
||||
kv_heads=num_kv_heads,
|
||||
eps=1e-5,
|
||||
bias=False,
|
||||
out_bias=False,
|
||||
processor=LuminaAttnProcessor2_0(),
|
||||
)
|
||||
|
||||
self.feed_forward = LuminaFeedForward(
|
||||
dim=dim,
|
||||
inner_dim=4 * dim,
|
||||
multiple_of=multiple_of,
|
||||
ffn_dim_multiplier=ffn_dim_multiplier,
|
||||
)
|
||||
|
||||
self.norm1 = LuminaRMSNormZero(
|
||||
embedding_dim=dim,
|
||||
norm_eps=norm_eps,
|
||||
norm_elementwise_affine=norm_elementwise_affine,
|
||||
)
|
||||
self.ffn_norm1 = RMSNorm(dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
|
||||
|
||||
self.norm2 = RMSNorm(dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
|
||||
self.ffn_norm2 = RMSNorm(dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
|
||||
|
||||
self.norm1_context = RMSNorm(cross_attention_dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
image_rotary_emb: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
encoder_mask: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
"""
|
||||
Perform a forward pass through the LuminaNextDiTBlock.
|
||||
|
||||
Parameters:
|
||||
hidden_states (`torch.Tensor`): The input of hidden_states for LuminaNextDiTBlock.
|
||||
attention_mask (`torch.Tensor): The input of hidden_states corresponse attention mask.
|
||||
image_rotary_emb (`torch.Tensor`): Precomputed cosine and sine frequencies.
|
||||
encoder_hidden_states: (`torch.Tensor`): The hidden_states of text prompt are processed by Gemma encoder.
|
||||
encoder_mask (`torch.Tensor`): The hidden_states of text prompt attention mask.
|
||||
temb (`torch.Tensor`): Timestep embedding with text prompt embedding.
|
||||
cross_attention_kwargs (`Dict[str, Any]`): kwargs for cross attention.
|
||||
"""
|
||||
residual = hidden_states
|
||||
|
||||
# Self-attention
|
||||
norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
|
||||
self_attn_output = self.attn1(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=norm_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
query_rotary_emb=image_rotary_emb,
|
||||
key_rotary_emb=image_rotary_emb,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
|
||||
# Cross-attention
|
||||
norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states)
|
||||
cross_attn_output = self.attn2(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=norm_encoder_hidden_states,
|
||||
attention_mask=encoder_mask,
|
||||
query_rotary_emb=image_rotary_emb,
|
||||
key_rotary_emb=None,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
cross_attn_output = cross_attn_output * self.gate.tanh().view(1, 1, -1, 1)
|
||||
mixed_attn_output = self_attn_output + cross_attn_output
|
||||
mixed_attn_output = mixed_attn_output.flatten(-2)
|
||||
# linear proj
|
||||
hidden_states = self.attn2.to_out[0](mixed_attn_output)
|
||||
|
||||
hidden_states = residual + gate_msa.unsqueeze(1).tanh() * self.norm2(hidden_states)
|
||||
|
||||
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1)))
|
||||
|
||||
hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LuminaNextDiT2DModel(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
LuminaNextDiT: Diffusion model with a Transformer backbone.
|
||||
|
||||
Inherit ModelMixin and ConfigMixin to be compatible with the sampler StableDiffusionPipeline of diffusers.
|
||||
|
||||
Parameters:
|
||||
sample_size (`int`): The width of the latent images. This is fixed during training since
|
||||
it is used to learn a number of position embeddings.
|
||||
patch_size (`int`, *optional*, (`int`, *optional*, defaults to 2):
|
||||
The size of each patch in the image. This parameter defines the resolution of patches fed into the model.
|
||||
in_channels (`int`, *optional*, defaults to 4):
|
||||
The number of input channels for the model. Typically, this matches the number of channels in the input
|
||||
images.
|
||||
hidden_size (`int`, *optional*, defaults to 4096):
|
||||
The dimensionality of the hidden layers in the model. This parameter determines the width of the model's
|
||||
hidden representations.
|
||||
num_layers (`int`, *optional*, default to 32):
|
||||
The number of layers in the model. This defines the depth of the neural network.
|
||||
num_attention_heads (`int`, *optional*, defaults to 32):
|
||||
The number of attention heads in each attention layer. This parameter specifies how many separate attention
|
||||
mechanisms are used.
|
||||
num_kv_heads (`int`, *optional*, defaults to 8):
|
||||
The number of key-value heads in the attention mechanism, if different from the number of attention heads.
|
||||
If None, it defaults to num_attention_heads.
|
||||
multiple_of (`int`, *optional*, defaults to 256):
|
||||
A factor that the hidden size should be a multiple of. This can help optimize certain hardware
|
||||
configurations.
|
||||
ffn_dim_multiplier (`float`, *optional*):
|
||||
A multiplier for the dimensionality of the feed-forward network. If None, it uses a default value based on
|
||||
the model configuration.
|
||||
norm_eps (`float`, *optional*, defaults to 1e-5):
|
||||
A small value added to the denominator for numerical stability in normalization layers.
|
||||
learn_sigma (`bool`, *optional*, defaults to True):
|
||||
Whether the model should learn the sigma parameter, which might be related to uncertainty or variance in
|
||||
predictions.
|
||||
qk_norm (`bool`, *optional*, defaults to True):
|
||||
Indicates if the queries and keys in the attention mechanism should be normalized.
|
||||
cross_attention_dim (`int`, *optional*, defaults to 2048):
|
||||
The dimensionality of the text embeddings. This parameter defines the size of the text representations used
|
||||
in the model.
|
||||
scaling_factor (`float`, *optional*, defaults to 1.0):
|
||||
A scaling factor applied to certain parameters or layers in the model. This can be used for adjusting the
|
||||
overall scale of the model's operations.
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
sample_size: int = 128,
|
||||
patch_size: Optional[int] = 2,
|
||||
in_channels: Optional[int] = 4,
|
||||
hidden_size: Optional[int] = 2304,
|
||||
num_layers: Optional[int] = 32,
|
||||
num_attention_heads: Optional[int] = 32,
|
||||
num_kv_heads: Optional[int] = None,
|
||||
multiple_of: Optional[int] = 256,
|
||||
ffn_dim_multiplier: Optional[float] = None,
|
||||
norm_eps: Optional[float] = 1e-5,
|
||||
learn_sigma: Optional[bool] = True,
|
||||
qk_norm: Optional[bool] = True,
|
||||
cross_attention_dim: Optional[int] = 2048,
|
||||
scaling_factor: Optional[float] = 1.0,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.sample_size = sample_size
|
||||
self.patch_size = patch_size
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels * 2 if learn_sigma else in_channels
|
||||
self.hidden_size = hidden_size
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.head_dim = hidden_size // num_attention_heads
|
||||
self.scaling_factor = scaling_factor
|
||||
|
||||
self.patch_embedder = LuminaPatchEmbed(
|
||||
patch_size=patch_size, in_channels=in_channels, embed_dim=hidden_size, bias=True
|
||||
)
|
||||
|
||||
self.pad_token = nn.Parameter(torch.empty(hidden_size))
|
||||
|
||||
self.time_caption_embed = LuminaCombinedTimestepCaptionEmbedding(
|
||||
hidden_size=min(hidden_size, 1024), cross_attention_dim=cross_attention_dim
|
||||
)
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
LuminaNextDiTBlock(
|
||||
hidden_size,
|
||||
num_attention_heads,
|
||||
num_kv_heads,
|
||||
multiple_of,
|
||||
ffn_dim_multiplier,
|
||||
norm_eps,
|
||||
qk_norm,
|
||||
cross_attention_dim,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
self.norm_out = LuminaLayerNormContinuous(
|
||||
embedding_dim=hidden_size,
|
||||
conditioning_embedding_dim=min(hidden_size, 1024),
|
||||
elementwise_affine=False,
|
||||
eps=1e-6,
|
||||
bias=True,
|
||||
out_dim=patch_size * patch_size * self.out_channels,
|
||||
)
|
||||
# self.final_layer = LuminaFinalLayer(hidden_size, patch_size, self.out_channels)
|
||||
|
||||
assert (hidden_size // num_attention_heads) % 4 == 0, "2d rope needs head dim to be divisible by 4"
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
encoder_mask: torch.Tensor,
|
||||
image_rotary_emb: torch.Tensor,
|
||||
cross_attention_kwargs: Dict[str, Any] = None,
|
||||
return_dict=True,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass of LuminaNextDiT.
|
||||
|
||||
Parameters:
|
||||
hidden_states (torch.Tensor): Input tensor of shape (N, C, H, W).
|
||||
timestep (torch.Tensor): Tensor of diffusion timesteps of shape (N,).
|
||||
encoder_hidden_states (torch.Tensor): Tensor of caption features of shape (N, D).
|
||||
encoder_mask (torch.Tensor): Tensor of caption masks of shape (N, L).
|
||||
"""
|
||||
hidden_states, mask, img_size, image_rotary_emb = self.patch_embedder(hidden_states, image_rotary_emb)
|
||||
image_rotary_emb = image_rotary_emb.to(hidden_states.device)
|
||||
|
||||
temb = self.time_caption_embed(timestep, encoder_hidden_states, encoder_mask)
|
||||
|
||||
encoder_mask = encoder_mask.bool()
|
||||
for layer in self.layers:
|
||||
hidden_states = layer(
|
||||
hidden_states,
|
||||
mask,
|
||||
image_rotary_emb,
|
||||
encoder_hidden_states,
|
||||
encoder_mask,
|
||||
temb=temb,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
)
|
||||
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
|
||||
# unpatchify
|
||||
height_tokens = width_tokens = self.patch_size
|
||||
height, width = img_size[0]
|
||||
batch_size = hidden_states.size(0)
|
||||
sequence_length = (height // height_tokens) * (width // width_tokens)
|
||||
hidden_states = hidden_states[:, :sequence_length].view(
|
||||
batch_size, height // height_tokens, width // width_tokens, height_tokens, width_tokens, self.out_channels
|
||||
)
|
||||
output = hidden_states.permute(0, 5, 1, 3, 2, 4).flatten(4, 5).flatten(2, 3)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
@@ -11,7 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -19,7 +19,6 @@ from torch import nn
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import is_torch_version, logging
|
||||
from ..attention import BasicTransformerBlock
|
||||
from ..attention_processor import AttentionProcessor
|
||||
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
@@ -187,64 +186,6 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
@property
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
Parameters:
|
||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
for **all** `Attention` layers.
|
||||
|
||||
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||
processor. This is strongly recommended when setting trainable attention processors.
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
|
||||
@@ -97,7 +97,7 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
|
||||
JointTransformerBlock(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=self.config.num_attention_heads,
|
||||
attention_head_dim=self.config.attention_head_dim,
|
||||
attention_head_dim=self.inner_dim,
|
||||
context_pre_only=i == num_layers - 1,
|
||||
)
|
||||
for i in range(self.config.num_layers)
|
||||
|
||||
@@ -19,7 +19,7 @@ import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
|
||||
from ...configuration_utils import ConfigMixin, FrozenDict, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, UNet2DConditionLoadersMixin
|
||||
from ...loaders import UNet2DConditionLoadersMixin
|
||||
from ...utils import logging
|
||||
from ..attention_processor import (
|
||||
ADDED_KV_ATTENTION_PROCESSORS,
|
||||
@@ -93,7 +93,7 @@ class MotionModules(nn.Module):
|
||||
)
|
||||
|
||||
|
||||
class MotionAdapter(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
class MotionAdapter(ModelMixin, ConfigMixin):
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -478,7 +478,9 @@ class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
create_custom_forward(block), x, r_embed, use_reentrant=False
|
||||
)
|
||||
else:
|
||||
x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), use_reentrant=False)
|
||||
x = x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block), use_reentrant=False
|
||||
)
|
||||
if i < len(repmap):
|
||||
x = repmap[i](x)
|
||||
level_outputs.insert(0, x)
|
||||
|
||||
@@ -142,7 +142,6 @@ else:
|
||||
_import_structure["pag"].extend(
|
||||
[
|
||||
"StableDiffusionPAGPipeline",
|
||||
"StableDiffusionControlNetPAGPipeline",
|
||||
"StableDiffusionXLPAGPipeline",
|
||||
"StableDiffusionXLPAGInpaintPipeline",
|
||||
"StableDiffusionXLControlNetPAGPipeline",
|
||||
@@ -199,12 +198,6 @@ else:
|
||||
"Kandinsky3Img2ImgPipeline",
|
||||
"Kandinsky3Pipeline",
|
||||
]
|
||||
_import_structure["kolors"] = [
|
||||
"KolorsPipeline",
|
||||
"KolorsImg2ImgPipeline",
|
||||
"ChatGLMModel",
|
||||
"ChatGLMTokenizer",
|
||||
]
|
||||
_import_structure["latent_consistency_models"] = [
|
||||
"LatentConsistencyModelImg2ImgPipeline",
|
||||
"LatentConsistencyModelPipeline",
|
||||
@@ -216,8 +209,6 @@ else:
|
||||
"LEditsPPPipelineStableDiffusionXL",
|
||||
]
|
||||
)
|
||||
_import_structure["latte"] = ["LattePipeline"]
|
||||
_import_structure["lumina"] = ["LuminaText2ImgPipeline"]
|
||||
_import_structure["marigold"].extend(
|
||||
[
|
||||
"MarigoldDepthPipeline",
|
||||
@@ -251,12 +242,7 @@ else:
|
||||
"StableDiffusionLDM3DPipeline",
|
||||
]
|
||||
)
|
||||
_import_structure["aura_flow"] = ["AuraFlowPipeline"]
|
||||
_import_structure["stable_diffusion_3"] = [
|
||||
"StableDiffusion3Pipeline",
|
||||
"StableDiffusion3Img2ImgPipeline",
|
||||
"StableDiffusion3InpaintPipeline",
|
||||
]
|
||||
_import_structure["stable_diffusion_3"] = ["StableDiffusion3Pipeline", "StableDiffusion3Img2ImgPipeline"]
|
||||
_import_structure["stable_diffusion_attend_and_excite"] = ["StableDiffusionAttendAndExcitePipeline"]
|
||||
_import_structure["stable_diffusion_safe"] = ["StableDiffusionPipelineSafe"]
|
||||
_import_structure["stable_diffusion_sag"] = ["StableDiffusionSAGPipeline"]
|
||||
@@ -420,7 +406,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AudioLDM2ProjectionModel,
|
||||
AudioLDM2UNet2DConditionModel,
|
||||
)
|
||||
from .aura_flow import AuraFlowPipeline
|
||||
from .blip_diffusion import BlipDiffusionPipeline
|
||||
from .controlnet import (
|
||||
BlipDiffusionControlNetPipeline,
|
||||
@@ -490,32 +475,23 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
Kandinsky3Img2ImgPipeline,
|
||||
Kandinsky3Pipeline,
|
||||
)
|
||||
from .kolors import (
|
||||
ChatGLMModel,
|
||||
ChatGLMTokenizer,
|
||||
KolorsImg2ImgPipeline,
|
||||
KolorsPipeline,
|
||||
)
|
||||
from .latent_consistency_models import (
|
||||
LatentConsistencyModelImg2ImgPipeline,
|
||||
LatentConsistencyModelPipeline,
|
||||
)
|
||||
from .latent_diffusion import LDMTextToImagePipeline
|
||||
from .latte import LattePipeline
|
||||
from .ledits_pp import (
|
||||
LEditsPPDiffusionPipelineOutput,
|
||||
LEditsPPInversionPipelineOutput,
|
||||
LEditsPPPipelineStableDiffusion,
|
||||
LEditsPPPipelineStableDiffusionXL,
|
||||
)
|
||||
from .lumina import LuminaText2ImgPipeline
|
||||
from .marigold import (
|
||||
MarigoldDepthPipeline,
|
||||
MarigoldNormalsPipeline,
|
||||
)
|
||||
from .musicldm import MusicLDMPipeline
|
||||
from .pag import (
|
||||
StableDiffusionControlNetPAGPipeline,
|
||||
StableDiffusionPAGPipeline,
|
||||
StableDiffusionXLControlNetPAGPipeline,
|
||||
StableDiffusionXLPAGImg2ImgPipeline,
|
||||
@@ -545,11 +521,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
StableUnCLIPImg2ImgPipeline,
|
||||
StableUnCLIPPipeline,
|
||||
)
|
||||
from .stable_diffusion_3 import (
|
||||
StableDiffusion3Img2ImgPipeline,
|
||||
StableDiffusion3InpaintPipeline,
|
||||
StableDiffusion3Pipeline,
|
||||
)
|
||||
from .stable_diffusion_3 import StableDiffusion3Img2ImgPipeline, StableDiffusion3Pipeline
|
||||
from .stable_diffusion_attend_and_excite import StableDiffusionAttendAndExcitePipeline
|
||||
from .stable_diffusion_diffedit import StableDiffusionDiffEditPipeline
|
||||
from .stable_diffusion_gligen import StableDiffusionGLIGENPipeline, StableDiffusionGLIGENTextImagePipeline
|
||||
|
||||
@@ -1,48 +0,0 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
DIFFUSERS_SLOW_IMPORT,
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_import_structure = {}
|
||||
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_aura_flow"] = ["AuraFlowPipeline"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .pipeline_aura_flow import AuraFlowPipeline
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
@@ -1,591 +0,0 @@
|
||||
# Copyright 2024 AuraFlow Authors and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import inspect
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from transformers import T5Tokenizer, UMT5EncoderModel
|
||||
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...models import AuraFlowTransformer2DModel, AutoencoderKL
|
||||
from ...models.attention_processor import AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from diffusers import AuraFlowPipeline
|
||||
|
||||
>>> pipe = AuraFlowPipeline.from_pretrained("fal/AuraFlow", torch_dtype=torch.float16)
|
||||
>>> pipe = pipe.to("cuda")
|
||||
>>> prompt = "A cat holding a sign that says hello world"
|
||||
>>> image = pipe(prompt).images[0]
|
||||
>>> image.save("aura_flow.png")
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accept_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class AuraFlowPipeline(DiffusionPipeline):
|
||||
r"""
|
||||
Args:
|
||||
tokenizer (`T5TokenizerFast`):
|
||||
Tokenizer of class
|
||||
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
|
||||
text_encoder ([`T5EncoderModel`]):
|
||||
Frozen text-encoder. AuraFlow uses
|
||||
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
|
||||
[EleutherAI/pile-t5-xl](https://huggingface.co/EleutherAI/pile-t5-xl) variant.
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
transformer ([`AuraFlowTransformer2DModel`]):
|
||||
Conditional Transformer (MMDiT and DiT) architecture to denoise the encoded image latents.
|
||||
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
"""
|
||||
|
||||
_optional_components = []
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: T5Tokenizer,
|
||||
text_encoder: UMT5EncoderModel,
|
||||
vae: AutoencoderKL,
|
||||
transformer: AuraFlowTransformer2DModel,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
|
||||
)
|
||||
|
||||
self.vae_scale_factor = (
|
||||
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
|
||||
)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
negative_prompt,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
prompt_attention_mask=None,
|
||||
negative_prompt_attention_mask=None,
|
||||
):
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and prompt_attention_mask is None:
|
||||
raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
|
||||
|
||||
if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
|
||||
raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
|
||||
|
||||
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
||||
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
||||
raise ValueError(
|
||||
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:
|
||||
raise ValueError(
|
||||
"`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`"
|
||||
f" {negative_prompt_attention_mask.shape}."
|
||||
)
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
negative_prompt: Union[str, List[str]] = None,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
num_images_per_prompt: int = 1,
|
||||
device: Optional[torch.device] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
max_sequence_length: int = 256,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`
|
||||
instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
|
||||
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
||||
whether to use classifier free guidance or not
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
number of images that should be generated per prompt
|
||||
device: (`torch.device`, *optional*):
|
||||
torch device to place the resulting embeddings on
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
prompt_attention_mask (`torch.Tensor`, *optional*):
|
||||
Pre-generated attention mask for text embeddings.
|
||||
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings.
|
||||
negative_prompt_attention_mask (`torch.Tensor`, *optional*):
|
||||
Pre-generated attention mask for negative text embeddings.
|
||||
max_sequence_length (`int`, defaults to 256): Maximum sequence length to use for the prompt.
|
||||
"""
|
||||
if device is None:
|
||||
device = self._execution_device
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
max_length = max_sequence_length
|
||||
if prompt_embeds is None:
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
truncation=True,
|
||||
max_length=max_length,
|
||||
padding="max_length",
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
|
||||
text_input_ids = text_inputs["input_ids"]
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
||||
text_input_ids, untruncated_ids
|
||||
):
|
||||
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because T5 can only handle sequences up to"
|
||||
f" {max_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
prompt_embeds = self.text_encoder(**text_inputs)[0]
|
||||
prompt_attention_mask = text_inputs["attention_mask"].unsqueeze(-1).expand(prompt_embeds.shape)
|
||||
prompt_embeds = prompt_embeds * prompt_attention_mask
|
||||
|
||||
if self.text_encoder is not None:
|
||||
dtype = self.text_encoder.dtype
|
||||
elif self.transformer is not None:
|
||||
dtype = self.transformer.dtype
|
||||
else:
|
||||
dtype = None
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
||||
prompt_attention_mask = prompt_attention_mask.reshape(bs_embed, -1)
|
||||
prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
|
||||
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
negative_prompt = negative_prompt or ""
|
||||
uncond_tokens = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt
|
||||
max_length = prompt_embeds.shape[1]
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_tokens,
|
||||
truncation=True,
|
||||
max_length=max_length,
|
||||
padding="max_length",
|
||||
return_tensors="pt",
|
||||
)
|
||||
uncond_input = {k: v.to(device) for k, v in uncond_input.items()}
|
||||
negative_prompt_embeds = self.text_encoder(**uncond_input)[0]
|
||||
negative_prompt_attention_mask = (
|
||||
uncond_input["attention_mask"].unsqueeze(-1).expand(negative_prompt_embeds.shape)
|
||||
)
|
||||
negative_prompt_embeds = negative_prompt_embeds * negative_prompt_attention_mask
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = negative_prompt_embeds.shape[1]
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
negative_prompt_attention_mask = negative_prompt_attention_mask.reshape(bs_embed, -1)
|
||||
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
|
||||
else:
|
||||
negative_prompt_embeds = None
|
||||
negative_prompt_attention_mask = None
|
||||
|
||||
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.prepare_latents
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
device,
|
||||
generator,
|
||||
latents=None,
|
||||
):
|
||||
if latents is not None:
|
||||
return latents.to(device=device, dtype=dtype)
|
||||
|
||||
shape = (
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
int(height) // self.vae_scale_factor,
|
||||
int(width) // self.vae_scale_factor,
|
||||
)
|
||||
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
|
||||
return latents
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.upcast_vae
|
||||
def upcast_vae(self):
|
||||
dtype = self.vae.dtype
|
||||
self.vae.to(dtype=torch.float32)
|
||||
use_torch_2_0_or_xformers = isinstance(
|
||||
self.vae.decoder.mid_block.attentions[0].processor,
|
||||
(
|
||||
AttnProcessor2_0,
|
||||
XFormersAttnProcessor,
|
||||
FusedAttnProcessor2_0,
|
||||
),
|
||||
)
|
||||
# if xformers or torch_2_0 is used attention block does not need
|
||||
# to be in float32 which can save lots of memory
|
||||
if use_torch_2_0_or_xformers:
|
||||
self.vae.post_quant_conv.to(dtype)
|
||||
self.vae.decoder.conv_in.to(dtype)
|
||||
self.vae.decoder.mid_block.to(dtype)
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
negative_prompt: Union[str, List[str]] = None,
|
||||
num_inference_steps: int = 50,
|
||||
timesteps: List[int] = None,
|
||||
sigmas: List[float] = None,
|
||||
guidance_scale: float = 3.5,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
height: Optional[int] = 512,
|
||||
width: Optional[int] = 512,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
max_sequence_length: int = 256,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
) -> Union[ImagePipelineOutput, Tuple]:
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
height (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor):
|
||||
The height in pixels of the generated image. This is set to 512 by default.
|
||||
width (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor):
|
||||
The width in pixels of the generated image. This is set to 512 by default.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
||||
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
||||
passed will be used. Must be in descending order.
|
||||
guidance_scale (`float`, *optional*, defaults to 5.0):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
prompt_attention_mask (`torch.Tensor`, *optional*):
|
||||
Pre-generated attention mask for text embeddings.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
negative_prompt_attention_mask (`torch.Tensor`, *optional*):
|
||||
Pre-generated attention mask for negative text embeddings.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
|
||||
of a plain tuple.
|
||||
max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns: [`~pipelines.ImagePipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is returned
|
||||
where the first element is a list with the generated images.
|
||||
"""
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
height = height or self.transformer.config.sample_size * self.vae_scale_factor
|
||||
width = width or self.transformer.config.sample_size * self.vae_scale_factor
|
||||
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
negative_prompt,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
prompt_attention_mask,
|
||||
negative_prompt_attention_mask,
|
||||
)
|
||||
|
||||
# 2. Determine batch size.
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
# 3. Encode input prompt
|
||||
(
|
||||
prompt_embeds,
|
||||
prompt_attention_mask,
|
||||
negative_prompt_embeds,
|
||||
negative_prompt_attention_mask,
|
||||
) = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
prompt_attention_mask=prompt_attention_mask,
|
||||
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
if do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
|
||||
# sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler, num_inference_steps, device, timesteps, sigmas
|
||||
)
|
||||
|
||||
# 5. Prepare latents.
|
||||
latent_channels = self.transformer.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
latent_channels,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# 6. Denoising loop
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
|
||||
# aura use timestep value between 0 and 1, with t=1 as noise and t=0 as the image
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = torch.tensor([t / 1000]).expand(latent_model_input.shape[0])
|
||||
timestep = timestep.to(latents.device, dtype=latents.dtype)
|
||||
|
||||
# predict noise model_output
|
||||
noise_pred = self.transformer(
|
||||
latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep=timestep,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if output_type == "latent":
|
||||
image = latents
|
||||
else:
|
||||
# make sure the VAE is in float32 mode, as it overflows in float16
|
||||
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
||||
if needs_upcasting:
|
||||
self.upcast_vae()
|
||||
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return ImagePipelineOutput(images=image)
|
||||
@@ -47,7 +47,6 @@ from .kandinsky2_2 import (
|
||||
from .kandinsky3 import Kandinsky3Img2ImgPipeline, Kandinsky3Pipeline
|
||||
from .latent_consistency_models import LatentConsistencyModelImg2ImgPipeline, LatentConsistencyModelPipeline
|
||||
from .pag import (
|
||||
StableDiffusionControlNetPAGPipeline,
|
||||
StableDiffusionPAGPipeline,
|
||||
StableDiffusionXLControlNetPAGPipeline,
|
||||
StableDiffusionXLPAGImg2ImgPipeline,
|
||||
@@ -91,7 +90,6 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
|
||||
("pixart-alpha", PixArtAlphaPipeline),
|
||||
("pixart-sigma", PixArtSigmaPipeline),
|
||||
("stable-diffusion-pag", StableDiffusionPAGPipeline),
|
||||
("stable-diffusion-controlnet-pag", StableDiffusionControlNetPAGPipeline),
|
||||
("stable-diffusion-xl-pag", StableDiffusionXLPAGPipeline),
|
||||
("stable-diffusion-xl-controlnet-pag", StableDiffusionXLControlNetPAGPipeline),
|
||||
]
|
||||
|
||||
@@ -1,53 +0,0 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
DIFFUSERS_SLOW_IMPORT,
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_import_structure = {}
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_kolors"] = ["KolorsPipeline"]
|
||||
_import_structure["pipeline_kolors_img2img"] = ["KolorsImg2ImgPipeline"]
|
||||
_import_structure["text_encoder"] = ["ChatGLMModel"]
|
||||
_import_structure["tokenizer"] = ["ChatGLMTokenizer"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import *
|
||||
|
||||
else:
|
||||
from .pipeline_kolors import KolorsPipeline
|
||||
from .pipeline_kolors_img2img import KolorsImg2ImgPipeline
|
||||
from .text_encoder import ChatGLMModel
|
||||
from .tokenizer import ChatGLMTokenizer
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
@@ -1,936 +0,0 @@
|
||||
# Copyright 2024 Stability AI, Kwai-Kolors Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import inspect
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import StableDiffusionXLLoraLoaderMixin
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...models.attention_processor import AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import is_torch_xla_available, logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
from .pipeline_output import KolorsPipelineOutput
|
||||
from .text_encoder import ChatGLMModel
|
||||
from .tokenizer import ChatGLMTokenizer
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from diffusers import KolorsPipeline
|
||||
|
||||
>>> pipe = KolorsPipeline.from_pretrained(
|
||||
... "Kwai-Kolors/Kolors-diffusers", variant="fp16", torch_dtype=torch.float16
|
||||
... )
|
||||
>>> pipe = pipe.to("cuda")
|
||||
|
||||
>>> prompt = (
|
||||
... "A photo of a ladybug, macro, zoom, high quality, film, holding a wooden sign with the text 'KOLORS'"
|
||||
... )
|
||||
>>> image = pipe(prompt).images[0]
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accept_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class KolorsPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffusionXLLoraLoaderMixin):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using Kolors.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
The pipeline also inherits the following loading methods:
|
||||
- [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
||||
- [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`ChatGLMModel`]):
|
||||
Frozen text-encoder. Kolors uses [ChatGLM3-6B](https://huggingface.co/THUDM/chatglm3-6b).
|
||||
tokenizer (`ChatGLMTokenizer`):
|
||||
Tokenizer of class
|
||||
[ChatGLMTokenizer](https://huggingface.co/THUDM/chatglm3-6b/blob/main/tokenization_chatglm.py).
|
||||
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"False"`):
|
||||
Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of
|
||||
`Kwai-Kolors/Kolors-diffusers`.
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->unet->vae"
|
||||
_callback_tensor_inputs = [
|
||||
"latents",
|
||||
"prompt_embeds",
|
||||
"negative_prompt_embeds",
|
||||
"add_text_embeds",
|
||||
"add_time_ids",
|
||||
"negative_pooled_prompt_embeds",
|
||||
"negative_add_time_ids",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: ChatGLMModel,
|
||||
tokenizer: ChatGLMTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
force_zeros_for_empty_prompt: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
|
||||
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
|
||||
self.vae_scale_factor = (
|
||||
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
|
||||
)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
|
||||
self.default_sample_size = self.unet.config.sample_size
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
device: Optional[torch.device] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
negative_prompt=None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
pooled_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
max_sequence_length: int = 256,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
device: (`torch.device`):
|
||||
torch device
|
||||
num_images_per_prompt (`int`):
|
||||
number of images that should be generated per prompt
|
||||
do_classifier_free_guidance (`bool`):
|
||||
whether to use classifier free guidance or not
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
pooled_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
||||
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
|
||||
input argument.
|
||||
max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
|
||||
"""
|
||||
# from IPython import embed; embed(); exit()
|
||||
device = device or self._execution_device
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
# Define tokenizers and text encoders
|
||||
tokenizers = [self.tokenizer]
|
||||
text_encoders = [self.text_encoder]
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds_list = []
|
||||
for tokenizer, text_encoder in zip(tokenizers, text_encoders):
|
||||
text_inputs = tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
).to(device)
|
||||
output = text_encoder(
|
||||
input_ids=text_inputs["input_ids"],
|
||||
attention_mask=text_inputs["attention_mask"],
|
||||
position_ids=text_inputs["position_ids"],
|
||||
output_hidden_states=True,
|
||||
)
|
||||
|
||||
# [max_sequence_length, batch, hidden_size] -> [batch, max_sequence_length, hidden_size]
|
||||
# clone to have a contiguous tensor
|
||||
prompt_embeds = output.hidden_states[-2].permute(1, 0, 2).clone()
|
||||
# [max_sequence_length, batch, hidden_size] -> [batch, hidden_size]
|
||||
pooled_prompt_embeds = output.hidden_states[-1][-1, :, :].clone()
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
prompt_embeds_list.append(prompt_embeds)
|
||||
|
||||
prompt_embeds = prompt_embeds_list[0]
|
||||
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
|
||||
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
|
||||
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
|
||||
elif do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
uncond_tokens: List[str]
|
||||
if negative_prompt is None:
|
||||
uncond_tokens = [""] * batch_size
|
||||
elif prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif isinstance(negative_prompt, str):
|
||||
uncond_tokens = [negative_prompt]
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
else:
|
||||
uncond_tokens = negative_prompt
|
||||
|
||||
negative_prompt_embeds_list = []
|
||||
|
||||
for tokenizer, text_encoder in zip(tokenizers, text_encoders):
|
||||
uncond_input = tokenizer(
|
||||
uncond_tokens,
|
||||
padding="max_length",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
).to(device)
|
||||
output = text_encoder(
|
||||
input_ids=uncond_input["input_ids"],
|
||||
attention_mask=uncond_input["attention_mask"],
|
||||
position_ids=uncond_input["position_ids"],
|
||||
output_hidden_states=True,
|
||||
)
|
||||
|
||||
# [max_sequence_length, batch, hidden_size] -> [batch, max_sequence_length, hidden_size]
|
||||
# clone to have a contiguous tensor
|
||||
negative_prompt_embeds = output.hidden_states[-2].permute(1, 0, 2).clone()
|
||||
# [max_sequence_length, batch, hidden_size] -> [batch, hidden_size]
|
||||
negative_pooled_prompt_embeds = output.hidden_states[-1][-1, :, :].clone()
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = negative_prompt_embeds.shape[1]
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder.dtype, device=device)
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(
|
||||
batch_size * num_images_per_prompt, seq_len, -1
|
||||
)
|
||||
|
||||
negative_prompt_embeds_list.append(negative_prompt_embeds)
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds_list[0]
|
||||
|
||||
bs_embed = pooled_prompt_embeds.shape[0]
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
|
||||
bs_embed * num_images_per_prompt, -1
|
||||
)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
|
||||
bs_embed * num_images_per_prompt, -1
|
||||
)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
|
||||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
# check if the scheduler accepts generator
|
||||
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
if accepts_generator:
|
||||
extra_step_kwargs["generator"] = generator
|
||||
return extra_step_kwargs
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
negative_prompt=None,
|
||||
prompt_embeds=None,
|
||||
pooled_prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
negative_pooled_prompt_embeds=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
max_sequence_length=None,
|
||||
):
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
||||
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
||||
raise ValueError(
|
||||
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and pooled_prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
|
||||
)
|
||||
|
||||
if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
|
||||
)
|
||||
|
||||
if max_sequence_length is not None and max_sequence_length > 256:
|
||||
raise ValueError(f"`max_sequence_length` cannot be greater than 256 but is {max_sequence_length}")
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
||||
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
||||
shape = (
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
int(height) // self.vae_scale_factor,
|
||||
int(width) // self.vae_scale_factor,
|
||||
)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
latents = latents.to(device)
|
||||
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids
|
||||
def _get_add_time_ids(
|
||||
self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None
|
||||
):
|
||||
add_time_ids = list(original_size + crops_coords_top_left + target_size)
|
||||
|
||||
passed_add_embed_dim = (
|
||||
self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
|
||||
)
|
||||
expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
|
||||
|
||||
if expected_add_embed_dim != passed_add_embed_dim:
|
||||
raise ValueError(
|
||||
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
|
||||
)
|
||||
|
||||
add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
|
||||
return add_time_ids
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.upcast_vae
|
||||
def upcast_vae(self):
|
||||
dtype = self.vae.dtype
|
||||
self.vae.to(dtype=torch.float32)
|
||||
use_torch_2_0_or_xformers = isinstance(
|
||||
self.vae.decoder.mid_block.attentions[0].processor,
|
||||
(
|
||||
AttnProcessor2_0,
|
||||
XFormersAttnProcessor,
|
||||
FusedAttnProcessor2_0,
|
||||
),
|
||||
)
|
||||
# if xformers or torch_2_0 is used attention block does not need
|
||||
# to be in float32 which can save lots of memory
|
||||
if use_torch_2_0_or_xformers:
|
||||
self.vae.post_quant_conv.to(dtype)
|
||||
self.vae.decoder.conv_in.to(dtype)
|
||||
self.vae.decoder.mid_block.to(dtype)
|
||||
|
||||
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
|
||||
def get_guidance_scale_embedding(
|
||||
self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
|
||||
|
||||
Args:
|
||||
w (`torch.Tensor`):
|
||||
Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.
|
||||
embedding_dim (`int`, *optional*, defaults to 512):
|
||||
Dimension of the embeddings to generate.
|
||||
dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
|
||||
Data type of the generated embeddings.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`.
|
||||
"""
|
||||
assert len(w.shape) == 1
|
||||
w = w * 1000.0
|
||||
|
||||
half_dim = embedding_dim // 2
|
||||
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
|
||||
emb = w.to(dtype)[:, None] * emb[None, :]
|
||||
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
||||
if embedding_dim % 2 == 1: # zero pad
|
||||
emb = torch.nn.functional.pad(emb, (0, 1))
|
||||
assert emb.shape == (w.shape[0], embedding_dim)
|
||||
return emb
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
|
||||
|
||||
@property
|
||||
def cross_attention_kwargs(self):
|
||||
return self._cross_attention_kwargs
|
||||
|
||||
@property
|
||||
def denoising_end(self):
|
||||
return self._denoising_end
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 50,
|
||||
timesteps: List[int] = None,
|
||||
sigmas: List[float] = None,
|
||||
denoising_end: Optional[float] = None,
|
||||
guidance_scale: float = 5.0,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
pooled_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
original_size: Optional[Tuple[int, int]] = None,
|
||||
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
||||
target_size: Optional[Tuple[int, int]] = None,
|
||||
negative_original_size: Optional[Tuple[int, int]] = None,
|
||||
negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
|
||||
negative_target_size: Optional[Tuple[int, int]] = None,
|
||||
callback_on_step_end: Optional[
|
||||
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
||||
] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 256,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
||||
Anything below 512 pixels won't work well for
|
||||
[Kwai-Kolors/Kolors-diffusers](https://huggingface.co/Kwai-Kolors/Kolors-diffusers) and checkpoints
|
||||
that are not specifically fine-tuned on low resolutions.
|
||||
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
||||
Anything below 512 pixels won't work well for
|
||||
[Kwai-Kolors/Kolors-diffusers](https://huggingface.co/Kwai-Kolors/Kolors-diffusers) and checkpoints
|
||||
that are not specifically fine-tuned on low resolutions.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
||||
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
||||
passed will be used. Must be in descending order.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
||||
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
||||
will be used.
|
||||
denoising_end (`float`, *optional*):
|
||||
When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
|
||||
completed before it is intentionally prematurely terminated. As a result, the returned sample will
|
||||
still retain a substantial amount of noise as determined by the discrete timesteps selected by the
|
||||
scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
|
||||
"Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
|
||||
Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
|
||||
guidance_scale (`float`, *optional*, defaults to 5.0):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
eta (`float`, *optional*, defaults to 0.0):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.Tensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
pooled_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
||||
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
|
||||
input argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.kolors.KolorsPipelineOutput`] instead of a plain tuple.
|
||||
cross_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
||||
If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
|
||||
`original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
|
||||
explained in section 2.2 of
|
||||
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
||||
crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
|
||||
`crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
|
||||
`crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
|
||||
`crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
|
||||
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
||||
target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
||||
For most cases, `target_size` should be set to the desired height and width of the generated image. If
|
||||
not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
|
||||
section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
||||
negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
||||
To negatively condition the generation process based on a specific image resolution. Part of SDXL's
|
||||
micro-conditioning as explained in section 2.2 of
|
||||
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
|
||||
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
|
||||
negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
|
||||
To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
|
||||
micro-conditioning as explained in section 2.2 of
|
||||
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
|
||||
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
|
||||
negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
||||
To negatively condition the generation process based on a target image resolution. It should be as same
|
||||
as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
|
||||
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
|
||||
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
|
||||
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
|
||||
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
|
||||
each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
|
||||
DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
|
||||
list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.kolors.KolorsPipelineOutput`] or `tuple`: [`~pipelines.kolors.KolorsPipelineOutput`] if
|
||||
`return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the
|
||||
generated images.
|
||||
"""
|
||||
|
||||
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
|
||||
# 0. Default height and width to unet
|
||||
height = height or self.default_sample_size * self.vae_scale_factor
|
||||
width = width or self.default_sample_size * self.vae_scale_factor
|
||||
|
||||
original_size = original_size or (height, width)
|
||||
target_size = target_size or (height, width)
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
negative_prompt,
|
||||
prompt_embeds,
|
||||
pooled_prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._cross_attention_kwargs = cross_attention_kwargs
|
||||
self._denoising_end = denoising_end
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
# 3. Encode input prompt
|
||||
(
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
) = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
||||
)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler, num_inference_steps, device, timesteps, sigmas
|
||||
)
|
||||
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = self.unet.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 7. Prepare added time ids & embeddings
|
||||
add_text_embeds = pooled_prompt_embeds
|
||||
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
|
||||
|
||||
add_time_ids = self._get_add_time_ids(
|
||||
original_size,
|
||||
crops_coords_top_left,
|
||||
target_size,
|
||||
dtype=prompt_embeds.dtype,
|
||||
text_encoder_projection_dim=text_encoder_projection_dim,
|
||||
)
|
||||
if negative_original_size is not None and negative_target_size is not None:
|
||||
negative_add_time_ids = self._get_add_time_ids(
|
||||
negative_original_size,
|
||||
negative_crops_coords_top_left,
|
||||
negative_target_size,
|
||||
dtype=prompt_embeds.dtype,
|
||||
text_encoder_projection_dim=text_encoder_projection_dim,
|
||||
)
|
||||
else:
|
||||
negative_add_time_ids = add_time_ids
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
|
||||
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(device)
|
||||
add_text_embeds = add_text_embeds.to(device)
|
||||
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
|
||||
|
||||
# 8. Denoising loop
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
|
||||
# 8.1 Apply denoising_end
|
||||
if (
|
||||
self.denoising_end is not None
|
||||
and isinstance(self.denoising_end, float)
|
||||
and self.denoising_end > 0
|
||||
and self.denoising_end < 1
|
||||
):
|
||||
discrete_timestep_cutoff = int(
|
||||
round(
|
||||
self.scheduler.config.num_train_timesteps
|
||||
- (self.denoising_end * self.scheduler.config.num_train_timesteps)
|
||||
)
|
||||
)
|
||||
num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
|
||||
timesteps = timesteps[:num_inference_steps]
|
||||
|
||||
# 9. Optionally get Guidance Scale Embedding
|
||||
timestep_cond = None
|
||||
if self.unet.config.time_cond_proj_dim is not None:
|
||||
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
|
||||
timestep_cond = self.get_guidance_scale_embedding(
|
||||
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
|
||||
).to(device=device, dtype=latents.dtype)
|
||||
|
||||
self._num_timesteps = len(timesteps)
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# predict the noise residual
|
||||
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
|
||||
|
||||
noise_pred = self.unet(
|
||||
latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep_cond=timestep_cond,
|
||||
cross_attention_kwargs=self.cross_attention_kwargs,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents_dtype = latents.dtype
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||||
if latents.dtype != latents_dtype:
|
||||
if torch.backends.mps.is_available():
|
||||
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
||||
latents = latents.to(latents_dtype)
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
|
||||
negative_pooled_prompt_embeds = callback_outputs.pop(
|
||||
"negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
|
||||
)
|
||||
add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
|
||||
negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids)
|
||||
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
if not output_type == "latent":
|
||||
# make sure the VAE is in float32 mode, as it overflows in float16
|
||||
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
||||
|
||||
if needs_upcasting:
|
||||
self.upcast_vae()
|
||||
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
||||
elif latents.dtype != self.vae.dtype:
|
||||
if torch.backends.mps.is_available():
|
||||
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
||||
self.vae = self.vae.to(latents.dtype)
|
||||
|
||||
# unscale/denormalize the latents
|
||||
latents = latents / self.vae.config.scaling_factor
|
||||
|
||||
image = self.vae.decode(latents, return_dict=False)[0]
|
||||
|
||||
# cast back to fp16 if needed
|
||||
if needs_upcasting:
|
||||
self.vae.to(dtype=torch.float16)
|
||||
else:
|
||||
image = latents
|
||||
|
||||
if not output_type == "latent":
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return KolorsPipelineOutput(images=image)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,21 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
|
||||
from ...utils import BaseOutput
|
||||
|
||||
|
||||
@dataclass
|
||||
class KolorsPipelineOutput(BaseOutput):
|
||||
"""
|
||||
Output class for Kolors pipelines.
|
||||
|
||||
Args:
|
||||
images (`List[PIL.Image.Image]` or `np.ndarray`)
|
||||
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
|
||||
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
|
||||
"""
|
||||
|
||||
images: Union[List[PIL.Image.Image], np.ndarray]
|
||||
@@ -1,889 +0,0 @@
|
||||
# Copyright 2024 ChatGLM3-6B Model Team, Kwai-Kolors Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from torch.nn import LayerNorm
|
||||
from torch.nn.utils import skip_init
|
||||
from transformers import PretrainedConfig, PreTrainedModel
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class ChatGLMConfig(PretrainedConfig):
|
||||
model_type = "chatglm"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_layers=28,
|
||||
padded_vocab_size=65024,
|
||||
hidden_size=4096,
|
||||
ffn_hidden_size=13696,
|
||||
kv_channels=128,
|
||||
num_attention_heads=32,
|
||||
seq_length=2048,
|
||||
hidden_dropout=0.0,
|
||||
classifier_dropout=None,
|
||||
attention_dropout=0.0,
|
||||
layernorm_epsilon=1e-5,
|
||||
rmsnorm=True,
|
||||
apply_residual_connection_post_layernorm=False,
|
||||
post_layer_norm=True,
|
||||
add_bias_linear=False,
|
||||
add_qkv_bias=False,
|
||||
bias_dropout_fusion=True,
|
||||
multi_query_attention=False,
|
||||
multi_query_group_num=1,
|
||||
apply_query_key_layer_scaling=True,
|
||||
attention_softmax_in_fp32=True,
|
||||
fp32_residual_connection=False,
|
||||
quantization_bit=0,
|
||||
pre_seq_len=None,
|
||||
prefix_projection=False,
|
||||
**kwargs,
|
||||
):
|
||||
self.num_layers = num_layers
|
||||
self.vocab_size = padded_vocab_size
|
||||
self.padded_vocab_size = padded_vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.ffn_hidden_size = ffn_hidden_size
|
||||
self.kv_channels = kv_channels
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.seq_length = seq_length
|
||||
self.hidden_dropout = hidden_dropout
|
||||
self.classifier_dropout = classifier_dropout
|
||||
self.attention_dropout = attention_dropout
|
||||
self.layernorm_epsilon = layernorm_epsilon
|
||||
self.rmsnorm = rmsnorm
|
||||
self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
|
||||
self.post_layer_norm = post_layer_norm
|
||||
self.add_bias_linear = add_bias_linear
|
||||
self.add_qkv_bias = add_qkv_bias
|
||||
self.bias_dropout_fusion = bias_dropout_fusion
|
||||
self.multi_query_attention = multi_query_attention
|
||||
self.multi_query_group_num = multi_query_group_num
|
||||
self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
|
||||
self.attention_softmax_in_fp32 = attention_softmax_in_fp32
|
||||
self.fp32_residual_connection = fp32_residual_connection
|
||||
self.quantization_bit = quantization_bit
|
||||
self.pre_seq_len = pre_seq_len
|
||||
self.prefix_projection = prefix_projection
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs):
|
||||
super().__init__()
|
||||
self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype))
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor):
|
||||
input_dtype = hidden_states.dtype
|
||||
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
|
||||
|
||||
return (self.weight * hidden_states).to(input_dtype)
|
||||
|
||||
|
||||
def _config_to_kwargs(args):
|
||||
common_kwargs = {
|
||||
"dtype": args.torch_dtype,
|
||||
}
|
||||
return common_kwargs
|
||||
|
||||
|
||||
class CoreAttention(torch.nn.Module):
|
||||
def __init__(self, config: ChatGLMConfig, layer_number):
|
||||
super(CoreAttention, self).__init__()
|
||||
|
||||
self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling
|
||||
self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
|
||||
if self.apply_query_key_layer_scaling:
|
||||
self.attention_softmax_in_fp32 = True
|
||||
self.layer_number = max(1, layer_number)
|
||||
|
||||
projection_size = config.kv_channels * config.num_attention_heads
|
||||
|
||||
# Per attention head and per partition values.
|
||||
self.hidden_size_per_partition = projection_size
|
||||
self.hidden_size_per_attention_head = projection_size // config.num_attention_heads
|
||||
self.num_attention_heads_per_partition = config.num_attention_heads
|
||||
|
||||
coeff = None
|
||||
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
|
||||
if self.apply_query_key_layer_scaling:
|
||||
coeff = self.layer_number
|
||||
self.norm_factor *= coeff
|
||||
self.coeff = coeff
|
||||
|
||||
self.attention_dropout = torch.nn.Dropout(config.attention_dropout)
|
||||
|
||||
def forward(self, query_layer, key_layer, value_layer, attention_mask):
|
||||
pytorch_major_version = int(torch.__version__.split(".")[0])
|
||||
if pytorch_major_version >= 2:
|
||||
query_layer, key_layer, value_layer = [
|
||||
k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]
|
||||
]
|
||||
if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
|
||||
context_layer = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_layer, key_layer, value_layer, is_causal=True
|
||||
)
|
||||
else:
|
||||
if attention_mask is not None:
|
||||
attention_mask = ~attention_mask
|
||||
context_layer = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_layer, key_layer, value_layer, attention_mask
|
||||
)
|
||||
context_layer = context_layer.permute(2, 0, 1, 3)
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
|
||||
context_layer = context_layer.reshape(*new_context_layer_shape)
|
||||
else:
|
||||
# Raw attention scores
|
||||
|
||||
# [b, np, sq, sk]
|
||||
output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0))
|
||||
|
||||
# [sq, b, np, hn] -> [sq, b * np, hn]
|
||||
query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1)
|
||||
# [sk, b, np, hn] -> [sk, b * np, hn]
|
||||
key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1)
|
||||
|
||||
# preallocting input tensor: [b * np, sq, sk]
|
||||
matmul_input_buffer = torch.empty(
|
||||
output_size[0] * output_size[1],
|
||||
output_size[2],
|
||||
output_size[3],
|
||||
dtype=query_layer.dtype,
|
||||
device=query_layer.device,
|
||||
)
|
||||
|
||||
# Raw attention scores. [b * np, sq, sk]
|
||||
matmul_result = torch.baddbmm(
|
||||
matmul_input_buffer,
|
||||
query_layer.transpose(0, 1), # [b * np, sq, hn]
|
||||
key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
|
||||
beta=0.0,
|
||||
alpha=(1.0 / self.norm_factor),
|
||||
)
|
||||
|
||||
# change view to [b, np, sq, sk]
|
||||
attention_scores = matmul_result.view(*output_size)
|
||||
|
||||
# ===========================
|
||||
# Attention probs and dropout
|
||||
# ===========================
|
||||
|
||||
# attention scores and attention mask [b, np, sq, sk]
|
||||
if self.attention_softmax_in_fp32:
|
||||
attention_scores = attention_scores.float()
|
||||
if self.coeff is not None:
|
||||
attention_scores = attention_scores * self.coeff
|
||||
if attention_mask is None and attention_scores.shape[2] == attention_scores.shape[3]:
|
||||
attention_mask = torch.ones(
|
||||
output_size[0], 1, output_size[2], output_size[3], device=attention_scores.device, dtype=torch.bool
|
||||
)
|
||||
attention_mask.tril_()
|
||||
attention_mask = ~attention_mask
|
||||
if attention_mask is not None:
|
||||
attention_scores = attention_scores.masked_fill(attention_mask, float("-inf"))
|
||||
attention_probs = F.softmax(attention_scores, dim=-1)
|
||||
attention_probs = attention_probs.type_as(value_layer)
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attention_probs = self.attention_dropout(attention_probs)
|
||||
# =========================
|
||||
# Context layer. [sq, b, hp]
|
||||
# =========================
|
||||
|
||||
# value_layer -> context layer.
|
||||
# [sk, b, np, hn] --> [b, np, sq, hn]
|
||||
|
||||
# context layer shape: [b, np, sq, hn]
|
||||
output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3))
|
||||
# change view [sk, b * np, hn]
|
||||
value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1)
|
||||
# change view [b * np, sq, sk]
|
||||
attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
|
||||
# matmul: [b * np, sq, hn]
|
||||
context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
|
||||
# change view [b, np, sq, hn]
|
||||
context_layer = context_layer.view(*output_size)
|
||||
# [b, np, sq, hn] --> [sq, b, np, hn]
|
||||
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
|
||||
# [sq, b, np, hn] --> [sq, b, hp]
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
|
||||
context_layer = context_layer.view(*new_context_layer_shape)
|
||||
|
||||
return context_layer
|
||||
|
||||
|
||||
def split_tensor_along_last_dim(
|
||||
tensor: torch.Tensor,
|
||||
num_partitions: int,
|
||||
contiguous_split_chunks: bool = False,
|
||||
) -> List[torch.Tensor]:
|
||||
"""Split a tensor along its last dimension.
|
||||
|
||||
Arguments:
|
||||
tensor: input tensor.
|
||||
num_partitions: number of partitions to split the tensor
|
||||
contiguous_split_chunks: If True, make each chunk contiguous
|
||||
in memory.
|
||||
|
||||
Returns:
|
||||
A list of Tensors
|
||||
"""
|
||||
# Get the size and dimension.
|
||||
last_dim = tensor.dim() - 1
|
||||
last_dim_size = tensor.size()[last_dim] // num_partitions
|
||||
# Split.
|
||||
tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
|
||||
# Note: torch.split does not create contiguous tensors by default.
|
||||
if contiguous_split_chunks:
|
||||
return tuple(chunk.contiguous() for chunk in tensor_list)
|
||||
|
||||
return tensor_list
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
|
||||
# x: [sq, b, np, hn]
|
||||
sq, _b, np, _hn = x.size(0), x.size(1), x.size(2), x.size(3)
|
||||
rot_dim = rope_cache.shape[-2] * 2
|
||||
x, x_pass = x[..., :rot_dim], x[..., rot_dim:]
|
||||
# truncate to support variable sizes
|
||||
rope_cache = rope_cache[:sq]
|
||||
xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2)
|
||||
rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2)
|
||||
x_out2 = torch.stack(
|
||||
[
|
||||
xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1],
|
||||
xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1],
|
||||
],
|
||||
-1,
|
||||
)
|
||||
x_out2 = x_out2.flatten(3)
|
||||
return torch.cat((x_out2, x_pass), dim=-1)
|
||||
|
||||
|
||||
class SelfAttention(torch.nn.Module):
|
||||
"""Parallel self-attention layer abstract class.
|
||||
|
||||
Self-attention layer takes input with size [s, b, h] and returns output of the same size.
|
||||
"""
|
||||
|
||||
def __init__(self, config: ChatGLMConfig, layer_number, device=None):
|
||||
super(SelfAttention, self).__init__()
|
||||
self.layer_number = max(1, layer_number)
|
||||
|
||||
self.projection_size = config.kv_channels * config.num_attention_heads
|
||||
|
||||
# Per attention head and per partition values.
|
||||
self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads
|
||||
self.num_attention_heads_per_partition = config.num_attention_heads
|
||||
|
||||
self.multi_query_attention = config.multi_query_attention
|
||||
self.qkv_hidden_size = 3 * self.projection_size
|
||||
if self.multi_query_attention:
|
||||
self.num_multi_query_groups_per_partition = config.multi_query_group_num
|
||||
self.qkv_hidden_size = (
|
||||
self.projection_size + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num
|
||||
)
|
||||
self.query_key_value = nn.Linear(
|
||||
config.hidden_size,
|
||||
self.qkv_hidden_size,
|
||||
bias=config.add_bias_linear or config.add_qkv_bias,
|
||||
device=device,
|
||||
**_config_to_kwargs(config),
|
||||
)
|
||||
|
||||
self.core_attention = CoreAttention(config, self.layer_number)
|
||||
|
||||
# Output.
|
||||
self.dense = nn.Linear(
|
||||
self.projection_size,
|
||||
config.hidden_size,
|
||||
bias=config.add_bias_linear,
|
||||
device=device,
|
||||
**_config_to_kwargs(config),
|
||||
)
|
||||
|
||||
def _allocate_memory(self, inference_max_sequence_len, batch_size, device=None, dtype=None):
|
||||
if self.multi_query_attention:
|
||||
num_attention_heads = self.num_multi_query_groups_per_partition
|
||||
else:
|
||||
num_attention_heads = self.num_attention_heads_per_partition
|
||||
return torch.empty(
|
||||
inference_max_sequence_len,
|
||||
batch_size,
|
||||
num_attention_heads,
|
||||
self.hidden_size_per_attention_head,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True):
|
||||
# hidden_states: [sq, b, h]
|
||||
|
||||
# =================================================
|
||||
# Pre-allocate memory for key-values for inference.
|
||||
# =================================================
|
||||
# =====================
|
||||
# Query, Key, and Value
|
||||
# =====================
|
||||
|
||||
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
|
||||
mixed_x_layer = self.query_key_value(hidden_states)
|
||||
|
||||
if self.multi_query_attention:
|
||||
(query_layer, key_layer, value_layer) = mixed_x_layer.split(
|
||||
[
|
||||
self.num_attention_heads_per_partition * self.hidden_size_per_attention_head,
|
||||
self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
|
||||
self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
query_layer = query_layer.view(
|
||||
query_layer.size()[:-1] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
|
||||
)
|
||||
key_layer = key_layer.view(
|
||||
key_layer.size()[:-1]
|
||||
+ (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
|
||||
)
|
||||
value_layer = value_layer.view(
|
||||
value_layer.size()[:-1]
|
||||
+ (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
|
||||
)
|
||||
else:
|
||||
new_tensor_shape = mixed_x_layer.size()[:-1] + (
|
||||
self.num_attention_heads_per_partition,
|
||||
3 * self.hidden_size_per_attention_head,
|
||||
)
|
||||
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
|
||||
|
||||
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
|
||||
(query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
|
||||
|
||||
# apply relative positional encoding (rotary embedding)
|
||||
if rotary_pos_emb is not None:
|
||||
query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb)
|
||||
key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb)
|
||||
|
||||
# adjust key and value for inference
|
||||
if kv_cache is not None:
|
||||
cache_k, cache_v = kv_cache
|
||||
key_layer = torch.cat((cache_k, key_layer), dim=0)
|
||||
value_layer = torch.cat((cache_v, value_layer), dim=0)
|
||||
if use_cache:
|
||||
kv_cache = (key_layer, value_layer)
|
||||
else:
|
||||
kv_cache = None
|
||||
|
||||
if self.multi_query_attention:
|
||||
key_layer = key_layer.unsqueeze(-2)
|
||||
key_layer = key_layer.expand(
|
||||
-1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1
|
||||
)
|
||||
key_layer = key_layer.contiguous().view(
|
||||
key_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
|
||||
)
|
||||
value_layer = value_layer.unsqueeze(-2)
|
||||
value_layer = value_layer.expand(
|
||||
-1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1
|
||||
)
|
||||
value_layer = value_layer.contiguous().view(
|
||||
value_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
|
||||
)
|
||||
|
||||
# ==================================
|
||||
# core attention computation
|
||||
# ==================================
|
||||
|
||||
context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask)
|
||||
|
||||
# =================
|
||||
# Output. [sq, b, h]
|
||||
# =================
|
||||
|
||||
output = self.dense(context_layer)
|
||||
|
||||
return output, kv_cache
|
||||
|
||||
|
||||
class MLP(torch.nn.Module):
|
||||
"""MLP.
|
||||
|
||||
MLP will take the input with h hidden state, project it to 4*h hidden dimension, perform nonlinear transformation,
|
||||
and project the state back into h hidden dimension.
|
||||
"""
|
||||
|
||||
def __init__(self, config: ChatGLMConfig, device=None):
|
||||
super(MLP, self).__init__()
|
||||
|
||||
self.add_bias = config.add_bias_linear
|
||||
|
||||
# Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
|
||||
self.dense_h_to_4h = nn.Linear(
|
||||
config.hidden_size,
|
||||
config.ffn_hidden_size * 2,
|
||||
bias=self.add_bias,
|
||||
device=device,
|
||||
**_config_to_kwargs(config),
|
||||
)
|
||||
|
||||
def swiglu(x):
|
||||
x = torch.chunk(x, 2, dim=-1)
|
||||
return F.silu(x[0]) * x[1]
|
||||
|
||||
self.activation_func = swiglu
|
||||
|
||||
# Project back to h.
|
||||
self.dense_4h_to_h = nn.Linear(
|
||||
config.ffn_hidden_size, config.hidden_size, bias=self.add_bias, device=device, **_config_to_kwargs(config)
|
||||
)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
# [s, b, 4hp]
|
||||
intermediate_parallel = self.dense_h_to_4h(hidden_states)
|
||||
intermediate_parallel = self.activation_func(intermediate_parallel)
|
||||
# [s, b, h]
|
||||
output = self.dense_4h_to_h(intermediate_parallel)
|
||||
return output
|
||||
|
||||
|
||||
class GLMBlock(torch.nn.Module):
|
||||
"""A single transformer layer.
|
||||
|
||||
Transformer layer takes input with size [s, b, h] and returns an output of the same size.
|
||||
"""
|
||||
|
||||
def __init__(self, config: ChatGLMConfig, layer_number, device=None):
|
||||
super(GLMBlock, self).__init__()
|
||||
self.layer_number = layer_number
|
||||
|
||||
self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
|
||||
|
||||
self.fp32_residual_connection = config.fp32_residual_connection
|
||||
|
||||
LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
|
||||
# Layernorm on the input data.
|
||||
self.input_layernorm = LayerNormFunc(
|
||||
config.hidden_size, eps=config.layernorm_epsilon, device=device, dtype=config.torch_dtype
|
||||
)
|
||||
|
||||
# Self attention.
|
||||
self.self_attention = SelfAttention(config, layer_number, device=device)
|
||||
self.hidden_dropout = config.hidden_dropout
|
||||
|
||||
# Layernorm on the attention output
|
||||
self.post_attention_layernorm = LayerNormFunc(
|
||||
config.hidden_size, eps=config.layernorm_epsilon, device=device, dtype=config.torch_dtype
|
||||
)
|
||||
|
||||
# MLP
|
||||
self.mlp = MLP(config, device=device)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
rotary_pos_emb,
|
||||
kv_cache=None,
|
||||
use_cache=True,
|
||||
):
|
||||
# hidden_states: [s, b, h]
|
||||
|
||||
# Layer norm at the beginning of the transformer layer.
|
||||
layernorm_output = self.input_layernorm(hidden_states)
|
||||
# Self attention.
|
||||
attention_output, kv_cache = self.self_attention(
|
||||
layernorm_output, attention_mask, rotary_pos_emb, kv_cache=kv_cache, use_cache=use_cache
|
||||
)
|
||||
|
||||
# Residual connection.
|
||||
if self.apply_residual_connection_post_layernorm:
|
||||
residual = layernorm_output
|
||||
else:
|
||||
residual = hidden_states
|
||||
|
||||
layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training)
|
||||
layernorm_input = residual + layernorm_input
|
||||
|
||||
# Layer norm post the self attention.
|
||||
layernorm_output = self.post_attention_layernorm(layernorm_input)
|
||||
|
||||
# MLP.
|
||||
mlp_output = self.mlp(layernorm_output)
|
||||
|
||||
# Second residual connection.
|
||||
if self.apply_residual_connection_post_layernorm:
|
||||
residual = layernorm_output
|
||||
else:
|
||||
residual = layernorm_input
|
||||
|
||||
output = torch.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=self.training)
|
||||
output = residual + output
|
||||
|
||||
return output, kv_cache
|
||||
|
||||
|
||||
class GLMTransformer(torch.nn.Module):
|
||||
"""Transformer class."""
|
||||
|
||||
def __init__(self, config: ChatGLMConfig, device=None):
|
||||
super(GLMTransformer, self).__init__()
|
||||
|
||||
self.fp32_residual_connection = config.fp32_residual_connection
|
||||
self.post_layer_norm = config.post_layer_norm
|
||||
|
||||
# Number of layers.
|
||||
self.num_layers = config.num_layers
|
||||
|
||||
# Transformer layers.
|
||||
def build_layer(layer_number):
|
||||
return GLMBlock(config, layer_number, device=device)
|
||||
|
||||
self.layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(self.num_layers)])
|
||||
|
||||
if self.post_layer_norm:
|
||||
LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
|
||||
# Final layer norm before output.
|
||||
self.final_layernorm = LayerNormFunc(
|
||||
config.hidden_size, eps=config.layernorm_epsilon, device=device, dtype=config.torch_dtype
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def _get_layer(self, layer_number):
|
||||
return self.layers[layer_number]
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
rotary_pos_emb,
|
||||
kv_caches=None,
|
||||
use_cache: Optional[bool] = True,
|
||||
output_hidden_states: Optional[bool] = False,
|
||||
):
|
||||
if not kv_caches:
|
||||
kv_caches = [None for _ in range(self.num_layers)]
|
||||
presents = () if use_cache else None
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
all_self_attentions = None
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
for index in range(self.num_layers):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
layer = self._get_layer(index)
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_ret = torch.utils.checkpoint.checkpoint(
|
||||
layer, hidden_states, attention_mask, rotary_pos_emb, kv_caches[index], use_cache
|
||||
)
|
||||
else:
|
||||
layer_ret = layer(
|
||||
hidden_states, attention_mask, rotary_pos_emb, kv_cache=kv_caches[index], use_cache=use_cache
|
||||
)
|
||||
hidden_states, kv_cache = layer_ret
|
||||
if use_cache:
|
||||
presents = presents + (kv_cache,)
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
# Final layer norm.
|
||||
if self.post_layer_norm:
|
||||
hidden_states = self.final_layernorm(hidden_states)
|
||||
|
||||
return hidden_states, presents, all_hidden_states, all_self_attentions
|
||||
|
||||
|
||||
class ChatGLMPreTrainedModel(PreTrainedModel):
|
||||
"""
|
||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||
models.
|
||||
"""
|
||||
|
||||
is_parallelizable = False
|
||||
supports_gradient_checkpointing = True
|
||||
config_class = ChatGLMConfig
|
||||
base_model_prefix = "transformer"
|
||||
_no_split_modules = ["GLMBlock"]
|
||||
|
||||
def _init_weights(self, module: nn.Module):
|
||||
"""Initialize the weights."""
|
||||
return
|
||||
|
||||
def get_masks(self, input_ids, past_key_values, padding_mask=None):
|
||||
batch_size, seq_length = input_ids.shape
|
||||
full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_ids.device)
|
||||
full_attention_mask.tril_()
|
||||
past_length = 0
|
||||
if past_key_values:
|
||||
past_length = past_key_values[0][0].shape[0]
|
||||
if past_length:
|
||||
full_attention_mask = torch.cat(
|
||||
(torch.ones(batch_size, seq_length, past_length, device=input_ids.device), full_attention_mask), dim=-1
|
||||
)
|
||||
if padding_mask is not None:
|
||||
full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1)
|
||||
if not past_length and padding_mask is not None:
|
||||
full_attention_mask -= padding_mask.unsqueeze(-1) - 1
|
||||
full_attention_mask = (full_attention_mask < 0.5).bool()
|
||||
full_attention_mask.unsqueeze_(1)
|
||||
return full_attention_mask
|
||||
|
||||
def get_position_ids(self, input_ids, device):
|
||||
batch_size, seq_length = input_ids.shape
|
||||
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
|
||||
return position_ids
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if isinstance(module, GLMTransformer):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
|
||||
def default_init(cls, *args, **kwargs):
|
||||
return cls(*args, **kwargs)
|
||||
|
||||
|
||||
class Embedding(torch.nn.Module):
|
||||
"""Language model embeddings."""
|
||||
|
||||
def __init__(self, config: ChatGLMConfig, device=None):
|
||||
super(Embedding, self).__init__()
|
||||
|
||||
self.hidden_size = config.hidden_size
|
||||
# Word embeddings (parallel).
|
||||
self.word_embeddings = nn.Embedding(
|
||||
config.padded_vocab_size, self.hidden_size, dtype=config.torch_dtype, device=device
|
||||
)
|
||||
self.fp32_residual_connection = config.fp32_residual_connection
|
||||
|
||||
def forward(self, input_ids):
|
||||
# Embeddings.
|
||||
words_embeddings = self.word_embeddings(input_ids)
|
||||
embeddings = words_embeddings
|
||||
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
|
||||
embeddings = embeddings.transpose(0, 1).contiguous()
|
||||
# If the input flag for fp32 residual connection is set, convert for float.
|
||||
if self.fp32_residual_connection:
|
||||
embeddings = embeddings.float()
|
||||
return embeddings
|
||||
|
||||
|
||||
class RotaryEmbedding(nn.Module):
|
||||
def __init__(self, dim, original_impl=False, device=None, dtype=None):
|
||||
super().__init__()
|
||||
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim))
|
||||
self.register_buffer("inv_freq", inv_freq)
|
||||
self.dim = dim
|
||||
self.original_impl = original_impl
|
||||
|
||||
def forward_impl(self, seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000):
|
||||
"""Enhanced Transformer with Rotary Position Embedding.
|
||||
|
||||
Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
|
||||
transformers/rope/__init__.py. MIT License:
|
||||
https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
|
||||
"""
|
||||
# $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
|
||||
theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=torch.float, device=device) / n_elem))
|
||||
|
||||
# Create position indexes `[0, 1, ..., seq_len - 1]`
|
||||
seq_idx = torch.arange(seq_len, dtype=torch.float, device=device)
|
||||
|
||||
# Calculate the product of position index and $\theta_i$
|
||||
idx_theta = torch.outer(seq_idx, theta).float()
|
||||
|
||||
cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1)
|
||||
|
||||
# this is to mimic the behaviour of complex32, else we will get different results
|
||||
if dtype in (torch.float16, torch.bfloat16, torch.int8):
|
||||
cache = cache.bfloat16() if dtype == torch.bfloat16 else cache.half()
|
||||
return cache
|
||||
|
||||
def forward(self, max_seq_len, offset=0):
|
||||
return self.forward_impl(max_seq_len, self.dim, dtype=self.inv_freq.dtype, device=self.inv_freq.device)
|
||||
|
||||
|
||||
class PrefixEncoder(torch.nn.Module):
|
||||
"""
|
||||
The torch.nn model to encode the prefix Input shape: (batch-size, prefix-length) Output shape: (batch-size,
|
||||
prefix-length, 2*layers*hidden)
|
||||
"""
|
||||
|
||||
def __init__(self, config: ChatGLMConfig):
|
||||
super().__init__()
|
||||
self.prefix_projection = config.prefix_projection
|
||||
if self.prefix_projection:
|
||||
# Use a two-layer MLP to encode the prefix
|
||||
kv_size = config.num_layers * config.kv_channels * config.multi_query_group_num * 2
|
||||
self.embedding = torch.nn.Embedding(config.pre_seq_len, kv_size)
|
||||
self.trans = torch.nn.Sequential(
|
||||
torch.nn.Linear(kv_size, config.hidden_size),
|
||||
torch.nn.Tanh(),
|
||||
torch.nn.Linear(config.hidden_size, kv_size),
|
||||
)
|
||||
else:
|
||||
self.embedding = torch.nn.Embedding(
|
||||
config.pre_seq_len, config.num_layers * config.kv_channels * config.multi_query_group_num * 2
|
||||
)
|
||||
|
||||
def forward(self, prefix: torch.Tensor):
|
||||
if self.prefix_projection:
|
||||
prefix_tokens = self.embedding(prefix)
|
||||
past_key_values = self.trans(prefix_tokens)
|
||||
else:
|
||||
past_key_values = self.embedding(prefix)
|
||||
return past_key_values
|
||||
|
||||
|
||||
class ChatGLMModel(ChatGLMPreTrainedModel):
|
||||
def __init__(self, config: ChatGLMConfig, device=None, empty_init=True):
|
||||
super().__init__(config)
|
||||
if empty_init:
|
||||
init_method = skip_init
|
||||
else:
|
||||
init_method = default_init
|
||||
init_kwargs = {}
|
||||
if device is not None:
|
||||
init_kwargs["device"] = device
|
||||
self.embedding = init_method(Embedding, config, **init_kwargs)
|
||||
self.num_layers = config.num_layers
|
||||
self.multi_query_group_num = config.multi_query_group_num
|
||||
self.kv_channels = config.kv_channels
|
||||
|
||||
# Rotary positional embeddings
|
||||
self.seq_length = config.seq_length
|
||||
rotary_dim = (
|
||||
config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
|
||||
)
|
||||
|
||||
self.rotary_pos_emb = RotaryEmbedding(
|
||||
rotary_dim // 2, original_impl=config.original_rope, device=device, dtype=config.torch_dtype
|
||||
)
|
||||
self.encoder = init_method(GLMTransformer, config, **init_kwargs)
|
||||
self.output_layer = init_method(
|
||||
nn.Linear,
|
||||
config.hidden_size,
|
||||
config.padded_vocab_size,
|
||||
bias=False,
|
||||
dtype=config.torch_dtype,
|
||||
**init_kwargs,
|
||||
)
|
||||
self.pre_seq_len = config.pre_seq_len
|
||||
self.prefix_projection = config.prefix_projection
|
||||
if self.pre_seq_len is not None:
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
self.prefix_tokens = torch.arange(self.pre_seq_len).long()
|
||||
self.prefix_encoder = PrefixEncoder(config)
|
||||
self.dropout = torch.nn.Dropout(0.1)
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embedding.word_embeddings
|
||||
|
||||
def get_prompt(self, batch_size, device, dtype=torch.half):
|
||||
prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device)
|
||||
past_key_values = self.prefix_encoder(prefix_tokens).type(dtype)
|
||||
past_key_values = past_key_values.view(
|
||||
batch_size, self.pre_seq_len, self.num_layers * 2, self.multi_query_group_num, self.kv_channels
|
||||
)
|
||||
# seq_len, b, nh, hidden_size
|
||||
past_key_values = self.dropout(past_key_values)
|
||||
past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2)
|
||||
return past_key_values
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.BoolTensor] = None,
|
||||
full_attention_mask: Optional[torch.BoolTensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
):
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
batch_size, seq_length = input_ids.shape
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embedding(input_ids)
|
||||
|
||||
if self.pre_seq_len is not None:
|
||||
if past_key_values is None:
|
||||
past_key_values = self.get_prompt(
|
||||
batch_size=batch_size, device=input_ids.device, dtype=inputs_embeds.dtype
|
||||
)
|
||||
if attention_mask is not None:
|
||||
attention_mask = torch.cat(
|
||||
[attention_mask.new_ones((batch_size, self.pre_seq_len)), attention_mask], dim=-1
|
||||
)
|
||||
|
||||
if full_attention_mask is None:
|
||||
if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
|
||||
full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask)
|
||||
|
||||
# Rotary positional embeddings
|
||||
rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
|
||||
if position_ids is not None:
|
||||
rotary_pos_emb = rotary_pos_emb[position_ids]
|
||||
else:
|
||||
rotary_pos_emb = rotary_pos_emb[None, :seq_length]
|
||||
rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
|
||||
|
||||
# Run encoder.
|
||||
hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
|
||||
inputs_embeds,
|
||||
full_attention_mask,
|
||||
rotary_pos_emb=rotary_pos_emb,
|
||||
kv_caches=past_key_values,
|
||||
use_cache=use_cache,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
|
||||
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=presents,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attentions,
|
||||
)
|
||||
@@ -1,322 +0,0 @@
|
||||
# Copyright 2024 ChatGLM3-6B Model Team, Kwai-Kolors Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from sentencepiece import SentencePieceProcessor
|
||||
from transformers import PreTrainedTokenizer
|
||||
from transformers.tokenization_utils_base import BatchEncoding, EncodedInput
|
||||
from transformers.utils import PaddingStrategy
|
||||
|
||||
|
||||
class SPTokenizer:
|
||||
def __init__(self, model_path: str):
|
||||
# reload tokenizer
|
||||
assert os.path.isfile(model_path), model_path
|
||||
self.sp_model = SentencePieceProcessor(model_file=model_path)
|
||||
|
||||
# BOS / EOS token IDs
|
||||
self.n_words: int = self.sp_model.vocab_size()
|
||||
self.bos_id: int = self.sp_model.bos_id()
|
||||
self.eos_id: int = self.sp_model.eos_id()
|
||||
self.pad_id: int = self.sp_model.unk_id()
|
||||
assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
|
||||
|
||||
role_special_tokens = ["<|system|>", "<|user|>", "<|assistant|>", "<|observation|>"]
|
||||
special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "sop", "eop"] + role_special_tokens
|
||||
self.special_tokens = {}
|
||||
self.index_special_tokens = {}
|
||||
for token in special_tokens:
|
||||
self.special_tokens[token] = self.n_words
|
||||
self.index_special_tokens[self.n_words] = token
|
||||
self.n_words += 1
|
||||
self.role_special_token_expression = "|".join([re.escape(token) for token in role_special_tokens])
|
||||
|
||||
def tokenize(self, s: str, encode_special_tokens=False):
|
||||
if encode_special_tokens:
|
||||
last_index = 0
|
||||
t = []
|
||||
for match in re.finditer(self.role_special_token_expression, s):
|
||||
if last_index < match.start():
|
||||
t.extend(self.sp_model.EncodeAsPieces(s[last_index : match.start()]))
|
||||
t.append(s[match.start() : match.end()])
|
||||
last_index = match.end()
|
||||
if last_index < len(s):
|
||||
t.extend(self.sp_model.EncodeAsPieces(s[last_index:]))
|
||||
return t
|
||||
else:
|
||||
return self.sp_model.EncodeAsPieces(s)
|
||||
|
||||
def encode(self, s: str, bos: bool = False, eos: bool = False) -> List[int]:
|
||||
assert isinstance(s, str)
|
||||
t = self.sp_model.encode(s)
|
||||
if bos:
|
||||
t = [self.bos_id] + t
|
||||
if eos:
|
||||
t = t + [self.eos_id]
|
||||
return t
|
||||
|
||||
def decode(self, t: List[int]) -> str:
|
||||
text, buffer = "", []
|
||||
for token in t:
|
||||
if token in self.index_special_tokens:
|
||||
if buffer:
|
||||
text += self.sp_model.decode(buffer)
|
||||
buffer = []
|
||||
text += self.index_special_tokens[token]
|
||||
else:
|
||||
buffer.append(token)
|
||||
if buffer:
|
||||
text += self.sp_model.decode(buffer)
|
||||
return text
|
||||
|
||||
def decode_tokens(self, tokens: List[str]) -> str:
|
||||
text = self.sp_model.DecodePieces(tokens)
|
||||
return text
|
||||
|
||||
def convert_token_to_id(self, token):
|
||||
"""Converts a token (str) in an id using the vocab."""
|
||||
if token in self.special_tokens:
|
||||
return self.special_tokens[token]
|
||||
return self.sp_model.PieceToId(token)
|
||||
|
||||
def convert_id_to_token(self, index):
|
||||
"""Converts an index (integer) in a token (str) using the vocab."""
|
||||
if index in self.index_special_tokens:
|
||||
return self.index_special_tokens[index]
|
||||
if index in [self.eos_id, self.bos_id, self.pad_id] or index < 0:
|
||||
return ""
|
||||
return self.sp_model.IdToPiece(index)
|
||||
|
||||
|
||||
class ChatGLMTokenizer(PreTrainedTokenizer):
|
||||
vocab_files_names = {"vocab_file": "tokenizer.model"}
|
||||
|
||||
model_input_names = ["input_ids", "attention_mask", "position_ids"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_file,
|
||||
padding_side="left",
|
||||
clean_up_tokenization_spaces=False,
|
||||
encode_special_tokens=False,
|
||||
**kwargs,
|
||||
):
|
||||
self.name = "GLMTokenizer"
|
||||
|
||||
self.vocab_file = vocab_file
|
||||
self.tokenizer = SPTokenizer(vocab_file)
|
||||
self.special_tokens = {
|
||||
"<bos>": self.tokenizer.bos_id,
|
||||
"<eos>": self.tokenizer.eos_id,
|
||||
"<pad>": self.tokenizer.pad_id,
|
||||
}
|
||||
self.encode_special_tokens = encode_special_tokens
|
||||
super().__init__(
|
||||
padding_side=padding_side,
|
||||
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
||||
encode_special_tokens=encode_special_tokens,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def get_command(self, token):
|
||||
if token in self.special_tokens:
|
||||
return self.special_tokens[token]
|
||||
assert token in self.tokenizer.special_tokens, f"{token} is not a special token for {self.name}"
|
||||
return self.tokenizer.special_tokens[token]
|
||||
|
||||
@property
|
||||
def unk_token(self) -> str:
|
||||
return "<unk>"
|
||||
|
||||
@property
|
||||
def pad_token(self) -> str:
|
||||
return "<unk>"
|
||||
|
||||
@property
|
||||
def pad_token_id(self):
|
||||
return self.get_command("<pad>")
|
||||
|
||||
@property
|
||||
def eos_token(self) -> str:
|
||||
return "</s>"
|
||||
|
||||
@property
|
||||
def eos_token_id(self):
|
||||
return self.get_command("<eos>")
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
return self.tokenizer.n_words
|
||||
|
||||
def get_vocab(self):
|
||||
"""Returns vocab as a dict"""
|
||||
vocab = {self._convert_id_to_token(i): i for i in range(self.vocab_size)}
|
||||
vocab.update(self.added_tokens_encoder)
|
||||
return vocab
|
||||
|
||||
def _tokenize(self, text, **kwargs):
|
||||
return self.tokenizer.tokenize(text, encode_special_tokens=self.encode_special_tokens)
|
||||
|
||||
def _convert_token_to_id(self, token):
|
||||
"""Converts a token (str) in an id using the vocab."""
|
||||
return self.tokenizer.convert_token_to_id(token)
|
||||
|
||||
def _convert_id_to_token(self, index):
|
||||
"""Converts an index (integer) in a token (str) using the vocab."""
|
||||
return self.tokenizer.convert_id_to_token(index)
|
||||
|
||||
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
||||
return self.tokenizer.decode_tokens(tokens)
|
||||
|
||||
def save_vocabulary(self, save_directory, filename_prefix=None):
|
||||
"""
|
||||
Save the vocabulary and special tokens file to a directory.
|
||||
|
||||
Args:
|
||||
save_directory (`str`):
|
||||
The directory in which to save the vocabulary.
|
||||
filename_prefix (`str`, *optional*):
|
||||
An optional prefix to add to the named of the saved files.
|
||||
|
||||
Returns:
|
||||
`Tuple(str)`: Paths to the files saved.
|
||||
"""
|
||||
if os.path.isdir(save_directory):
|
||||
vocab_file = os.path.join(save_directory, self.vocab_files_names["vocab_file"])
|
||||
else:
|
||||
vocab_file = save_directory
|
||||
|
||||
with open(self.vocab_file, "rb") as fin:
|
||||
proto_str = fin.read()
|
||||
|
||||
with open(vocab_file, "wb") as writer:
|
||||
writer.write(proto_str)
|
||||
|
||||
return (vocab_file,)
|
||||
|
||||
def get_prefix_tokens(self):
|
||||
prefix_tokens = [self.get_command("[gMASK]"), self.get_command("sop")]
|
||||
return prefix_tokens
|
||||
|
||||
def build_single_message(self, role, metadata, message):
|
||||
assert role in ["system", "user", "assistant", "observation"], role
|
||||
role_tokens = [self.get_command(f"<|{role}|>")] + self.tokenizer.encode(f"{metadata}\n")
|
||||
message_tokens = self.tokenizer.encode(message)
|
||||
tokens = role_tokens + message_tokens
|
||||
return tokens
|
||||
|
||||
def build_chat_input(self, query, history=None, role="user"):
|
||||
if history is None:
|
||||
history = []
|
||||
input_ids = []
|
||||
for item in history:
|
||||
content = item["content"]
|
||||
if item["role"] == "system" and "tools" in item:
|
||||
content = content + "\n" + json.dumps(item["tools"], indent=4, ensure_ascii=False)
|
||||
input_ids.extend(self.build_single_message(item["role"], item.get("metadata", ""), content))
|
||||
input_ids.extend(self.build_single_message(role, "", query))
|
||||
input_ids.extend([self.get_command("<|assistant|>")])
|
||||
return self.batch_encode_plus([input_ids], return_tensors="pt", is_split_into_words=True)
|
||||
|
||||
def build_inputs_with_special_tokens(
|
||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
||||
) -> List[int]:
|
||||
"""
|
||||
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
|
||||
adding special tokens. A BERT sequence has the following format:
|
||||
|
||||
- single sequence: `[CLS] X [SEP]`
|
||||
- pair of sequences: `[CLS] A [SEP] B [SEP]`
|
||||
|
||||
Args:
|
||||
token_ids_0 (`List[int]`):
|
||||
List of IDs to which the special tokens will be added.
|
||||
token_ids_1 (`List[int]`, *optional*):
|
||||
Optional second list of IDs for sequence pairs.
|
||||
|
||||
Returns:
|
||||
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
|
||||
"""
|
||||
prefix_tokens = self.get_prefix_tokens()
|
||||
token_ids_0 = prefix_tokens + token_ids_0
|
||||
if token_ids_1 is not None:
|
||||
token_ids_0 = token_ids_0 + token_ids_1 + [self.get_command("<eos>")]
|
||||
return token_ids_0
|
||||
|
||||
def _pad(
|
||||
self,
|
||||
encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
|
||||
max_length: Optional[int] = None,
|
||||
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
|
||||
pad_to_multiple_of: Optional[int] = None,
|
||||
return_attention_mask: Optional[bool] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
|
||||
|
||||
Args:
|
||||
encoded_inputs:
|
||||
Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).
|
||||
max_length: maximum length of the returned list and optionally padding length (see below).
|
||||
Will truncate by taking into account the special tokens.
|
||||
padding_strategy: PaddingStrategy to use for padding.
|
||||
|
||||
- PaddingStrategy.LONGEST Pad to the longest sequence in the batch
|
||||
- PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
|
||||
- PaddingStrategy.DO_NOT_PAD: Do not pad
|
||||
The tokenizer padding sides are defined in self.padding_side:
|
||||
|
||||
- 'left': pads on the left of the sequences
|
||||
- 'right': pads on the right of the sequences
|
||||
pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
|
||||
This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
|
||||
`>= 7.5` (Volta).
|
||||
return_attention_mask:
|
||||
(optional) Set to False to avoid returning attention mask (default: set to model specifics)
|
||||
"""
|
||||
# Load from model defaults
|
||||
assert self.padding_side == "left"
|
||||
|
||||
required_input = encoded_inputs[self.model_input_names[0]]
|
||||
seq_length = len(required_input)
|
||||
|
||||
if padding_strategy == PaddingStrategy.LONGEST:
|
||||
max_length = len(required_input)
|
||||
|
||||
if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
|
||||
max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
|
||||
|
||||
needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
|
||||
|
||||
# Initialize attention mask if not present.
|
||||
if "attention_mask" not in encoded_inputs:
|
||||
encoded_inputs["attention_mask"] = [1] * seq_length
|
||||
|
||||
if "position_ids" not in encoded_inputs:
|
||||
encoded_inputs["position_ids"] = list(range(seq_length))
|
||||
|
||||
if needs_to_be_padded:
|
||||
difference = max_length - len(required_input)
|
||||
|
||||
if "attention_mask" in encoded_inputs:
|
||||
encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"]
|
||||
if "position_ids" in encoded_inputs:
|
||||
encoded_inputs["position_ids"] = [0] * difference + encoded_inputs["position_ids"]
|
||||
encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
|
||||
|
||||
return encoded_inputs
|
||||
@@ -1,48 +0,0 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
DIFFUSERS_SLOW_IMPORT,
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_import_structure = {}
|
||||
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_latte"] = ["LattePipeline"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .pipeline_latte import LattePipeline
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
@@ -1,881 +0,0 @@
|
||||
# Copyright 2024 the Latte Team and The HuggingFace Team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import html
|
||||
import inspect
|
||||
import re
|
||||
import urllib.parse as ul
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from transformers import T5EncoderModel, T5Tokenizer
|
||||
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...models import AutoencoderKL, LatteTransformer3DModel
|
||||
from ...pipelines.pipeline_utils import DiffusionPipeline
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import (
|
||||
BACKENDS_MAPPING,
|
||||
BaseOutput,
|
||||
is_bs4_available,
|
||||
is_ftfy_available,
|
||||
logging,
|
||||
replace_example_docstring,
|
||||
)
|
||||
from ...utils.torch_utils import is_compiled_module, randn_tensor
|
||||
from ...video_processor import VideoProcessor
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
if is_bs4_available():
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
if is_ftfy_available():
|
||||
import ftfy
|
||||
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from diffusers import LattePipeline
|
||||
>>> from diffusers.utils import export_to_gif
|
||||
|
||||
>>> # You can replace the checkpoint id with "maxin-cn/Latte-1" too.
|
||||
>>> pipe = LattePipeline.from_pretrained("maxin-cn/Latte-1", torch_dtype=torch.float16).to("cuda")
|
||||
>>> # Enable memory optimizations.
|
||||
>>> pipe.enable_model_cpu_offload()
|
||||
|
||||
>>> prompt = "A small cactus with a happy face in the Sahara desert."
|
||||
>>> videos = pipe(prompt).frames[0]
|
||||
>>> export_to_gif(videos, "latte.gif")
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accept_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
@dataclass
|
||||
class LattePipelineOutput(BaseOutput):
|
||||
frames: torch.Tensor
|
||||
|
||||
|
||||
class LattePipeline(DiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for text-to-video generation using Latte.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
||||
text_encoder ([`T5EncoderModel`]):
|
||||
Frozen text-encoder. Latte uses
|
||||
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
|
||||
[t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
|
||||
tokenizer (`T5Tokenizer`):
|
||||
Tokenizer of class
|
||||
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
|
||||
transformer ([`LatteTransformer3DModel`]):
|
||||
A text conditioned `LatteTransformer3DModel` to denoise the encoded video latents.
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
|
||||
"""
|
||||
|
||||
bad_punct_regex = re.compile(r"[#®•©™&@·º½¾¿¡§~\)\(\]\[\}\{\|\\/\\*]{1,}")
|
||||
|
||||
_optional_components = ["tokenizer", "text_encoder"]
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
|
||||
_callback_tensor_inputs = [
|
||||
"latents",
|
||||
"prompt_embeds",
|
||||
"negative_prompt_embeds",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: T5Tokenizer,
|
||||
text_encoder: T5EncoderModel,
|
||||
vae: AutoencoderKL,
|
||||
transformer: LatteTransformer3DModel,
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
|
||||
)
|
||||
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
|
||||
# Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/utils.py
|
||||
def mask_text_embeddings(self, emb, mask):
|
||||
if emb.shape[0] == 1:
|
||||
keep_index = mask.sum().item()
|
||||
return emb[:, :, :keep_index, :], keep_index # 1, 120, 4096 -> 1 7 4096
|
||||
else:
|
||||
masked_feature = emb * mask[:, None, :, None] # 1 120 4096
|
||||
return masked_feature, emb.shape[2]
|
||||
|
||||
# Adapted from diffusers.pipelines.deepfloyd_if.pipeline_if.encode_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
do_classifier_free_guidance: bool = True,
|
||||
negative_prompt: str = "",
|
||||
num_images_per_prompt: int = 1,
|
||||
device: Optional[torch.device] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
clean_caption: bool = False,
|
||||
mask_feature: bool = True,
|
||||
dtype=None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt not to guide the video generation. If not defined, one has to pass `negative_prompt_embeds`
|
||||
instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For
|
||||
Latte, this should be "".
|
||||
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
||||
whether to use classifier free guidance or not
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
number of video that should be generated per prompt
|
||||
device: (`torch.device`, *optional*):
|
||||
torch device to place the resulting embeddings on
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. For Latte, it's should be the embeddings of the "" string.
|
||||
clean_caption (bool, defaults to `False`):
|
||||
If `True`, the function will preprocess and clean the provided caption before encoding.
|
||||
mask_feature: (bool, defaults to `True`):
|
||||
If `True`, the function will mask the text embeddings.
|
||||
"""
|
||||
embeds_initially_provided = prompt_embeds is not None and negative_prompt_embeds is not None
|
||||
|
||||
if device is None:
|
||||
device = self._execution_device
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
max_length = 120
|
||||
if prompt_embeds is None:
|
||||
prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_attention_mask=True,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
||||
text_input_ids, untruncated_ids
|
||||
):
|
||||
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {max_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
attention_mask = text_inputs.attention_mask.to(device)
|
||||
prompt_embeds_attention_mask = attention_mask
|
||||
|
||||
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
|
||||
prompt_embeds = prompt_embeds[0]
|
||||
else:
|
||||
prompt_embeds_attention_mask = torch.ones_like(prompt_embeds)
|
||||
|
||||
if self.text_encoder is not None:
|
||||
dtype = self.text_encoder.dtype
|
||||
elif self.transformer is not None:
|
||||
dtype = self.transformer.dtype
|
||||
else:
|
||||
dtype = None
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
||||
prompt_embeds_attention_mask = prompt_embeds_attention_mask.view(bs_embed, -1)
|
||||
prompt_embeds_attention_mask = prompt_embeds_attention_mask.repeat(num_images_per_prompt, 1)
|
||||
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
uncond_tokens = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt
|
||||
uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
|
||||
max_length = prompt_embeds.shape[1]
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_tokens,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_attention_mask=True,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
attention_mask = uncond_input.attention_mask.to(device)
|
||||
|
||||
negative_prompt_embeds = self.text_encoder(
|
||||
uncond_input.input_ids.to(device),
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
negative_prompt_embeds = negative_prompt_embeds[0]
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = negative_prompt_embeds.shape[1]
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
else:
|
||||
negative_prompt_embeds = None
|
||||
|
||||
# Perform additional masking.
|
||||
if mask_feature and not embeds_initially_provided:
|
||||
prompt_embeds = prompt_embeds.unsqueeze(1)
|
||||
masked_prompt_embeds, keep_indices = self.mask_text_embeddings(prompt_embeds, prompt_embeds_attention_mask)
|
||||
masked_prompt_embeds = masked_prompt_embeds.squeeze(1)
|
||||
masked_negative_prompt_embeds = (
|
||||
negative_prompt_embeds[:, :keep_indices, :] if negative_prompt_embeds is not None else None
|
||||
)
|
||||
|
||||
return masked_prompt_embeds, masked_negative_prompt_embeds
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
|
||||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
# check if the scheduler accepts generator
|
||||
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
if accepts_generator:
|
||||
extra_step_kwargs["generator"] = generator
|
||||
return extra_step_kwargs
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
negative_prompt,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
):
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
||||
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
||||
raise ValueError(
|
||||
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
|
||||
def _text_preprocessing(self, text, clean_caption=False):
|
||||
if clean_caption and not is_bs4_available():
|
||||
logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
|
||||
logger.warning("Setting `clean_caption` to False...")
|
||||
clean_caption = False
|
||||
|
||||
if clean_caption and not is_ftfy_available():
|
||||
logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
|
||||
logger.warning("Setting `clean_caption` to False...")
|
||||
clean_caption = False
|
||||
|
||||
if not isinstance(text, (tuple, list)):
|
||||
text = [text]
|
||||
|
||||
def process(text: str):
|
||||
if clean_caption:
|
||||
text = self._clean_caption(text)
|
||||
text = self._clean_caption(text)
|
||||
else:
|
||||
text = text.lower().strip()
|
||||
return text
|
||||
|
||||
return [process(t) for t in text]
|
||||
|
||||
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption
|
||||
def _clean_caption(self, caption):
|
||||
caption = str(caption)
|
||||
caption = ul.unquote_plus(caption)
|
||||
caption = caption.strip().lower()
|
||||
caption = re.sub("<person>", "person", caption)
|
||||
# urls:
|
||||
caption = re.sub(
|
||||
r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
|
||||
"",
|
||||
caption,
|
||||
) # regex for urls
|
||||
caption = re.sub(
|
||||
r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
|
||||
"",
|
||||
caption,
|
||||
) # regex for urls
|
||||
# html:
|
||||
caption = BeautifulSoup(caption, features="html.parser").text
|
||||
|
||||
# @<nickname>
|
||||
caption = re.sub(r"@[\w\d]+\b", "", caption)
|
||||
|
||||
# 31C0—31EF CJK Strokes
|
||||
# 31F0—31FF Katakana Phonetic Extensions
|
||||
# 3200—32FF Enclosed CJK Letters and Months
|
||||
# 3300—33FF CJK Compatibility
|
||||
# 3400—4DBF CJK Unified Ideographs Extension A
|
||||
# 4DC0—4DFF Yijing Hexagram Symbols
|
||||
# 4E00—9FFF CJK Unified Ideographs
|
||||
caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
|
||||
caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
|
||||
caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
|
||||
caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
|
||||
caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
|
||||
caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
|
||||
caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
|
||||
#######################################################
|
||||
|
||||
# все виды тире / all types of dash --> "-"
|
||||
caption = re.sub(
|
||||
r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
|
||||
"-",
|
||||
caption,
|
||||
)
|
||||
|
||||
# кавычки к одному стандарту
|
||||
caption = re.sub(r"[`´«»“”¨]", '"', caption)
|
||||
caption = re.sub(r"[‘’]", "'", caption)
|
||||
|
||||
# "
|
||||
caption = re.sub(r""?", "", caption)
|
||||
# &
|
||||
caption = re.sub(r"&", "", caption)
|
||||
|
||||
# ip adresses:
|
||||
caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
|
||||
|
||||
# article ids:
|
||||
caption = re.sub(r"\d:\d\d\s+$", "", caption)
|
||||
|
||||
# \n
|
||||
caption = re.sub(r"\\n", " ", caption)
|
||||
|
||||
# "#123"
|
||||
caption = re.sub(r"#\d{1,3}\b", "", caption)
|
||||
# "#12345.."
|
||||
caption = re.sub(r"#\d{5,}\b", "", caption)
|
||||
# "123456.."
|
||||
caption = re.sub(r"\b\d{6,}\b", "", caption)
|
||||
# filenames:
|
||||
caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
|
||||
|
||||
#
|
||||
caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
|
||||
caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
|
||||
|
||||
caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
|
||||
caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
|
||||
|
||||
# this-is-my-cute-cat / this_is_my_cute_cat
|
||||
regex2 = re.compile(r"(?:\-|\_)")
|
||||
if len(re.findall(regex2, caption)) > 3:
|
||||
caption = re.sub(regex2, " ", caption)
|
||||
|
||||
caption = ftfy.fix_text(caption)
|
||||
caption = html.unescape(html.unescape(caption))
|
||||
|
||||
caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
|
||||
caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
|
||||
caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
|
||||
|
||||
caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
|
||||
caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
|
||||
caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
|
||||
caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
|
||||
caption = re.sub(r"\bpage\s+\d+\b", "", caption)
|
||||
|
||||
caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
|
||||
|
||||
caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
|
||||
|
||||
caption = re.sub(r"\b\s+\:\s+", r": ", caption)
|
||||
caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
|
||||
caption = re.sub(r"\s+", " ", caption)
|
||||
|
||||
caption.strip()
|
||||
|
||||
caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
|
||||
caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
|
||||
caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
|
||||
caption = re.sub(r"^\.\S+$", "", caption)
|
||||
|
||||
return caption.strip()
|
||||
|
||||
# Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents
|
||||
def prepare_latents(
|
||||
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
|
||||
):
|
||||
shape = (
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
num_frames,
|
||||
height // self.vae_scale_factor,
|
||||
width // self.vae_scale_factor,
|
||||
)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
latents = latents.to(device)
|
||||
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
negative_prompt: str = "",
|
||||
num_inference_steps: int = 50,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
guidance_scale: float = 7.5,
|
||||
num_images_per_prompt: int = 1,
|
||||
video_length: int = 16,
|
||||
height: int = 512,
|
||||
width: int = 512,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: str = "pil",
|
||||
return_dict: bool = True,
|
||||
callback_on_step_end: Optional[
|
||||
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
||||
] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
clean_caption: bool = True,
|
||||
mask_feature: bool = True,
|
||||
enable_temporal_attentions: bool = True,
|
||||
decode_chunk_size: Optional[int] = None,
|
||||
) -> Union[LattePipelineOutput, Tuple]:
|
||||
"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the video generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the video generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
num_inference_steps (`int`, *optional*, defaults to 100):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality video at the
|
||||
expense of slower inference.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
|
||||
timesteps are used. Must be in descending order.
|
||||
guidance_scale (`float`, *optional*, defaults to 7.0):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate videos that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower video quality.
|
||||
video_length (`int`, *optional*, defaults to 16):
|
||||
The number of video frames that are generated. Defaults to 16 frames which at 8 frames per seconds
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of videos to generate per prompt.
|
||||
height (`int`, *optional*, defaults to self.unet.config.sample_size):
|
||||
The height in pixels of the generated video.
|
||||
width (`int`, *optional*, defaults to self.unet.config.sample_size):
|
||||
The width in pixels of the generated video.
|
||||
eta (`float`, *optional*, defaults to 0.0):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. For Latte this negative prompt should be "". If not provided,
|
||||
negative_prompt_embeds will be generated from `negative_prompt` input argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate video. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
|
||||
callback_on_step_end (`Callable[[int, int, Dict], None]`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
|
||||
A callback function or a list of callback functions to be called at the end of each denoising step.
|
||||
callback_on_step_end_tensor_inputs (`List[str]`, *optional*):
|
||||
A list of tensor inputs that should be passed to the callback function. If not defined, all tensor
|
||||
inputs will be passed.
|
||||
clean_caption (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
|
||||
be installed. If the dependencies are not installed, the embeddings will be created from the raw
|
||||
prompt.
|
||||
mask_feature (`bool` defaults to `True`): If set to `True`, the text embeddings will be masked.
|
||||
enable_temporal_attentions (`bool`, *optional*, defaults to `True`): Whether to enable temporal attentions
|
||||
decode_chunk_size (`int`, *optional*):
|
||||
The number of frames to decode at a time. Higher chunk size leads to better temporal consistency at the
|
||||
expense of more memory usage. By default, the decoder decodes all frames at once for maximal quality.
|
||||
For lower memory usage, reduce `decode_chunk_size`.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.latte.pipeline_latte.LattePipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`~pipelines.latte.pipeline_latte.LattePipelineOutput`] is returned,
|
||||
otherwise a `tuple` is returned where the first element is a list with the generated images
|
||||
"""
|
||||
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
|
||||
# 0. Default
|
||||
decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else video_length
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
height = height or self.transformer.config.sample_size * self.vae_scale_factor
|
||||
width = width or self.transformer.config.sample_size * self.vae_scale_factor
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
negative_prompt,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
)
|
||||
self._guidance_scale = guidance_scale
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Default height and width to transformer
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
# 3. Encode input prompt
|
||||
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
||||
prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt=negative_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
clean_caption=clean_caption,
|
||||
mask_feature=mask_feature,
|
||||
)
|
||||
if do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# 5. Prepare latents.
|
||||
latent_channels = self.transformer.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
latent_channels,
|
||||
video_length,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 7. Denoising loop
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
current_timestep = t
|
||||
if not torch.is_tensor(current_timestep):
|
||||
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
||||
# This would be a good case for the `match` statement (Python 3.10+)
|
||||
is_mps = latent_model_input.device.type == "mps"
|
||||
if isinstance(current_timestep, float):
|
||||
dtype = torch.float32 if is_mps else torch.float64
|
||||
else:
|
||||
dtype = torch.int32 if is_mps else torch.int64
|
||||
current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device)
|
||||
elif len(current_timestep.shape) == 0:
|
||||
current_timestep = current_timestep[None].to(latent_model_input.device)
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
current_timestep = current_timestep.expand(latent_model_input.shape[0])
|
||||
|
||||
# predict noise model_output
|
||||
noise_pred = self.transformer(
|
||||
latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep=current_timestep,
|
||||
enable_temporal_attentions=enable_temporal_attentions,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# use learned sigma?
|
||||
if not (
|
||||
hasattr(self.scheduler.config, "variance_type")
|
||||
and self.scheduler.config.variance_type in ["learned", "learned_range"]
|
||||
):
|
||||
noise_pred = noise_pred.chunk(2, dim=1)[0]
|
||||
|
||||
# compute previous video: x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||||
|
||||
# call the callback, if provided
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if not output_type == "latents":
|
||||
video = self.decode_latents(latents, video_length, decode_chunk_size=14)
|
||||
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
|
||||
else:
|
||||
video = latents
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (video,)
|
||||
|
||||
return LattePipelineOutput(frames=video)
|
||||
|
||||
# Similar to diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion.decode_latents
|
||||
def decode_latents(self, latents: torch.Tensor, video_length: int, decode_chunk_size: int = 14):
|
||||
# [batch, channels, frames, height, width] -> [batch*frames, channels, height, width]
|
||||
latents = latents.permute(0, 2, 1, 3, 4).flatten(0, 1)
|
||||
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
|
||||
forward_vae_fn = self.vae._orig_mod.forward if is_compiled_module(self.vae) else self.vae.forward
|
||||
accepts_num_frames = "num_frames" in set(inspect.signature(forward_vae_fn).parameters.keys())
|
||||
|
||||
# decode decode_chunk_size frames at a time to avoid OOM
|
||||
frames = []
|
||||
for i in range(0, latents.shape[0], decode_chunk_size):
|
||||
num_frames_in = latents[i : i + decode_chunk_size].shape[0]
|
||||
decode_kwargs = {}
|
||||
if accepts_num_frames:
|
||||
# we only pass num_frames_in if it's expected
|
||||
decode_kwargs["num_frames"] = num_frames_in
|
||||
|
||||
frame = self.vae.decode(latents[i : i + decode_chunk_size], **decode_kwargs).sample
|
||||
frames.append(frame)
|
||||
frames = torch.cat(frames, dim=0)
|
||||
|
||||
# [batch*frames, channels, height, width] -> [batch, channels, frames, height, width]
|
||||
frames = frames.reshape(-1, video_length, *frames.shape[1:]).permute(0, 2, 1, 3, 4)
|
||||
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
||||
frames = frames.float()
|
||||
return frames
|
||||
@@ -1,48 +0,0 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
DIFFUSERS_SLOW_IMPORT,
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_import_structure = {}
|
||||
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_lumina"] = ["LuminaText2ImgPipeline"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .pipeline_lumina import LuminaText2ImgPipeline
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
@@ -1,897 +0,0 @@
|
||||
# Copyright 2024 Alpha-VLLM and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import html
|
||||
import inspect
|
||||
import math
|
||||
import re
|
||||
import urllib.parse as ul
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...models import AutoencoderKL
|
||||
from ...models.embeddings import get_2d_rotary_pos_embed_lumina
|
||||
from ...models.transformers.lumina_nextdit2d import LuminaNextDiT2DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import (
|
||||
BACKENDS_MAPPING,
|
||||
is_bs4_available,
|
||||
is_ftfy_available,
|
||||
logging,
|
||||
replace_example_docstring,
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
if is_bs4_available():
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
if is_ftfy_available():
|
||||
import ftfy
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from diffusers import LuminaText2ImgPipeline
|
||||
|
||||
>>> pipe = LuminaText2ImgPipeline.from_pretrained(
|
||||
... "Alpha-VLLM/Lumina-Next-SFT-diffusers", torch_dtype=torch.bfloat16
|
||||
... ).cuda()
|
||||
>>> # Enable memory optimizations.
|
||||
>>> pipe.enable_model_cpu_offload()
|
||||
|
||||
>>> prompt = "Upper body of a young woman in a Victorian-era outfit with brass goggles and leather straps. Background shows an industrial revolution cityscape with smoky skies and tall, metal structures"
|
||||
>>> image = pipe(prompt).images[0]
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accept_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class LuminaText2ImgPipeline(DiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using Lumina-T2I.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`AutoModel`]):
|
||||
Frozen text-encoder. Lumina-T2I uses
|
||||
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.AutoModel), specifically the
|
||||
[t5-v1_1-xxl](https://huggingface.co/Alpha-VLLM/tree/main/t5-v1_1-xxl) variant.
|
||||
tokenizer (`AutoModel`):
|
||||
Tokenizer of class
|
||||
[AutoModel](https://huggingface.co/docs/transformers/model_doc/t5#transformers.AutoModel).
|
||||
transformer ([`Transformer2DModel`]):
|
||||
A text conditioned `Transformer2DModel` to denoise the encoded image latents.
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
"""
|
||||
|
||||
bad_punct_regex = re.compile(
|
||||
r"["
|
||||
+ "#®•©™&@·º½¾¿¡§~"
|
||||
+ r"\)"
|
||||
+ r"\("
|
||||
+ r"\]"
|
||||
+ r"\["
|
||||
+ r"\}"
|
||||
+ r"\{"
|
||||
+ r"\|"
|
||||
+ "\\"
|
||||
+ r"\/"
|
||||
+ r"\*"
|
||||
+ r"]{1,}"
|
||||
) # noqa
|
||||
|
||||
_optional_components = []
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transformer: LuminaNextDiT2DModel,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: AutoModel,
|
||||
tokenizer: AutoTokenizer,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
self.vae_scale_factor = 8
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
self.max_sequence_length = 256
|
||||
self.default_sample_size = (
|
||||
self.transformer.config.sample_size
|
||||
if hasattr(self, "transformer") and self.transformer is not None
|
||||
else 128
|
||||
)
|
||||
self.default_image_size = self.default_sample_size * self.vae_scale_factor
|
||||
|
||||
def _get_gemma_prompt_embeds(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
num_images_per_prompt: int = 1,
|
||||
device: Optional[torch.device] = None,
|
||||
clean_caption: Optional[bool] = False,
|
||||
max_length: Optional[int] = None,
|
||||
):
|
||||
device = device or self._execution_device
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt)
|
||||
|
||||
prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
pad_to_multiple_of=8,
|
||||
max_length=self.max_sequence_length,
|
||||
truncation=True,
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids.to(device)
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids.to(device)
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
||||
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.max_sequence_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because Gemma can only handle sequences up to"
|
||||
f" {self.max_sequence_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
prompt_attention_mask = text_inputs.attention_mask.to(device)
|
||||
prompt_embeds = self.text_encoder(
|
||||
text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True
|
||||
)
|
||||
prompt_embeds = prompt_embeds.hidden_states[-2]
|
||||
|
||||
if self.text_encoder is not None:
|
||||
dtype = self.text_encoder.dtype
|
||||
elif self.transformer is not None:
|
||||
dtype = self.transformer.dtype
|
||||
else:
|
||||
dtype = None
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
|
||||
prompt_attention_mask = prompt_attention_mask.view(batch_size * num_images_per_prompt, -1)
|
||||
|
||||
return prompt_embeds, prompt_attention_mask
|
||||
|
||||
# Adapted from diffusers.pipelines.deepfloyd_if.pipeline_if.encode_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
do_classifier_free_guidance: bool = True,
|
||||
negative_prompt: Union[str, List[str]] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
device: Optional[torch.device] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
clean_caption: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`
|
||||
instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For
|
||||
Lumina-T2I, this should be "".
|
||||
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
||||
whether to use classifier free guidance or not
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
number of images that should be generated per prompt
|
||||
device: (`torch.device`, *optional*):
|
||||
torch device to place the resulting embeddings on
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings. For Lumina-T2I, it's should be the embeddings of the "" string.
|
||||
clean_caption (`bool`, defaults to `False`):
|
||||
If `True`, the function will preprocess and clean the provided caption before encoding.
|
||||
max_sequence_length (`int`, defaults to 256): Maximum sequence length to use for the prompt.
|
||||
"""
|
||||
if device is None:
|
||||
device = self._execution_device
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
if prompt is not None:
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds(
|
||||
prompt=prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
clean_caption=clean_caption,
|
||||
)
|
||||
|
||||
# Get negative embeddings for classifier free guidance
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
negative_prompt = negative_prompt if negative_prompt is not None else ""
|
||||
|
||||
# Normalize str to list
|
||||
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
|
||||
if prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif isinstance(negative_prompt, str):
|
||||
negative_prompt = [negative_prompt]
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
# Padding negative prompt to the same length with prompt
|
||||
prompt_max_length = prompt_embeds.shape[1]
|
||||
negative_text_inputs = self.tokenizer(
|
||||
negative_prompt,
|
||||
padding="max_length",
|
||||
max_length=prompt_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
negative_text_input_ids = negative_text_inputs.input_ids.to(device)
|
||||
negative_prompt_attention_mask = negative_text_inputs.attention_mask.to(device)
|
||||
# Get the negative prompt embeddings
|
||||
negative_prompt_embeds = self.text_encoder(
|
||||
negative_text_input_ids,
|
||||
attention_mask=negative_prompt_attention_mask,
|
||||
output_hidden_states=True,
|
||||
)
|
||||
|
||||
negative_dtype = self.text_encoder.dtype
|
||||
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
|
||||
_, seq_len, _ = negative_prompt_embeds.shape
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=negative_dtype, device=device)
|
||||
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
|
||||
negative_prompt_attention_mask = negative_prompt_attention_mask.view(
|
||||
batch_size * num_images_per_prompt, -1
|
||||
)
|
||||
|
||||
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
|
||||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
# check if the scheduler accepts generator
|
||||
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
if accepts_generator:
|
||||
extra_step_kwargs["generator"] = generator
|
||||
return extra_step_kwargs
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
negative_prompt,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
prompt_attention_mask=None,
|
||||
negative_prompt_attention_mask=None,
|
||||
):
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and prompt_attention_mask is None:
|
||||
raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
|
||||
|
||||
if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
|
||||
raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
|
||||
|
||||
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
||||
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
||||
raise ValueError(
|
||||
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:
|
||||
raise ValueError(
|
||||
"`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`"
|
||||
f" {negative_prompt_attention_mask.shape}."
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
|
||||
def _text_preprocessing(self, text, clean_caption=False):
|
||||
if clean_caption and not is_bs4_available():
|
||||
logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
|
||||
logger.warning("Setting `clean_caption` to False...")
|
||||
clean_caption = False
|
||||
|
||||
if clean_caption and not is_ftfy_available():
|
||||
logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
|
||||
logger.warning("Setting `clean_caption` to False...")
|
||||
clean_caption = False
|
||||
|
||||
if not isinstance(text, (tuple, list)):
|
||||
text = [text]
|
||||
|
||||
def process(text: str):
|
||||
if clean_caption:
|
||||
text = self._clean_caption(text)
|
||||
text = self._clean_caption(text)
|
||||
else:
|
||||
text = text.lower().strip()
|
||||
return text
|
||||
|
||||
return [process(t) for t in text]
|
||||
|
||||
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption
|
||||
def _clean_caption(self, caption):
|
||||
caption = str(caption)
|
||||
caption = ul.unquote_plus(caption)
|
||||
caption = caption.strip().lower()
|
||||
caption = re.sub("<person>", "person", caption)
|
||||
# urls:
|
||||
caption = re.sub(
|
||||
r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
|
||||
"",
|
||||
caption,
|
||||
) # regex for urls
|
||||
caption = re.sub(
|
||||
r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
|
||||
"",
|
||||
caption,
|
||||
) # regex for urls
|
||||
# html:
|
||||
caption = BeautifulSoup(caption, features="html.parser").text
|
||||
|
||||
# @<nickname>
|
||||
caption = re.sub(r"@[\w\d]+\b", "", caption)
|
||||
|
||||
# 31C0—31EF CJK Strokes
|
||||
# 31F0—31FF Katakana Phonetic Extensions
|
||||
# 3200—32FF Enclosed CJK Letters and Months
|
||||
# 3300—33FF CJK Compatibility
|
||||
# 3400—4DBF CJK Unified Ideographs Extension A
|
||||
# 4DC0—4DFF Yijing Hexagram Symbols
|
||||
# 4E00—9FFF CJK Unified Ideographs
|
||||
caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
|
||||
caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
|
||||
caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
|
||||
caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
|
||||
caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
|
||||
caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
|
||||
caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
|
||||
#######################################################
|
||||
|
||||
# все виды тире / all types of dash --> "-"
|
||||
caption = re.sub(
|
||||
r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
|
||||
"-",
|
||||
caption,
|
||||
)
|
||||
|
||||
# кавычки к одному стандарту
|
||||
caption = re.sub(r"[`´«»“”¨]", '"', caption)
|
||||
caption = re.sub(r"[‘’]", "'", caption)
|
||||
|
||||
# "
|
||||
caption = re.sub(r""?", "", caption)
|
||||
# &
|
||||
caption = re.sub(r"&", "", caption)
|
||||
|
||||
# ip adresses:
|
||||
caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
|
||||
|
||||
# article ids:
|
||||
caption = re.sub(r"\d:\d\d\s+$", "", caption)
|
||||
|
||||
# \n
|
||||
caption = re.sub(r"\\n", " ", caption)
|
||||
|
||||
# "#123"
|
||||
caption = re.sub(r"#\d{1,3}\b", "", caption)
|
||||
# "#12345.."
|
||||
caption = re.sub(r"#\d{5,}\b", "", caption)
|
||||
# "123456.."
|
||||
caption = re.sub(r"\b\d{6,}\b", "", caption)
|
||||
# filenames:
|
||||
caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
|
||||
|
||||
#
|
||||
caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
|
||||
caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
|
||||
|
||||
caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
|
||||
caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
|
||||
|
||||
# this-is-my-cute-cat / this_is_my_cute_cat
|
||||
regex2 = re.compile(r"(?:\-|\_)")
|
||||
if len(re.findall(regex2, caption)) > 3:
|
||||
caption = re.sub(regex2, " ", caption)
|
||||
|
||||
caption = ftfy.fix_text(caption)
|
||||
caption = html.unescape(html.unescape(caption))
|
||||
|
||||
caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
|
||||
caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
|
||||
caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
|
||||
|
||||
caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
|
||||
caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
|
||||
caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
|
||||
caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
|
||||
caption = re.sub(r"\bpage\s+\d+\b", "", caption)
|
||||
|
||||
caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
|
||||
|
||||
caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
|
||||
|
||||
caption = re.sub(r"\b\s+\:\s+", r": ", caption)
|
||||
caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
|
||||
caption = re.sub(r"\s+", " ", caption)
|
||||
|
||||
caption.strip()
|
||||
|
||||
caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
|
||||
caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
|
||||
caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
|
||||
caption = re.sub(r"^\.\S+$", "", caption)
|
||||
|
||||
return caption.strip()
|
||||
|
||||
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
||||
shape = (
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
int(height) // self.vae_scale_factor,
|
||||
int(width) // self.vae_scale_factor,
|
||||
)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
latents = latents.to(device)
|
||||
|
||||
return latents
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
width: Optional[int] = None,
|
||||
height: Optional[int] = None,
|
||||
num_inference_steps: int = 30,
|
||||
timesteps: List[int] = None,
|
||||
guidance_scale: float = 4.0,
|
||||
negative_prompt: Union[str, List[str]] = None,
|
||||
sigmas: List[float] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
clean_caption: bool = True,
|
||||
max_sequence_length: int = 256,
|
||||
scaling_watershed: Optional[float] = 1.0,
|
||||
proportional_attn: Optional[bool] = True,
|
||||
) -> Union[ImagePipelineOutput, Tuple]:
|
||||
"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
num_inference_steps (`int`, *optional*, defaults to 30):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
||||
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
||||
passed will be used. Must be in descending order.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
||||
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
||||
will be used.
|
||||
guidance_scale (`float`, *optional*, defaults to 4.0):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
height (`int`, *optional*, defaults to self.unet.config.sample_size):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to self.unet.config.sample_size):
|
||||
The width in pixels of the generated image.
|
||||
eta (`float`, *optional*, defaults to 0.0):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.Tensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings.
|
||||
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings. For Lumina-T2I this negative prompt should be "". If not
|
||||
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
|
||||
negative_prompt_attention_mask (`torch.Tensor`, *optional*):
|
||||
Pre-generated attention mask for negative text embeddings.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
|
||||
clean_caption (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
|
||||
be installed. If the dependencies are not installed, the embeddings will be created from the raw
|
||||
prompt.
|
||||
max_sequence_length (`int` defaults to 120):
|
||||
Maximum sequence length to use with the `prompt`.
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
||||
`callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.ImagePipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
|
||||
returned where the first element is a list with the generated images
|
||||
"""
|
||||
height = height or self.default_sample_size * self.vae_scale_factor
|
||||
width = width or self.default_sample_size * self.vae_scale_factor
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
prompt_attention_mask=prompt_attention_mask,
|
||||
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
||||
)
|
||||
cross_attention_kwargs = {}
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if proportional_attn:
|
||||
cross_attention_kwargs["base_sequence_length"] = (self.default_image_size // 16) ** 2
|
||||
|
||||
scaling_factor = math.sqrt(width * height / self.default_image_size**2)
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
# 3. Encode input prompt
|
||||
(
|
||||
prompt_embeds,
|
||||
prompt_attention_mask,
|
||||
negative_prompt_embeds,
|
||||
negative_prompt_attention_mask,
|
||||
) = self.encode_prompt(
|
||||
prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt=negative_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
prompt_attention_mask=prompt_attention_mask,
|
||||
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
||||
clean_caption=clean_caption,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
if do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([prompt_embeds, negative_prompt_embeds], dim=0)
|
||||
prompt_attention_mask = torch.cat([prompt_attention_mask, negative_prompt_attention_mask], dim=0)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler, num_inference_steps, device, timesteps, sigmas
|
||||
)
|
||||
|
||||
# 5. Prepare latents.
|
||||
latent_channels = self.transformer.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
latent_channels,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# 6. Denoising loop
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
|
||||
current_timestep = t
|
||||
if not torch.is_tensor(current_timestep):
|
||||
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
||||
# This would be a good case for the `match` statement (Python 3.10+)
|
||||
is_mps = latent_model_input.device.type == "mps"
|
||||
if isinstance(current_timestep, float):
|
||||
dtype = torch.float32 if is_mps else torch.float64
|
||||
else:
|
||||
dtype = torch.int32 if is_mps else torch.int64
|
||||
current_timestep = torch.tensor(
|
||||
[current_timestep],
|
||||
dtype=dtype,
|
||||
device=latent_model_input.device,
|
||||
)
|
||||
elif len(current_timestep.shape) == 0:
|
||||
current_timestep = current_timestep[None].to(latent_model_input.device)
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
current_timestep = current_timestep.expand(latent_model_input.shape[0])
|
||||
|
||||
# reverse the timestep since Lumina uses t=0 as the noise and t=1 as the image
|
||||
current_timestep = 1 - current_timestep / self.scheduler.config.num_train_timesteps
|
||||
|
||||
# prepare image_rotary_emb for positional encoding
|
||||
# dynamic scaling_factor for different resolution.
|
||||
# NOTE: For `Time-aware` denosing mechanism from Lumina-Next
|
||||
# https://arxiv.org/abs/2406.18583, Sec 2.3
|
||||
# NOTE: We should compute different image_rotary_emb with different timestep.
|
||||
if current_timestep[0] < scaling_watershed:
|
||||
linear_factor = scaling_factor
|
||||
ntk_factor = 1.0
|
||||
else:
|
||||
linear_factor = 1.0
|
||||
ntk_factor = scaling_factor
|
||||
image_rotary_emb = get_2d_rotary_pos_embed_lumina(
|
||||
self.transformer.head_dim,
|
||||
384,
|
||||
384,
|
||||
linear_factor=linear_factor,
|
||||
ntk_factor=ntk_factor,
|
||||
)
|
||||
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=current_timestep,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
encoder_mask=prompt_attention_mask,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = noise_pred.chunk(2, dim=1)[0]
|
||||
|
||||
# perform guidance scale
|
||||
# NOTE: For exact reproducibility reasons, we apply classifier-free guidance on only
|
||||
# three channels by default. The standard approach to cfg applies it to all channels.
|
||||
# This can be done by uncommenting the following line and commenting-out the line following that.
|
||||
# eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_eps, noise_pred_rest = noise_pred[:, :3], noise_pred[:, 3:]
|
||||
noise_pred_cond_eps, noise_pred_uncond_eps = torch.split(
|
||||
noise_pred_eps, len(noise_pred_eps) // 2, dim=0
|
||||
)
|
||||
noise_pred_half = noise_pred_uncond_eps + guidance_scale * (
|
||||
noise_pred_cond_eps - noise_pred_uncond_eps
|
||||
)
|
||||
noise_pred_eps = torch.cat([noise_pred_half, noise_pred_half], dim=0)
|
||||
|
||||
noise_pred = torch.cat([noise_pred_eps, noise_pred_rest], dim=1)
|
||||
noise_pred, _ = noise_pred.chunk(2, dim=0)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents_dtype = latents.dtype
|
||||
noise_pred = -noise_pred
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
|
||||
if latents.dtype != latents_dtype:
|
||||
if torch.backends.mps.is_available():
|
||||
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
||||
latents = latents.to(latents_dtype)
|
||||
|
||||
progress_bar.update()
|
||||
|
||||
if not output_type == "latent":
|
||||
latents = latents / self.vae.config.scaling_factor
|
||||
image = self.vae.decode(latents, return_dict=False)[0]
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
else:
|
||||
image = latents
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return ImagePipelineOutput(images=image)
|
||||
@@ -22,7 +22,6 @@ except OptionalDependencyNotAvailable:
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_pag_controlnet_sd"] = ["StableDiffusionControlNetPAGPipeline"]
|
||||
_import_structure["pipeline_pag_controlnet_sd_xl"] = ["StableDiffusionXLControlNetPAGPipeline"]
|
||||
_import_structure["pipeline_pag_sd"] = ["StableDiffusionPAGPipeline"]
|
||||
_import_structure["pipeline_pag_sd_xl"] = ["StableDiffusionXLPAGPipeline"]
|
||||
@@ -37,7 +36,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .pipeline_pag_controlnet_sd import StableDiffusionControlNetPAGPipeline
|
||||
from .pipeline_pag_controlnet_sd_xl import StableDiffusionXLControlNetPAGPipeline
|
||||
from .pipeline_pag_sd import StableDiffusionPAGPipeline
|
||||
from .pipeline_pag_sd_xl import StableDiffusionXLPAGPipeline
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
+1
@@ -661,6 +661,7 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
|
||||
noise_guidance_edit_tmp = torch.einsum(
|
||||
"cb,cbijk->bijk", concept_weights_tmp, noise_guidance_edit_tmp
|
||||
)
|
||||
noise_guidance_edit_tmp = noise_guidance_edit_tmp
|
||||
noise_guidance = noise_guidance + noise_guidance_edit_tmp
|
||||
|
||||
self.sem_guidance[i] = noise_guidance_edit_tmp.detach().cpu()
|
||||
|
||||
@@ -25,7 +25,6 @@ except OptionalDependencyNotAvailable:
|
||||
else:
|
||||
_import_structure["pipeline_stable_diffusion_3"] = ["StableDiffusion3Pipeline"]
|
||||
_import_structure["pipeline_stable_diffusion_3_img2img"] = ["StableDiffusion3Img2ImgPipeline"]
|
||||
_import_structure["pipeline_stable_diffusion_3_inpaint"] = ["StableDiffusion3InpaintPipeline"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
@@ -36,7 +35,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
else:
|
||||
from .pipeline_stable_diffusion_3 import StableDiffusion3Pipeline
|
||||
from .pipeline_stable_diffusion_3_img2img import StableDiffusion3Img2ImgPipeline
|
||||
from .pipeline_stable_diffusion_3_inpaint import StableDiffusion3InpaintPipeline
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -57,7 +57,6 @@ else:
|
||||
_import_structure["scheduling_euler_ancestral_discrete"] = ["EulerAncestralDiscreteScheduler"]
|
||||
_import_structure["scheduling_euler_discrete"] = ["EulerDiscreteScheduler"]
|
||||
_import_structure["scheduling_flow_match_euler_discrete"] = ["FlowMatchEulerDiscreteScheduler"]
|
||||
_import_structure["scheduling_flow_match_heun_discrete"] = ["FlowMatchHeunDiscreteScheduler"]
|
||||
_import_structure["scheduling_heun_discrete"] = ["HeunDiscreteScheduler"]
|
||||
_import_structure["scheduling_ipndm"] = ["IPNDMScheduler"]
|
||||
_import_structure["scheduling_k_dpm_2_ancestral_discrete"] = ["KDPM2AncestralDiscreteScheduler"]
|
||||
@@ -154,7 +153,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler
|
||||
from .scheduling_euler_discrete import EulerDiscreteScheduler
|
||||
from .scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
|
||||
from .scheduling_flow_match_heun_discrete import FlowMatchHeunDiscreteScheduler
|
||||
from .scheduling_heun_discrete import HeunDiscreteScheduler
|
||||
from .scheduling_ipndm import IPNDMScheduler
|
||||
from .scheduling_k_dpm_2_ancestral_discrete import KDPM2AncestralDiscreteScheduler
|
||||
|
||||
@@ -377,7 +377,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
Whether or not to return a [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`.
|
||||
|
||||
Returns:
|
||||
[`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`:
|
||||
[`~schedulers.scheduling_utils.DDIMSchedulerOutput`] or `tuple`:
|
||||
If return_dict is `True`, [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] is returned, otherwise a
|
||||
tuple is returned where the first element is the sample tensor.
|
||||
|
||||
|
||||
@@ -194,7 +194,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
sample_max_value: float = 1.0,
|
||||
timestep_spacing: str = "leading",
|
||||
steps_offset: int = 0,
|
||||
rescale_betas_zero_snr: bool = False,
|
||||
rescale_betas_zero_snr: int = False,
|
||||
):
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
||||
|
||||
@@ -202,7 +202,7 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
|
||||
sample_max_value: float = 1.0,
|
||||
timestep_spacing: str = "leading",
|
||||
steps_offset: int = 0,
|
||||
rescale_betas_zero_snr: bool = False,
|
||||
rescale_betas_zero_snr: int = False,
|
||||
):
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -158,12 +158,7 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
def _sigma_to_t(self, sigma):
|
||||
return sigma * self.config.num_train_timesteps
|
||||
|
||||
def set_timesteps(
|
||||
self,
|
||||
num_inference_steps: int = None,
|
||||
device: Union[str, torch.device] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
):
|
||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||
|
||||
@@ -173,19 +168,17 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
"""
|
||||
self.num_inference_steps = num_inference_steps
|
||||
|
||||
if sigmas is None:
|
||||
self.num_inference_steps = num_inference_steps
|
||||
timesteps = np.linspace(
|
||||
self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps
|
||||
)
|
||||
|
||||
sigmas = timesteps / self.config.num_train_timesteps
|
||||
sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas)
|
||||
timesteps = np.linspace(
|
||||
self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps
|
||||
)
|
||||
|
||||
sigmas = timesteps / self.config.num_train_timesteps
|
||||
sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas)
|
||||
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
|
||||
timesteps = sigmas * self.config.num_train_timesteps
|
||||
|
||||
timesteps = sigmas * self.config.num_train_timesteps
|
||||
self.timesteps = timesteps.to(device=device)
|
||||
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
|
||||
|
||||
|
||||
@@ -1,321 +0,0 @@
|
||||
# Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import BaseOutput, logging
|
||||
from ..utils.torch_utils import randn_tensor
|
||||
from .scheduling_utils import SchedulerMixin
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlowMatchHeunDiscreteSchedulerOutput(BaseOutput):
|
||||
"""
|
||||
Output class for the scheduler's `step` function output.
|
||||
|
||||
Args:
|
||||
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
||||
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
|
||||
denoising loop.
|
||||
"""
|
||||
|
||||
prev_sample: torch.FloatTensor
|
||||
|
||||
|
||||
class FlowMatchHeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
Heun scheduler.
|
||||
|
||||
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
|
||||
methods the library implements for all schedulers such as loading and saving.
|
||||
|
||||
Args:
|
||||
num_train_timesteps (`int`, defaults to 1000):
|
||||
The number of diffusion steps to train the model.
|
||||
timestep_spacing (`str`, defaults to `"linspace"`):
|
||||
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
|
||||
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
||||
shift (`float`, defaults to 1.0):
|
||||
The shift value for the timestep schedule.
|
||||
"""
|
||||
|
||||
_compatibles = []
|
||||
order = 2
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_train_timesteps: int = 1000,
|
||||
shift: float = 1.0,
|
||||
):
|
||||
timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
|
||||
timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
|
||||
|
||||
sigmas = timesteps / num_train_timesteps
|
||||
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
|
||||
|
||||
self.timesteps = sigmas * num_train_timesteps
|
||||
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
|
||||
self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
self.sigma_min = self.sigmas[-1].item()
|
||||
self.sigma_max = self.sigmas[0].item()
|
||||
|
||||
@property
|
||||
def step_index(self):
|
||||
"""
|
||||
The index counter for current timestep. It will increase 1 after each scheduler step.
|
||||
"""
|
||||
return self._step_index
|
||||
|
||||
@property
|
||||
def begin_index(self):
|
||||
"""
|
||||
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
||||
"""
|
||||
return self._begin_index
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
|
||||
def set_begin_index(self, begin_index: int = 0):
|
||||
"""
|
||||
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
||||
|
||||
Args:
|
||||
begin_index (`int`):
|
||||
The begin index for the scheduler.
|
||||
"""
|
||||
self._begin_index = begin_index
|
||||
|
||||
def scale_noise(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
timestep: Union[float, torch.FloatTensor],
|
||||
noise: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
Forward process in flow-matching
|
||||
|
||||
Args:
|
||||
sample (`torch.FloatTensor`):
|
||||
The input sample.
|
||||
timestep (`int`, *optional*):
|
||||
The current timestep in the diffusion chain.
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`:
|
||||
A scaled input sample.
|
||||
"""
|
||||
if self.step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
|
||||
sigma = self.sigmas[self.step_index]
|
||||
sample = sigma * noise + (1.0 - sigma) * sample
|
||||
|
||||
return sample
|
||||
|
||||
def _sigma_to_t(self, sigma):
|
||||
return sigma * self.config.num_train_timesteps
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||
|
||||
Args:
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
"""
|
||||
self.num_inference_steps = num_inference_steps
|
||||
|
||||
timesteps = np.linspace(
|
||||
self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps
|
||||
)
|
||||
|
||||
sigmas = timesteps / self.config.num_train_timesteps
|
||||
sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas)
|
||||
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
|
||||
|
||||
timesteps = sigmas * self.config.num_train_timesteps
|
||||
timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2)])
|
||||
self.timesteps = timesteps.to(device=device)
|
||||
|
||||
sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
|
||||
self.sigmas = torch.cat([sigmas[:1], sigmas[1:-1].repeat_interleave(2), sigmas[-1:]])
|
||||
|
||||
# empty dt and derivative
|
||||
self.prev_derivative = None
|
||||
self.dt = None
|
||||
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
|
||||
indices = (schedule_timesteps == timestep).nonzero()
|
||||
|
||||
# The sigma index that is taken for the **very** first `step`
|
||||
# is always the second index (or the last index if there is only 1)
|
||||
# This way we can ensure we don't accidentally skip a sigma in
|
||||
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
||||
pos = 1 if len(indices) > 1 else 0
|
||||
|
||||
return indices[pos].item()
|
||||
|
||||
def _init_step_index(self, timestep):
|
||||
if self.begin_index is None:
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
self._step_index = self.index_for_timestep(timestep)
|
||||
else:
|
||||
self._step_index = self._begin_index
|
||||
|
||||
@property
|
||||
def state_in_first_order(self):
|
||||
return self.dt is None
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
timestep: Union[float, torch.FloatTensor],
|
||||
sample: torch.FloatTensor,
|
||||
s_churn: float = 0.0,
|
||||
s_tmin: float = 0.0,
|
||||
s_tmax: float = float("inf"),
|
||||
s_noise: float = 1.0,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[FlowMatchHeunDiscreteSchedulerOutput, Tuple]:
|
||||
"""
|
||||
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
|
||||
process from the learned model outputs (most often the predicted noise).
|
||||
|
||||
Args:
|
||||
model_output (`torch.FloatTensor`):
|
||||
The direct output from learned diffusion model.
|
||||
timestep (`float`):
|
||||
The current discrete timestep in the diffusion chain.
|
||||
sample (`torch.FloatTensor`):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
s_churn (`float`):
|
||||
s_tmin (`float`):
|
||||
s_tmax (`float`):
|
||||
s_noise (`float`, defaults to 1.0):
|
||||
Scaling factor for noise added to the sample.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A random number generator.
|
||||
return_dict (`bool`):
|
||||
Whether or not to return a [`~schedulers.scheduling_Heun_discrete.HeunDiscreteSchedulerOutput`] or
|
||||
tuple.
|
||||
|
||||
Returns:
|
||||
[`~schedulers.scheduling_Heun_discrete.HeunDiscreteSchedulerOutput`] or `tuple`:
|
||||
If return_dict is `True`, [`~schedulers.scheduling_Heun_discrete.HeunDiscreteSchedulerOutput`] is
|
||||
returned, otherwise a tuple is returned where the first element is the sample tensor.
|
||||
"""
|
||||
|
||||
if (
|
||||
isinstance(timestep, int)
|
||||
or isinstance(timestep, torch.IntTensor)
|
||||
or isinstance(timestep, torch.LongTensor)
|
||||
):
|
||||
raise ValueError(
|
||||
(
|
||||
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
|
||||
" `HeunDiscreteScheduler.step()` is not supported. Make sure to pass"
|
||||
" one of the `scheduler.timesteps` as a timestep."
|
||||
),
|
||||
)
|
||||
|
||||
if self.step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
|
||||
# Upcast to avoid precision issues when computing prev_sample
|
||||
sample = sample.to(torch.float32)
|
||||
|
||||
if self.state_in_first_order:
|
||||
sigma = self.sigmas[self.step_index]
|
||||
sigma_next = self.sigmas[self.step_index + 1]
|
||||
else:
|
||||
# 2nd order / Heun's method
|
||||
sigma = self.sigmas[self.step_index - 1]
|
||||
sigma_next = self.sigmas[self.step_index]
|
||||
|
||||
gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0
|
||||
|
||||
noise = randn_tensor(
|
||||
model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
|
||||
)
|
||||
|
||||
eps = noise * s_noise
|
||||
sigma_hat = sigma * (gamma + 1)
|
||||
|
||||
if gamma > 0:
|
||||
sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
|
||||
|
||||
if self.state_in_first_order:
|
||||
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
|
||||
denoised = sample - model_output * sigma
|
||||
# 2. convert to an ODE derivative for 1st order
|
||||
derivative = (sample - denoised) / sigma_hat
|
||||
# 3. Delta timestep
|
||||
dt = sigma_next - sigma_hat
|
||||
|
||||
# store for 2nd order step
|
||||
self.prev_derivative = derivative
|
||||
self.dt = dt
|
||||
self.sample = sample
|
||||
else:
|
||||
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
|
||||
denoised = sample - model_output * sigma_next
|
||||
# 2. 2nd order / Heun's method
|
||||
derivative = (sample - denoised) / sigma_next
|
||||
derivative = 0.5 * (self.prev_derivative + derivative)
|
||||
|
||||
# 3. take prev timestep & sample
|
||||
dt = self.dt
|
||||
sample = self.sample
|
||||
|
||||
# free dt and derivative
|
||||
# Note, this puts the scheduler in "first order mode"
|
||||
self.prev_derivative = None
|
||||
self.dt = None
|
||||
self.sample = None
|
||||
|
||||
prev_sample = sample + derivative * dt
|
||||
# Cast sample back to model compatible dtype
|
||||
prev_sample = prev_sample.to(model_output.dtype)
|
||||
|
||||
# upon completion increase step index by one
|
||||
self._step_index += 1
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
|
||||
return FlowMatchHeunDiscreteSchedulerOutput(prev_sample=prev_sample)
|
||||
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
@@ -17,21 +17,6 @@ class AsymmetricAutoencoderKL(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class AuraFlowTransformer2DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class AutoencoderKL(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@@ -212,36 +197,6 @@ class Kandinsky3UNet(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class LatteTransformer3DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class LuminaNextDiT2DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class ModelMixin(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@@ -1140,21 +1095,6 @@ class FlowMatchEulerDiscreteScheduler(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class FlowMatchHeunDiscreteScheduler(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class HeunDiscreteScheduler(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
@@ -182,51 +182,6 @@ class AudioLDMPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class AuraFlowPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class ChatGLMModel(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class ChatGLMTokenizer(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class CLIPImageProjection(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
@@ -692,36 +647,6 @@ class KandinskyV22PriorPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class KolorsImg2ImgPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class KolorsPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class LatentConsistencyModelImg2ImgPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
@@ -752,21 +677,6 @@ class LatentConsistencyModelPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class LattePipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class LDMTextToImagePipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
@@ -812,21 +722,6 @@ class LEditsPPPipelineStableDiffusionXL(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class LuminaText2ImgPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class MarigoldDepthPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
@@ -1052,21 +947,6 @@ class StableDiffusion3Img2ImgPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class StableDiffusion3InpaintPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class StableDiffusion3Pipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
@@ -1142,21 +1022,6 @@ class StableDiffusionControlNetInpaintPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class StableDiffusionControlNetPAGPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class StableDiffusionControlNetPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
@@ -455,13 +455,10 @@ def _get_checkpoint_shard_files(
|
||||
|
||||
# At this stage pretrained_model_name_or_path is a model identifier on the Hub
|
||||
allow_patterns = original_shard_filenames
|
||||
if subfolder is not None:
|
||||
allow_patterns = [os.path.join(subfolder, p) for p in allow_patterns]
|
||||
|
||||
ignore_patterns = ["*.json", "*.md"]
|
||||
if not local_files_only:
|
||||
# `model_info` call must guarded with the above condition.
|
||||
model_files_info = model_info(pretrained_model_name_or_path, revision=revision)
|
||||
model_files_info = model_info(pretrained_model_name_or_path)
|
||||
for shard_file in original_shard_filenames:
|
||||
shard_file_present = any(shard_file in k.rfilename for k in model_files_info.siblings)
|
||||
if not shard_file_present:
|
||||
@@ -484,8 +481,6 @@ def _get_checkpoint_shard_files(
|
||||
ignore_patterns=ignore_patterns,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
if subfolder is not None:
|
||||
cached_folder = os.path.join(cached_folder, subfolder)
|
||||
|
||||
# We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so
|
||||
# we don't have to catch them here. We have also dealt with EntryNotFoundError.
|
||||
|
||||
@@ -153,6 +153,7 @@ class SD3LoRATests(unittest.TestCase):
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
pipe.transformer.add_adapter(transformer_config)
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
|
||||
|
||||
@@ -361,10 +361,9 @@ class ConsistencyDecoderVAETests(ModelTesterMixin, unittest.TestCase):
|
||||
forward_requires_fresh_args = True
|
||||
|
||||
def inputs_dict(self, seed=None):
|
||||
if seed is None:
|
||||
generator = torch.Generator("cpu").manual_seed(0)
|
||||
else:
|
||||
generator = torch.Generator("cpu").manual_seed(seed)
|
||||
generator = torch.Generator("cpu")
|
||||
if seed is not None:
|
||||
generator.manual_seed(0)
|
||||
image = randn_tensor((4, 3, 32, 32), generator=generator, device=torch.device(torch_device))
|
||||
|
||||
return {"sample": image, "generator": generator}
|
||||
|
||||
@@ -885,11 +885,11 @@ class ModelTesterMixin:
|
||||
|
||||
@require_torch_gpu
|
||||
def test_sharded_checkpoints(self):
|
||||
torch.manual_seed(0)
|
||||
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**config).eval()
|
||||
model = model.to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
base_output = model(**inputs_dict)
|
||||
|
||||
model_size = compute_module_sizes(model)[""]
|
||||
@@ -905,14 +905,11 @@ class ModelTesterMixin:
|
||||
actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(".safetensors")])
|
||||
self.assertTrue(actual_num_shards == expected_num_shards)
|
||||
|
||||
new_model = self.model_class.from_pretrained(tmp_dir).eval()
|
||||
new_model = self.model_class.from_pretrained(tmp_dir)
|
||||
new_model = new_model.to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
if "generator" in inputs_dict:
|
||||
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
new_output = new_model(**inputs_dict)
|
||||
|
||||
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
|
||||
|
||||
@require_torch_gpu
|
||||
@@ -943,8 +940,6 @@ class ModelTesterMixin:
|
||||
new_model = new_model.to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
if "generator" in inputs_dict:
|
||||
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
new_output = new_model(**inputs_dict)
|
||||
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
|
||||
|
||||
|
||||
@@ -144,6 +144,9 @@ class PriorTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
class PriorTransformerIntegrationTests(unittest.TestCase):
|
||||
def get_dummy_seed_input(self, batch_size=1, embedding_dim=768, num_embeddings=77, seed=0):
|
||||
torch.manual_seed(seed)
|
||||
batch_size = batch_size
|
||||
embedding_dim = embedding_dim
|
||||
num_embeddings = num_embeddings
|
||||
|
||||
hidden_states = torch.randn((batch_size, embedding_dim)).to(torch_device)
|
||||
|
||||
|
||||
@@ -1,73 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import AuraFlowTransformer2DModel
|
||||
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
|
||||
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class SD3TransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = AuraFlowTransformer2DModel
|
||||
main_input_name = "hidden_states"
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 2
|
||||
num_channels = 4
|
||||
height = width = embedding_dim = 32
|
||||
sequence_length = 256
|
||||
|
||||
hidden_states = torch.randn((batch_size, num_channels, height, width)).to(torch_device)
|
||||
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
|
||||
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
|
||||
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"timestep": timestep,
|
||||
}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (4, 32, 32)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (4, 32, 32)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"sample_size": 32,
|
||||
"patch_size": 2,
|
||||
"in_channels": 4,
|
||||
"num_mmdit_layers": 1,
|
||||
"num_single_dit_layers": 1,
|
||||
"attention_head_dim": 8,
|
||||
"num_attention_heads": 4,
|
||||
"caption_projection_dim": 32,
|
||||
"joint_attention_dim": 32,
|
||||
"out_channels": 4,
|
||||
"pos_embed_max_size": 256,
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
@@ -1045,18 +1045,6 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
|
||||
assert loaded_model
|
||||
assert new_output.sample.shape == (4, 4, 16, 16)
|
||||
|
||||
@require_torch_gpu
|
||||
def test_load_sharded_checkpoint_from_hub_subfolder(self):
|
||||
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
loaded_model = self.model_class.from_pretrained(
|
||||
"hf-internal-testing/unet2d-sharded-dummy-subfolder", subfolder="unet"
|
||||
)
|
||||
loaded_model = loaded_model.to(torch_device)
|
||||
new_output = loaded_model(**inputs_dict)
|
||||
|
||||
assert loaded_model
|
||||
assert new_output.sample.shape == (4, 4, 16, 16)
|
||||
|
||||
@require_torch_gpu
|
||||
def test_load_sharded_checkpoint_from_hub_local(self):
|
||||
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
@@ -1,121 +0,0 @@
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import AutoTokenizer, UMT5EncoderModel
|
||||
|
||||
from diffusers import AuraFlowPipeline, AuraFlowTransformer2DModel, AutoencoderKL, FlowMatchEulerDiscreteScheduler
|
||||
from diffusers.utils.testing_utils import (
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ..test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
|
||||
class AuraFlowPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
|
||||
pipeline_class = AuraFlowPipeline
|
||||
params = frozenset(
|
||||
[
|
||||
"prompt",
|
||||
"height",
|
||||
"width",
|
||||
"guidance_scale",
|
||||
"negative_prompt",
|
||||
"prompt_embeds",
|
||||
"negative_prompt_embeds",
|
||||
]
|
||||
)
|
||||
batch_params = frozenset(["prompt", "negative_prompt"])
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
transformer = AuraFlowTransformer2DModel(
|
||||
sample_size=32,
|
||||
patch_size=2,
|
||||
in_channels=4,
|
||||
num_mmdit_layers=1,
|
||||
num_single_dit_layers=1,
|
||||
attention_head_dim=8,
|
||||
num_attention_heads=4,
|
||||
caption_projection_dim=32,
|
||||
joint_attention_dim=32,
|
||||
out_channels=4,
|
||||
pos_embed_max_size=256,
|
||||
)
|
||||
|
||||
text_encoder = UMT5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-umt5")
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKL(
|
||||
block_out_channels=[32, 64],
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||
latent_channels=4,
|
||||
sample_size=32,
|
||||
)
|
||||
|
||||
scheduler = FlowMatchEulerDiscreteScheduler()
|
||||
|
||||
return {
|
||||
"scheduler": scheduler,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"transformer": transformer,
|
||||
"vae": vae,
|
||||
}
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device="cpu").manual_seed(seed)
|
||||
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 5.0,
|
||||
"output_type": "np",
|
||||
"height": None,
|
||||
"width": None,
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_aura_flow_prompt_embeds(self):
|
||||
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
output_with_prompt = pipe(**inputs).images[0]
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
prompt = inputs.pop("prompt")
|
||||
|
||||
do_classifier_free_guidance = inputs["guidance_scale"] > 1
|
||||
(
|
||||
prompt_embeds,
|
||||
prompt_attention_mask,
|
||||
negative_prompt_embeds,
|
||||
negative_prompt_attention_mask,
|
||||
) = pipe.encode_prompt(
|
||||
prompt,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
device=torch_device,
|
||||
)
|
||||
output_with_embeds = pipe(
|
||||
prompt_embeds=prompt_embeds,
|
||||
prompt_attention_mask=prompt_attention_mask,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
||||
**inputs,
|
||||
).images[0]
|
||||
|
||||
max_diff = np.abs(output_with_prompt - output_with_embeds).max()
|
||||
assert max_diff < 1e-4
|
||||
|
||||
def test_attention_slicing_forward_pass(self):
|
||||
# Attention slicing needs to implemented differently for this because how single DiT and MMDiT
|
||||
# blocks interfere with each other.
|
||||
return
|
||||
@@ -1,152 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
ChatGLMModel,
|
||||
ChatGLMTokenizer,
|
||||
EulerDiscreteScheduler,
|
||||
KolorsPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.utils.testing_utils import enable_full_determinism
|
||||
|
||||
from ..pipeline_params import (
|
||||
TEXT_TO_IMAGE_BATCH_PARAMS,
|
||||
TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS,
|
||||
TEXT_TO_IMAGE_IMAGE_PARAMS,
|
||||
TEXT_TO_IMAGE_PARAMS,
|
||||
)
|
||||
from ..test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class KolorsPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = KolorsPipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS.union({"add_text_embeds", "add_time_ids"})
|
||||
|
||||
def get_dummy_components(self, time_cond_proj_dim=None):
|
||||
torch.manual_seed(0)
|
||||
unet = UNet2DConditionModel(
|
||||
block_out_channels=(2, 4),
|
||||
layers_per_block=2,
|
||||
time_cond_proj_dim=time_cond_proj_dim,
|
||||
sample_size=32,
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
|
||||
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
|
||||
# specific config below
|
||||
attention_head_dim=(2, 4),
|
||||
use_linear_projection=True,
|
||||
addition_embed_type="text_time",
|
||||
addition_time_embed_dim=8,
|
||||
transformer_layers_per_block=(1, 2),
|
||||
projection_class_embeddings_input_dim=56,
|
||||
cross_attention_dim=8,
|
||||
norm_num_groups=1,
|
||||
)
|
||||
scheduler = EulerDiscreteScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
steps_offset=1,
|
||||
beta_schedule="scaled_linear",
|
||||
timestep_spacing="leading",
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKL(
|
||||
block_out_channels=[32, 64],
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||
latent_channels=4,
|
||||
sample_size=128,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
text_encoder = ChatGLMModel.from_pretrained("hf-internal-testing/tiny-random-chatglm3-6b")
|
||||
tokenizer = ChatGLMTokenizer.from_pretrained("hf-internal-testing/tiny-random-chatglm3-6b")
|
||||
|
||||
components = {
|
||||
"unet": unet,
|
||||
"scheduler": scheduler,
|
||||
"vae": vae,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 5.0,
|
||||
"output_type": "np",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_inference(self):
|
||||
device = "cpu"
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = pipe(**inputs).images
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
self.assertEqual(image.shape, (1, 64, 64, 3))
|
||||
expected_slice = np.array(
|
||||
[0.26413745, 0.4425478, 0.4102801, 0.42693347, 0.52529025, 0.3867405, 0.47512037, 0.41538602, 0.43855375]
|
||||
)
|
||||
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
|
||||
self.assertLessEqual(max_diff, 1e-3)
|
||||
|
||||
# should skip it but pipe._optional_components = [] so it doesn't
|
||||
def test_save_load_optional_components(self):
|
||||
pass
|
||||
|
||||
# throws AttributeError: property 'eos_token' of 'ChatGLMTokenizer' object has no setter
|
||||
# not sure if it is worth to fix it before integrating it to transformers
|
||||
def test_save_load_float16(self):
|
||||
# TODO (Alvaro) need to fix later
|
||||
pass
|
||||
|
||||
# throws AttributeError: property 'eos_token' of 'ChatGLMTokenizer' object has no setter
|
||||
# not sure if it is worth to fix it before integrating it to transformers
|
||||
def test_save_load_local(self):
|
||||
# TODO (Alvaro) need to fix later
|
||||
pass
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=5e-4)
|
||||
@@ -1,295 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 Latte Team and HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import inspect
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import AutoTokenizer, T5EncoderModel
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDIMScheduler,
|
||||
LattePipeline,
|
||||
LatteTransformer3DModel,
|
||||
)
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
|
||||
from ..test_pipelines_common import PipelineTesterMixin, to_np
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class LattePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = LattePipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
|
||||
required_optional_params = PipelineTesterMixin.required_optional_params
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
transformer = LatteTransformer3DModel(
|
||||
sample_size=8,
|
||||
num_layers=1,
|
||||
patch_size=2,
|
||||
attention_head_dim=8,
|
||||
num_attention_heads=3,
|
||||
caption_channels=32,
|
||||
in_channels=4,
|
||||
cross_attention_dim=24,
|
||||
out_channels=8,
|
||||
attention_bias=True,
|
||||
activation_fn="gelu-approximate",
|
||||
num_embeds_ada_norm=1000,
|
||||
norm_type="ada_norm_single",
|
||||
norm_elementwise_affine=False,
|
||||
norm_eps=1e-6,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKL()
|
||||
|
||||
scheduler = DDIMScheduler()
|
||||
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
components = {
|
||||
"transformer": transformer.eval(),
|
||||
"vae": vae.eval(),
|
||||
"scheduler": scheduler,
|
||||
"text_encoder": text_encoder.eval(),
|
||||
"tokenizer": tokenizer,
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"negative_prompt": "low quality",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 5.0,
|
||||
"height": 8,
|
||||
"width": 8,
|
||||
"video_length": 1,
|
||||
"output_type": "pt",
|
||||
"clean_caption": False,
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_inference(self):
|
||||
device = "cpu"
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
video = pipe(**inputs).frames
|
||||
generated_video = video[0]
|
||||
|
||||
self.assertEqual(generated_video.shape, (1, 3, 8, 8))
|
||||
expected_video = torch.randn(1, 3, 8, 8)
|
||||
max_diff = np.abs(generated_video - expected_video).max()
|
||||
self.assertLessEqual(max_diff, 1e10)
|
||||
|
||||
def test_callback_inputs(self):
|
||||
sig = inspect.signature(self.pipeline_class.__call__)
|
||||
has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters
|
||||
has_callback_step_end = "callback_on_step_end" in sig.parameters
|
||||
|
||||
if not (has_callback_tensor_inputs and has_callback_step_end):
|
||||
return
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
self.assertTrue(
|
||||
hasattr(pipe, "_callback_tensor_inputs"),
|
||||
f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
|
||||
)
|
||||
|
||||
def callback_inputs_subset(pipe, i, t, callback_kwargs):
|
||||
# iterate over callback args
|
||||
for tensor_name, tensor_value in callback_kwargs.items():
|
||||
# check that we're only passing in allowed tensor inputs
|
||||
assert tensor_name in pipe._callback_tensor_inputs
|
||||
|
||||
return callback_kwargs
|
||||
|
||||
def callback_inputs_all(pipe, i, t, callback_kwargs):
|
||||
for tensor_name in pipe._callback_tensor_inputs:
|
||||
assert tensor_name in callback_kwargs
|
||||
|
||||
# iterate over callback args
|
||||
for tensor_name, tensor_value in callback_kwargs.items():
|
||||
# check that we're only passing in allowed tensor inputs
|
||||
assert tensor_name in pipe._callback_tensor_inputs
|
||||
|
||||
return callback_kwargs
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
# Test passing in a subset
|
||||
inputs["callback_on_step_end"] = callback_inputs_subset
|
||||
inputs["callback_on_step_end_tensor_inputs"] = ["latents"]
|
||||
output = pipe(**inputs)[0]
|
||||
|
||||
# Test passing in a everything
|
||||
inputs["callback_on_step_end"] = callback_inputs_all
|
||||
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
|
||||
output = pipe(**inputs)[0]
|
||||
|
||||
def callback_inputs_change_tensor(pipe, i, t, callback_kwargs):
|
||||
is_last = i == (pipe.num_timesteps - 1)
|
||||
if is_last:
|
||||
callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"])
|
||||
return callback_kwargs
|
||||
|
||||
inputs["callback_on_step_end"] = callback_inputs_change_tensor
|
||||
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
|
||||
output = pipe(**inputs)[0]
|
||||
assert output.abs().sum() < 1e10
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-3)
|
||||
|
||||
def test_attention_slicing_forward_pass(self):
|
||||
pass
|
||||
|
||||
def test_save_load_optional_components(self):
|
||||
if not hasattr(self.pipeline_class, "_optional_components"):
|
||||
return
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
|
||||
for component in pipe.components.values():
|
||||
if hasattr(component, "set_default_attn_processor"):
|
||||
component.set_default_attn_processor()
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
prompt = inputs["prompt"]
|
||||
generator = inputs["generator"]
|
||||
|
||||
(
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
) = pipe.encode_prompt(prompt)
|
||||
|
||||
# inputs with prompt converted to embeddings
|
||||
inputs = {
|
||||
"prompt_embeds": prompt_embeds,
|
||||
"negative_prompt": None,
|
||||
"negative_prompt_embeds": negative_prompt_embeds,
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 5.0,
|
||||
"height": 8,
|
||||
"width": 8,
|
||||
"video_length": 1,
|
||||
"mask_feature": False,
|
||||
"output_type": "pt",
|
||||
"clean_caption": False,
|
||||
}
|
||||
|
||||
# set all optional components to None
|
||||
for optional_component in pipe._optional_components:
|
||||
setattr(pipe, optional_component, None)
|
||||
|
||||
output = pipe(**inputs)[0]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
pipe.save_pretrained(tmpdir, safe_serialization=False)
|
||||
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir)
|
||||
pipe_loaded.to(torch_device)
|
||||
|
||||
for component in pipe_loaded.components.values():
|
||||
if hasattr(component, "set_default_attn_processor"):
|
||||
component.set_default_attn_processor()
|
||||
|
||||
pipe_loaded.set_progress_bar_config(disable=None)
|
||||
|
||||
for optional_component in pipe._optional_components:
|
||||
self.assertTrue(
|
||||
getattr(pipe_loaded, optional_component) is None,
|
||||
f"`{optional_component}` did not stay set to None after loading.",
|
||||
)
|
||||
|
||||
output_loaded = pipe_loaded(**inputs)[0]
|
||||
|
||||
max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
|
||||
self.assertLess(max_diff, 1.0)
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class LattePipelineIntegrationTests(unittest.TestCase):
|
||||
prompt = "A painting of a squirrel eating a burger."
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def test_latte(self):
|
||||
generator = torch.Generator("cpu").manual_seed(0)
|
||||
|
||||
pipe = LattePipeline.from_pretrained("maxin-cn/Latte-1", torch_dtype=torch.float16)
|
||||
pipe.enable_model_cpu_offload()
|
||||
prompt = self.prompt
|
||||
|
||||
videos = pipe(
|
||||
prompt=prompt,
|
||||
height=512,
|
||||
width=512,
|
||||
generator=generator,
|
||||
num_inference_steps=2,
|
||||
clean_caption=False,
|
||||
).frames
|
||||
|
||||
video = videos[0]
|
||||
expected_video = torch.randn(1, 512, 512, 3).numpy()
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(video.flatten(), expected_video)
|
||||
assert max_diff < 1e-3, f"Max diff is too high. got {video.flatten()}"
|
||||
@@ -1,179 +0,0 @@
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import AutoTokenizer, GemmaConfig, GemmaForCausalLM
|
||||
|
||||
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, LuminaNextDiT2DModel, LuminaText2ImgPipeline
|
||||
from diffusers.utils.testing_utils import (
|
||||
numpy_cosine_similarity_distance,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ..test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
|
||||
class LuminaText2ImgPipelinePipelineFastTests(unittest.TestCase, PipelineTesterMixin):
|
||||
pipeline_class = LuminaText2ImgPipeline
|
||||
params = frozenset(
|
||||
[
|
||||
"prompt",
|
||||
"height",
|
||||
"width",
|
||||
"guidance_scale",
|
||||
"negative_prompt",
|
||||
"prompt_embeds",
|
||||
"negative_prompt_embeds",
|
||||
]
|
||||
)
|
||||
batch_params = frozenset(["prompt", "negative_prompt"])
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
transformer = LuminaNextDiT2DModel(
|
||||
sample_size=16,
|
||||
patch_size=2,
|
||||
in_channels=4,
|
||||
hidden_size=24,
|
||||
num_layers=2,
|
||||
num_attention_heads=3,
|
||||
num_kv_heads=1,
|
||||
multiple_of=16,
|
||||
ffn_dim_multiplier=None,
|
||||
norm_eps=1e-5,
|
||||
learn_sigma=True,
|
||||
qk_norm=True,
|
||||
cross_attention_dim=32,
|
||||
scaling_factor=1.0,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKL()
|
||||
|
||||
scheduler = FlowMatchEulerDiscreteScheduler()
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/dummy-gemma")
|
||||
|
||||
torch.manual_seed(0)
|
||||
config = GemmaConfig(
|
||||
head_dim=4,
|
||||
hidden_size=32,
|
||||
intermediate_size=37,
|
||||
num_attention_heads=4,
|
||||
num_hidden_layers=2,
|
||||
num_key_value_heads=4,
|
||||
)
|
||||
text_encoder = GemmaForCausalLM(config)
|
||||
|
||||
components = {
|
||||
"transformer": transformer.eval(),
|
||||
"vae": vae.eval(),
|
||||
"scheduler": scheduler,
|
||||
"text_encoder": text_encoder.eval(),
|
||||
"tokenizer": tokenizer,
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device="cpu").manual_seed(seed)
|
||||
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 5.0,
|
||||
"output_type": "np",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_lumina_prompt_embeds(self):
|
||||
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
output_with_prompt = pipe(**inputs).images[0]
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
prompt = inputs.pop("prompt")
|
||||
|
||||
do_classifier_free_guidance = inputs["guidance_scale"] > 1
|
||||
(
|
||||
prompt_embeds,
|
||||
prompt_attention_mask,
|
||||
negative_prompt_embeds,
|
||||
negative_prompt_attention_mask,
|
||||
) = pipe.encode_prompt(
|
||||
prompt,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
device=torch_device,
|
||||
)
|
||||
output_with_embeds = pipe(
|
||||
prompt_embeds=prompt_embeds,
|
||||
prompt_attention_mask=prompt_attention_mask,
|
||||
**inputs,
|
||||
).images[0]
|
||||
|
||||
max_diff = np.abs(output_with_prompt - output_with_embeds).max()
|
||||
assert max_diff < 1e-4
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class LuminaText2ImgPipelineSlowTests(unittest.TestCase):
|
||||
pipeline_class = LuminaText2ImgPipeline
|
||||
repo_id = "Alpha-VLLM/Lumina-Next-SFT-diffusers"
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device="cpu").manual_seed(seed)
|
||||
|
||||
return {
|
||||
"prompt": "A photo of a cat",
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 5.0,
|
||||
"output_type": "np",
|
||||
"generator": generator,
|
||||
}
|
||||
|
||||
def test_lumina_inference(self):
|
||||
pipe = self.pipeline_class.from_pretrained(self.repo_id, torch_dtype=torch.bfloat16)
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
inputs = self.get_inputs(torch_device)
|
||||
|
||||
image = pipe(**inputs).images[0]
|
||||
image_slice = image[0, :10, :10]
|
||||
expected_slice = np.array(
|
||||
[
|
||||
[0.17773438, 0.18554688, 0.22070312],
|
||||
[0.046875, 0.06640625, 0.10351562],
|
||||
[0.0, 0.0, 0.02148438],
|
||||
[0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0],
|
||||
],
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten())
|
||||
|
||||
assert max_diff < 1e-4
|
||||
@@ -1,248 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
ControlNetModel,
|
||||
DDIMScheduler,
|
||||
StableDiffusionControlNetPAGPipeline,
|
||||
StableDiffusionControlNetPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
)
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ..pipeline_params import (
|
||||
TEXT_TO_IMAGE_BATCH_PARAMS,
|
||||
TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS,
|
||||
TEXT_TO_IMAGE_IMAGE_PARAMS,
|
||||
TEXT_TO_IMAGE_PARAMS,
|
||||
)
|
||||
from ..test_pipelines_common import (
|
||||
IPAdapterTesterMixin,
|
||||
PipelineFromPipeTesterMixin,
|
||||
PipelineLatentTesterMixin,
|
||||
PipelineTesterMixin,
|
||||
)
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class StableDiffusionControlNetPAGPipelineFastTests(
|
||||
PipelineTesterMixin,
|
||||
IPAdapterTesterMixin,
|
||||
PipelineLatentTesterMixin,
|
||||
PipelineFromPipeTesterMixin,
|
||||
unittest.TestCase,
|
||||
):
|
||||
pipeline_class = StableDiffusionControlNetPAGPipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS.union({"pag_scale", "pag_adaptive_scale"})
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS.union({"add_text_embeds", "add_time_ids"})
|
||||
|
||||
def get_dummy_components(self, time_cond_proj_dim=None):
|
||||
# Copied from tests.pipelines.controlnet.test_controlnet_sdxl.StableDiffusionXLControlNetPipelineFastTests.get_dummy_components
|
||||
torch.manual_seed(0)
|
||||
unet = UNet2DConditionModel(
|
||||
block_out_channels=(4, 8),
|
||||
layers_per_block=2,
|
||||
sample_size=32,
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
|
||||
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
|
||||
cross_attention_dim=8,
|
||||
time_cond_proj_dim=time_cond_proj_dim,
|
||||
norm_num_groups=2,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
controlnet = ControlNetModel(
|
||||
block_out_channels=(4, 8),
|
||||
layers_per_block=2,
|
||||
in_channels=4,
|
||||
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
|
||||
conditioning_embedding_out_channels=(2, 4),
|
||||
cross_attention_dim=8,
|
||||
norm_num_groups=2,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
scheduler = DDIMScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
clip_sample=False,
|
||||
set_alpha_to_one=False,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKL(
|
||||
block_out_channels=[4, 8],
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||
latent_channels=4,
|
||||
norm_num_groups=2,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
text_encoder_config = CLIPTextConfig(
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
hidden_size=8,
|
||||
intermediate_size=16,
|
||||
layer_norm_eps=1e-05,
|
||||
num_attention_heads=2,
|
||||
num_hidden_layers=2,
|
||||
pad_token_id=1,
|
||||
vocab_size=1000,
|
||||
)
|
||||
text_encoder = CLIPTextModel(text_encoder_config)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
components = {
|
||||
"unet": unet,
|
||||
"controlnet": controlnet,
|
||||
"scheduler": scheduler,
|
||||
"vae": vae,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"safety_checker": None,
|
||||
"feature_extractor": None,
|
||||
"image_encoder": None,
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
|
||||
controlnet_embedder_scale_factor = 2
|
||||
image = randn_tensor(
|
||||
(1, 3, 32 * controlnet_embedder_scale_factor, 32 * controlnet_embedder_scale_factor),
|
||||
generator=generator,
|
||||
device=torch.device(device),
|
||||
)
|
||||
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 6.0,
|
||||
"pag_scale": 3.0,
|
||||
"output_type": "np",
|
||||
"image": image,
|
||||
}
|
||||
|
||||
return inputs
|
||||
|
||||
def test_pag_disable_enable(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
|
||||
# base pipeline (expect same output when pag is disabled)
|
||||
pipe_sd = StableDiffusionControlNetPipeline(**components)
|
||||
pipe_sd = pipe_sd.to(device)
|
||||
pipe_sd.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
del inputs["pag_scale"]
|
||||
assert (
|
||||
"pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
|
||||
), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
|
||||
out = pipe_sd(**inputs).images[0, -3:, -3:, -1]
|
||||
|
||||
# pag disabled with pag_scale=0.0
|
||||
pipe_pag = self.pipeline_class(**components)
|
||||
pipe_pag = pipe_pag.to(device)
|
||||
pipe_pag.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
inputs["pag_scale"] = 0.0
|
||||
out_pag_disabled = pipe_pag(**inputs).images[0, -3:, -3:, -1]
|
||||
|
||||
# pag enabled
|
||||
pipe_pag = self.pipeline_class(**components, pag_applied_layers=["mid", "up", "down"])
|
||||
pipe_pag = pipe_pag.to(device)
|
||||
pipe_pag.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
out_pag_enabled = pipe_pag(**inputs).images[0, -3:, -3:, -1]
|
||||
|
||||
assert np.abs(out.flatten() - out_pag_disabled.flatten()).max() < 1e-3
|
||||
assert np.abs(out.flatten() - out_pag_enabled.flatten()).max() > 1e-3
|
||||
|
||||
def test_pag_cfg(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
|
||||
pipe_pag = self.pipeline_class(**components, pag_applied_layers=["mid", "up", "down"])
|
||||
pipe_pag = pipe_pag.to(device)
|
||||
pipe_pag.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = pipe_pag(**inputs).images
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (
|
||||
1,
|
||||
64,
|
||||
64,
|
||||
3,
|
||||
), f"the shape of the output image should be (1, 64, 64, 3) but got {image.shape}"
|
||||
expected_slice = np.array(
|
||||
[0.45505235, 0.2785938, 0.16334778, 0.79689944, 0.53095645, 0.40135607, 0.7052706, 0.69065094, 0.41548574]
|
||||
)
|
||||
|
||||
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
|
||||
assert max_diff < 1e-3, f"output is different from expected, {image_slice.flatten()}"
|
||||
|
||||
def test_pag_uncond(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
|
||||
pipe_pag = self.pipeline_class(**components, pag_applied_layers=["mid", "up", "down"])
|
||||
pipe_pag = pipe_pag.to(device)
|
||||
pipe_pag.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
inputs["guidance_scale"] = 0.0
|
||||
image = pipe_pag(**inputs).images
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (
|
||||
1,
|
||||
64,
|
||||
64,
|
||||
3,
|
||||
), f"the shape of the output image should be (1, 64, 64, 3) but got {image.shape}"
|
||||
expected_slice = np.array(
|
||||
[0.45127502, 0.2797252, 0.15970308, 0.7993157, 0.5414344, 0.40160775, 0.7114598, 0.69803864, 0.4217583]
|
||||
)
|
||||
|
||||
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
|
||||
assert max_diff < 1e-3, f"output is different from expected, {image_slice.flatten()}"
|
||||
@@ -193,7 +193,7 @@ class StableDiffusionXLControlNetPAGPipelineFastTests(
|
||||
del inputs["pag_scale"]
|
||||
assert (
|
||||
"pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
|
||||
), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
|
||||
), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__calss__.__name__}."
|
||||
out = pipe_sd(**inputs).images[0, -3:, -3:, -1]
|
||||
|
||||
# pag disabled with pag_scale=0.0
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user