Compare commits
31 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 80a7854540 | |||
| 5ac3d644d2 | |||
| b8317da20f | |||
| a5fe2bd4fd | |||
| 153cf0c393 | |||
| 495fddb8ae | |||
| 367fdef96d | |||
| 82fa9df1c7 | |||
| fb229b54bb | |||
| 0a44380a36 | |||
| 2ed59c178d | |||
| 1f3e02f4da | |||
| 38a603939e | |||
| 169bb0df9c | |||
| f731664773 | |||
| 3dde07a647 | |||
| 701cf86654 | |||
| ca715a9771 | |||
| aa8e328328 | |||
| ff5f2ee505 | |||
| 46619ea717 | |||
| c76e1cc17e | |||
| 315e357a18 | |||
| 1f33ca276d | |||
| 41b0c473d2 | |||
| 0e232ac8c0 | |||
| 2557238b4d | |||
| d71fe55895 | |||
| 7ab424a15a | |||
| dd69b41834 | |||
| 406b1062f8 |
@@ -23,7 +23,7 @@ jobs:
|
||||
runs-on:
|
||||
group: aws-g6-4xlarge-plus
|
||||
container:
|
||||
image: diffusers/diffusers-pytorch-cuda
|
||||
image: diffusers/diffusers-pytorch-compile-cuda
|
||||
options: --shm-size "16gb" --ipc host --gpus 0
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
|
||||
@@ -38,16 +38,9 @@ jobs:
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Build Changed Docker Images
|
||||
env:
|
||||
CHANGED_FILES: "${{ steps.file_changes.outputs.all }}"
|
||||
run: |
|
||||
CHANGED_FILES="${{ steps.file_changes.outputs.all }}"
|
||||
for FILE in $CHANGED_FILES; do
|
||||
# skip anything that isn’t still on disk
|
||||
if [[ ! -f "$FILE" ]]; then
|
||||
echo "Skipping removed file $FILE"
|
||||
continue
|
||||
fi
|
||||
|
||||
if [[ "$FILE" == docker/*Dockerfile ]]; then
|
||||
DOCKER_PATH="${FILE%/Dockerfile}"
|
||||
DOCKER_TAG=$(basename "$DOCKER_PATH")
|
||||
@@ -72,7 +65,7 @@ jobs:
|
||||
image-name:
|
||||
- diffusers-pytorch-cpu
|
||||
- diffusers-pytorch-cuda
|
||||
- diffusers-pytorch-cuda
|
||||
- diffusers-pytorch-compile-cuda
|
||||
- diffusers-pytorch-xformers-cuda
|
||||
- diffusers-pytorch-minimum-cuda
|
||||
- diffusers-flax-cpu
|
||||
|
||||
@@ -188,7 +188,7 @@ jobs:
|
||||
group: aws-g4dn-2xlarge
|
||||
|
||||
container:
|
||||
image: diffusers/diffusers-pytorch-cuda
|
||||
image: diffusers/diffusers-pytorch-compile-cuda
|
||||
options: --gpus 0 --shm-size "16gb" --ipc host
|
||||
|
||||
steps:
|
||||
|
||||
@@ -262,7 +262,7 @@ jobs:
|
||||
group: aws-g4dn-2xlarge
|
||||
|
||||
container:
|
||||
image: diffusers/diffusers-pytorch-cuda
|
||||
image: diffusers/diffusers-pytorch-compile-cuda
|
||||
options: --gpus 0 --shm-size "16gb" --ipc host
|
||||
|
||||
steps:
|
||||
|
||||
@@ -316,7 +316,7 @@ jobs:
|
||||
group: aws-g4dn-2xlarge
|
||||
|
||||
container:
|
||||
image: diffusers/diffusers-pytorch-cuda
|
||||
image: diffusers/diffusers-pytorch-compile-cuda
|
||||
options: --gpus 0 --shm-size "16gb" --ipc host
|
||||
|
||||
steps:
|
||||
|
||||
@@ -0,0 +1,50 @@
|
||||
FROM nvidia/cuda:12.1.0-runtime-ubuntu20.04
|
||||
LABEL maintainer="Hugging Face"
|
||||
LABEL repository="diffusers"
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
RUN apt-get -y update \
|
||||
&& apt-get install -y software-properties-common \
|
||||
&& add-apt-repository ppa:deadsnakes/ppa
|
||||
|
||||
RUN apt install -y bash \
|
||||
build-essential \
|
||||
git \
|
||||
git-lfs \
|
||||
curl \
|
||||
ca-certificates \
|
||||
libsndfile1-dev \
|
||||
libgl1 \
|
||||
python3.10 \
|
||||
python3.10-dev \
|
||||
python3-pip \
|
||||
python3.10-venv && \
|
||||
rm -rf /var/lib/apt/lists
|
||||
|
||||
# make sure to use venv
|
||||
RUN python3.10 -m venv /opt/venv
|
||||
ENV PATH="/opt/venv/bin:$PATH"
|
||||
|
||||
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
|
||||
RUN python3.10 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
|
||||
python3.10 -m uv pip install --no-cache-dir \
|
||||
torch \
|
||||
torchvision \
|
||||
torchaudio \
|
||||
invisible_watermark && \
|
||||
python3.10 -m pip install --no-cache-dir \
|
||||
accelerate \
|
||||
datasets \
|
||||
hf-doc-builder \
|
||||
huggingface-hub \
|
||||
hf_transfer \
|
||||
Jinja2 \
|
||||
librosa \
|
||||
numpy==1.26.4 \
|
||||
scipy \
|
||||
tensorboard \
|
||||
transformers \
|
||||
hf_transfer
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
@@ -208,7 +208,7 @@
|
||||
- local: optimization/mps
|
||||
title: Metal Performance Shaders (MPS)
|
||||
- local: optimization/habana
|
||||
title: Intel Gaudi
|
||||
title: Habana Gaudi
|
||||
- local: optimization/neuron
|
||||
title: AWS Neuron
|
||||
title: Optimized hardware
|
||||
|
||||
+56
-33
@@ -11,33 +11,6 @@ specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# Caching methods
|
||||
|
||||
## Pyramid Attention Broadcast
|
||||
|
||||
[Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) from Xuanlei Zhao, Xiaolong Jin, Kai Wang, Yang You.
|
||||
|
||||
Pyramid Attention Broadcast (PAB) is a method that speeds up inference in diffusion models by systematically skipping attention computations between successive inference steps and reusing cached attention states. The attention states are not very different between successive inference steps. The most prominent difference is in the spatial attention blocks, not as much in the temporal attention blocks, and finally the least in the cross attention blocks. Therefore, many cross attention computation blocks can be skipped, followed by the temporal and spatial attention blocks. By combining other techniques like sequence parallelism and classifier-free guidance parallelism, PAB achieves near real-time video generation.
|
||||
|
||||
Enable PAB with [`~PyramidAttentionBroadcastConfig`] on any pipeline. For some benchmarks, refer to [this](https://github.com/huggingface/diffusers/pull/9562) pull request.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import CogVideoXPipeline, PyramidAttentionBroadcastConfig
|
||||
|
||||
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
|
||||
pipe.to("cuda")
|
||||
|
||||
# Increasing the value of `spatial_attention_timestep_skip_range[0]` or decreasing the value of
|
||||
# `spatial_attention_timestep_skip_range[1]` will decrease the interval in which pyramid attention
|
||||
# broadcast is active, leader to slower inference speeds. However, large intervals can lead to
|
||||
# poorer quality of generated videos.
|
||||
config = PyramidAttentionBroadcastConfig(
|
||||
spatial_attention_block_skip_range=2,
|
||||
spatial_attention_timestep_skip_range=(100, 800),
|
||||
current_timestep_callback=lambda: pipe.current_timestep,
|
||||
)
|
||||
pipe.transformer.enable_cache(config)
|
||||
```
|
||||
|
||||
## Faster Cache
|
||||
|
||||
[FasterCache](https://huggingface.co/papers/2410.19355) from Zhengyao Lv, Chenyang Si, Junhao Song, Zhenyu Yang, Yu Qiao, Ziwei Liu, Kwan-Yee K. Wong.
|
||||
@@ -65,18 +38,68 @@ config = FasterCacheConfig(
|
||||
pipe.transformer.enable_cache(config)
|
||||
```
|
||||
|
||||
## First Block Cache
|
||||
|
||||
[First Block Cache](https://github.com/chengzeyi/ParaAttention/blob/7a266123671b55e7e5a2fe9af3121f07a36afc78/README.md#first-block-cache-our-dynamic-caching) is a method that builds upon the ideas of [TeaCache](https://huggingface.co/papers/2411.19108) to speed up inference in diffusion transformers. The generation quality is superior with greatly reduced inference time. This method always computes the output of the first transformer block and computes the differences between past and current outputs of the first transformer block. If the difference is smaller than a predefined threshold, the computation of remaining transformer blocks is skipped, and otherwise the computation is performed as usual.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import CogVideoXPipeline, FirstBlockCacheConfig
|
||||
|
||||
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
|
||||
pipe.to("cuda")
|
||||
|
||||
# Increasing the threshold may lead to faster inference speeds, but may also lead to poorer quality of generated videos.
|
||||
# Smaller values between 0.02-0.20 are recommended based on the model being used. The default value is 0.05.
|
||||
config = FirstBlockCacheConfig(threshold=0.07)
|
||||
pipe.transformer.enable_cache(config)
|
||||
```
|
||||
|
||||
## Pyramid Attention Broadcast
|
||||
|
||||
[Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) from Xuanlei Zhao, Xiaolong Jin, Kai Wang, Yang You.
|
||||
|
||||
Pyramid Attention Broadcast (PAB) is a method that speeds up inference in diffusion models by systematically skipping attention computations between successive inference steps and reusing cached attention states. The attention states are not very different between successive inference steps. The most prominent difference is in the spatial attention blocks, not as much in the temporal attention blocks, and finally the least in the cross attention blocks. Therefore, many cross attention computation blocks can be skipped, followed by the temporal and spatial attention blocks. By combining other techniques like sequence parallelism and classifier-free guidance parallelism, PAB achieves near real-time video generation.
|
||||
|
||||
Enable PAB with [`~PyramidAttentionBroadcastConfig`] on any pipeline. For some benchmarks, refer to [this](https://github.com/huggingface/diffusers/pull/9562) pull request.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import CogVideoXPipeline, PyramidAttentionBroadcastConfig
|
||||
|
||||
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
|
||||
pipe.to("cuda")
|
||||
|
||||
# Increasing the value of `spatial_attention_timestep_skip_range[0]` or decreasing the value of
|
||||
# `spatial_attention_timestep_skip_range[1]` will decrease the interval in which pyramid attention
|
||||
# broadcast is active, leader to slower inference speeds. However, large intervals can lead to
|
||||
# poorer quality of generated videos.
|
||||
config = PyramidAttentionBroadcastConfig(
|
||||
spatial_attention_block_skip_range=2,
|
||||
spatial_attention_timestep_skip_range=(100, 800),
|
||||
current_timestep_callback=lambda: pipe.current_timestep,
|
||||
)
|
||||
pipe.transformer.enable_cache(config)
|
||||
```
|
||||
|
||||
### CacheMixin
|
||||
|
||||
[[autodoc]] CacheMixin
|
||||
|
||||
### PyramidAttentionBroadcastConfig
|
||||
|
||||
[[autodoc]] PyramidAttentionBroadcastConfig
|
||||
|
||||
[[autodoc]] apply_pyramid_attention_broadcast
|
||||
|
||||
### FasterCacheConfig
|
||||
|
||||
[[autodoc]] FasterCacheConfig
|
||||
|
||||
[[autodoc]] apply_faster_cache
|
||||
|
||||
### FirstBlockCacheConfig
|
||||
|
||||
[[autodoc]] FirstBlockCacheConfig
|
||||
|
||||
[[autodoc]] apply_first_block_cache
|
||||
|
||||
### PyramidAttentionBroadcastConfig
|
||||
|
||||
[[autodoc]] PyramidAttentionBroadcastConfig
|
||||
|
||||
[[autodoc]] apply_pyramid_attention_broadcast
|
||||
|
||||
@@ -10,22 +10,67 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Intel Gaudi
|
||||
# Habana Gaudi
|
||||
|
||||
The Intel Gaudi AI accelerator family includes [Intel Gaudi 1](https://habana.ai/products/gaudi/), [Intel Gaudi 2](https://habana.ai/products/gaudi2/), and [Intel Gaudi 3](https://habana.ai/products/gaudi3/). Each server is equipped with 8 devices, known as Habana Processing Units (HPUs), providing 128GB of memory on Gaudi 3, 96GB on Gaudi 2, and 32GB on the first-gen Gaudi. For more details on the underlying hardware architecture, check out the [Gaudi Architecture](https://docs.habana.ai/en/latest/Gaudi_Overview/Gaudi_Architecture.html) overview.
|
||||
🤗 Diffusers is compatible with Habana Gaudi through 🤗 [Optimum](https://huggingface.co/docs/optimum/habana/usage_guides/stable_diffusion). Follow the [installation](https://docs.habana.ai/en/latest/Installation_Guide/index.html) guide to install the SynapseAI and Gaudi drivers, and then install Optimum Habana:
|
||||
|
||||
Diffusers pipelines can take advantage of HPU acceleration, even if a pipeline hasn't been added to [Optimum for Intel Gaudi](https://huggingface.co/docs/optimum/main/en/habana/index) yet, with the [GPU Migration Toolkit](https://docs.habana.ai/en/latest/PyTorch/PyTorch_Model_Porting/GPU_Migration_Toolkit/GPU_Migration_Toolkit.html).
|
||||
|
||||
Call `.to("hpu")` on your pipeline to move it to a HPU device as shown below for Flux:
|
||||
```py
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
|
||||
pipeline.to("hpu")
|
||||
|
||||
image = pipeline("An image of a squirrel in Picasso style").images[0]
|
||||
```bash
|
||||
python -m pip install --upgrade-strategy eager optimum[habana]
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
> For Gaudi-optimized diffusion pipeline implementations, we recommend using [Optimum for Intel Gaudi](https://huggingface.co/docs/optimum/main/en/habana/index).
|
||||
To generate images with Stable Diffusion 1 and 2 on Gaudi, you need to instantiate two instances:
|
||||
|
||||
- [`~optimum.habana.diffusers.GaudiStableDiffusionPipeline`], a pipeline for text-to-image generation.
|
||||
- [`~optimum.habana.diffusers.GaudiDDIMScheduler`], a Gaudi-optimized scheduler.
|
||||
|
||||
When you initialize the pipeline, you have to specify `use_habana=True` to deploy it on HPUs and to get the fastest possible generation, you should enable **HPU graphs** with `use_hpu_graphs=True`.
|
||||
|
||||
Finally, specify a [`~optimum.habana.GaudiConfig`] which can be downloaded from the [Habana](https://huggingface.co/Habana) organization on the Hub.
|
||||
|
||||
```python
|
||||
from optimum.habana import GaudiConfig
|
||||
from optimum.habana.diffusers import GaudiDDIMScheduler, GaudiStableDiffusionPipeline
|
||||
|
||||
model_name = "stabilityai/stable-diffusion-2-base"
|
||||
scheduler = GaudiDDIMScheduler.from_pretrained(model_name, subfolder="scheduler")
|
||||
pipeline = GaudiStableDiffusionPipeline.from_pretrained(
|
||||
model_name,
|
||||
scheduler=scheduler,
|
||||
use_habana=True,
|
||||
use_hpu_graphs=True,
|
||||
gaudi_config="Habana/stable-diffusion-2",
|
||||
)
|
||||
```
|
||||
|
||||
Now you can call the pipeline to generate images by batches from one or several prompts:
|
||||
|
||||
```python
|
||||
outputs = pipeline(
|
||||
prompt=[
|
||||
"High quality photo of an astronaut riding a horse in space",
|
||||
"Face of a yellow cat, high resolution, sitting on a park bench",
|
||||
],
|
||||
num_images_per_prompt=10,
|
||||
batch_size=4,
|
||||
)
|
||||
```
|
||||
|
||||
For more information, check out 🤗 Optimum Habana's [documentation](https://huggingface.co/docs/optimum/habana/usage_guides/stable_diffusion) and the [example](https://github.com/huggingface/optimum-habana/tree/main/examples/stable-diffusion) provided in the official GitHub repository.
|
||||
|
||||
## Benchmark
|
||||
|
||||
We benchmarked Habana's first-generation Gaudi and Gaudi2 with the [Habana/stable-diffusion](https://huggingface.co/Habana/stable-diffusion) and [Habana/stable-diffusion-2](https://huggingface.co/Habana/stable-diffusion-2) Gaudi configurations (mixed precision bf16/fp32) to demonstrate their performance.
|
||||
|
||||
For [Stable Diffusion v1.5](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) on 512x512 images:
|
||||
|
||||
| | Latency (batch size = 1) | Throughput |
|
||||
| ---------------------- |:------------------------:|:---------------------------:|
|
||||
| first-generation Gaudi | 3.80s | 0.308 images/s (batch size = 8) |
|
||||
| Gaudi2 | 1.33s | 1.081 images/s (batch size = 8) |
|
||||
|
||||
For [Stable Diffusion v2.1](https://huggingface.co/stabilityai/stable-diffusion-2-1) on 768x768 images:
|
||||
|
||||
| | Latency (batch size = 1) | Throughput |
|
||||
| ---------------------- |:------------------------:|:-------------------------------:|
|
||||
| first-generation Gaudi | 10.2s | 0.108 images/s (batch size = 4) |
|
||||
| Gaudi2 | 3.17s | 0.379 images/s (batch size = 8) |
|
||||
|
||||
@@ -13,30 +13,80 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
# Quantization
|
||||
|
||||
Quantization focuses on representing data with fewer bits while also trying to preserve the precision of the original data. This often means converting a data type to represent the same information with fewer bits. For example, if your model weights are stored as 32-bit floating points and they're quantized to 16-bit floating points, this halves the model size which makes it easier to store and reduces memory usage. Lower precision can also speedup inference because it takes less time to perform calculations with fewer bits.
|
||||
Quantization techniques focus on representing data with less information while also trying to not lose too much accuracy. This often means converting a data type to represent the same information with fewer bits. For example, if your model weights are stored as 32-bit floating points and they're quantized to 16-bit floating points, this halves the model size which makes it easier to store and reduces memory-usage. Lower precision can also speedup inference because it takes less time to perform calculations with fewer bits.
|
||||
|
||||
Diffusers supports multiple quantization backends to make large diffusion models like [Flux](../api/pipelines/flux) more accessible. This guide shows how to use the [`~quantizers.PipelineQuantizationConfig`] class to quantize a pipeline during its initialization from a pretrained or non-quantized checkpoint.
|
||||
<Tip>
|
||||
|
||||
Interested in adding a new quantization method to Diffusers? Refer to the [Contribute new quantization method guide](https://huggingface.co/docs/transformers/main/en/quantization/contribute) to learn more about adding a new quantization method.
|
||||
|
||||
</Tip>
|
||||
|
||||
<Tip>
|
||||
|
||||
If you are new to the quantization field, we recommend you to check out these beginner-friendly courses about quantization in collaboration with DeepLearning.AI:
|
||||
|
||||
* [Quantization Fundamentals with Hugging Face](https://www.deeplearning.ai/short-courses/quantization-fundamentals-with-hugging-face/)
|
||||
* [Quantization in Depth](https://www.deeplearning.ai/short-courses/quantization-in-depth/)
|
||||
|
||||
</Tip>
|
||||
|
||||
## When to use what?
|
||||
|
||||
Diffusers currently supports the following quantization methods.
|
||||
- [BitsandBytes](./bitsandbytes)
|
||||
- [TorchAO](./torchao)
|
||||
- [GGUF](./gguf)
|
||||
- [Quanto](./quanto.md)
|
||||
|
||||
[This resource](https://huggingface.co/docs/transformers/main/en/quantization/overview#when-to-use-what) provides a good overview of the pros and cons of different quantization techniques.
|
||||
|
||||
## Pipeline-level quantization
|
||||
|
||||
There are two ways you can use [`~quantizers.PipelineQuantizationConfig`] depending on the level of control you want over the quantization specifications of each model in the pipeline.
|
||||
Diffusers allows users to directly initialize pipelines from checkpoints that may contain quantized models ([example](https://huggingface.co/hf-internal-testing/flux.1-dev-nf4-pkg)). However, users may want to apply
|
||||
quantization on-the-fly when initializing a pipeline from a pre-trained and non-quantized checkpoint. You can
|
||||
do this with [`~quantizers.PipelineQuantizationConfig`].
|
||||
|
||||
- for more basic and simple use cases, you only need to define the `quant_backend`, `quant_kwargs`, and `components_to_quantize`
|
||||
- for more granular quantization control, provide a `quant_mapping` that provides the quantization specifications for the individual model components
|
||||
|
||||
### Simple quantization
|
||||
|
||||
Initialize [`~quantizers.PipelineQuantizationConfig`] with the following parameters.
|
||||
|
||||
- `quant_backend` specifies which quantization backend to use. Currently supported backends include: `bitsandbytes_4bit`, `bitsandbytes_8bit`, `gguf`, `quanto`, and `torchao`.
|
||||
- `quant_kwargs` contains the specific quantization arguments to use.
|
||||
- `components_to_quantize` specifies which components of the pipeline to quantize. Typically, you should quantize the most compute intensive components like the transformer. The text encoder is another component to consider quantizing if a pipeline has more than one such as [`FluxPipeline`]. The example below quantizes the T5 text encoder in [`FluxPipeline`] while keeping the CLIP model intact.
|
||||
Start by defining a `PipelineQuantizationConfig`:
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers.quantizers.quantization_config import QuantoConfig
|
||||
from diffusers.quantizers import PipelineQuantizationConfig
|
||||
from transformers import BitsAndBytesConfig
|
||||
|
||||
pipeline_quant_config = PipelineQuantizationConfig(
|
||||
quant_mapping={
|
||||
"transformer": QuantoConfig(weights_dtype="int8"),
|
||||
"text_encoder_2": BitsAndBytesConfig(
|
||||
load_in_4bit=True, compute_dtype=torch.bfloat16
|
||||
),
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
Then pass it to [`~DiffusionPipeline.from_pretrained`] and run inference:
|
||||
|
||||
```py
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
quantization_config=pipeline_quant_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
).to("cuda")
|
||||
|
||||
image = pipe("photo of a cute dog").images[0]
|
||||
```
|
||||
|
||||
This method allows for more granular control over the quantization specifications of individual
|
||||
model-level components of a pipeline. It also allows for different quantization backends for
|
||||
different components. In the above example, you used a combination of Quanto and BitsandBytes. However,
|
||||
one caveat of this method is that users need to know which components come from `transformers` to be able
|
||||
to import the right quantization config class.
|
||||
|
||||
The other method is simpler in terms of experience but is
|
||||
less-flexible. Start by defining a `PipelineQuantizationConfig` but in a different way:
|
||||
|
||||
```py
|
||||
pipeline_quant_config = PipelineQuantizationConfig(
|
||||
quant_backend="bitsandbytes_4bit",
|
||||
quant_kwargs={"load_in_4bit": True, "bnb_4bit_quant_type": "nf4", "bnb_4bit_compute_dtype": torch.bfloat16},
|
||||
@@ -44,89 +94,35 @@ pipeline_quant_config = PipelineQuantizationConfig(
|
||||
)
|
||||
```
|
||||
|
||||
Pass the `pipeline_quant_config` to [`~DiffusionPipeline.from_pretrained`] to quantize the pipeline.
|
||||
This `pipeline_quant_config` can now be passed to [`~DiffusionPipeline.from_pretrained`] similar to the above example.
|
||||
|
||||
In this case, `quant_kwargs` will be used to initialize the quantization specifications
|
||||
of the respective quantization configuration class of `quant_backend`. `components_to_quantize`
|
||||
is used to denote the components that will be quantized. For most pipelines, you would want to
|
||||
keep `transformer` in the list as that is often the most compute and memory intensive.
|
||||
|
||||
The config below will work for most diffusion pipelines that have a `transformer` component present.
|
||||
In most case, you will want to quantize the `transformer` component as that is often the most compute-
|
||||
intensive part of a diffusion pipeline.
|
||||
|
||||
```py
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
quantization_config=pipeline_quant_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
).to("cuda")
|
||||
|
||||
image = pipe("photo of a cute dog").images[0]
|
||||
```
|
||||
|
||||
### quant_mapping
|
||||
|
||||
The `quant_mapping` argument provides more flexible options for how to quantize each individual component in a pipeline, like combining different quantization backends.
|
||||
|
||||
Initialize [`~quantizers.PipelineQuantizationConfig`] and pass a `quant_mapping` to it. The `quant_mapping` allows you to specify the quantization options for each component in the pipeline such as the transformer and text encoder.
|
||||
|
||||
The example below uses two quantization backends, [`~quantizers.QuantoConfig`] and [`transformers.BitsAndBytesConfig`], for the transformer and text encoder.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
|
||||
from diffusers.quantizers.quantization_config import QuantoConfig
|
||||
from diffusers.quantizers import PipelineQuantizationConfig
|
||||
from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig
|
||||
|
||||
pipeline_quant_config = PipelineQuantizationConfig(
|
||||
quant_mapping={
|
||||
"transformer": QuantoConfig(weights_dtype="int8"),
|
||||
"text_encoder_2": TransformersBitsAndBytesConfig(
|
||||
load_in_4bit=True, compute_dtype=torch.bfloat16
|
||||
),
|
||||
}
|
||||
quant_backend="bitsandbytes_4bit",
|
||||
quant_kwargs={"load_in_4bit": True, "bnb_4bit_quant_type": "nf4", "bnb_4bit_compute_dtype": torch.bfloat16},
|
||||
components_to_quantize=["transformer"],
|
||||
)
|
||||
```
|
||||
|
||||
There is a separate bitsandbytes backend in [Transformers](https://huggingface.co/docs/transformers/main_classes/quantization#transformers.BitsAndBytesConfig). You need to import and use [`transformers.BitsAndBytesConfig`] for components that come from Transformers. For example, `text_encoder_2` in [`FluxPipeline`] is a [`~transformers.T5EncoderModel`] from Transformers so you need to use [`transformers.BitsAndBytesConfig`] instead of [`diffusers.BitsAndBytesConfig`].
|
||||
Below is a list of the supported quantization backends available in both `diffusers` and `transformers`:
|
||||
|
||||
> [!TIP]
|
||||
> Use the [simple quantization](#simple-quantization) method above if you don't want to manage these distinct imports or aren't sure where each pipeline component comes from.
|
||||
* `bitsandbytes_4bit`
|
||||
* `bitsandbytes_8bit`
|
||||
* `gguf`
|
||||
* `quanto`
|
||||
* `torchao`
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
|
||||
from diffusers.quantizers import PipelineQuantizationConfig
|
||||
from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig
|
||||
|
||||
pipeline_quant_config = PipelineQuantizationConfig(
|
||||
quant_mapping={
|
||||
"transformer": DiffusersBitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16),
|
||||
"text_encoder_2": TransformersBitsAndBytesConfig(
|
||||
load_in_4bit=True, compute_dtype=torch.bfloat16
|
||||
),
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
Pass the `pipeline_quant_config` to [`~DiffusionPipeline.from_pretrained`] to quantize the pipeline.
|
||||
|
||||
```py
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
quantization_config=pipeline_quant_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
).to("cuda")
|
||||
|
||||
image = pipe("photo of a cute dog").images[0]
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
Check out the resources below to learn more about quantization.
|
||||
|
||||
- If you are new to quantization, we recommend checking out the following beginner-friendly courses in collaboration with DeepLearning.AI.
|
||||
|
||||
- [Quantization Fundamentals with Hugging Face](https://www.deeplearning.ai/short-courses/quantization-fundamentals-with-hugging-face/)
|
||||
- [Quantization in Depth](https://www.deeplearning.ai/short-courses/quantization-in-depth/)
|
||||
|
||||
- Refer to the [Contribute new quantization method guide](https://huggingface.co/docs/transformers/main/en/quantization/contribute) if you're interested in adding a new quantization method.
|
||||
|
||||
- The Transformers quantization [Overview](https://huggingface.co/docs/transformers/quantization/overview#when-to-use-what) provides an overview of the pros and cons of different quantization backends.
|
||||
|
||||
- Read the [Exploring Quantization Backends in Diffusers](https://huggingface.co/blog/diffusers-quantization) blog post for a brief introduction to each quantization backend, how to choose a backend, and combining quantization with other memory optimizations.
|
||||
Diffusion pipelines can have multiple text encoders. [`FluxPipeline`] has two, for example. It's
|
||||
recommended to quantize the text encoders that are memory-intensive. Some examples include T5,
|
||||
Llama, Gemma, etc. In the above example, you quantized the T5 model of [`FluxPipeline`] through
|
||||
`text_encoder_2` while keeping the CLIP model intact (accessible through `text_encoder`).
|
||||
@@ -175,7 +175,7 @@
|
||||
- local: optimization/mps
|
||||
title: Metal Performance Shaders (MPS)
|
||||
- local: optimization/habana
|
||||
title: Intel Gaudi
|
||||
title: Habana Gaudi
|
||||
title: 최적화된 하드웨어
|
||||
title: 추론 가속화와 메모리 줄이기
|
||||
- sections:
|
||||
|
||||
@@ -10,7 +10,7 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Intel Gaudi에서 Stable Diffusion을 사용하는 방법
|
||||
# Habana Gaudi에서 Stable Diffusion을 사용하는 방법
|
||||
|
||||
🤗 Diffusers는 🤗 [Optimum Habana](https://huggingface.co/docs/optimum/habana/usage_guides/stable_diffusion)를 통해서 Habana Gaudi와 호환됩니다.
|
||||
|
||||
|
||||
@@ -133,9 +133,11 @@ else:
|
||||
_import_structure["hooks"].extend(
|
||||
[
|
||||
"FasterCacheConfig",
|
||||
"FirstBlockCacheConfig",
|
||||
"HookRegistry",
|
||||
"PyramidAttentionBroadcastConfig",
|
||||
"apply_faster_cache",
|
||||
"apply_first_block_cache",
|
||||
"apply_pyramid_attention_broadcast",
|
||||
]
|
||||
)
|
||||
@@ -740,9 +742,11 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
else:
|
||||
from .hooks import (
|
||||
FasterCacheConfig,
|
||||
FirstBlockCacheConfig,
|
||||
HookRegistry,
|
||||
PyramidAttentionBroadcastConfig,
|
||||
apply_faster_cache,
|
||||
apply_first_block_cache,
|
||||
apply_pyramid_attention_broadcast,
|
||||
)
|
||||
from .models import (
|
||||
|
||||
@@ -1,8 +1,23 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ..utils import is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from .faster_cache import FasterCacheConfig, apply_faster_cache
|
||||
from .first_block_cache import FirstBlockCacheConfig, apply_first_block_cache
|
||||
from .group_offloading import apply_group_offloading
|
||||
from .hooks import HookRegistry, ModelHook
|
||||
from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook
|
||||
|
||||
@@ -0,0 +1,30 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ..models.attention_processor import Attention, MochiAttention
|
||||
|
||||
|
||||
_ATTENTION_CLASSES = (Attention, MochiAttention)
|
||||
|
||||
_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks", "layers")
|
||||
_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
|
||||
_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "layers")
|
||||
|
||||
_ALL_TRANSFORMER_BLOCK_IDENTIFIERS = tuple(
|
||||
{
|
||||
*_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS,
|
||||
*_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS,
|
||||
*_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS,
|
||||
}
|
||||
)
|
||||
@@ -0,0 +1,221 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from ..utils import get_logger
|
||||
from ..utils.torch_utils import unwrap_module
|
||||
from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS
|
||||
from .hooks import BaseState, HookRegistry, ModelHook, StateManager
|
||||
|
||||
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
_FBC_LEADER_BLOCK_HOOK = "fbc_leader_block_hook"
|
||||
_FBC_BLOCK_HOOK = "fbc_block_hook"
|
||||
|
||||
|
||||
@dataclass
|
||||
class FirstBlockCacheConfig:
|
||||
r"""
|
||||
Configuration for [First Block
|
||||
Cache](https://github.com/chengzeyi/ParaAttention/blob/7a266123671b55e7e5a2fe9af3121f07a36afc78/README.md#first-block-cache-our-dynamic-caching).
|
||||
|
||||
Args:
|
||||
threshold (`float`, defaults to `0.05`):
|
||||
The threshold to determine whether or not a forward pass through all layers of the model is required. A
|
||||
higher threshold usually results in a forward pass through a lower number of layers and faster inference,
|
||||
but might lead to poorer generation quality. A lower threshold may not result in significant generation
|
||||
speedup. The threshold is compared against the absmean difference of the residuals between the current and
|
||||
cached outputs from the first transformer block. If the difference is below the threshold, the forward pass
|
||||
is skipped.
|
||||
"""
|
||||
|
||||
threshold: float = 0.05
|
||||
|
||||
|
||||
class FBCSharedBlockState(BaseState):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.head_block_output_hidden_states: torch.Tensor = None
|
||||
self.head_block_output_encoder_hidden_states: torch.Tensor = None
|
||||
self.head_block_residual_hidden_states: torch.Tensor = None
|
||||
self.tail_block_residual_hidden_states: torch.Tensor = None
|
||||
self.tail_block_residual_encoder_hidden_states: torch.Tensor = None
|
||||
self.should_compute: bool = True
|
||||
|
||||
def reset(self):
|
||||
self.head_block_output_hidden_states = None
|
||||
self.head_block_output_encoder_hidden_states = None
|
||||
self.head_block_residual_hidden_states = None
|
||||
self.tail_block_residual_hidden_states = None
|
||||
self.tail_block_residual_encoder_hidden_states = None
|
||||
self.should_compute = True
|
||||
|
||||
|
||||
class FBCHeadBlockHook(ModelHook):
|
||||
_is_stateful = True
|
||||
|
||||
def __init__(self, state_manager: StateManager, threshold: float):
|
||||
self.state_manager = state_manager
|
||||
self.threshold = threshold
|
||||
self._metadata = None
|
||||
|
||||
def initialize_hook(self, module):
|
||||
unwrapped_module = unwrap_module(module)
|
||||
if not hasattr(unwrapped_module, "_diffusers_transformer_block_metadata"):
|
||||
raise ValueError(
|
||||
f"Module {unwrapped_module} does not have any registered metadata. "
|
||||
"Make sure to register the metadata using `diffusers.models.metadata.register_transformer_block`."
|
||||
)
|
||||
self._metadata = unwrapped_module._diffusers_transformer_block_metadata
|
||||
return module
|
||||
|
||||
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
|
||||
original_hidden_states = self._metadata._get_parameter_from_args_kwargs("hidden_states", args, kwargs)
|
||||
|
||||
output = self.fn_ref.original_forward(*args, **kwargs)
|
||||
hidden_states_residual = output.hidden_states - original_hidden_states
|
||||
|
||||
shared_state: FBCSharedBlockState = self.state_manager.get_state()
|
||||
hidden_states = encoder_hidden_states = None
|
||||
should_compute = self._should_compute_remaining_blocks(hidden_states_residual)
|
||||
shared_state.should_compute = should_compute
|
||||
|
||||
if not should_compute:
|
||||
# Apply caching
|
||||
return_output = output.__class__()
|
||||
hidden_states = shared_state.tail_block_residual_hidden_states + output.hidden_states
|
||||
return_output = return_output._replace(hidden_states=hidden_states)
|
||||
if hasattr(output, "encoder_hidden_states"):
|
||||
encoder_hidden_states = (
|
||||
shared_state.tail_block_residual_encoder_hidden_states + output.encoder_hidden_states
|
||||
)
|
||||
return_output = return_output._replace(encoder_hidden_states=encoder_hidden_states)
|
||||
else:
|
||||
return_output = output
|
||||
shared_state.head_block_output_hidden_states = output.hidden_states
|
||||
if hasattr(output, "encoder_hidden_states"):
|
||||
shared_state.head_block_output_encoder_hidden_states = output.encoder_hidden_states
|
||||
shared_state.head_block_residual_hidden_states = hidden_states_residual
|
||||
|
||||
return return_output
|
||||
|
||||
def reset_state(self, module):
|
||||
self.state_manager.reset()
|
||||
return module
|
||||
|
||||
@torch.compiler.disable
|
||||
def _should_compute_remaining_blocks(self, hidden_states_residual: torch.Tensor) -> bool:
|
||||
shared_state = self.state_manager.get_state()
|
||||
if shared_state.head_block_residual_hidden_states is None:
|
||||
return True
|
||||
prev_hidden_states_residual = shared_state.head_block_residual_hidden_states
|
||||
absmean = (hidden_states_residual - prev_hidden_states_residual).abs().mean()
|
||||
prev_hidden_states_absmean = prev_hidden_states_residual.abs().mean()
|
||||
diff = (absmean / prev_hidden_states_absmean).item()
|
||||
return diff > self.threshold
|
||||
|
||||
|
||||
class FBCBlockHook(ModelHook):
|
||||
def __init__(self, state_manager: StateManager, is_tail: bool = False):
|
||||
super().__init__()
|
||||
self.state_manager = state_manager
|
||||
self.is_tail = is_tail
|
||||
self._metadata = None
|
||||
self._output_cls = None
|
||||
|
||||
def initialize_hook(self, module):
|
||||
unwrapped_module = unwrap_module(module)
|
||||
if not hasattr(unwrapped_module, "_diffusers_transformer_block_metadata"):
|
||||
raise ValueError(
|
||||
f"Module {unwrapped_module} does not have any registered metadata. "
|
||||
"Make sure to register the metadata using `diffusers.models.metadata.register_transformer_block`."
|
||||
)
|
||||
self._metadata = unwrapped_module._diffusers_transformer_block_metadata
|
||||
return module
|
||||
|
||||
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
|
||||
original_hidden_states = self._metadata._get_parameter_from_args_kwargs("hidden_states", args, kwargs)
|
||||
|
||||
original_encoder_hidden_states = None
|
||||
try:
|
||||
original_encoder_hidden_states = self._metadata._get_parameter_from_args_kwargs(
|
||||
"encoder_hidden_states", args, kwargs
|
||||
)
|
||||
except ValueError:
|
||||
# This is expected for models that don't have use encoder_hidden_states in their forward definition
|
||||
pass
|
||||
|
||||
shared_state = self.state_manager.get_state()
|
||||
|
||||
if shared_state.should_compute:
|
||||
output = self.fn_ref.original_forward(*args, **kwargs)
|
||||
if self._output_cls is None:
|
||||
self._output_cls = output.__class__
|
||||
if self.is_tail:
|
||||
hidden_states_residual = output.hidden_states - shared_state.head_block_output_hidden_states
|
||||
if hasattr(output, "encoder_hidden_states"):
|
||||
encoder_hidden_states_residual = (
|
||||
output.encoder_hidden_states - shared_state.head_block_output_encoder_hidden_states
|
||||
)
|
||||
shared_state.tail_block_residual_hidden_states = hidden_states_residual
|
||||
shared_state.tail_block_residual_encoder_hidden_states = encoder_hidden_states_residual
|
||||
return output
|
||||
|
||||
assert self._output_cls is not None
|
||||
return_output = self._output_cls()
|
||||
return_output = return_output._replace(hidden_states=original_hidden_states)
|
||||
if hasattr(return_output, "encoder_hidden_states"):
|
||||
return_output = return_output._replace(encoder_hidden_states=original_encoder_hidden_states)
|
||||
return return_output
|
||||
|
||||
|
||||
def apply_first_block_cache(module: torch.nn.Module, config: FirstBlockCacheConfig) -> None:
|
||||
state_manager = StateManager(FBCSharedBlockState, (), {})
|
||||
remaining_blocks = []
|
||||
|
||||
for name, submodule in module.named_children():
|
||||
if name not in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS or not isinstance(submodule, torch.nn.ModuleList):
|
||||
continue
|
||||
for index, block in enumerate(submodule):
|
||||
remaining_blocks.append((f"{name}.{index}", block))
|
||||
|
||||
head_block_name, head_block = remaining_blocks.pop(0)
|
||||
tail_block_name, tail_block = remaining_blocks.pop(-1)
|
||||
|
||||
logger.debug(f"Applying FBCHeadBlockHook to '{head_block_name}'")
|
||||
_apply_fbc_head_block_hook(head_block, state_manager, config.threshold)
|
||||
|
||||
for name, block in remaining_blocks:
|
||||
logger.debug(f"Applying FBCBlockHook to '{name}'")
|
||||
_apply_fbc_block_hook(block, state_manager)
|
||||
|
||||
logger.debug(f"Applying FBCBlockHook to tail block '{tail_block_name}'")
|
||||
_apply_fbc_block_hook(tail_block, state_manager, is_tail=True)
|
||||
|
||||
|
||||
def _apply_fbc_head_block_hook(block: torch.nn.Module, state_manager: StateManager, threshold: float) -> None:
|
||||
registry = HookRegistry.check_if_exists_or_initialize(block)
|
||||
hook = FBCHeadBlockHook(state_manager, threshold)
|
||||
registry.register_hook(hook, _FBC_LEADER_BLOCK_HOOK)
|
||||
|
||||
|
||||
def _apply_fbc_block_hook(block: torch.nn.Module, state_manager: StateManager, is_tail: bool = False) -> None:
|
||||
registry = HookRegistry.check_if_exists_or_initialize(block)
|
||||
hook = FBCBlockHook(state_manager, is_tail)
|
||||
registry.register_hook(hook, _FBC_BLOCK_HOOK)
|
||||
@@ -18,11 +18,44 @@ from typing import Any, Dict, Optional, Tuple
|
||||
import torch
|
||||
|
||||
from ..utils.logging import get_logger
|
||||
from ..utils.torch_utils import unwrap_module
|
||||
|
||||
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class BaseState:
|
||||
def reset(self, *args, **kwargs) -> None:
|
||||
raise NotImplementedError(
|
||||
"BaseState::reset is not implemented. Please implement this method in the derived class."
|
||||
)
|
||||
|
||||
|
||||
class StateManager:
|
||||
def __init__(self, state_cls: BaseState, init_args=None, init_kwargs=None):
|
||||
self._state_cls = state_cls
|
||||
self._init_args = init_args if init_args is not None else ()
|
||||
self._init_kwargs = init_kwargs if init_kwargs is not None else {}
|
||||
self._state_cache = {}
|
||||
self._current_context = None
|
||||
|
||||
def get_state(self):
|
||||
if self._current_context is None:
|
||||
raise ValueError("No context is set. Please set a context before retrieving the state.")
|
||||
if self._current_context not in self._state_cache.keys():
|
||||
self._state_cache[self._current_context] = self._state_cls(*self._init_args, **self._init_kwargs)
|
||||
return self._state_cache[self._current_context]
|
||||
|
||||
def set_context(self, name: str) -> None:
|
||||
self._current_context = name
|
||||
|
||||
def reset(self, *args, **kwargs) -> None:
|
||||
for name, state in list(self._state_cache.items()):
|
||||
state.reset(*args, **kwargs)
|
||||
self._state_cache.pop(name)
|
||||
self._current_context = None
|
||||
|
||||
|
||||
class ModelHook:
|
||||
r"""
|
||||
A hook that contains callbacks to be executed just before and after the forward method of a model.
|
||||
@@ -99,6 +132,14 @@ class ModelHook:
|
||||
raise NotImplementedError("This hook is stateful and needs to implement the `reset_state` method.")
|
||||
return module
|
||||
|
||||
def _set_context(self, module: torch.nn.Module, name: str) -> None:
|
||||
# Iterate over all attributes of the hook to see if any of them have the type `StateManager`. If so, call `set_context` on them.
|
||||
for attr_name in dir(self):
|
||||
attr = getattr(self, attr_name)
|
||||
if isinstance(attr, StateManager):
|
||||
attr.set_context(name)
|
||||
return module
|
||||
|
||||
|
||||
class HookFunctionReference:
|
||||
def __init__(self) -> None:
|
||||
@@ -211,9 +252,10 @@ class HookRegistry:
|
||||
hook.reset_state(self._module_ref)
|
||||
|
||||
if recurse:
|
||||
for module_name, module in self._module_ref.named_modules():
|
||||
for module_name, module in unwrap_module(self._module_ref).named_modules():
|
||||
if module_name == "":
|
||||
continue
|
||||
module = unwrap_module(module)
|
||||
if hasattr(module, "_diffusers_hook"):
|
||||
module._diffusers_hook.reset_stateful_hooks(recurse=False)
|
||||
|
||||
@@ -223,6 +265,19 @@ class HookRegistry:
|
||||
module._diffusers_hook = cls(module)
|
||||
return module._diffusers_hook
|
||||
|
||||
def _set_context(self, name: Optional[str] = None) -> None:
|
||||
for hook_name in reversed(self._hook_order):
|
||||
hook = self.hooks[hook_name]
|
||||
if hook._is_stateful:
|
||||
hook._set_context(self._module_ref, name)
|
||||
|
||||
for module_name, module in unwrap_module(self._module_ref).named_modules():
|
||||
if module_name == "":
|
||||
continue
|
||||
module = unwrap_module(module)
|
||||
if hasattr(module, "_diffusers_hook"):
|
||||
module._diffusers_hook._set_context(name)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
registry_repr = ""
|
||||
for i, hook_name in enumerate(self._hook_order):
|
||||
|
||||
@@ -2545,13 +2545,14 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
if unexpected_modules:
|
||||
logger.debug(f"Found unexpected modules: {unexpected_modules}. These will be ignored.")
|
||||
|
||||
is_peft_loaded = getattr(transformer, "peft_config", None) is not None
|
||||
for k in lora_module_names:
|
||||
if k in unexpected_modules:
|
||||
continue
|
||||
|
||||
base_param_name = (
|
||||
f"{k.replace(prefix, '')}.base_layer.weight"
|
||||
if f"{k.replace(prefix, '')}.base_layer.weight" in transformer_state_dict
|
||||
if is_peft_loaded and f"{k.replace(prefix, '')}.base_layer.weight" in transformer_state_dict
|
||||
else f"{k.replace(prefix, '')}.weight"
|
||||
)
|
||||
base_weight_param = transformer_state_dict[base_param_name]
|
||||
|
||||
@@ -22,6 +22,7 @@ from ..utils.torch_utils import maybe_allow_in_graph
|
||||
from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, LinearActivation, SwiGLU
|
||||
from .attention_processor import Attention, JointAttnProcessor2_0
|
||||
from .embeddings import SinusoidalPositionalEmbedding
|
||||
from .metadata import TransformerBlockMetadata, register_transformer_block
|
||||
from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm, SD35AdaLayerNormZeroX
|
||||
|
||||
|
||||
@@ -258,6 +259,12 @@ class JointTransformerBlock(nn.Module):
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
@register_transformer_block(
|
||||
metadata=TransformerBlockMetadata(
|
||||
return_hidden_states_index=0,
|
||||
return_encoder_hidden_states_index=None,
|
||||
)
|
||||
)
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
r"""
|
||||
A basic Transformer block.
|
||||
|
||||
@@ -12,16 +12,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import importlib
|
||||
import os
|
||||
from typing import Optional, Union
|
||||
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from ..utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class AutoModel(ConfigMixin):
|
||||
@@ -155,50 +152,15 @@ class AutoModel(ConfigMixin):
|
||||
"token": token,
|
||||
"local_files_only": local_files_only,
|
||||
"revision": revision,
|
||||
"subfolder": subfolder,
|
||||
}
|
||||
|
||||
library = None
|
||||
orig_class_name = None
|
||||
config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
|
||||
orig_class_name = config["_class_name"]
|
||||
|
||||
# Always attempt to fetch model_index.json first
|
||||
try:
|
||||
cls.config_name = "model_index.json"
|
||||
config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
|
||||
|
||||
if subfolder is not None and subfolder in config:
|
||||
library, orig_class_name = config[subfolder]
|
||||
load_config_kwargs.update({"subfolder": subfolder})
|
||||
|
||||
except EnvironmentError as e:
|
||||
logger.debug(e)
|
||||
|
||||
# Unable to load from model_index.json so fallback to loading from config
|
||||
if library is None and orig_class_name is None:
|
||||
cls.config_name = "config.json"
|
||||
config = cls.load_config(pretrained_model_or_path, subfolder=subfolder, **load_config_kwargs)
|
||||
|
||||
if "_class_name" in config:
|
||||
# If we find a class name in the config, we can try to load the model as a diffusers model
|
||||
orig_class_name = config["_class_name"]
|
||||
library = "diffusers"
|
||||
load_config_kwargs.update({"subfolder": subfolder})
|
||||
elif "model_type" in config:
|
||||
orig_class_name = "AutoModel"
|
||||
library = "transformers"
|
||||
load_config_kwargs.update({"subfolder": "" if subfolder is None else subfolder})
|
||||
else:
|
||||
raise ValueError(f"Couldn't find model associated with the config file at {pretrained_model_or_path}.")
|
||||
|
||||
from ..pipelines.pipeline_loading_utils import ALL_IMPORTABLE_CLASSES, get_class_obj_and_candidates
|
||||
|
||||
model_cls, _ = get_class_obj_and_candidates(
|
||||
library_name=library,
|
||||
class_name=orig_class_name,
|
||||
importable_classes=ALL_IMPORTABLE_CLASSES,
|
||||
pipelines=None,
|
||||
is_pipeline_module=False,
|
||||
)
|
||||
library = importlib.import_module("diffusers")
|
||||
|
||||
model_cls = getattr(library, orig_class_name, None)
|
||||
if model_cls is None:
|
||||
raise ValueError(f"AutoModel can't find a model linked to {orig_class_name}.")
|
||||
|
||||
|
||||
@@ -12,6 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from contextlib import contextmanager
|
||||
|
||||
from ..utils.logging import get_logger
|
||||
|
||||
|
||||
@@ -25,6 +27,7 @@ class CacheMixin:
|
||||
Supported caching techniques:
|
||||
- [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588)
|
||||
- [FasterCache](https://huggingface.co/papers/2410.19355)
|
||||
- [FirstBlockCache](https://github.com/chengzeyi/ParaAttention/blob/7a266123671b55e7e5a2fe9af3121f07a36afc78/README.md#first-block-cache-our-dynamic-caching)
|
||||
"""
|
||||
|
||||
_cache_config = None
|
||||
@@ -62,8 +65,10 @@ class CacheMixin:
|
||||
|
||||
from ..hooks import (
|
||||
FasterCacheConfig,
|
||||
FirstBlockCacheConfig,
|
||||
PyramidAttentionBroadcastConfig,
|
||||
apply_faster_cache,
|
||||
apply_first_block_cache,
|
||||
apply_pyramid_attention_broadcast,
|
||||
)
|
||||
|
||||
@@ -72,31 +77,36 @@ class CacheMixin:
|
||||
f"Caching has already been enabled with {type(self._cache_config)}. To apply a new caching technique, please disable the existing one first."
|
||||
)
|
||||
|
||||
if isinstance(config, PyramidAttentionBroadcastConfig):
|
||||
apply_pyramid_attention_broadcast(self, config)
|
||||
elif isinstance(config, FasterCacheConfig):
|
||||
if isinstance(config, FasterCacheConfig):
|
||||
apply_faster_cache(self, config)
|
||||
elif isinstance(config, FirstBlockCacheConfig):
|
||||
apply_first_block_cache(self, config)
|
||||
elif isinstance(config, PyramidAttentionBroadcastConfig):
|
||||
apply_pyramid_attention_broadcast(self, config)
|
||||
else:
|
||||
raise ValueError(f"Cache config {type(config)} is not supported.")
|
||||
|
||||
self._cache_config = config
|
||||
|
||||
def disable_cache(self) -> None:
|
||||
from ..hooks import FasterCacheConfig, HookRegistry, PyramidAttentionBroadcastConfig
|
||||
from ..hooks import FasterCacheConfig, FirstBlockCacheConfig, HookRegistry, PyramidAttentionBroadcastConfig
|
||||
from ..hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK
|
||||
from ..hooks.first_block_cache import _FBC_BLOCK_HOOK, _FBC_LEADER_BLOCK_HOOK
|
||||
from ..hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK
|
||||
|
||||
if self._cache_config is None:
|
||||
logger.warning("Caching techniques have not been enabled, so there's nothing to disable.")
|
||||
return
|
||||
|
||||
if isinstance(self._cache_config, PyramidAttentionBroadcastConfig):
|
||||
registry = HookRegistry.check_if_exists_or_initialize(self)
|
||||
registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True)
|
||||
elif isinstance(self._cache_config, FasterCacheConfig):
|
||||
registry = HookRegistry.check_if_exists_or_initialize(self)
|
||||
registry = HookRegistry.check_if_exists_or_initialize(self)
|
||||
if isinstance(self._cache_config, FasterCacheConfig):
|
||||
registry.remove_hook(_FASTER_CACHE_DENOISER_HOOK, recurse=True)
|
||||
registry.remove_hook(_FASTER_CACHE_BLOCK_HOOK, recurse=True)
|
||||
elif isinstance(self._cache_config, FirstBlockCacheConfig):
|
||||
registry.remove_hook(_FBC_LEADER_BLOCK_HOOK, recurse=True)
|
||||
registry.remove_hook(_FBC_BLOCK_HOOK, recurse=True)
|
||||
elif isinstance(self._cache_config, PyramidAttentionBroadcastConfig):
|
||||
registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True)
|
||||
else:
|
||||
raise ValueError(f"Cache config {type(self._cache_config)} is not supported.")
|
||||
|
||||
@@ -106,3 +116,15 @@ class CacheMixin:
|
||||
from ..hooks import HookRegistry
|
||||
|
||||
HookRegistry.check_if_exists_or_initialize(self).reset_stateful_hooks(recurse=recurse)
|
||||
|
||||
@contextmanager
|
||||
def cache_context(self, name: str):
|
||||
r"""Context manager that provides additional methods for cache management."""
|
||||
from ..hooks import HookRegistry
|
||||
|
||||
registry = HookRegistry.check_if_exists_or_initialize(self)
|
||||
registry._set_context(name)
|
||||
|
||||
yield
|
||||
|
||||
registry._set_context(None)
|
||||
|
||||
@@ -343,25 +343,25 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
)
|
||||
block_samples = block_samples + (hidden_states,)
|
||||
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
|
||||
single_block_samples = ()
|
||||
for index_block, block in enumerate(self.single_transformer_blocks):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
image_rotary_emb,
|
||||
)
|
||||
|
||||
else:
|
||||
hidden_states = block(
|
||||
encoder_hidden_states, hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
single_block_samples = single_block_samples + (hidden_states[:, encoder_hidden_states.shape[1] :],)
|
||||
single_block_samples = single_block_samples + (hidden_states,)
|
||||
|
||||
# controlnet block
|
||||
controlnet_block_samples = ()
|
||||
|
||||
@@ -0,0 +1,53 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Type
|
||||
|
||||
|
||||
@dataclass
|
||||
class TransformerBlockMetadata:
|
||||
return_hidden_states_index: int = None
|
||||
return_encoder_hidden_states_index: int = None
|
||||
|
||||
_cls: Type = None
|
||||
_cached_parameter_indices: Dict[str, int] = None
|
||||
|
||||
def _get_parameter_from_args_kwargs(self, identifier: str, args=(), kwargs=None):
|
||||
kwargs = kwargs or {}
|
||||
if identifier in kwargs:
|
||||
return kwargs[identifier]
|
||||
if self._cached_parameter_indices is not None:
|
||||
return args[self._cached_parameter_indices[identifier]]
|
||||
if self._cls is None:
|
||||
raise ValueError("Model class is not set for metadata.")
|
||||
parameters = list(inspect.signature(self._cls.forward).parameters.keys())
|
||||
parameters = parameters[1:] # skip `self`
|
||||
self._cached_parameter_indices = {param: i for i, param in enumerate(parameters)}
|
||||
if identifier not in self._cached_parameter_indices:
|
||||
raise ValueError(f"Parameter '{identifier}' not found in function signature but was requested.")
|
||||
index = self._cached_parameter_indices[identifier]
|
||||
if index >= len(args):
|
||||
raise ValueError(f"Expected {index} arguments but got {len(args)}.")
|
||||
return args[index]
|
||||
|
||||
|
||||
def register_transformer_block(metadata: TransformerBlockMetadata):
|
||||
def inner(model_class: Type):
|
||||
metadata._cls = model_class
|
||||
model_class._diffusers_transformer_block_metadata = metadata
|
||||
return model_class
|
||||
|
||||
return inner
|
||||
@@ -13,7 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
from typing import Any, Dict, NamedTuple, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -34,6 +34,11 @@ from ..normalization import AdaLayerNorm, CogVideoXLayerNormZero
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class CogVideoXBlockOutput(NamedTuple):
|
||||
hidden_states: torch.Tensor = None
|
||||
encoder_hidden_states: torch.Tensor = None
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class CogVideoXBlock(nn.Module):
|
||||
r"""
|
||||
@@ -122,7 +127,7 @@ class CogVideoXBlock(nn.Module):
|
||||
temb: torch.Tensor,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> CogVideoXBlockOutput:
|
||||
text_seq_length = encoder_hidden_states.size(1)
|
||||
attention_kwargs = attention_kwargs or {}
|
||||
|
||||
@@ -154,7 +159,7 @@ class CogVideoXBlock(nn.Module):
|
||||
hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:]
|
||||
encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length]
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
return CogVideoXBlockOutput(hidden_states, encoder_hidden_states)
|
||||
|
||||
|
||||
class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, CacheMixin):
|
||||
|
||||
@@ -21,10 +21,12 @@ import torch.nn.functional as F
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import FeedForward
|
||||
from ..attention_processor import Attention
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import CogView3CombinedTimestepSizeEmbeddings
|
||||
from ..metadata import TransformerBlockMetadata, register_transformer_block
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNormContinuous
|
||||
@@ -453,6 +455,13 @@ class CogView4TrainingAttnProcessor:
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
@register_transformer_block(
|
||||
metadata=TransformerBlockMetadata(
|
||||
return_hidden_states_index=0,
|
||||
return_encoder_hidden_states_index=1,
|
||||
)
|
||||
)
|
||||
class CogView4TransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -34,6 +34,7 @@ from ..attention_processor import (
|
||||
)
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
|
||||
from ..metadata import TransformerBlockMetadata, register_transformer_block
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
|
||||
@@ -43,6 +44,12 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
@register_transformer_block(
|
||||
metadata=TransformerBlockMetadata(
|
||||
return_hidden_states_index=1,
|
||||
return_encoder_hidden_states_index=0,
|
||||
)
|
||||
)
|
||||
class FluxSingleTransformerBlock(nn.Module):
|
||||
def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0):
|
||||
super().__init__()
|
||||
@@ -79,10 +86,14 @@ class FluxSingleTransformerBlock(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> torch.Tensor:
|
||||
text_seq_len = encoder_hidden_states.shape[1]
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
|
||||
residual = hidden_states
|
||||
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
|
||||
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
|
||||
@@ -100,10 +111,17 @@ class FluxSingleTransformerBlock(nn.Module):
|
||||
if hidden_states.dtype == torch.float16:
|
||||
hidden_states = hidden_states.clip(-65504, 65504)
|
||||
|
||||
return hidden_states
|
||||
encoder_hidden_states, hidden_states = hidden_states[:, :text_seq_len], hidden_states[:, text_seq_len:]
|
||||
return encoder_hidden_states, hidden_states
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
@register_transformer_block(
|
||||
metadata=TransformerBlockMetadata(
|
||||
return_hidden_states_index=1,
|
||||
return_encoder_hidden_states_index=0,
|
||||
)
|
||||
)
|
||||
class FluxTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6
|
||||
@@ -508,20 +526,21 @@ class FluxTransformer2DModel(
|
||||
)
|
||||
else:
|
||||
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
|
||||
for index_block, block in enumerate(self.single_transformer_blocks):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
image_rotary_emb,
|
||||
)
|
||||
|
||||
else:
|
||||
hidden_states = block(
|
||||
encoder_hidden_states, hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
joint_attention_kwargs=joint_attention_kwargs,
|
||||
@@ -531,12 +550,7 @@ class FluxTransformer2DModel(
|
||||
if controlnet_single_block_samples is not None:
|
||||
interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
|
||||
interval_control = int(np.ceil(interval_control))
|
||||
hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
|
||||
hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|
||||
+ controlnet_single_block_samples[index_block // interval_control]
|
||||
)
|
||||
|
||||
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|
||||
hidden_states = hidden_states + controlnet_single_block_samples[index_block // interval_control]
|
||||
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
output = self.proj_out(hidden_states)
|
||||
|
||||
@@ -33,6 +33,7 @@ from ..embeddings import (
|
||||
Timesteps,
|
||||
get_1d_rotary_pos_embed,
|
||||
)
|
||||
from ..metadata import TransformerBlockMetadata, register_transformer_block
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle, FP32LayerNorm
|
||||
@@ -310,6 +311,12 @@ class HunyuanVideoConditionEmbedding(nn.Module):
|
||||
return conditioning, token_replace_emb
|
||||
|
||||
|
||||
@register_transformer_block(
|
||||
metadata=TransformerBlockMetadata(
|
||||
return_hidden_states_index=0,
|
||||
return_encoder_hidden_states_index=None,
|
||||
)
|
||||
)
|
||||
class HunyuanVideoIndividualTokenRefinerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -489,6 +496,12 @@ class HunyuanVideoRotaryPosEmbed(nn.Module):
|
||||
return freqs_cos, freqs_sin
|
||||
|
||||
|
||||
@register_transformer_block(
|
||||
metadata=TransformerBlockMetadata(
|
||||
return_hidden_states_index=0,
|
||||
return_encoder_hidden_states_index=1,
|
||||
)
|
||||
)
|
||||
class HunyuanVideoSingleTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -565,6 +578,12 @@ class HunyuanVideoSingleTransformerBlock(nn.Module):
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
@register_transformer_block(
|
||||
metadata=TransformerBlockMetadata(
|
||||
return_hidden_states_index=0,
|
||||
return_encoder_hidden_states_index=1,
|
||||
)
|
||||
)
|
||||
class HunyuanVideoTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -644,6 +663,12 @@ class HunyuanVideoTransformerBlock(nn.Module):
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
@register_transformer_block(
|
||||
metadata=TransformerBlockMetadata(
|
||||
return_hidden_states_index=0,
|
||||
return_encoder_hidden_states_index=1,
|
||||
)
|
||||
)
|
||||
class HunyuanVideoTokenReplaceSingleTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -724,6 +749,12 @@ class HunyuanVideoTokenReplaceSingleTransformerBlock(nn.Module):
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
@register_transformer_block(
|
||||
metadata=TransformerBlockMetadata(
|
||||
return_hidden_states_index=0,
|
||||
return_encoder_hidden_states_index=1,
|
||||
)
|
||||
)
|
||||
class HunyuanVideoTokenReplaceTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -28,6 +28,7 @@ from ..attention import FeedForward
|
||||
from ..attention_processor import Attention
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import PixArtAlphaTextProjection
|
||||
from ..metadata import TransformerBlockMetadata, register_transformer_block
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNormSingle, RMSNorm
|
||||
@@ -196,6 +197,12 @@ class LTXVideoRotaryPosEmbed(nn.Module):
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
@register_transformer_block(
|
||||
metadata=TransformerBlockMetadata(
|
||||
return_hidden_states_index=0,
|
||||
return_encoder_hidden_states_index=None,
|
||||
)
|
||||
)
|
||||
class LTXVideoTransformerBlock(nn.Module):
|
||||
r"""
|
||||
Transformer block used in [LTX](https://huggingface.co/Lightricks/LTX-Video).
|
||||
|
||||
@@ -27,6 +27,7 @@ from ..attention import FeedForward
|
||||
from ..attention_processor import MochiAttention, MochiAttnProcessor2_0
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import MochiCombinedTimestepCaptionEmbedding, PatchEmbed
|
||||
from ..metadata import TransformerBlockMetadata, register_transformer_block
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNormContinuous, RMSNorm
|
||||
@@ -116,6 +117,12 @@ class MochiRMSNormZero(nn.Module):
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
@register_transformer_block(
|
||||
metadata=TransformerBlockMetadata(
|
||||
return_hidden_states_index=0,
|
||||
return_encoder_hidden_states_index=1,
|
||||
)
|
||||
)
|
||||
class MochiTransformerBlock(nn.Module):
|
||||
r"""
|
||||
Transformer block used in [Mochi](https://huggingface.co/genmo/mochi-1-preview).
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
from typing import Any, Dict, NamedTuple, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -22,10 +22,12 @@ import torch.nn.functional as F
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import FeedForward
|
||||
from ..attention_processor import Attention
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed
|
||||
from ..metadata import TransformerBlockMetadata, register_transformer_block
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import FP32LayerNorm
|
||||
@@ -34,6 +36,11 @@ from ..normalization import FP32LayerNorm
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class WanTransformerBlockOutput(NamedTuple):
|
||||
hidden_states: torch.Tensor = None
|
||||
encoder_hidden_states: torch.Tensor = None
|
||||
|
||||
|
||||
class WanAttnProcessor2_0:
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
@@ -219,6 +226,8 @@ class WanRotaryPosEmbed(nn.Module):
|
||||
return freqs
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
@register_transformer_block(TransformerBlockMetadata())
|
||||
class WanTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -276,7 +285,7 @@ class WanTransformerBlock(nn.Module):
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
rotary_emb: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
) -> WanTransformerBlockOutput:
|
||||
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
|
||||
self.scale_shift_table + temb.float()
|
||||
).chunk(6, dim=1)
|
||||
@@ -298,7 +307,7 @@ class WanTransformerBlock(nn.Module):
|
||||
ff_output = self.ffn(norm_hidden_states)
|
||||
hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
return WanTransformerBlockOutput(hidden_states, encoder_hidden_states)
|
||||
|
||||
|
||||
class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
|
||||
@@ -447,12 +456,14 @@ class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
|
||||
# 4. Transformer blocks
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
for block in self.blocks:
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
|
||||
block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb
|
||||
)
|
||||
else:
|
||||
for block in self.blocks:
|
||||
hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
|
||||
hidden_states, encoder_hidden_states = block(
|
||||
hidden_states, encoder_hidden_states, timestep_proj, rotary_emb
|
||||
)
|
||||
|
||||
# 5. Output norm, projection & unpatchify
|
||||
shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
|
||||
|
||||
@@ -718,14 +718,15 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
||||
timestep = t.expand(latent_model_input.shape[0])
|
||||
|
||||
# predict noise model_output
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep=timestep,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
with self.transformer.cache_context("cond_uncond"):
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep=timestep,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = noise_pred.float()
|
||||
|
||||
# perform guidance
|
||||
|
||||
@@ -784,14 +784,15 @@ class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
||||
timestep = t.expand(latent_model_input.shape[0])
|
||||
|
||||
# predict noise model_output
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep=timestep,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
with self.transformer.cache_context("cond_uncond"):
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep=timestep,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = noise_pred.float()
|
||||
|
||||
# perform guidance
|
||||
|
||||
@@ -831,15 +831,16 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
|
||||
timestep = t.expand(latent_model_input.shape[0])
|
||||
|
||||
# predict noise model_output
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep=timestep,
|
||||
ofs=ofs_emb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
with self.transformer.cache_context("cond_uncond"):
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep=timestep,
|
||||
ofs=ofs_emb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = noise_pred.float()
|
||||
|
||||
# perform guidance
|
||||
|
||||
@@ -799,14 +799,15 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
|
||||
timestep = t.expand(latent_model_input.shape[0])
|
||||
|
||||
# predict noise model_output
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep=timestep,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
with self.transformer.cache_context("cond_uncond"):
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep=timestep,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = noise_pred.float()
|
||||
|
||||
# perform guidance
|
||||
|
||||
@@ -619,22 +619,10 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latents.shape[0])
|
||||
|
||||
noise_pred_cond = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep=timestep,
|
||||
original_size=original_size,
|
||||
target_size=target_size,
|
||||
crop_coords=crops_coords_top_left,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_pred_uncond = self.transformer(
|
||||
with self.transformer.cache_context("cond"):
|
||||
noise_pred_cond = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep=timestep,
|
||||
original_size=original_size,
|
||||
target_size=target_size,
|
||||
@@ -643,6 +631,19 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
if self.do_classifier_free_guidance:
|
||||
with self.transformer.cache_context("uncond"):
|
||||
noise_pred_uncond = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
timestep=timestep,
|
||||
original_size=original_size,
|
||||
target_size=target_size,
|
||||
crop_coords=crops_coords_top_left,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
||||
else:
|
||||
noise_pred = noise_pred_cond
|
||||
|
||||
@@ -909,32 +909,35 @@ class FluxPipeline(
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
||||
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latents,
|
||||
timestep=timestep / 1000,
|
||||
guidance=guidance,
|
||||
pooled_projections=pooled_prompt_embeds,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
txt_ids=text_ids,
|
||||
img_ids=latent_image_ids,
|
||||
joint_attention_kwargs=self.joint_attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
if do_true_cfg:
|
||||
if negative_image_embeds is not None:
|
||||
self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
|
||||
neg_noise_pred = self.transformer(
|
||||
with self.transformer.cache_context("cond"):
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latents,
|
||||
timestep=timestep / 1000,
|
||||
guidance=guidance,
|
||||
pooled_projections=negative_pooled_prompt_embeds,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
txt_ids=negative_text_ids,
|
||||
pooled_projections=pooled_prompt_embeds,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
txt_ids=text_ids,
|
||||
img_ids=latent_image_ids,
|
||||
joint_attention_kwargs=self.joint_attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
if do_true_cfg:
|
||||
if negative_image_embeds is not None:
|
||||
self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
|
||||
|
||||
with self.transformer.cache_context("uncond"):
|
||||
neg_noise_pred = self.transformer(
|
||||
hidden_states=latents,
|
||||
timestep=timestep / 1000,
|
||||
guidance=guidance,
|
||||
pooled_projections=negative_pooled_prompt_embeds,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
txt_ids=negative_text_ids,
|
||||
img_ids=latent_image_ids,
|
||||
joint_attention_kwargs=self.joint_attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
|
||||
@@ -693,28 +693,30 @@ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
||||
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
encoder_attention_mask=prompt_attention_mask,
|
||||
pooled_projections=pooled_prompt_embeds,
|
||||
guidance=guidance,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
if do_true_cfg:
|
||||
neg_noise_pred = self.transformer(
|
||||
with self.transformer.cache_context("cond"):
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
encoder_attention_mask=negative_prompt_attention_mask,
|
||||
pooled_projections=negative_pooled_prompt_embeds,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
encoder_attention_mask=prompt_attention_mask,
|
||||
pooled_projections=pooled_prompt_embeds,
|
||||
guidance=guidance,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
if do_true_cfg:
|
||||
with self.transformer.cache_context("uncond"):
|
||||
neg_noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
encoder_attention_mask=negative_prompt_attention_mask,
|
||||
pooled_projections=negative_pooled_prompt_embeds,
|
||||
guidance=guidance,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
|
||||
@@ -757,18 +757,19 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latent_model_input.shape[0])
|
||||
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep=timestep,
|
||||
encoder_attention_mask=prompt_attention_mask,
|
||||
num_frames=latent_num_frames,
|
||||
height=latent_height,
|
||||
width=latent_width,
|
||||
rope_interpolation_scale=rope_interpolation_scale,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
with self.transformer.cache_context("cond_uncond"):
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep=timestep,
|
||||
encoder_attention_mask=prompt_attention_mask,
|
||||
num_frames=latent_num_frames,
|
||||
height=latent_height,
|
||||
width=latent_width,
|
||||
rope_interpolation_scale=rope_interpolation_scale,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = noise_pred.float()
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
|
||||
@@ -1177,15 +1177,16 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
|
||||
if is_conditioning_image_or_video:
|
||||
timestep = torch.min(timestep, (1 - conditioning_mask_model_input) * 1000.0)
|
||||
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep=timestep,
|
||||
encoder_attention_mask=prompt_attention_mask,
|
||||
video_coords=video_coords,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
with self.transformer.cache_context("cond_uncond"):
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep=timestep,
|
||||
encoder_attention_mask=prompt_attention_mask,
|
||||
video_coords=video_coords,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
|
||||
@@ -830,18 +830,19 @@ class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLo
|
||||
timestep = t.expand(latent_model_input.shape[0])
|
||||
timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask)
|
||||
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep=timestep,
|
||||
encoder_attention_mask=prompt_attention_mask,
|
||||
num_frames=latent_num_frames,
|
||||
height=latent_height,
|
||||
width=latent_width,
|
||||
rope_interpolation_scale=rope_interpolation_scale,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
with self.transformer.cache_context("cond_uncond"):
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep=timestep,
|
||||
encoder_attention_mask=prompt_attention_mask,
|
||||
num_frames=latent_num_frames,
|
||||
height=latent_height,
|
||||
width=latent_width,
|
||||
rope_interpolation_scale=rope_interpolation_scale,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = noise_pred.float()
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
|
||||
@@ -671,14 +671,15 @@ class MochiPipeline(DiffusionPipeline, Mochi1LoraLoaderMixin):
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)
|
||||
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep=timestep,
|
||||
encoder_attention_mask=prompt_attention_mask,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
with self.transformer.cache_context("cond_uncond"):
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep=timestep,
|
||||
encoder_attention_mask=prompt_attention_mask,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
# Mochi CFG + Sampling runs in FP32
|
||||
noise_pred = noise_pred.to(torch.float32)
|
||||
|
||||
|
||||
@@ -92,7 +92,7 @@ for library in LOADABLE_CLASSES:
|
||||
ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library])
|
||||
|
||||
|
||||
def is_safetensors_compatible(filenames, passed_components=None, folder_names=None, variant=None) -> bool:
|
||||
def is_safetensors_compatible(filenames, passed_components=None, folder_names=None) -> bool:
|
||||
"""
|
||||
Checking for safetensors compatibility:
|
||||
- The model is safetensors compatible only if there is a safetensors file for each model component present in
|
||||
@@ -103,31 +103,6 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No
|
||||
- For models from the transformers library, the filename changes from "pytorch_model" to "model", and the ".bin"
|
||||
extension is replaced with ".safetensors"
|
||||
"""
|
||||
weight_names = [
|
||||
WEIGHTS_NAME,
|
||||
SAFETENSORS_WEIGHTS_NAME,
|
||||
FLAX_WEIGHTS_NAME,
|
||||
ONNX_WEIGHTS_NAME,
|
||||
ONNX_EXTERNAL_WEIGHTS_NAME,
|
||||
]
|
||||
|
||||
if is_transformers_available():
|
||||
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME]
|
||||
|
||||
# model_pytorch, diffusion_model_pytorch, ...
|
||||
weight_prefixes = [w.split(".")[0] for w in weight_names]
|
||||
# .bin, .safetensors, ...
|
||||
weight_suffixs = [w.split(".")[-1] for w in weight_names]
|
||||
# -00001-of-00002
|
||||
transformers_index_format = r"\d{5}-of-\d{5}"
|
||||
# `diffusion_pytorch_model.bin` as well as `model-00001-of-00002.safetensors`
|
||||
variant_file_re = re.compile(
|
||||
rf"({'|'.join(weight_prefixes)})\.({variant}|{variant}-{transformers_index_format})\.({'|'.join(weight_suffixs)})$"
|
||||
)
|
||||
non_variant_file_re = re.compile(
|
||||
rf"({'|'.join(weight_prefixes)})(-{transformers_index_format})?\.({'|'.join(weight_suffixs)})$"
|
||||
)
|
||||
|
||||
passed_components = passed_components or []
|
||||
if folder_names:
|
||||
filenames = {f for f in filenames if os.path.split(f)[0] in folder_names}
|
||||
@@ -146,29 +121,15 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No
|
||||
components[component].append(component_filename)
|
||||
|
||||
# If there are no component folders check the main directory for safetensors files
|
||||
filtered_filenames = set()
|
||||
if not components:
|
||||
if variant is not None:
|
||||
filtered_filenames = filter_with_regex(filenames, variant_file_re)
|
||||
|
||||
# If no variant filenames exist check if non-variant files are available
|
||||
if not filtered_filenames:
|
||||
filtered_filenames = filter_with_regex(filenames, non_variant_file_re)
|
||||
return any(".safetensors" in filename for filename in filtered_filenames)
|
||||
return any(".safetensors" in filename for filename in filenames)
|
||||
|
||||
# iterate over all files of a component
|
||||
# check if safetensor files exist for that component
|
||||
# if variant is provided check if the variant of the safetensors exists
|
||||
for component, component_filenames in components.items():
|
||||
matches = []
|
||||
filtered_component_filenames = set()
|
||||
# if variant is provided check if the variant of the safetensors exists
|
||||
if variant is not None:
|
||||
filtered_component_filenames = filter_with_regex(component_filenames, variant_file_re)
|
||||
|
||||
# if variant safetensor files do not exist check for non-variants
|
||||
if not filtered_component_filenames:
|
||||
filtered_component_filenames = filter_with_regex(component_filenames, non_variant_file_re)
|
||||
for component_filename in filtered_component_filenames:
|
||||
for component_filename in component_filenames:
|
||||
filename, extension = os.path.splitext(component_filename)
|
||||
|
||||
match_exists = extension == ".safetensors"
|
||||
@@ -198,10 +159,6 @@ def filter_model_files(filenames):
|
||||
return [f for f in filenames if any(f.endswith(extension) for extension in allowed_extensions)]
|
||||
|
||||
|
||||
def filter_with_regex(filenames, pattern_re):
|
||||
return {f for f in filenames if pattern_re.match(f.split("/")[-1]) is not None}
|
||||
|
||||
|
||||
def variant_compatible_siblings(filenames, variant=None, ignore_patterns=None) -> Union[List[os.PathLike], str]:
|
||||
weight_names = [
|
||||
WEIGHTS_NAME,
|
||||
@@ -250,6 +207,9 @@ def variant_compatible_siblings(filenames, variant=None, ignore_patterns=None) -
|
||||
# interested in the extension name
|
||||
return {f for f in filenames if not any(f.endswith(pat.lstrip("*.")) for pat in ignore_patterns)}
|
||||
|
||||
def filter_with_regex(filenames, pattern_re):
|
||||
return {f for f in filenames if pattern_re.match(f.split("/")[-1]) is not None}
|
||||
|
||||
# Group files by component
|
||||
components = {}
|
||||
for filename in filenames:
|
||||
@@ -375,14 +335,14 @@ def get_class_obj_and_candidates(
|
||||
library_name, class_name, importable_classes, pipelines, is_pipeline_module, component_name=None, cache_dir=None
|
||||
):
|
||||
"""Simple helper method to retrieve class object of module as well as potential parent class objects"""
|
||||
component_folder = os.path.join(cache_dir, component_name) if component_name and cache_dir else None
|
||||
component_folder = os.path.join(cache_dir, component_name)
|
||||
|
||||
if is_pipeline_module:
|
||||
pipeline_module = getattr(pipelines, library_name)
|
||||
|
||||
class_obj = getattr(pipeline_module, class_name)
|
||||
class_candidates = dict.fromkeys(importable_classes.keys(), class_obj)
|
||||
elif component_folder and os.path.isfile(os.path.join(component_folder, library_name + ".py")):
|
||||
elif os.path.isfile(os.path.join(component_folder, library_name + ".py")):
|
||||
# load custom component
|
||||
class_obj = get_class_from_dynamic_module(
|
||||
component_folder, module_file=library_name + ".py", class_name=class_name
|
||||
@@ -1037,7 +997,7 @@ def _get_ignore_patterns(
|
||||
use_safetensors
|
||||
and not allow_pickle
|
||||
and not is_safetensors_compatible(
|
||||
model_filenames, passed_components=passed_components, folder_names=model_folder_names, variant=variant
|
||||
model_filenames, passed_components=passed_components, folder_names=model_folder_names
|
||||
)
|
||||
):
|
||||
raise EnvironmentError(
|
||||
@@ -1048,7 +1008,7 @@ def _get_ignore_patterns(
|
||||
ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"]
|
||||
|
||||
elif use_safetensors and is_safetensors_compatible(
|
||||
model_filenames, passed_components=passed_components, folder_names=model_folder_names, variant=variant
|
||||
model_filenames, passed_components=passed_components, folder_names=model_folder_names
|
||||
):
|
||||
ignore_patterns = ["*.bin", "*.msgpack"]
|
||||
|
||||
|
||||
@@ -1665,8 +1665,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
signature_types[k] = (v.annotation,)
|
||||
elif get_origin(v.annotation) == Union:
|
||||
signature_types[k] = get_args(v.annotation)
|
||||
elif get_origin(v.annotation) in [List, Dict, list, dict]:
|
||||
signature_types[k] = (v.annotation,)
|
||||
else:
|
||||
logger.warning(f"cannot get type annotation for Parameter {k} of {cls}.")
|
||||
return signature_types
|
||||
|
||||
@@ -530,22 +530,24 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
latent_model_input = latents.to(transformer_dtype)
|
||||
timestep = t.expand(latents.shape[0])
|
||||
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_uncond = self.transformer(
|
||||
with self.transformer.cache_context("cond"):
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
with self.transformer.cache_context("uncond"):
|
||||
noise_uncond = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
|
||||
@@ -17,6 +17,21 @@ class FasterCacheConfig(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class FirstBlockCacheConfig(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class HookRegistry(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@@ -51,6 +66,10 @@ def apply_faster_cache(*args, **kwargs):
|
||||
requires_backends(apply_faster_cache, ["torch"])
|
||||
|
||||
|
||||
def apply_first_block_cache(*args, **kwargs):
|
||||
requires_backends(apply_first_block_cache, ["torch"])
|
||||
|
||||
|
||||
def apply_pyramid_attention_broadcast(*args, **kwargs):
|
||||
requires_backends(apply_pyramid_attention_broadcast, ["torch"])
|
||||
|
||||
|
||||
@@ -635,10 +635,10 @@ def load_numpy(arry: Union[str, np.ndarray], local_path: Optional[str] = None) -
|
||||
return arry
|
||||
|
||||
|
||||
def load_pt(url: str, map_location: Optional[str] = None, weights_only: Optional[bool] = True):
|
||||
def load_pt(url: str, map_location: str):
|
||||
response = requests.get(url, timeout=DIFFUSERS_REQUEST_TIMEOUT)
|
||||
response.raise_for_status()
|
||||
arry = torch.load(BytesIO(response.content), map_location=map_location, weights_only=weights_only)
|
||||
arry = torch.load(BytesIO(response.content), map_location=map_location)
|
||||
return arry
|
||||
|
||||
|
||||
|
||||
@@ -38,7 +38,7 @@ except (ImportError, ModuleNotFoundError):
|
||||
def randn_tensor(
|
||||
shape: Union[Tuple, List],
|
||||
generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None,
|
||||
device: Optional[Union[str, "torch.device"]] = None,
|
||||
device: Optional["torch.device"] = None,
|
||||
dtype: Optional["torch.dtype"] = None,
|
||||
layout: Optional["torch.layout"] = None,
|
||||
):
|
||||
@@ -47,8 +47,6 @@ def randn_tensor(
|
||||
is always created on the CPU.
|
||||
"""
|
||||
# device on which tensor is created defaults to device
|
||||
if isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
rand_device = device
|
||||
batch_size = shape[0]
|
||||
|
||||
@@ -92,6 +90,11 @@ def is_compiled_module(module) -> bool:
|
||||
return isinstance(module, torch._dynamo.eval_frame.OptimizedModule)
|
||||
|
||||
|
||||
def unwrap_module(module):
|
||||
"""Unwraps a module if it was compiled with torch.compile()"""
|
||||
return module._orig_mod if is_compiled_module(module) else module
|
||||
|
||||
|
||||
def fourier_filter(x_in: "torch.Tensor", threshold: int, scale: int) -> "torch.Tensor":
|
||||
"""Fourier filter as introduced in FreeU (https://huggingface.co/papers/2309.11497).
|
||||
|
||||
|
||||
@@ -2149,51 +2149,3 @@ class PeftLoraLoaderMixinTests:
|
||||
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
def test_inference_load_delete_load_adapters(self):
|
||||
"Tests if `load_lora_weights()` -> `delete_adapters()` -> `load_lora_weights()` works."
|
||||
for scheduler_cls in self.scheduler_classes:
|
||||
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
|
||||
pipe.text_encoder.add_adapter(text_lora_config)
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
|
||||
)
|
||||
|
||||
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
|
||||
denoiser.add_adapter(denoiser_lora_config)
|
||||
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
|
||||
|
||||
if self.has_two_text_encoders or self.has_three_text_encoders:
|
||||
lora_loadable_components = self.pipeline_class._lora_loadable_modules
|
||||
if "text_encoder_2" in lora_loadable_components:
|
||||
pipe.text_encoder_2.add_adapter(text_lora_config)
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
|
||||
)
|
||||
|
||||
output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
|
||||
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
|
||||
self.pipeline_class.save_lora_weights(save_directory=tmpdirname, **lora_state_dicts)
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
|
||||
|
||||
# First, delete adapter and compare.
|
||||
pipe.delete_adapters(pipe.get_active_adapters()[0])
|
||||
output_no_adapter = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertFalse(np.allclose(output_adapter_1, output_no_adapter, atol=1e-3, rtol=1e-3))
|
||||
self.assertTrue(np.allclose(output_no_lora, output_no_adapter, atol=1e-3, rtol=1e-3))
|
||||
|
||||
# Then load adapter and compare.
|
||||
pipe.load_lora_weights(tmpdirname)
|
||||
output_lora_loaded = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(np.allclose(output_adapter_1, output_lora_loaded, atol=1e-3, rtol=1e-3))
|
||||
|
||||
@@ -1748,14 +1748,14 @@ class TorchCompileTesterMixin:
|
||||
def setUp(self):
|
||||
# clean up the VRAM before each test
|
||||
super().setUp()
|
||||
torch.compiler.reset()
|
||||
torch._dynamo.reset()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test in case of CUDA runtime errors
|
||||
super().tearDown()
|
||||
torch.compiler.reset()
|
||||
torch._dynamo.reset()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
@@ -1764,17 +1764,13 @@ class TorchCompileTesterMixin:
|
||||
@is_torch_compile
|
||||
@slow
|
||||
def test_torch_compile_recompilation_and_graph_break(self):
|
||||
torch.compiler.reset()
|
||||
torch._dynamo.reset()
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model = torch.compile(model, fullgraph=True)
|
||||
|
||||
with (
|
||||
torch._inductor.utils.fresh_inductor_cache(),
|
||||
torch._dynamo.config.patch(error_on_recompile=True),
|
||||
torch.no_grad(),
|
||||
):
|
||||
with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad():
|
||||
_ = model(**inputs_dict)
|
||||
_ = model(**inputs_dict)
|
||||
|
||||
@@ -1802,7 +1798,7 @@ class LoraHotSwappingForModelTesterMixin:
|
||||
# It is critical that the dynamo cache is reset for each test. Otherwise, if the test re-uses the same model,
|
||||
# there will be recompilation errors, as torch caches the model when run in the same process.
|
||||
super().tearDown()
|
||||
torch.compiler.reset()
|
||||
torch._dynamo.reset()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
@@ -1919,7 +1915,7 @@ class LoraHotSwappingForModelTesterMixin:
|
||||
def test_hotswapping_compiled_model_linear(self, rank0, rank1):
|
||||
# It's important to add this context to raise an error on recompilation
|
||||
target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
|
||||
with torch._dynamo.config.patch(error_on_recompile=True), torch._inductor.utils.fresh_inductor_cache():
|
||||
with torch._dynamo.config.patch(error_on_recompile=True):
|
||||
self.check_model_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules)
|
||||
|
||||
@parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa
|
||||
@@ -1929,7 +1925,7 @@ class LoraHotSwappingForModelTesterMixin:
|
||||
|
||||
# It's important to add this context to raise an error on recompilation
|
||||
target_modules = ["conv", "conv1", "conv2"]
|
||||
with torch._dynamo.config.patch(error_on_recompile=True), torch._inductor.utils.fresh_inductor_cache():
|
||||
with torch._dynamo.config.patch(error_on_recompile=True):
|
||||
self.check_model_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules)
|
||||
|
||||
@parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa
|
||||
@@ -1939,7 +1935,7 @@ class LoraHotSwappingForModelTesterMixin:
|
||||
|
||||
# It's important to add this context to raise an error on recompilation
|
||||
target_modules = ["to_q", "conv"]
|
||||
with torch._dynamo.config.patch(error_on_recompile=True), torch._inductor.utils.fresh_inductor_cache():
|
||||
with torch._dynamo.config.patch(error_on_recompile=True):
|
||||
self.check_model_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules)
|
||||
|
||||
@parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa
|
||||
|
||||
@@ -1,32 +0,0 @@
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
from transformers import CLIPTextModel, LongformerModel
|
||||
|
||||
from diffusers.models import AutoModel, UNet2DConditionModel
|
||||
|
||||
|
||||
class TestAutoModel(unittest.TestCase):
|
||||
@patch(
|
||||
"diffusers.models.AutoModel.load_config",
|
||||
side_effect=[EnvironmentError("File not found"), {"_class_name": "UNet2DConditionModel"}],
|
||||
)
|
||||
def test_load_from_config_diffusers_with_subfolder(self, mock_load_config):
|
||||
model = AutoModel.from_pretrained("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="unet")
|
||||
assert isinstance(model, UNet2DConditionModel)
|
||||
|
||||
@patch(
|
||||
"diffusers.models.AutoModel.load_config",
|
||||
side_effect=[EnvironmentError("File not found"), {"model_type": "clip_text_model"}],
|
||||
)
|
||||
def test_load_from_config_transformers_with_subfolder(self, mock_load_config):
|
||||
model = AutoModel.from_pretrained("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="text_encoder")
|
||||
assert isinstance(model, CLIPTextModel)
|
||||
|
||||
def test_load_from_config_without_subfolder(self):
|
||||
model = AutoModel.from_pretrained("hf-internal-testing/tiny-random-longformer")
|
||||
assert isinstance(model, LongformerModel)
|
||||
|
||||
def test_load_from_model_index(self):
|
||||
model = AutoModel.from_pretrained("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="text_encoder")
|
||||
assert isinstance(model, CLIPTextModel)
|
||||
@@ -19,16 +19,20 @@ import torch
|
||||
from diffusers import HunyuanVideoTransformer3DModel
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
is_torch_compile,
|
||||
require_torch_2,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class HunyuanVideoTransformer3DTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase):
|
||||
class HunyuanVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = HunyuanVideoTransformer3DModel
|
||||
main_input_name = "hidden_states"
|
||||
uses_custom_attn_processor = True
|
||||
@@ -92,8 +96,23 @@ class HunyuanVideoTransformer3DTests(ModelTesterMixin, TorchCompileTesterMixin,
|
||||
expected_set = {"HunyuanVideoTransformer3DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
@require_torch_gpu
|
||||
@require_torch_2
|
||||
@is_torch_compile
|
||||
@slow
|
||||
def test_torch_compile_recompilation_and_graph_break(self):
|
||||
torch._dynamo.reset()
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
class HunyuanSkyreelsImageToVideoTransformer3DTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase):
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model = torch.compile(model, fullgraph=True)
|
||||
|
||||
with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad():
|
||||
_ = model(**inputs_dict)
|
||||
_ = model(**inputs_dict)
|
||||
|
||||
|
||||
class HunyuanSkyreelsImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = HunyuanVideoTransformer3DModel
|
||||
main_input_name = "hidden_states"
|
||||
uses_custom_attn_processor = True
|
||||
@@ -160,8 +179,23 @@ class HunyuanSkyreelsImageToVideoTransformer3DTests(ModelTesterMixin, TorchCompi
|
||||
expected_set = {"HunyuanVideoTransformer3DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
@require_torch_gpu
|
||||
@require_torch_2
|
||||
@is_torch_compile
|
||||
@slow
|
||||
def test_torch_compile_recompilation_and_graph_break(self):
|
||||
torch._dynamo.reset()
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
class HunyuanVideoImageToVideoTransformer3DTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase):
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model = torch.compile(model, fullgraph=True)
|
||||
|
||||
with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad():
|
||||
_ = model(**inputs_dict)
|
||||
_ = model(**inputs_dict)
|
||||
|
||||
|
||||
class HunyuanVideoImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = HunyuanVideoTransformer3DModel
|
||||
main_input_name = "hidden_states"
|
||||
uses_custom_attn_processor = True
|
||||
@@ -226,10 +260,23 @@ class HunyuanVideoImageToVideoTransformer3DTests(ModelTesterMixin, TorchCompileT
|
||||
expected_set = {"HunyuanVideoTransformer3DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
@require_torch_gpu
|
||||
@require_torch_2
|
||||
@is_torch_compile
|
||||
@slow
|
||||
def test_torch_compile_recompilation_and_graph_break(self):
|
||||
torch._dynamo.reset()
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
class HunyuanVideoTokenReplaceImageToVideoTransformer3DTests(
|
||||
ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase
|
||||
):
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model = torch.compile(model, fullgraph=True)
|
||||
|
||||
with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad():
|
||||
_ = model(**inputs_dict)
|
||||
_ = model(**inputs_dict)
|
||||
|
||||
|
||||
class HunyuanVideoTokenReplaceImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = HunyuanVideoTransformer3DModel
|
||||
main_input_name = "hidden_states"
|
||||
uses_custom_attn_processor = True
|
||||
@@ -295,3 +342,18 @@ class HunyuanVideoTokenReplaceImageToVideoTransformer3DTests(
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"HunyuanVideoTransformer3DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
@require_torch_gpu
|
||||
@require_torch_2
|
||||
@is_torch_compile
|
||||
@slow
|
||||
def test_torch_compile_recompilation_and_graph_break(self):
|
||||
torch._dynamo.reset()
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model = torch.compile(model, fullgraph=True)
|
||||
|
||||
with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad():
|
||||
_ = model(**inputs_dict)
|
||||
_ = model(**inputs_dict)
|
||||
|
||||
@@ -19,16 +19,20 @@ import torch
|
||||
from diffusers import WanTransformer3DModel
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
is_torch_compile,
|
||||
require_torch_2,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class WanTransformer3DTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase):
|
||||
class WanTransformer3DTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = WanTransformer3DModel
|
||||
main_input_name = "hidden_states"
|
||||
uses_custom_attn_processor = True
|
||||
@@ -82,3 +86,18 @@ class WanTransformer3DTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"WanTransformer3DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
@require_torch_gpu
|
||||
@require_torch_2
|
||||
@is_torch_compile
|
||||
@slow
|
||||
def test_torch_compile_recompilation_and_graph_break(self):
|
||||
torch._dynamo.reset()
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model = torch.compile(model, fullgraph=True)
|
||||
|
||||
with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad():
|
||||
_ = model(**inputs_dict)
|
||||
_ = model(**inputs_dict)
|
||||
|
||||
@@ -32,6 +32,7 @@ from diffusers.utils.testing_utils import (
|
||||
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
|
||||
from ..test_pipelines_common import (
|
||||
FasterCacheTesterMixin,
|
||||
FirstBlockCacheTesterMixin,
|
||||
PipelineTesterMixin,
|
||||
PyramidAttentionBroadcastTesterMixin,
|
||||
check_qkv_fusion_matches_attn_procs_length,
|
||||
@@ -44,7 +45,11 @@ enable_full_determinism()
|
||||
|
||||
|
||||
class CogVideoXPipelineFastTests(
|
||||
PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, FasterCacheTesterMixin, unittest.TestCase
|
||||
PipelineTesterMixin,
|
||||
PyramidAttentionBroadcastTesterMixin,
|
||||
FasterCacheTesterMixin,
|
||||
FirstBlockCacheTesterMixin,
|
||||
unittest.TestCase,
|
||||
):
|
||||
pipeline_class = CogVideoXPipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
|
||||
import gc
|
||||
import tempfile
|
||||
import traceback
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
@@ -38,9 +39,13 @@ from diffusers.utils.testing_utils import (
|
||||
backend_reset_max_memory_allocated,
|
||||
backend_reset_peak_memory_stats,
|
||||
enable_full_determinism,
|
||||
get_python_version,
|
||||
is_torch_compile,
|
||||
load_image,
|
||||
load_numpy,
|
||||
require_torch_2,
|
||||
require_torch_accelerator,
|
||||
run_test_in_subprocess,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
@@ -63,6 +68,52 @@ from ..test_pipelines_common import (
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
# Will be run via run_test_in_subprocess
|
||||
def _test_stable_diffusion_compile(in_queue, out_queue, timeout):
|
||||
error = None
|
||||
try:
|
||||
_ = in_queue.get(timeout=timeout)
|
||||
|
||||
controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny")
|
||||
|
||||
pipe = StableDiffusionControlNetPipeline.from_pretrained(
|
||||
"stable-diffusion-v1-5/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet
|
||||
)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
pipe.unet.to(memory_format=torch.channels_last)
|
||||
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
|
||||
|
||||
pipe.controlnet.to(memory_format=torch.channels_last)
|
||||
pipe.controlnet = torch.compile(pipe.controlnet, mode="reduce-overhead", fullgraph=True)
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
prompt = "bird"
|
||||
image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
|
||||
).resize((512, 512))
|
||||
|
||||
output = pipe(prompt, image, num_inference_steps=10, generator=generator, output_type="np")
|
||||
image = output.images[0]
|
||||
|
||||
assert image.shape == (512, 512, 3)
|
||||
|
||||
expected_image = load_numpy(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny_out_full.npy"
|
||||
)
|
||||
expected_image = np.resize(expected_image, (512, 512, 3))
|
||||
|
||||
assert np.abs(expected_image - image).max() < 1.0
|
||||
|
||||
except Exception:
|
||||
error = f"{traceback.format_exc()}"
|
||||
|
||||
results = {"error": error}
|
||||
out_queue.put(results, timeout=timeout)
|
||||
out_queue.join()
|
||||
|
||||
|
||||
class ControlNetPipelineFastTests(
|
||||
IPAdapterTesterMixin,
|
||||
PipelineLatentTesterMixin,
|
||||
@@ -1002,6 +1053,15 @@ class ControlNetPipelineSlowTests(unittest.TestCase):
|
||||
expected_slice = np.array([0.1655, 0.1721, 0.1623, 0.1685, 0.1711, 0.1646, 0.1651, 0.1631, 0.1494])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
@is_torch_compile
|
||||
@require_torch_2
|
||||
@unittest.skipIf(
|
||||
get_python_version == (3, 12),
|
||||
reason="Torch Dynamo isn't yet supported for Python 3.12.",
|
||||
)
|
||||
def test_stable_diffusion_compile(self):
|
||||
run_test_in_subprocess(test_case=self, target_func=_test_stable_diffusion_compile, inputs=None)
|
||||
|
||||
def test_v11_shuffle_global_pool_conditions(self):
|
||||
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11e_sd15_shuffle")
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import traceback
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
@@ -35,9 +36,13 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
enable_full_determinism,
|
||||
is_torch_compile,
|
||||
load_image,
|
||||
load_numpy,
|
||||
require_accelerator,
|
||||
require_torch_2,
|
||||
require_torch_accelerator,
|
||||
run_test_in_subprocess,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
@@ -73,6 +78,53 @@ def to_np(tensor):
|
||||
return tensor
|
||||
|
||||
|
||||
# Will be run via run_test_in_subprocess
|
||||
def _test_stable_diffusion_compile(in_queue, out_queue, timeout):
|
||||
error = None
|
||||
try:
|
||||
_ = in_queue.get(timeout=timeout)
|
||||
|
||||
controlnet = ControlNetXSAdapter.from_pretrained(
|
||||
"UmerHA/Testing-ConrolNetXS-SD2.1-canny", torch_dtype=torch.float16
|
||||
)
|
||||
pipe = StableDiffusionControlNetXSPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-2-1-base",
|
||||
controlnet=controlnet,
|
||||
safety_checker=None,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
pipe.unet.to(memory_format=torch.channels_last)
|
||||
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
prompt = "bird"
|
||||
image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
|
||||
).resize((512, 512))
|
||||
|
||||
output = pipe(prompt, image, num_inference_steps=10, generator=generator, output_type="np")
|
||||
image = output.images[0]
|
||||
|
||||
assert image.shape == (512, 512, 3)
|
||||
|
||||
expected_image = load_numpy(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny_out_full.npy"
|
||||
)
|
||||
expected_image = np.resize(expected_image, (512, 512, 3))
|
||||
|
||||
assert np.abs(expected_image - image).max() < 1.0
|
||||
|
||||
except Exception:
|
||||
error = f"{traceback.format_exc()}"
|
||||
|
||||
results = {"error": error}
|
||||
out_queue.put(results, timeout=timeout)
|
||||
out_queue.join()
|
||||
|
||||
|
||||
class ControlNetXSPipelineFastTests(
|
||||
PipelineLatentTesterMixin,
|
||||
PipelineKarrasSchedulerTesterMixin,
|
||||
@@ -350,3 +402,8 @@ class ControlNetXSPipelineSlowTests(unittest.TestCase):
|
||||
original_image = image[-3:, -3:, -1].flatten()
|
||||
expected_image = np.array([0.4844, 0.4937, 0.4956, 0.4663, 0.5039, 0.5044, 0.4565, 0.4883, 0.4941])
|
||||
assert np.allclose(original_image, expected_image, atol=1e-04)
|
||||
|
||||
@is_torch_compile
|
||||
@require_torch_2
|
||||
def test_stable_diffusion_compile(self):
|
||||
run_test_in_subprocess(test_case=self, target_func=_test_stable_diffusion_compile, inputs=None)
|
||||
|
||||
@@ -25,6 +25,7 @@ from diffusers.utils.testing_utils import (
|
||||
|
||||
from ..test_pipelines_common import (
|
||||
FasterCacheTesterMixin,
|
||||
FirstBlockCacheTesterMixin,
|
||||
FluxIPAdapterTesterMixin,
|
||||
PipelineTesterMixin,
|
||||
PyramidAttentionBroadcastTesterMixin,
|
||||
@@ -34,11 +35,12 @@ from ..test_pipelines_common import (
|
||||
|
||||
|
||||
class FluxPipelineFastTests(
|
||||
unittest.TestCase,
|
||||
PipelineTesterMixin,
|
||||
FluxIPAdapterTesterMixin,
|
||||
PyramidAttentionBroadcastTesterMixin,
|
||||
FasterCacheTesterMixin,
|
||||
FirstBlockCacheTesterMixin,
|
||||
unittest.TestCase,
|
||||
):
|
||||
pipeline_class = FluxPipeline
|
||||
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])
|
||||
|
||||
@@ -33,6 +33,7 @@ from diffusers.utils.testing_utils import (
|
||||
|
||||
from ..test_pipelines_common import (
|
||||
FasterCacheTesterMixin,
|
||||
FirstBlockCacheTesterMixin,
|
||||
PipelineTesterMixin,
|
||||
PyramidAttentionBroadcastTesterMixin,
|
||||
to_np,
|
||||
@@ -43,7 +44,11 @@ enable_full_determinism()
|
||||
|
||||
|
||||
class HunyuanVideoPipelineFastTests(
|
||||
PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, FasterCacheTesterMixin, unittest.TestCase
|
||||
PipelineTesterMixin,
|
||||
PyramidAttentionBroadcastTesterMixin,
|
||||
FasterCacheTesterMixin,
|
||||
FirstBlockCacheTesterMixin,
|
||||
unittest.TestCase,
|
||||
):
|
||||
pipeline_class = HunyuanVideoPipeline
|
||||
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])
|
||||
|
||||
@@ -23,13 +23,13 @@ from diffusers import AutoencoderKLLTXVideo, FlowMatchEulerDiscreteScheduler, LT
|
||||
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
|
||||
|
||||
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
|
||||
from ..test_pipelines_common import PipelineTesterMixin, to_np
|
||||
from ..test_pipelines_common import FirstBlockCacheTesterMixin, PipelineTesterMixin, to_np
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class LTXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
class LTXPipelineFastTests(PipelineTesterMixin, FirstBlockCacheTesterMixin, unittest.TestCase):
|
||||
pipeline_class = LTXPipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||
@@ -49,7 +49,7 @@ class LTXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
test_layerwise_casting = True
|
||||
test_group_offloading = True
|
||||
|
||||
def get_dummy_components(self):
|
||||
def get_dummy_components(self, num_layers: int = 1):
|
||||
torch.manual_seed(0)
|
||||
transformer = LTXVideoTransformer3DModel(
|
||||
in_channels=8,
|
||||
@@ -59,7 +59,7 @@ class LTXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
num_attention_heads=4,
|
||||
attention_head_dim=8,
|
||||
cross_attention_dim=32,
|
||||
num_layers=1,
|
||||
num_layers=num_layers,
|
||||
caption_channels=32,
|
||||
)
|
||||
|
||||
|
||||
@@ -33,13 +33,15 @@ from diffusers.utils.testing_utils import (
|
||||
)
|
||||
|
||||
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
|
||||
from ..test_pipelines_common import FasterCacheTesterMixin, PipelineTesterMixin, to_np
|
||||
from ..test_pipelines_common import FasterCacheTesterMixin, FirstBlockCacheTesterMixin, PipelineTesterMixin, to_np
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class MochiPipelineFastTests(PipelineTesterMixin, FasterCacheTesterMixin, unittest.TestCase):
|
||||
class MochiPipelineFastTests(
|
||||
PipelineTesterMixin, FasterCacheTesterMixin, FirstBlockCacheTesterMixin, unittest.TestCase
|
||||
):
|
||||
pipeline_class = MochiPipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||
|
||||
@@ -304,8 +304,7 @@ class StableCascadeDecoderPipelineIntegrationTests(unittest.TestCase):
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
image_embedding = load_pt(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/stable_cascade/image_embedding.pt",
|
||||
map_location=torch_device,
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/stable_cascade/image_embedding.pt"
|
||||
)
|
||||
|
||||
image = pipe(
|
||||
@@ -321,4 +320,4 @@ class StableCascadeDecoderPipelineIntegrationTests(unittest.TestCase):
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/stable_cascade/stable_cascade_decoder_image.npy"
|
||||
)
|
||||
max_diff = numpy_cosine_similarity_distance(image.flatten(), expected_image.flatten())
|
||||
assert max_diff < 2e-4
|
||||
assert max_diff < 1e-4
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
import gc
|
||||
import tempfile
|
||||
import time
|
||||
import traceback
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
@@ -48,12 +49,16 @@ from diffusers.utils.testing_utils import (
|
||||
backend_reset_max_memory_allocated,
|
||||
backend_reset_peak_memory_stats,
|
||||
enable_full_determinism,
|
||||
is_torch_compile,
|
||||
load_image,
|
||||
load_numpy,
|
||||
nightly,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_accelerate_version_greater,
|
||||
require_torch_2,
|
||||
require_torch_accelerator,
|
||||
require_torch_multi_accelerator,
|
||||
run_test_in_subprocess,
|
||||
skip_mps,
|
||||
slow,
|
||||
torch_device,
|
||||
@@ -76,6 +81,39 @@ from ..test_pipelines_common import (
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
# Will be run via run_test_in_subprocess
|
||||
def _test_stable_diffusion_compile(in_queue, out_queue, timeout):
|
||||
error = None
|
||||
try:
|
||||
inputs = in_queue.get(timeout=timeout)
|
||||
torch_device = inputs.pop("torch_device")
|
||||
seed = inputs.pop("seed")
|
||||
inputs["generator"] = torch.Generator(device=torch_device).manual_seed(seed)
|
||||
|
||||
sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", safety_checker=None)
|
||||
sd_pipe.scheduler = DDIMScheduler.from_config(sd_pipe.scheduler.config)
|
||||
sd_pipe = sd_pipe.to(torch_device)
|
||||
|
||||
sd_pipe.unet.to(memory_format=torch.channels_last)
|
||||
sd_pipe.unet = torch.compile(sd_pipe.unet, mode="reduce-overhead", fullgraph=True)
|
||||
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
image = sd_pipe(**inputs).images
|
||||
image_slice = image[0, -3:, -3:, -1].flatten()
|
||||
|
||||
assert image.shape == (1, 512, 512, 3)
|
||||
expected_slice = np.array([0.38019, 0.28647, 0.27321, 0.40377, 0.38290, 0.35446, 0.39218, 0.38165, 0.42239])
|
||||
|
||||
assert np.abs(image_slice - expected_slice).max() < 5e-3
|
||||
except Exception:
|
||||
error = f"{traceback.format_exc()}"
|
||||
|
||||
results = {"error": error}
|
||||
out_queue.put(results, timeout=timeout)
|
||||
out_queue.join()
|
||||
|
||||
|
||||
class StableDiffusionPipelineFastTests(
|
||||
IPAdapterTesterMixin,
|
||||
PipelineLatentTesterMixin,
|
||||
@@ -1186,6 +1224,40 @@ class StableDiffusionPipelineSlowTests(unittest.TestCase):
|
||||
max_diff = np.abs(expected_image - image).max()
|
||||
assert max_diff < 8e-1
|
||||
|
||||
@is_torch_compile
|
||||
@require_torch_2
|
||||
def test_stable_diffusion_compile(self):
|
||||
seed = 0
|
||||
inputs = self.get_inputs(torch_device, seed=seed)
|
||||
# Can't pickle a Generator object
|
||||
del inputs["generator"]
|
||||
inputs["torch_device"] = torch_device
|
||||
inputs["seed"] = seed
|
||||
run_test_in_subprocess(test_case=self, target_func=_test_stable_diffusion_compile, inputs=inputs)
|
||||
|
||||
def test_stable_diffusion_lcm(self):
|
||||
unet = UNet2DConditionModel.from_pretrained("SimianLuo/LCM_Dreamshaper_v7", subfolder="unet")
|
||||
sd_pipe = StableDiffusionPipeline.from_pretrained("Lykon/dreamshaper-7", unet=unet).to(torch_device)
|
||||
sd_pipe.scheduler = LCMScheduler.from_config(sd_pipe.scheduler.config)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_inputs(torch_device)
|
||||
inputs["num_inference_steps"] = 6
|
||||
inputs["output_type"] = "pil"
|
||||
|
||||
image = sd_pipe(**inputs).images[0]
|
||||
|
||||
expected_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/lcm_full/stable_diffusion_lcm.png"
|
||||
)
|
||||
|
||||
image = sd_pipe.image_processor.pil_to_numpy(image)
|
||||
expected_image = sd_pipe.image_processor.pil_to_numpy(expected_image)
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(image.flatten(), expected_image.flatten())
|
||||
|
||||
assert max_diff < 1e-2
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_accelerator
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
|
||||
import gc
|
||||
import random
|
||||
import traceback
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
@@ -40,10 +41,13 @@ from diffusers.utils.testing_utils import (
|
||||
backend_reset_peak_memory_stats,
|
||||
enable_full_determinism,
|
||||
floats_tensor,
|
||||
is_torch_compile,
|
||||
load_image,
|
||||
load_numpy,
|
||||
nightly,
|
||||
require_torch_2,
|
||||
require_torch_accelerator,
|
||||
run_test_in_subprocess,
|
||||
skip_mps,
|
||||
slow,
|
||||
torch_device,
|
||||
@@ -66,6 +70,38 @@ from ..test_pipelines_common import (
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
# Will be run via run_test_in_subprocess
|
||||
def _test_img2img_compile(in_queue, out_queue, timeout):
|
||||
error = None
|
||||
try:
|
||||
inputs = in_queue.get(timeout=timeout)
|
||||
torch_device = inputs.pop("torch_device")
|
||||
seed = inputs.pop("seed")
|
||||
inputs["generator"] = torch.Generator(device=torch_device).manual_seed(seed)
|
||||
|
||||
pipe = StableDiffusionImg2ImgPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", safety_checker=None)
|
||||
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.unet.set_default_attn_processor()
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.unet.to(memory_format=torch.channels_last)
|
||||
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
|
||||
|
||||
image = pipe(**inputs).images
|
||||
image_slice = image[0, -3:, -3:, -1].flatten()
|
||||
|
||||
assert image.shape == (1, 512, 768, 3)
|
||||
expected_slice = np.array([0.0606, 0.0570, 0.0805, 0.0579, 0.0628, 0.0623, 0.0843, 0.1115, 0.0806])
|
||||
|
||||
assert np.abs(expected_slice - image_slice).max() < 1e-3
|
||||
except Exception:
|
||||
error = f"{traceback.format_exc()}"
|
||||
|
||||
results = {"error": error}
|
||||
out_queue.put(results, timeout=timeout)
|
||||
out_queue.join()
|
||||
|
||||
|
||||
class StableDiffusionImg2ImgPipelineFastTests(
|
||||
IPAdapterTesterMixin,
|
||||
PipelineLatentTesterMixin,
|
||||
@@ -618,6 +654,17 @@ class StableDiffusionImg2ImgPipelineSlowTests(unittest.TestCase):
|
||||
assert out.nsfw_content_detected[0], f"Safety checker should work for prompt: {inputs['prompt']}"
|
||||
assert np.abs(out.images[0]).sum() < 1e-5 # should be all zeros
|
||||
|
||||
@is_torch_compile
|
||||
@require_torch_2
|
||||
def test_img2img_compile(self):
|
||||
seed = 0
|
||||
inputs = self.get_inputs(torch_device, seed=seed)
|
||||
# Can't pickle a Generator object
|
||||
del inputs["generator"]
|
||||
inputs["torch_device"] = torch_device
|
||||
inputs["seed"] = seed
|
||||
run_test_in_subprocess(test_case=self, target_func=_test_img2img_compile, inputs=inputs)
|
||||
|
||||
|
||||
@nightly
|
||||
@require_torch_accelerator
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
|
||||
import gc
|
||||
import random
|
||||
import traceback
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
@@ -43,10 +44,13 @@ from diffusers.utils.testing_utils import (
|
||||
backend_reset_peak_memory_stats,
|
||||
enable_full_determinism,
|
||||
floats_tensor,
|
||||
is_torch_compile,
|
||||
load_image,
|
||||
load_numpy,
|
||||
nightly,
|
||||
require_torch_2,
|
||||
require_torch_accelerator,
|
||||
run_test_in_subprocess,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
@@ -67,6 +71,40 @@ from ..test_pipelines_common import (
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
# Will be run via run_test_in_subprocess
|
||||
def _test_inpaint_compile(in_queue, out_queue, timeout):
|
||||
error = None
|
||||
try:
|
||||
inputs = in_queue.get(timeout=timeout)
|
||||
torch_device = inputs.pop("torch_device")
|
||||
seed = inputs.pop("seed")
|
||||
inputs["generator"] = torch.Generator(device=torch_device).manual_seed(seed)
|
||||
|
||||
pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
||||
"botp/stable-diffusion-v1-5-inpainting", safety_checker=None
|
||||
)
|
||||
pipe.unet.set_default_attn_processor()
|
||||
pipe.scheduler = PNDMScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
pipe.unet.to(memory_format=torch.channels_last)
|
||||
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
|
||||
|
||||
image = pipe(**inputs).images
|
||||
image_slice = image[0, 253:256, 253:256, -1].flatten()
|
||||
|
||||
assert image.shape == (1, 512, 512, 3)
|
||||
expected_slice = np.array([0.0689, 0.0699, 0.0790, 0.0536, 0.0470, 0.0488, 0.041, 0.0508, 0.04179])
|
||||
assert np.abs(expected_slice - image_slice).max() < 3e-3
|
||||
except Exception:
|
||||
error = f"{traceback.format_exc()}"
|
||||
|
||||
results = {"error": error}
|
||||
out_queue.put(results, timeout=timeout)
|
||||
out_queue.join()
|
||||
|
||||
|
||||
class StableDiffusionInpaintPipelineFastTests(
|
||||
IPAdapterTesterMixin,
|
||||
PipelineLatentTesterMixin,
|
||||
@@ -689,6 +727,17 @@ class StableDiffusionInpaintPipelineSlowTests(unittest.TestCase):
|
||||
# make sure that less than 2.2 GB is allocated
|
||||
assert mem_bytes < 2.2 * 10**9
|
||||
|
||||
@is_torch_compile
|
||||
@require_torch_2
|
||||
def test_inpaint_compile(self):
|
||||
seed = 0
|
||||
inputs = self.get_inputs(torch_device, seed=seed)
|
||||
# Can't pickle a Generator object
|
||||
del inputs["generator"]
|
||||
inputs["torch_device"] = torch_device
|
||||
inputs["seed"] = seed
|
||||
run_test_in_subprocess(test_case=self, target_func=_test_inpaint_compile, inputs=inputs)
|
||||
|
||||
def test_stable_diffusion_inpaint_pil_input_resolution_test(self):
|
||||
pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
||||
"botp/stable-diffusion-v1-5-inpainting", safety_checker=None
|
||||
@@ -915,6 +964,11 @@ class StableDiffusionInpaintPipelineAsymmetricAutoencoderKLSlowTests(unittest.Te
|
||||
# make sure that less than 2.45 GB is allocated
|
||||
assert mem_bytes < 2.45 * 10**9
|
||||
|
||||
@is_torch_compile
|
||||
@require_torch_2
|
||||
def test_inpaint_compile(self):
|
||||
pass
|
||||
|
||||
def test_stable_diffusion_inpaint_pil_input_resolution_test(self):
|
||||
vae = AsymmetricAutoencoderKL.from_pretrained(
|
||||
"cross-attention/asymmetric-autoencoder-kl-x-1-5",
|
||||
|
||||
@@ -20,32 +20,26 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from diffusers import StableDiffusionKDiffusionPipeline
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
enable_full_determinism,
|
||||
nightly,
|
||||
require_torch_accelerator,
|
||||
torch_device,
|
||||
)
|
||||
from diffusers.utils.testing_utils import enable_full_determinism, nightly, require_torch_gpu, torch_device
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
@nightly
|
||||
@require_torch_accelerator
|
||||
@require_torch_gpu
|
||||
class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# clean up the VRAM before each test
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def test_stable_diffusion_1(self):
|
||||
sd_pipe = StableDiffusionKDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
|
||||
|
||||
@@ -28,13 +28,7 @@ from diffusers import (
|
||||
StableDiffusionLDM3DPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
enable_full_determinism,
|
||||
nightly,
|
||||
require_torch_accelerator,
|
||||
torch_device,
|
||||
)
|
||||
from diffusers.utils.testing_utils import enable_full_determinism, nightly, require_torch_gpu, torch_device
|
||||
|
||||
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
|
||||
|
||||
@@ -211,17 +205,17 @@ class StableDiffusionLDM3DPipelineFastTests(unittest.TestCase):
|
||||
|
||||
|
||||
@nightly
|
||||
@require_torch_accelerator
|
||||
@require_torch_gpu
|
||||
class StableDiffusionLDM3DPipelineSlowTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
|
||||
generator = torch.Generator(device=generator_device).manual_seed(seed)
|
||||
@@ -262,17 +256,17 @@ class StableDiffusionLDM3DPipelineSlowTests(unittest.TestCase):
|
||||
|
||||
|
||||
@nightly
|
||||
@require_torch_accelerator
|
||||
@require_torch_gpu
|
||||
class StableDiffusionPipelineNightlyTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
|
||||
generator = torch.Generator(device=generator_device).manual_seed(seed)
|
||||
|
||||
@@ -29,13 +29,7 @@ from diffusers import (
|
||||
StableDiffusionSAGPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
enable_full_determinism,
|
||||
nightly,
|
||||
require_torch_accelerator,
|
||||
torch_device,
|
||||
)
|
||||
from diffusers.utils.testing_utils import enable_full_determinism, nightly, require_torch_gpu, torch_device
|
||||
|
||||
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
|
||||
from ..test_pipelines_common import (
|
||||
@@ -168,19 +162,19 @@ class StableDiffusionSAGPipelineFastTests(
|
||||
|
||||
|
||||
@nightly
|
||||
@require_torch_accelerator
|
||||
@require_torch_gpu
|
||||
class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# clean up the VRAM before each test
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def test_stable_diffusion_1(self):
|
||||
sag_pipe = StableDiffusionSAGPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
|
||||
|
||||
@@ -13,17 +13,7 @@ from diffusers import (
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.pipelines.stable_diffusion.stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
backend_max_memory_allocated,
|
||||
backend_reset_max_memory_allocated,
|
||||
backend_reset_peak_memory_stats,
|
||||
enable_full_determinism,
|
||||
load_numpy,
|
||||
nightly,
|
||||
require_torch_accelerator,
|
||||
torch_device,
|
||||
)
|
||||
from diffusers.utils.testing_utils import enable_full_determinism, load_numpy, nightly, require_torch_gpu, torch_device
|
||||
|
||||
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
|
||||
from ..test_pipelines_common import (
|
||||
@@ -200,19 +190,19 @@ class StableUnCLIPPipelineFastTests(
|
||||
|
||||
|
||||
@nightly
|
||||
@require_torch_accelerator
|
||||
@require_torch_gpu
|
||||
class StableUnCLIPPipelineIntegrationTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# clean up the VRAM before each test
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def test_stable_unclip(self):
|
||||
expected_image = load_numpy(
|
||||
@@ -236,9 +226,9 @@ class StableUnCLIPPipelineIntegrationTests(unittest.TestCase):
|
||||
assert_mean_pixel_difference(image, expected_image)
|
||||
|
||||
def test_stable_unclip_pipeline_with_sequential_cpu_offloading(self):
|
||||
backend_empty_cache(torch_device)
|
||||
backend_reset_max_memory_allocated(torch_device)
|
||||
backend_reset_peak_memory_stats(torch_device)
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_max_memory_allocated()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
pipe = StableUnCLIPPipeline.from_pretrained("fusing/stable-unclip-2-1-l", torch_dtype=torch.float16)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
@@ -252,6 +242,6 @@ class StableUnCLIPPipelineIntegrationTests(unittest.TestCase):
|
||||
output_type="np",
|
||||
)
|
||||
|
||||
mem_bytes = backend_max_memory_allocated(torch_device)
|
||||
mem_bytes = torch.cuda.max_memory_allocated()
|
||||
# make sure that less than 7 GB is allocated
|
||||
assert mem_bytes < 7 * 10**9
|
||||
|
||||
@@ -18,16 +18,12 @@ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.pipelines.stable_diffusion.stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
backend_max_memory_allocated,
|
||||
backend_reset_max_memory_allocated,
|
||||
backend_reset_peak_memory_stats,
|
||||
enable_full_determinism,
|
||||
floats_tensor,
|
||||
load_image,
|
||||
load_numpy,
|
||||
nightly,
|
||||
require_torch_accelerator,
|
||||
require_torch_gpu,
|
||||
skip_mps,
|
||||
torch_device,
|
||||
)
|
||||
@@ -217,19 +213,19 @@ class StableUnCLIPImg2ImgPipelineFastTests(
|
||||
|
||||
|
||||
@nightly
|
||||
@require_torch_accelerator
|
||||
@require_torch_gpu
|
||||
class StableUnCLIPImg2ImgPipelineIntegrationTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# clean up the VRAM before each test
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def test_stable_unclip_l_img2img(self):
|
||||
input_image = load_image(
|
||||
@@ -290,9 +286,9 @@ class StableUnCLIPImg2ImgPipelineIntegrationTests(unittest.TestCase):
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/stable_unclip/turtle.png"
|
||||
)
|
||||
|
||||
backend_empty_cache(torch_device)
|
||||
backend_reset_max_memory_allocated(torch_device)
|
||||
backend_reset_peak_memory_stats(torch_device)
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_max_memory_allocated()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
pipe = StableUnCLIPImg2ImgPipeline.from_pretrained(
|
||||
"fusing/stable-unclip-2-1-h-img2img", torch_dtype=torch.float16
|
||||
@@ -308,6 +304,6 @@ class StableUnCLIPImg2ImgPipelineIntegrationTests(unittest.TestCase):
|
||||
output_type="np",
|
||||
)
|
||||
|
||||
mem_bytes = backend_max_memory_allocated(torch_device)
|
||||
mem_bytes = torch.cuda.max_memory_allocated()
|
||||
# make sure that less than 7 GB is allocated
|
||||
assert mem_bytes < 7 * 10**9
|
||||
|
||||
@@ -87,24 +87,21 @@ class IsSafetensorsCompatibleTests(unittest.TestCase):
|
||||
"unet/diffusion_pytorch_model.fp16.bin",
|
||||
"unet/diffusion_pytorch_model.fp16.safetensors",
|
||||
]
|
||||
self.assertFalse(is_safetensors_compatible(filenames))
|
||||
self.assertTrue(is_safetensors_compatible(filenames, variant="fp16"))
|
||||
self.assertTrue(is_safetensors_compatible(filenames))
|
||||
|
||||
def test_diffusers_model_is_compatible_variant(self):
|
||||
filenames = [
|
||||
"unet/diffusion_pytorch_model.fp16.bin",
|
||||
"unet/diffusion_pytorch_model.fp16.safetensors",
|
||||
]
|
||||
self.assertFalse(is_safetensors_compatible(filenames))
|
||||
self.assertTrue(is_safetensors_compatible(filenames, variant="fp16"))
|
||||
self.assertTrue(is_safetensors_compatible(filenames))
|
||||
|
||||
def test_diffusers_model_is_compatible_variant_mixed(self):
|
||||
filenames = [
|
||||
"unet/diffusion_pytorch_model.bin",
|
||||
"unet/diffusion_pytorch_model.fp16.safetensors",
|
||||
]
|
||||
self.assertFalse(is_safetensors_compatible(filenames))
|
||||
self.assertTrue(is_safetensors_compatible(filenames, variant="fp16"))
|
||||
self.assertTrue(is_safetensors_compatible(filenames))
|
||||
|
||||
def test_diffusers_model_is_not_compatible_variant(self):
|
||||
filenames = [
|
||||
@@ -124,8 +121,7 @@ class IsSafetensorsCompatibleTests(unittest.TestCase):
|
||||
"text_encoder/pytorch_model.fp16.bin",
|
||||
"text_encoder/model.fp16.safetensors",
|
||||
]
|
||||
self.assertFalse(is_safetensors_compatible(filenames))
|
||||
self.assertTrue(is_safetensors_compatible(filenames, variant="fp16"))
|
||||
self.assertTrue(is_safetensors_compatible(filenames))
|
||||
|
||||
def test_transformer_model_is_not_compatible_variant(self):
|
||||
filenames = [
|
||||
@@ -149,8 +145,7 @@ class IsSafetensorsCompatibleTests(unittest.TestCase):
|
||||
"unet/diffusion_pytorch_model.fp16.bin",
|
||||
"unet/diffusion_pytorch_model.fp16.safetensors",
|
||||
]
|
||||
self.assertFalse(is_safetensors_compatible(filenames, folder_names={"vae", "unet"}))
|
||||
self.assertTrue(is_safetensors_compatible(filenames, folder_names={"vae", "unet"}, variant="fp16"))
|
||||
self.assertTrue(is_safetensors_compatible(filenames, folder_names={"vae", "unet"}))
|
||||
|
||||
def test_transformer_model_is_not_compatible_variant_extra_folder(self):
|
||||
filenames = [
|
||||
@@ -178,8 +173,7 @@ class IsSafetensorsCompatibleTests(unittest.TestCase):
|
||||
"text_encoder/model.fp16-00001-of-00002.safetensors",
|
||||
"text_encoder/model.fp16-00001-of-00002.safetensors",
|
||||
]
|
||||
self.assertFalse(is_safetensors_compatible(filenames))
|
||||
self.assertTrue(is_safetensors_compatible(filenames, variant="fp16"))
|
||||
self.assertTrue(is_safetensors_compatible(filenames))
|
||||
|
||||
def test_diffusers_is_compatible_sharded(self):
|
||||
filenames = [
|
||||
@@ -195,15 +189,13 @@ class IsSafetensorsCompatibleTests(unittest.TestCase):
|
||||
"unet/diffusion_pytorch_model.fp16-00001-of-00002.safetensors",
|
||||
"unet/diffusion_pytorch_model.fp16-00001-of-00002.safetensors",
|
||||
]
|
||||
self.assertFalse(is_safetensors_compatible(filenames))
|
||||
self.assertTrue(is_safetensors_compatible(filenames, variant="fp16"))
|
||||
self.assertTrue(is_safetensors_compatible(filenames))
|
||||
|
||||
def test_diffusers_is_compatible_only_variants(self):
|
||||
filenames = [
|
||||
"unet/diffusion_pytorch_model.fp16.safetensors",
|
||||
]
|
||||
self.assertFalse(is_safetensors_compatible(filenames))
|
||||
self.assertTrue(is_safetensors_compatible(filenames, variant="fp16"))
|
||||
self.assertTrue(is_safetensors_compatible(filenames))
|
||||
|
||||
def test_diffusers_is_compatible_no_components(self):
|
||||
filenames = [
|
||||
@@ -217,20 +209,6 @@ class IsSafetensorsCompatibleTests(unittest.TestCase):
|
||||
]
|
||||
self.assertFalse(is_safetensors_compatible(filenames))
|
||||
|
||||
def test_is_compatible_mixed_variants(self):
|
||||
filenames = [
|
||||
"unet/diffusion_pytorch_model.fp16.safetensors",
|
||||
"vae/diffusion_pytorch_model.safetensors",
|
||||
]
|
||||
self.assertTrue(is_safetensors_compatible(filenames, variant="fp16"))
|
||||
|
||||
def test_is_compatible_variant_and_non_safetensors(self):
|
||||
filenames = [
|
||||
"unet/diffusion_pytorch_model.fp16.safetensors",
|
||||
"vae/diffusion_pytorch_model.bin",
|
||||
]
|
||||
self.assertFalse(is_safetensors_compatible(filenames, variant="fp16"))
|
||||
|
||||
|
||||
class VariantCompatibleSiblingsTest(unittest.TestCase):
|
||||
def test_only_non_variants_downloaded(self):
|
||||
|
||||
@@ -588,17 +588,20 @@ class DownloadTests(unittest.TestCase):
|
||||
logger = logging.get_logger("diffusers.pipelines.pipeline_utils")
|
||||
deprecated_warning_msg = "Warning: The repository contains sharded checkpoints for variant"
|
||||
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
local_repo_id = snapshot_download(repo_id, cache_dir=tmpdirname)
|
||||
for is_local in [True, False]:
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
local_repo_id = repo_id
|
||||
if is_local:
|
||||
local_repo_id = snapshot_download(repo_id, cache_dir=tmpdirname)
|
||||
|
||||
_ = DiffusionPipeline.from_pretrained(
|
||||
local_repo_id,
|
||||
safety_checker=None,
|
||||
variant="fp16",
|
||||
use_safetensors=True,
|
||||
)
|
||||
assert deprecated_warning_msg in str(cap_logger), "Deprecation warning not found in logs"
|
||||
_ = DiffusionPipeline.from_pretrained(
|
||||
local_repo_id,
|
||||
safety_checker=None,
|
||||
variant="fp16",
|
||||
use_safetensors=True,
|
||||
)
|
||||
assert deprecated_warning_msg in str(cap_logger), "Deprecation warning not found in logs"
|
||||
|
||||
def test_download_safetensors_only_variant_exists_for_model(self):
|
||||
variant = None
|
||||
@@ -613,7 +616,7 @@ class DownloadTests(unittest.TestCase):
|
||||
variant=variant,
|
||||
use_safetensors=use_safetensors,
|
||||
)
|
||||
assert "Could not find the necessary `safetensors` weights" in str(error_context.exception)
|
||||
assert "Error no file name" in str(error_context.exception)
|
||||
|
||||
# text encoder has fp16 variants so we can load it
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
@@ -672,7 +675,7 @@ class DownloadTests(unittest.TestCase):
|
||||
use_safetensors=use_safetensors,
|
||||
)
|
||||
|
||||
assert "Could not find the necessary `safetensors` weights" in str(error_context.exception)
|
||||
assert "Error no file name" in str(error_context.exception)
|
||||
|
||||
def test_download_bin_variant_does_not_exist_for_model(self):
|
||||
variant = "no_ema"
|
||||
@@ -1994,9 +1997,7 @@ class PipelineSlowTests(unittest.TestCase):
|
||||
reason="Torch Dynamo isn't yet supported for Python 3.12.",
|
||||
)
|
||||
def test_from_save_pretrained_dynamo(self):
|
||||
torch.compiler.rest()
|
||||
with torch._inductor.utils.fresh_inductor_cache():
|
||||
run_test_in_subprocess(test_case=self, target_func=_test_from_save_pretrained_dynamo, inputs=None)
|
||||
run_test_in_subprocess(test_case=self, target_func=_test_from_save_pretrained_dynamo, inputs=None)
|
||||
|
||||
def test_from_pretrained_hub(self):
|
||||
model_path = "google/ddpm-cifar10-32"
|
||||
@@ -2208,7 +2209,7 @@ class TestLoraHotSwappingForPipeline(unittest.TestCase):
|
||||
# It is critical that the dynamo cache is reset for each test. Otherwise, if the test re-uses the same model,
|
||||
# there will be recompilation errors, as torch caches the model when run in the same process.
|
||||
super().tearDown()
|
||||
torch.compiler.reset()
|
||||
torch._dynamo.reset()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
@@ -2333,21 +2334,21 @@ class TestLoraHotSwappingForPipeline(unittest.TestCase):
|
||||
def test_hotswapping_compiled_pipline_linear(self, rank0, rank1):
|
||||
# It's important to add this context to raise an error on recompilation
|
||||
target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
|
||||
with torch._dynamo.config.patch(error_on_recompile=True), torch._inductor.utils.fresh_inductor_cache():
|
||||
with torch._dynamo.config.patch(error_on_recompile=True):
|
||||
self.check_pipeline_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules)
|
||||
|
||||
@parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa
|
||||
def test_hotswapping_compiled_pipline_conv2d(self, rank0, rank1):
|
||||
# It's important to add this context to raise an error on recompilation
|
||||
target_modules = ["conv", "conv1", "conv2"]
|
||||
with torch._dynamo.config.patch(error_on_recompile=True), torch._inductor.utils.fresh_inductor_cache():
|
||||
with torch._dynamo.config.patch(error_on_recompile=True):
|
||||
self.check_pipeline_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules)
|
||||
|
||||
@parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa
|
||||
def test_hotswapping_compiled_pipline_both_linear_and_conv2d(self, rank0, rank1):
|
||||
# It's important to add this context to raise an error on recompilation
|
||||
target_modules = ["to_q", "conv"]
|
||||
with torch._dynamo.config.patch(error_on_recompile=True), torch._inductor.utils.fresh_inductor_cache():
|
||||
with torch._dynamo.config.patch(error_on_recompile=True):
|
||||
self.check_pipeline_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules)
|
||||
|
||||
def test_enable_lora_hotswap_called_after_adapter_added_raises(self):
|
||||
|
||||
@@ -33,6 +33,7 @@ from diffusers import (
|
||||
)
|
||||
from diffusers.hooks import apply_group_offloading
|
||||
from diffusers.hooks.faster_cache import FasterCacheBlockHook, FasterCacheDenoiserHook
|
||||
from diffusers.hooks.first_block_cache import FirstBlockCacheConfig
|
||||
from diffusers.hooks.pyramid_attention_broadcast import PyramidAttentionBroadcastHook
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.loaders import FluxIPAdapterMixin, IPAdapterMixin
|
||||
@@ -1111,14 +1112,14 @@ class PipelineTesterMixin:
|
||||
def setUp(self):
|
||||
# clean up the VRAM before each test
|
||||
super().setUp()
|
||||
torch.compiler.reset()
|
||||
torch._dynamo.reset()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test in case of CUDA runtime errors
|
||||
super().tearDown()
|
||||
torch.compiler.reset()
|
||||
torch._dynamo.reset()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
@@ -2632,7 +2633,7 @@ class FasterCacheTesterMixin:
|
||||
self.faster_cache_config.current_timestep_callback = lambda: pipe.current_timestep
|
||||
pipe = create_pipe()
|
||||
pipe.transformer.enable_cache(self.faster_cache_config)
|
||||
output = run_forward(pipe).flatten().flatten()
|
||||
output = run_forward(pipe).flatten()
|
||||
image_slice_faster_cache_enabled = np.concatenate((output[:8], output[-8:]))
|
||||
|
||||
# Run inference with FasterCache disabled
|
||||
@@ -2739,6 +2740,55 @@ class FasterCacheTesterMixin:
|
||||
self.assertTrue(state.cache is None, "Cache should be reset to None.")
|
||||
|
||||
|
||||
# TODO(aryan, dhruv): the cache tester mixins should probably be rewritten so that more models can be tested out
|
||||
# of the box once there is better cache support/implementation
|
||||
class FirstBlockCacheTesterMixin:
|
||||
# threshold is intentionally set higher than usual values since we're testing with random unconverged models
|
||||
# that will not satisfy the expected properties of the denoiser for caching to be effective
|
||||
first_block_cache_config = FirstBlockCacheConfig(threshold=0.8)
|
||||
|
||||
def test_first_block_cache_inference(self, expected_atol: float = 0.1):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
|
||||
def create_pipe():
|
||||
torch.manual_seed(0)
|
||||
num_layers = 2
|
||||
components = self.get_dummy_components(num_layers=num_layers)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
return pipe
|
||||
|
||||
def run_forward(pipe):
|
||||
torch.manual_seed(0)
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
inputs["num_inference_steps"] = 4
|
||||
return pipe(**inputs)[0]
|
||||
|
||||
# Run inference without FirstBlockCache
|
||||
pipe = create_pipe()
|
||||
output = run_forward(pipe).flatten()
|
||||
original_image_slice = np.concatenate((output[:8], output[-8:]))
|
||||
|
||||
# Run inference with FirstBlockCache enabled
|
||||
pipe = create_pipe()
|
||||
pipe.transformer.enable_cache(self.first_block_cache_config)
|
||||
output = run_forward(pipe).flatten()
|
||||
image_slice_fbc_enabled = np.concatenate((output[:8], output[-8:]))
|
||||
|
||||
# Run inference with FirstBlockCache disabled
|
||||
pipe.transformer.disable_cache()
|
||||
output = run_forward(pipe).flatten()
|
||||
image_slice_fbc_disabled = np.concatenate((output[:8], output[-8:]))
|
||||
|
||||
assert np.allclose(original_image_slice, image_slice_fbc_enabled, atol=expected_atol), (
|
||||
"FirstBlockCache outputs should not differ much."
|
||||
)
|
||||
assert np.allclose(original_image_slice, image_slice_fbc_disabled, atol=1e-4), (
|
||||
"Outputs from normal inference and after disabling cache should not differ."
|
||||
)
|
||||
|
||||
|
||||
# Some models (e.g. unCLIP) are extremely likely to significantly deviate depending on which hardware is used.
|
||||
# This helper function is used to check that the image doesn't deviate on average more than 10 pixels from a
|
||||
# reference image.
|
||||
|
||||
@@ -19,44 +19,37 @@ import unittest
|
||||
import torch
|
||||
|
||||
from diffusers import DDIMScheduler, TextToVideoZeroPipeline
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
load_pt,
|
||||
nightly,
|
||||
require_torch_accelerator,
|
||||
torch_device,
|
||||
)
|
||||
from diffusers.utils.testing_utils import load_pt, nightly, require_torch_gpu
|
||||
|
||||
from ..test_pipelines_common import assert_mean_pixel_difference
|
||||
|
||||
|
||||
@nightly
|
||||
@require_torch_accelerator
|
||||
@require_torch_gpu
|
||||
class TextToVideoZeroPipelineSlowTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# clean up the VRAM before each test
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def test_full_model(self):
|
||||
model_id = "stable-diffusion-v1-5/stable-diffusion-v1-5"
|
||||
pipe = TextToVideoZeroPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to(torch_device)
|
||||
pipe = TextToVideoZeroPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
|
||||
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
generator = torch.Generator(device="cuda").manual_seed(0)
|
||||
|
||||
prompt = "A bear is playing a guitar on Times Square"
|
||||
result = pipe(prompt=prompt, generator=generator).images
|
||||
|
||||
expected_result = load_pt(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/text-to-video/A bear is playing a guitar on Times Square.pt",
|
||||
weights_only=False,
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/text-to-video/A bear is playing a guitar on Times Square.pt"
|
||||
)
|
||||
|
||||
assert_mean_pixel_difference(result, expected_result)
|
||||
|
||||
@@ -24,11 +24,11 @@ from transformers import CLIPTextConfig, CLIPTextModel, CLIPTextModelWithProject
|
||||
|
||||
from diffusers import AutoencoderKL, DDIMScheduler, TextToVideoZeroSDXLPipeline, UNet2DConditionModel
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
enable_full_determinism,
|
||||
nightly,
|
||||
require_accelerate_version_greater,
|
||||
require_torch_accelerator,
|
||||
require_accelerator,
|
||||
require_torch_gpu,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
@@ -220,7 +220,7 @@ class TextToVideoZeroSDXLPipelineFastTests(PipelineTesterMixin, PipelineFromPipe
|
||||
self.assertLess(max_diff, expected_max_difference)
|
||||
|
||||
@unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU")
|
||||
@require_torch_accelerator
|
||||
@require_accelerator
|
||||
def test_float16_inference(self, expected_max_diff=5e-2):
|
||||
components = self.get_dummy_components()
|
||||
for name, module in components.items():
|
||||
@@ -262,7 +262,7 @@ class TextToVideoZeroSDXLPipelineFastTests(PipelineTesterMixin, PipelineFromPipe
|
||||
def test_inference_batch_single_identical(self):
|
||||
pass
|
||||
|
||||
@require_torch_accelerator
|
||||
@require_accelerator
|
||||
@require_accelerate_version_greater("0.17.0")
|
||||
def test_model_cpu_offload_forward_pass(self, expected_max_diff=2e-4):
|
||||
components = self.get_dummy_components()
|
||||
@@ -285,7 +285,7 @@ class TextToVideoZeroSDXLPipelineFastTests(PipelineTesterMixin, PipelineFromPipe
|
||||
pass
|
||||
|
||||
@unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU")
|
||||
@require_torch_accelerator
|
||||
@require_accelerator
|
||||
def test_save_load_float16(self, expected_max_diff=1e-2):
|
||||
components = self.get_dummy_components()
|
||||
for name, module in components.items():
|
||||
@@ -337,7 +337,7 @@ class TextToVideoZeroSDXLPipelineFastTests(PipelineTesterMixin, PipelineFromPipe
|
||||
def test_sequential_cpu_offload_forward_pass(self):
|
||||
pass
|
||||
|
||||
@require_torch_accelerator
|
||||
@require_accelerator
|
||||
def test_to_device(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
@@ -365,19 +365,19 @@ class TextToVideoZeroSDXLPipelineFastTests(PipelineTesterMixin, PipelineFromPipe
|
||||
|
||||
|
||||
@nightly
|
||||
@require_torch_accelerator
|
||||
@require_torch_gpu
|
||||
class TextToVideoZeroSDXLPipelineSlowTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# clean up the VRAM before each test
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def test_full_model(self):
|
||||
model_id = "stabilityai/stable-diffusion-xl-base-1.0"
|
||||
|
||||
@@ -23,14 +23,10 @@ from transformers import CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokeni
|
||||
from diffusers import PriorTransformer, UnCLIPPipeline, UnCLIPScheduler, UNet2DConditionModel, UNet2DModel
|
||||
from diffusers.pipelines.unclip.text_proj import UnCLIPTextProjModel
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
backend_max_memory_allocated,
|
||||
backend_reset_max_memory_allocated,
|
||||
backend_reset_peak_memory_stats,
|
||||
enable_full_determinism,
|
||||
load_numpy,
|
||||
nightly,
|
||||
require_torch_accelerator,
|
||||
require_torch_gpu,
|
||||
skip_mps,
|
||||
torch_device,
|
||||
)
|
||||
@@ -430,13 +426,13 @@ class UnCLIPPipelineCPUIntegrationTests(unittest.TestCase):
|
||||
# clean up the VRAM before each test
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def test_unclip_karlo_cpu_fp32(self):
|
||||
expected_image = load_numpy(
|
||||
@@ -462,19 +458,19 @@ class UnCLIPPipelineCPUIntegrationTests(unittest.TestCase):
|
||||
|
||||
|
||||
@nightly
|
||||
@require_torch_accelerator
|
||||
@require_torch_gpu
|
||||
class UnCLIPPipelineIntegrationTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# clean up the VRAM before each test
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def test_unclip_karlo(self):
|
||||
expected_image = load_numpy(
|
||||
@@ -500,9 +496,9 @@ class UnCLIPPipelineIntegrationTests(unittest.TestCase):
|
||||
assert_mean_pixel_difference(image, expected_image)
|
||||
|
||||
def test_unclip_pipeline_with_sequential_cpu_offloading(self):
|
||||
backend_empty_cache(torch_device)
|
||||
backend_reset_max_memory_allocated(torch_device)
|
||||
backend_reset_peak_memory_stats(torch_device)
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_max_memory_allocated()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
pipe = UnCLIPPipeline.from_pretrained("kakaobrain/karlo-v1-alpha", torch_dtype=torch.float16)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
@@ -518,6 +514,6 @@ class UnCLIPPipelineIntegrationTests(unittest.TestCase):
|
||||
output_type="np",
|
||||
)
|
||||
|
||||
mem_bytes = backend_max_memory_allocated(torch_device)
|
||||
mem_bytes = torch.cuda.max_memory_allocated()
|
||||
# make sure that less than 7 GB is allocated
|
||||
assert mem_bytes < 7 * 10**9
|
||||
|
||||
@@ -37,13 +37,12 @@ from diffusers import (
|
||||
)
|
||||
from diffusers.pipelines.unclip.text_proj import UnCLIPTextProjModel
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
enable_full_determinism,
|
||||
floats_tensor,
|
||||
load_image,
|
||||
load_numpy,
|
||||
nightly,
|
||||
require_torch_accelerator,
|
||||
require_torch_gpu,
|
||||
skip_mps,
|
||||
torch_device,
|
||||
)
|
||||
@@ -497,19 +496,19 @@ class UnCLIPImageVariationPipelineFastTests(PipelineTesterMixin, unittest.TestCa
|
||||
|
||||
|
||||
@nightly
|
||||
@require_torch_accelerator
|
||||
@require_torch_gpu
|
||||
class UnCLIPImageVariationPipelineIntegrationTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# clean up the VRAM before each test
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def test_unclip_image_variation_karlo(self):
|
||||
input_image = load_image(
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import gc
|
||||
import random
|
||||
import traceback
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
@@ -26,7 +27,9 @@ from diffusers.utils.testing_utils import (
|
||||
floats_tensor,
|
||||
load_image,
|
||||
nightly,
|
||||
require_torch_2,
|
||||
require_torch_accelerator,
|
||||
run_test_in_subprocess,
|
||||
torch_device,
|
||||
)
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
@@ -42,6 +45,38 @@ from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, Pipeline
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
# Will be run via run_test_in_subprocess
|
||||
def _test_unidiffuser_compile(in_queue, out_queue, timeout):
|
||||
error = None
|
||||
try:
|
||||
inputs = in_queue.get(timeout=timeout)
|
||||
torch_device = inputs.pop("torch_device")
|
||||
seed = inputs.pop("seed")
|
||||
inputs["generator"] = torch.Generator(device=torch_device).manual_seed(seed)
|
||||
|
||||
pipe = UniDiffuserPipeline.from_pretrained("thu-ml/unidiffuser-v1")
|
||||
# pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
|
||||
pipe = pipe.to(torch_device)
|
||||
|
||||
pipe.unet.to(memory_format=torch.channels_last)
|
||||
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
|
||||
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
image = pipe(**inputs).images
|
||||
image_slice = image[0, -3:, -3:, -1].flatten()
|
||||
|
||||
assert image.shape == (1, 512, 512, 3)
|
||||
expected_slice = np.array([0.2402, 0.2375, 0.2285, 0.2378, 0.2407, 0.2263, 0.2354, 0.2307, 0.2520])
|
||||
assert np.abs(image_slice - expected_slice).max() < 1e-1
|
||||
except Exception:
|
||||
error = f"{traceback.format_exc()}"
|
||||
|
||||
results = {"error": error}
|
||||
out_queue.put(results, timeout=timeout)
|
||||
out_queue.join()
|
||||
|
||||
|
||||
class UniDiffuserPipelineFastTests(
|
||||
PipelineTesterMixin, PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, unittest.TestCase
|
||||
):
|
||||
@@ -655,6 +690,19 @@ class UniDiffuserPipelineSlowTests(unittest.TestCase):
|
||||
expected_text_prefix = "An astronaut"
|
||||
assert text[0][: len(expected_text_prefix)] == expected_text_prefix
|
||||
|
||||
@unittest.skip(reason="Skip torch.compile test to speed up the slow test suite.")
|
||||
@require_torch_2
|
||||
def test_unidiffuser_compile(self, seed=0):
|
||||
inputs = self.get_inputs(torch_device, seed=seed, generate_latents=True)
|
||||
# Delete prompt and image for joint inference.
|
||||
del inputs["prompt"]
|
||||
del inputs["image"]
|
||||
# Can't pickle a Generator object
|
||||
del inputs["generator"]
|
||||
inputs["torch_device"] = torch_device
|
||||
inputs["seed"] = seed
|
||||
run_test_in_subprocess(test_case=self, target_func=_test_unidiffuser_compile, inputs=inputs)
|
||||
|
||||
|
||||
@nightly
|
||||
@require_torch_accelerator
|
||||
|
||||
Reference in New Issue
Block a user