Compare commits
7 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| e30dc5176c | |||
| ca4a16f94c | |||
| 30a72ee140 | |||
| 58dfcf9e92 | |||
| d03240801f | |||
| e62804ffbd | |||
| bb1d9a8b75 |
@@ -316,6 +316,67 @@ if integrity_checker.test_image(image_):
|
||||
raise ValueError("Your image has been flagged. Choose another prompt/image or try again.")
|
||||
```
|
||||
|
||||
### Kontext Inpainting
|
||||
`FluxKontextInpaintPipeline` enables image modification within a fixed mask region. It currently supports both text-based conditioning and image-reference conditioning.
|
||||
<hfoptions id="kontext-inpaint">
|
||||
<hfoption id="text-only">
|
||||
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import FluxKontextInpaintPipeline
|
||||
from diffusers.utils import load_image
|
||||
|
||||
prompt = "Change the yellow dinosaur to green one"
|
||||
img_url = (
|
||||
"https://github.com/ZenAI-Vietnam/Flux-Kontext-pipelines/blob/main/assets/dinosaur_input.jpeg?raw=true"
|
||||
)
|
||||
mask_url = (
|
||||
"https://github.com/ZenAI-Vietnam/Flux-Kontext-pipelines/blob/main/assets/dinosaur_mask.png?raw=true"
|
||||
)
|
||||
|
||||
source = load_image(img_url)
|
||||
mask = load_image(mask_url)
|
||||
|
||||
pipe = FluxKontextInpaintPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16
|
||||
)
|
||||
pipe.to("cuda")
|
||||
|
||||
image = pipe(prompt=prompt, image=source, mask_image=mask, strength=1.0).images[0]
|
||||
image.save("kontext_inpainting_normal.png")
|
||||
```
|
||||
</hfoption>
|
||||
<hfoption id="image conditioning">
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import FluxKontextInpaintPipeline
|
||||
from diffusers.utils import load_image
|
||||
|
||||
pipe = FluxKontextInpaintPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16
|
||||
)
|
||||
pipe.to("cuda")
|
||||
|
||||
prompt = "Replace this ball"
|
||||
img_url = "https://images.pexels.com/photos/39362/the-ball-stadion-football-the-pitch-39362.jpeg?auto=compress&cs=tinysrgb&dpr=1&w=500"
|
||||
mask_url = "https://github.com/ZenAI-Vietnam/Flux-Kontext-pipelines/blob/main/assets/ball_mask.png?raw=true"
|
||||
image_reference_url = "https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcTah3x6OL_ECMBaZ5ZlJJhNsyC-OSMLWAI-xw&s"
|
||||
|
||||
source = load_image(img_url)
|
||||
mask = load_image(mask_url)
|
||||
image_reference = load_image(image_reference_url)
|
||||
|
||||
mask = pipe.mask_processor.blur(mask, blur_factor=12)
|
||||
image = pipe(
|
||||
prompt=prompt, image=source, mask_image=mask, image_reference=image_reference, strength=1.0
|
||||
).images[0]
|
||||
image.save("kontext_inpainting_ref.png")
|
||||
```
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
## Combining Flux Turbo LoRAs with Flux Control, Fill, and Redux
|
||||
|
||||
We can combine Flux Turbo LoRAs with Flux Control and other pipelines like Fill and Redux to enable few-steps' inference. The example below shows how to do that for Flux Control LoRA for depth and turbo LoRA from [`ByteDance/Hyper-SD`](https://hf.co/ByteDance/Hyper-SD).
|
||||
@@ -646,3 +707,15 @@ image.save("flux-fp8-dev.png")
|
||||
[[autodoc]] FluxFillPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## FluxKontextPipeline
|
||||
|
||||
[[autodoc]] FluxKontextPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## FluxKontextInpaintPipeline
|
||||
|
||||
[[autodoc]] FluxKontextInpaintPipeline
|
||||
- all
|
||||
- __call__
|
||||
@@ -51,10 +51,10 @@ t2i_pipeline = t2i_blocks.init_pipeline(modular_repo_id, components_manager=comp
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
Components are only loaded and registered when using [`~ModularPipeline.load_components`] or [`~ModularPipeline.load_default_components`]. The example below uses [`~ModularPipeline.load_default_components`] to create a second pipeline that reuses all the components from the first one, and assigns it to a different collection
|
||||
Components are only loaded and registered when using [`~ModularPipeline.load_components`] or [`~ModularPipeline.load_components`]. The example below uses [`~ModularPipeline.load_components`] to create a second pipeline that reuses all the components from the first one, and assigns it to a different collection
|
||||
|
||||
```py
|
||||
pipe.load_default_components()
|
||||
pipe.load_components()
|
||||
pipe2 = ModularPipeline.from_pretrained("YiYiXu/modular-demo-auto", components_manager=comp, collection="test2")
|
||||
```
|
||||
|
||||
@@ -187,4 +187,4 @@ comp.enable_auto_cpu_offload(device="cuda")
|
||||
|
||||
All models begin on the CPU and [`ComponentsManager`] moves them to the appropriate device right before they're needed, and moves other models back to the CPU when GPU memory is low.
|
||||
|
||||
You can set your own rules for which models to offload first.
|
||||
You can set your own rules for which models to offload first.
|
||||
|
||||
@@ -75,13 +75,13 @@ Guiders that are already saved on the Hub with a `modular_model_index.json` file
|
||||
}
|
||||
```
|
||||
|
||||
The guider is only created after calling [`~ModularPipeline.load_default_components`] based on the loading specification in `modular_model_index.json`.
|
||||
The guider is only created after calling [`~ModularPipeline.load_components`] based on the loading specification in `modular_model_index.json`.
|
||||
|
||||
```py
|
||||
t2i_pipeline = t2i_blocks.init_pipeline("YiYiXu/modular-doc-guider")
|
||||
# not created during init
|
||||
assert t2i_pipeline.guider is None
|
||||
t2i_pipeline.load_default_components()
|
||||
t2i_pipeline.load_components()
|
||||
# loaded as PAG guider
|
||||
t2i_pipeline.guider
|
||||
```
|
||||
@@ -172,4 +172,4 @@ t2i_pipeline.push_to_hub("YiYiXu/modular-doc-guider")
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
</hfoptions>
|
||||
|
||||
@@ -29,7 +29,7 @@ blocks = SequentialPipelineBlocks.from_blocks_dict(TEXT2IMAGE_BLOCKS)
|
||||
modular_repo_id = "YiYiXu/modular-loader-t2i-0704"
|
||||
pipeline = blocks.init_pipeline(modular_repo_id)
|
||||
|
||||
pipeline.load_default_components(torch_dtype=torch.float16)
|
||||
pipeline.load_components(torch_dtype=torch.float16)
|
||||
pipeline.to("cuda")
|
||||
|
||||
image = pipeline(prompt="Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", output="images")[0]
|
||||
@@ -49,7 +49,7 @@ blocks = SequentialPipelineBlocks.from_blocks_dict(IMAGE2IMAGE_BLOCKS)
|
||||
modular_repo_id = "YiYiXu/modular-loader-t2i-0704"
|
||||
pipeline = blocks.init_pipeline(modular_repo_id)
|
||||
|
||||
pipeline.load_default_components(torch_dtype=torch.float16)
|
||||
pipeline.load_components(torch_dtype=torch.float16)
|
||||
pipeline.to("cuda")
|
||||
|
||||
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl-text2img.png"
|
||||
@@ -73,7 +73,7 @@ blocks = SequentialPipelineBlocks.from_blocks_dict(INPAINT_BLOCKS)
|
||||
modular_repo_id = "YiYiXu/modular-loader-t2i-0704"
|
||||
pipeline = blocks.init_pipeline(modular_repo_id)
|
||||
|
||||
pipeline.load_default_components(torch_dtype=torch.float16)
|
||||
pipeline.load_components(torch_dtype=torch.float16)
|
||||
pipeline.to("cuda")
|
||||
|
||||
img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl-text2img.png"
|
||||
@@ -176,15 +176,15 @@ diffdiff_pipeline = ModularPipeline.from_pretrained(modular_repo_id, trust_remot
|
||||
|
||||
## Loading components
|
||||
|
||||
A [`ModularPipeline`] doesn't automatically instantiate with components. It only loads the configuration and component specifications. You can load all components with [`~ModularPipeline.load_default_components`] or only load specific components with [`~ModularPipeline.load_components`].
|
||||
A [`ModularPipeline`] doesn't automatically instantiate with components. It only loads the configuration and component specifications. You can load all components with [`~ModularPipeline.load_components`] or only load specific components with [`~ModularPipeline.load_components`].
|
||||
|
||||
<hfoptions id="load">
|
||||
<hfoption id="load_default_components">
|
||||
<hfoption id="load_components">
|
||||
|
||||
```py
|
||||
import torch
|
||||
|
||||
t2i_pipeline.load_default_components(torch_dtype=torch.float16)
|
||||
t2i_pipeline.load_components(torch_dtype=torch.float16)
|
||||
t2i_pipeline.to("cuda")
|
||||
```
|
||||
|
||||
@@ -355,4 +355,4 @@ The [config.json](https://huggingface.co/YiYiXu/modular-diffdiff-0704/blob/main/
|
||||
"ModularPipelineBlocks": "block.DiffDiffBlocks"
|
||||
}
|
||||
}
|
||||
```
|
||||
```
|
||||
|
||||
@@ -173,9 +173,9 @@ print(dd_blocks)
|
||||
|
||||
## ModularPipeline
|
||||
|
||||
Convert the [`SequentialPipelineBlocks`] into a [`ModularPipeline`] with the [`ModularPipeline.init_pipeline`] method. This initializes the expected components to load from a `modular_model_index.json` file. Explicitly load the components by calling [`ModularPipeline.load_default_components`].
|
||||
Convert the [`SequentialPipelineBlocks`] into a [`ModularPipeline`] with the [`ModularPipeline.init_pipeline`] method. This initializes the expected components to load from a `modular_model_index.json` file. Explicitly load the components by calling [`ModularPipeline.load_components`].
|
||||
|
||||
It is a good idea to initialize the [`ComponentManager`] with the pipeline to help manage the different components. Once you call [`~ModularPipeline.load_default_components`], the components are registered to the [`ComponentManager`] and can be shared between workflows. The example below uses the `collection` argument to assign the components a `"diffdiff"` label for better organization.
|
||||
It is a good idea to initialize the [`ComponentManager`] with the pipeline to help manage the different components. Once you call [`~ModularPipeline.load_components`], the components are registered to the [`ComponentManager`] and can be shared between workflows. The example below uses the `collection` argument to assign the components a `"diffdiff"` label for better organization.
|
||||
|
||||
```py
|
||||
from diffusers.modular_pipelines import ComponentsManager
|
||||
@@ -209,11 +209,11 @@ Use the [`sub_blocks.insert`] method to insert it into the [`ModularPipeline`].
|
||||
dd_blocks.sub_blocks.insert("ip_adapter", ip_adapter_block, 0)
|
||||
```
|
||||
|
||||
Call [`~ModularPipeline.init_pipeline`] to initialize a [`ModularPipeline`] and use [`~ModularPipeline.load_default_components`] to load the model components. Load and set the IP-Adapter to run the pipeline.
|
||||
Call [`~ModularPipeline.init_pipeline`] to initialize a [`ModularPipeline`] and use [`~ModularPipeline.load_components`] to load the model components. Load and set the IP-Adapter to run the pipeline.
|
||||
|
||||
```py
|
||||
dd_pipeline = dd_blocks.init_pipeline("YiYiXu/modular-demo-auto", collection="diffdiff")
|
||||
dd_pipeline.load_default_components(torch_dtype=torch.float16)
|
||||
dd_pipeline.load_components(torch_dtype=torch.float16)
|
||||
dd_pipeline.loader.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin")
|
||||
dd_pipeline.loader.set_ip_adapter_scale(0.6)
|
||||
dd_pipeline = dd_pipeline.to(device)
|
||||
@@ -260,14 +260,14 @@ class SDXLDiffDiffControlNetDenoiseStep(StableDiffusionXLDenoiseLoopWrapper):
|
||||
controlnet_denoise_block = SDXLDiffDiffControlNetDenoiseStep()
|
||||
```
|
||||
|
||||
Insert the `controlnet_input` block and replace the `denoise` block with the new `controlnet_denoise_block`. Initialize a [`ModularPipeline`] and [`~ModularPipeline.load_default_components`] into it.
|
||||
Insert the `controlnet_input` block and replace the `denoise` block with the new `controlnet_denoise_block`. Initialize a [`ModularPipeline`] and [`~ModularPipeline.load_components`] into it.
|
||||
|
||||
```py
|
||||
dd_blocks.sub_blocks.insert("controlnet_input", control_input_block, 7)
|
||||
dd_blocks.sub_blocks["denoise"] = controlnet_denoise_block
|
||||
|
||||
dd_pipeline = dd_blocks.init_pipeline("YiYiXu/modular-demo-auto", collection="diffdiff")
|
||||
dd_pipeline.load_default_components(torch_dtype=torch.float16)
|
||||
dd_pipeline.load_components(torch_dtype=torch.float16)
|
||||
dd_pipeline = dd_pipeline.to(device)
|
||||
|
||||
control_image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/diffdiff_tomato_canny.jpeg")
|
||||
@@ -320,7 +320,7 @@ Call [`SequentialPipelineBlocks.from_blocks_dict`] to create a [`SequentialPipel
|
||||
```py
|
||||
dd_auto_blocks = SequentialPipelineBlocks.from_blocks_dict(DIFFDIFF_AUTO_BLOCKS)
|
||||
dd_pipeline = dd_auto_blocks.init_pipeline("YiYiXu/modular-demo-auto", collection="diffdiff")
|
||||
dd_pipeline.load_default_components(torch_dtype=torch.float16)
|
||||
dd_pipeline.load_components(torch_dtype=torch.float16)
|
||||
```
|
||||
|
||||
## Share
|
||||
@@ -340,5 +340,5 @@ from diffusers.modular_pipelines import ModularPipeline, ComponentsManager
|
||||
components = ComponentsManager()
|
||||
|
||||
diffdiff_pipeline = ModularPipeline.from_pretrained("YiYiXu/modular-diffdiff-0704", trust_remote_code=True, components_manager=components, collection="diffdiff")
|
||||
diffdiff_pipeline.load_default_components(torch_dtype=torch.float16)
|
||||
```
|
||||
diffdiff_pipeline.load_components(torch_dtype=torch.float16)
|
||||
```
|
||||
|
||||
@@ -162,6 +162,9 @@ Take a look at the [Quantization](./quantization/overview) section for more deta
|
||||
|
||||
## Optimizations
|
||||
|
||||
> [!TIP]
|
||||
> Optimization is dependent on hardware specs such as memory. Use this [Space](https://huggingface.co/spaces/diffusers/optimized-diffusers-code) to generate code examples that include all of Diffusers' available memory and speed optimization techniques for any model you're using.
|
||||
|
||||
Modern diffusion models are very large and have billions of parameters. The iterative denoising process is also computationally intensive and slow. Diffusers provides techniques for reducing memory usage and boosting inference speed. These techniques can be combined with quantization to optimize for both memory usage and inference speed.
|
||||
|
||||
### Memory usage
|
||||
|
||||
@@ -48,10 +48,10 @@ t2i_pipeline = t2i_blocks.init_pipeline(modular_repo_id, components_manager=comp
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
组件仅在调用 [`~ModularPipeline.load_components`] 或 [`~ModularPipeline.load_default_components`] 时加载和注册。以下示例使用 [`~ModularPipeline.load_default_components`] 创建第二个管道,重用第一个管道的所有组件,并将其分配到不同的集合。
|
||||
组件仅在调用 [`~ModularPipeline.load_components`] 或 [`~ModularPipeline.load_components`] 时加载和注册。以下示例使用 [`~ModularPipeline.load_components`] 创建第二个管道,重用第一个管道的所有组件,并将其分配到不同的集合。
|
||||
|
||||
```py
|
||||
pipe.load_default_components()
|
||||
pipe.load_components()
|
||||
pipe2 = ModularPipeline.from_pretrained("YiYiXu/modular-demo-auto", components_manager=comp, collection="test2")
|
||||
```
|
||||
|
||||
@@ -185,4 +185,4 @@ comp.enable_auto_cpu_offload(device="cuda")
|
||||
|
||||
所有模型开始时都在 CPU 上,[`ComponentsManager`] 在需要它们之前将它们移动到适当的设备,并在 GPU 内存不足时将其他模型移回 CPU。
|
||||
|
||||
您可以设置自己的规则来决定哪些模型要卸载。
|
||||
您可以设置自己的规则来决定哪些模型要卸载。
|
||||
|
||||
@@ -73,13 +73,13 @@ ComponentSpec(name='guider', type_hint=<class 'diffusers.guiders.perturbed_atten
|
||||
}
|
||||
```
|
||||
|
||||
引导器只有在调用 [`~ModularPipeline.load_default_components`] 之后才会创建,基于 `modular_model_index.json` 中的加载规范。
|
||||
引导器只有在调用 [`~ModularPipeline.load_components`] 之后才会创建,基于 `modular_model_index.json` 中的加载规范。
|
||||
|
||||
```py
|
||||
t2i_pipeline = t2i_blocks.init_pipeline("YiYiXu/modular-doc-guider")
|
||||
# 在初始化时未创建
|
||||
assert t2i_pipeline.guider is None
|
||||
t2i_pipeline.load_default_components()
|
||||
t2i_pipeline.load_components()
|
||||
# 加载为 PAG 引导器
|
||||
t2i_pipeline.guider
|
||||
```
|
||||
@@ -170,4 +170,4 @@ t2i_pipeline.push_to_hub("YiYiXu/modular-doc-guider")
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
</hfoptions>
|
||||
|
||||
@@ -28,7 +28,7 @@ blocks = SequentialPipelineBlocks.from_blocks_dict(TEXT2IMAGE_BLOCKS)
|
||||
modular_repo_id = "YiYiXu/modular-loader-t2i-0704"
|
||||
pipeline = blocks.init_pipeline(modular_repo_id)
|
||||
|
||||
pipeline.load_default_components(torch_dtype=torch.float16)
|
||||
pipeline.load_components(torch_dtype=torch.float16)
|
||||
pipeline.to("cuda")
|
||||
|
||||
image = pipeline(prompt="Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", output="images")[0]
|
||||
@@ -48,7 +48,7 @@ blocks = SequentialPipelineBlocks.from_blocks_dict(IMAGE2IMAGE_BLOCKS)
|
||||
modular_repo_id = "YiYiXu/modular-loader-t2i-0704"
|
||||
pipeline = blocks.init_pipeline(modular_repo_id)
|
||||
|
||||
pipeline.load_default_components(torch_dtype=torch.float16)
|
||||
pipeline.load_components(torch_dtype=torch.float16)
|
||||
pipeline.to("cuda")
|
||||
|
||||
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl-text2img.png"
|
||||
@@ -72,7 +72,7 @@ blocks = SequentialPipelineBlocks.from_blocks_dict(INPAINT_BLOCKS)
|
||||
modular_repo_id = "YiYiXu/modular-loader-t2i-0704"
|
||||
pipeline = blocks.init_pipeline(modular_repo_id)
|
||||
|
||||
pipeline.load_default_components(torch_dtype=torch.float16)
|
||||
pipeline.load_components(torch_dtype=torch.float16)
|
||||
pipeline.to("cuda")
|
||||
|
||||
img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl-text2img.png"
|
||||
@@ -176,15 +176,15 @@ diffdiff_pipeline = ModularPipeline.from_pretrained(modular_repo_id, trust_remot
|
||||
|
||||
## 加载组件
|
||||
|
||||
一个[`ModularPipeline`]不会自动实例化组件。它只加载配置和组件规范。您可以使用[`~ModularPipeline.load_default_components`]加载所有组件,或仅使用[`~ModularPipeline.load_components`]加载特定组件。
|
||||
一个[`ModularPipeline`]不会自动实例化组件。它只加载配置和组件规范。您可以使用[`~ModularPipeline.load_components`]加载所有组件,或仅使用[`~ModularPipeline.load_components`]加载特定组件。
|
||||
|
||||
<hfoptions id="load">
|
||||
<hfoption id="load_default_components">
|
||||
<hfoption id="load_components">
|
||||
|
||||
```py
|
||||
import torch
|
||||
|
||||
t2i_pipeline.load_default_components(torch_dtype=torch.float16)
|
||||
t2i_pipeline.load_components(torch_dtype=torch.float16)
|
||||
t2i_pipeline.to("cuda")
|
||||
```
|
||||
|
||||
|
||||
@@ -175,7 +175,7 @@ print(dd_blocks)
|
||||
将 [`SequentialPipelineBlocks`] 转换为 [`ModularPipeline`],使用 [`ModularPipeline.init_pipeline`] 方法。这会初始化从 `modular_model_index.json` 文件加载的预期组件。通过调用 [`ModularPipeline.load_defau
|
||||
lt_components`]。
|
||||
|
||||
初始化[`ComponentManager`]时传入pipeline是一个好主意,以帮助管理不同的组件。一旦调用[`~ModularPipeline.load_default_components`],组件就会被注册到[`ComponentManager`]中,并且可以在工作流之间共享。下面的例子使用`collection`参数为组件分配了一个`"diffdiff"`标签,以便更好地组织。
|
||||
初始化[`ComponentManager`]时传入pipeline是一个好主意,以帮助管理不同的组件。一旦调用[`~ModularPipeline.load_components`],组件就会被注册到[`ComponentManager`]中,并且可以在工作流之间共享。下面的例子使用`collection`参数为组件分配了一个`"diffdiff"`标签,以便更好地组织。
|
||||
|
||||
```py
|
||||
from diffusers.modular_pipelines import ComponentsManager
|
||||
@@ -209,11 +209,11 @@ ip_adapter_block = StableDiffusionXLAutoIPAdapterStep()
|
||||
dd_blocks.sub_blocks.insert("ip_adapter", ip_adapter_block, 0)
|
||||
```
|
||||
|
||||
调用[`~ModularPipeline.init_pipeline`]来初始化一个[`ModularPipeline`],并使用[`~ModularPipeline.load_default_components`]加载模型组件。加载并设置IP-Adapter以运行pipeline。
|
||||
调用[`~ModularPipeline.init_pipeline`]来初始化一个[`ModularPipeline`],并使用[`~ModularPipeline.load_components`]加载模型组件。加载并设置IP-Adapter以运行pipeline。
|
||||
|
||||
```py
|
||||
dd_pipeline = dd_blocks.init_pipeline("YiYiXu/modular-demo-auto", collection="diffdiff")
|
||||
dd_pipeline.load_default_components(torch_dtype=torch.float16)
|
||||
dd_pipeline.load_components(torch_dtype=torch.float16)
|
||||
dd_pipeline.loader.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin")
|
||||
dd_pipeline.loader.set_ip_adapter_scale(0.6)
|
||||
dd_pipeline = dd_pipeline.to(device)
|
||||
@@ -261,14 +261,14 @@ class SDXLDiffDiffControlNetDenoiseStep(StableDiffusionXLDenoiseLoopWrapper):
|
||||
controlnet_denoise_block = SDXLDiffDiffControlNetDenoiseStep()
|
||||
```
|
||||
|
||||
插入 `controlnet_input` 块并用新的 `controlnet_denoise_block` 替换 `denoise` 块。初始化一个 [`ModularPipeline`] 并将 [`~ModularPipeline.load_default_components`] 加载到其中。
|
||||
插入 `controlnet_input` 块并用新的 `controlnet_denoise_block` 替换 `denoise` 块。初始化一个 [`ModularPipeline`] 并将 [`~ModularPipeline.load_components`] 加载到其中。
|
||||
|
||||
```py
|
||||
dd_blocks.sub_blocks.insert("controlnet_input", control_input_block, 7)
|
||||
dd_blocks.sub_blocks["denoise"] = controlnet_denoise_block
|
||||
|
||||
dd_pipeline = dd_blocks.init_pipeline("YiYiXu/modular-demo-auto", collection="diffdiff")
|
||||
dd_pipeline.load_default_components(torch_dtype=torch.float16)
|
||||
dd_pipeline.load_components(torch_dtype=torch.float16)
|
||||
dd_pipeline = dd_pipeline.to(device)
|
||||
|
||||
control_image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/diffdiff_tomato_canny.jpeg")
|
||||
@@ -322,7 +322,7 @@ DIFFDIFF_AUTO_BLOCKS.insert("controlnet_input",StableDiffusionXLControlNetAutoIn
|
||||
```py
|
||||
dd_auto_blocks = SequentialPipelineBlocks.from_blocks_dict(DIFFDIFF_AUTO_BLOCKS)
|
||||
dd_pipeline = dd_auto_blocks.init_pipeline("YiYiXu/modular-demo-auto", collection="diffdiff")
|
||||
dd_pipeline.load_default_components(torch_dtype=torch.float16)
|
||||
dd_pipeline.load_components(torch_dtype=torch.float16)
|
||||
```
|
||||
|
||||
## 分享
|
||||
@@ -342,5 +342,5 @@ from diffusers.modular_pipelines import ModularPipeline, ComponentsManager
|
||||
components = ComponentsManager()
|
||||
|
||||
diffdiff_pipeline = ModularPipeline.from_pretrained("YiYiXu/modular-diffdiff-0704", trust_remote_code=True, components_manager=components, collection="diffdiff")
|
||||
diffdiff_pipeline.load_default_components(torch_dtype=torch.float16)
|
||||
diffdiff_pipeline.load_components(torch_dtype=torch.float16)
|
||||
```
|
||||
|
||||
@@ -13,7 +13,6 @@ from .utils import (
|
||||
is_k_diffusion_available,
|
||||
is_librosa_available,
|
||||
is_note_seq_available,
|
||||
is_nunchaku_available,
|
||||
is_onnx_available,
|
||||
is_opencv_available,
|
||||
is_optimum_quanto_available,
|
||||
@@ -100,18 +99,6 @@ except OptionalDependencyNotAvailable:
|
||||
else:
|
||||
_import_structure["quantizers.quantization_config"].append("TorchAoConfig")
|
||||
|
||||
try:
|
||||
if not is_torch_available() and not is_accelerate_available() and not is_nunchaku_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils import dummy_nunchaku_objects
|
||||
|
||||
_import_structure["utils.dummy_nunchaku_objects"] = [
|
||||
name for name in dir(dummy_nunchaku_objects) if not name.startswith("_")
|
||||
]
|
||||
else:
|
||||
_import_structure["quantizers.quantization_config"].append("NunchakuConfig")
|
||||
|
||||
try:
|
||||
if not is_torch_available() and not is_accelerate_available() and not is_optimum_quanto_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
@@ -804,14 +791,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
else:
|
||||
from .quantizers.quantization_config import QuantoConfig
|
||||
|
||||
try:
|
||||
if not is_nunchaku_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils.dummy_optimum_quanto_objects import *
|
||||
else:
|
||||
from .quantizers.quantization_config import NunchakuConfig
|
||||
|
||||
try:
|
||||
if not is_onnx_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
|
||||
@@ -23,7 +23,6 @@ from typing_extensions import Self
|
||||
|
||||
from .. import __version__
|
||||
from ..quantizers import DiffusersAutoQuantizer
|
||||
from ..quantizers.quantization_config import NunchakuConfig
|
||||
from ..utils import deprecate, is_accelerate_available, is_torch_version, logging
|
||||
from ..utils.torch_utils import empty_device_cache
|
||||
from .single_file_utils import (
|
||||
@@ -43,7 +42,6 @@ from .single_file_utils import (
|
||||
convert_ltx_vae_checkpoint_to_diffusers,
|
||||
convert_lumina2_to_diffusers,
|
||||
convert_mochi_transformer_checkpoint_to_diffusers,
|
||||
convert_nunchaku_flux_to_diffusers,
|
||||
convert_sana_transformer_to_diffusers,
|
||||
convert_sd3_transformer_checkpoint_to_diffusers,
|
||||
convert_stable_cascade_unet_single_file_to_diffusers,
|
||||
@@ -192,23 +190,6 @@ def _get_mapping_function_kwargs(mapping_fn, **kwargs):
|
||||
return mapping_kwargs
|
||||
|
||||
|
||||
def _maybe_determine_modules_to_not_convert(quantization_config, state_dict):
|
||||
if quantization_config is None:
|
||||
return None
|
||||
else:
|
||||
is_nunchaku = quantization_config.quant_method == "nunchaku"
|
||||
if not is_nunchaku:
|
||||
return None
|
||||
else:
|
||||
no_qweight = set()
|
||||
for key in state_dict:
|
||||
if key.endswith(".weight"):
|
||||
# module name is everything except the last piece after "."
|
||||
module_name = ".".join(key.split(".")[:-1])
|
||||
no_qweight.add(module_name)
|
||||
return sorted(no_qweight)
|
||||
|
||||
|
||||
class FromOriginalModelMixin:
|
||||
"""
|
||||
Load pretrained weights saved in the `.ckpt` or `.safetensors` format into a model.
|
||||
@@ -423,14 +404,8 @@ class FromOriginalModelMixin:
|
||||
model = cls.from_config(diffusers_model_config)
|
||||
|
||||
checkpoint_mapping_kwargs = _get_mapping_function_kwargs(checkpoint_mapping_fn, **kwargs)
|
||||
model_state_dict = model.state_dict()
|
||||
# TODO: Only flux nunchaku checkpoint for now. Unify with how checkpoint mappers are done.
|
||||
# For `nunchaku` checkpoints, we might want to determine the `modules_to_not_convert`.
|
||||
if quantization_config is not None and quantization_config.quant_method == "nunchaku":
|
||||
diffusers_format_checkpoint = convert_nunchaku_flux_to_diffusers(
|
||||
checkpoint, model_state_dict=model_state_dict
|
||||
)
|
||||
elif _should_convert_state_dict_to_diffusers(model_state_dict, checkpoint):
|
||||
|
||||
if _should_convert_state_dict_to_diffusers(model.state_dict(), checkpoint):
|
||||
diffusers_format_checkpoint = checkpoint_mapping_fn(
|
||||
config=diffusers_model_config, checkpoint=checkpoint, **checkpoint_mapping_kwargs
|
||||
)
|
||||
@@ -441,27 +416,6 @@ class FromOriginalModelMixin:
|
||||
raise SingleFileComponentError(
|
||||
f"Failed to load {mapping_class_name}. Weights for this component appear to be missing in the checkpoint."
|
||||
)
|
||||
|
||||
# This step is better off here than above because `diffusers_format_checkpoint` holds the keys we expect.
|
||||
# We can move it to a separate function as well.
|
||||
if quantization_config is not None:
|
||||
original_modules_to_not_convert = quantization_config.modules_to_not_convert or []
|
||||
determined_modules_to_not_convert = _maybe_determine_modules_to_not_convert(
|
||||
quantization_config, checkpoint
|
||||
)
|
||||
if determined_modules_to_not_convert:
|
||||
determined_modules_to_not_convert.extend(original_modules_to_not_convert)
|
||||
determined_modules_to_not_convert = list(set(determined_modules_to_not_convert))
|
||||
logger.debug(
|
||||
f"`modules_to_not_convert` in the quantization_config was updated from {quantization_config.modules_to_not_convert} to {determined_modules_to_not_convert}."
|
||||
)
|
||||
modified_quant_config = quantization_config.to_dict()
|
||||
modified_quant_config["modules_to_not_convert"] = determined_modules_to_not_convert
|
||||
# TODO: figure out a better way.
|
||||
modified_quant_config = NunchakuConfig.from_dict(modified_quant_config)
|
||||
setattr(hf_quantizer, "quantization_config", modified_quant_config)
|
||||
logger.debug("TODO")
|
||||
|
||||
# Check if `_keep_in_fp32_modules` is not None
|
||||
use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (
|
||||
(torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules")
|
||||
@@ -489,12 +443,6 @@ class FromOriginalModelMixin:
|
||||
unexpected_keys = [
|
||||
param_name for param_name in diffusers_format_checkpoint if param_name not in empty_state_dict
|
||||
]
|
||||
for k in unexpected_keys:
|
||||
if "single_transformer_blocks.0" in k:
|
||||
print(f"Unexpected {k=}")
|
||||
for k in empty_state_dict:
|
||||
if "single_transformer_blocks.0" in k:
|
||||
print(f"model {k=}")
|
||||
device_map = {"": param_device}
|
||||
load_model_dict_into_meta(
|
||||
model,
|
||||
|
||||
@@ -2189,105 +2189,6 @@ def convert_animatediff_checkpoint_to_diffusers(checkpoint, **kwargs):
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
# Adapted from https://github.com/nunchaku-tech/nunchaku/blob/3ec299f439f9986a69ded320798cab4e258c871d/nunchaku/models/transformers/transformer_flux_v2.py#L395
|
||||
def convert_nunchaku_flux_to_diffusers(checkpoint, **kwargs):
|
||||
from .single_file_utils_nunchaku import _unpack_qkv_state_dict
|
||||
|
||||
_SMOOTH_ORIG_RE = re.compile(r"\.smooth_orig(\.|$)")
|
||||
_SMOOTH_RE = re.compile(r"\.smooth(\.|$)")
|
||||
|
||||
new_state_dict = {}
|
||||
model_state_dict = kwargs["model_state_dict"]
|
||||
|
||||
ckpt_keys = list(checkpoint.keys())
|
||||
for k in ckpt_keys:
|
||||
if "qweight" in k:
|
||||
# only the shape information of this tensor is needed
|
||||
v = checkpoint[k]
|
||||
# if the tensor has qweight, but does not have low-rank branch, we need to add some artificial tensors
|
||||
for t in ["lora_up", "lora_down"]:
|
||||
new_k = k.replace(".qweight", f".{t}")
|
||||
if new_k not in ckpt_keys:
|
||||
oc, ic = v.shape
|
||||
ic = ic * 2 # v is packed into INT8, so we need to double the size
|
||||
checkpoint[k.replace(".qweight", f".{t}")] = torch.zeros(
|
||||
(0, ic) if t == "lora_down" else (oc, 0), device=v.device, dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
for k, v in checkpoint.items():
|
||||
new_k = k # start with original, then apply independent replacements
|
||||
|
||||
if k.startswith("single_transformer_blocks."):
|
||||
# attention / qkv / norms
|
||||
new_k = new_k.replace(".qkv_proj.", ".attn.to_qkv.")
|
||||
new_k = new_k.replace(".out_proj.", ".proj_out.")
|
||||
new_k = new_k.replace(".norm_k.", ".attn.norm_k.")
|
||||
new_k = new_k.replace(".norm_q.", ".attn.norm_q.")
|
||||
|
||||
# mlp heads
|
||||
new_k = new_k.replace(".mlp_fc1.", ".proj_mlp.")
|
||||
new_k = new_k.replace(".mlp_fc2.", ".proj_out.")
|
||||
|
||||
# smooth params (use regex to avoid substring collisions)
|
||||
new_k = _SMOOTH_ORIG_RE.sub(r".smooth_factor_orig\1", new_k)
|
||||
new_k = _SMOOTH_RE.sub(r".smooth_factor\1", new_k)
|
||||
|
||||
# lora -> proj
|
||||
new_k = new_k.replace(".lora_down", ".proj_down")
|
||||
new_k = new_k.replace(".lora_up", ".proj_up")
|
||||
|
||||
elif k.startswith("transformer_blocks."):
|
||||
# feed-forward (context & base)
|
||||
new_k = new_k.replace(".mlp_context_fc1.", ".ff_context.net.0.proj.")
|
||||
new_k = new_k.replace(".mlp_context_fc2.", ".ff_context.net.2.")
|
||||
new_k = new_k.replace(".mlp_fc1.", ".ff.net.0.proj.")
|
||||
new_k = new_k.replace(".mlp_fc2.", ".ff.net.2.")
|
||||
|
||||
# attention projections
|
||||
new_k = new_k.replace(".qkv_proj_context.", ".attn.add_qkv_proj.")
|
||||
new_k = new_k.replace(".qkv_proj.", ".attn.to_qkv.")
|
||||
new_k = new_k.replace(".out_proj.", ".attn.to_out.0.")
|
||||
new_k = new_k.replace(".out_proj_context.", ".attn.to_add_out.")
|
||||
|
||||
# norms
|
||||
new_k = new_k.replace(".norm_k.", ".attn.norm_k.")
|
||||
new_k = new_k.replace(".norm_q.", ".attn.norm_q.")
|
||||
new_k = new_k.replace(".norm_added_k.", ".attn.norm_added_k.")
|
||||
new_k = new_k.replace(".norm_added_q.", ".attn.norm_added_q.")
|
||||
|
||||
# smooth params
|
||||
new_k = _SMOOTH_ORIG_RE.sub(r".smooth_factor_orig\1", new_k)
|
||||
new_k = _SMOOTH_RE.sub(r".smooth_factor\1", new_k)
|
||||
|
||||
# lora -> proj
|
||||
new_k = new_k.replace(".lora_down", ".proj_down")
|
||||
new_k = new_k.replace(".lora_up", ".proj_up")
|
||||
|
||||
new_state_dict[new_k] = v
|
||||
|
||||
new_state_dict = _unpack_qkv_state_dict(new_state_dict)
|
||||
|
||||
# some remnant keys need to be patched
|
||||
new_sd_keys = list(new_state_dict.keys())
|
||||
for k in new_sd_keys:
|
||||
if "qweight" in k:
|
||||
no_qweight_k = ".".join(k.split(".qweight")[:-1])
|
||||
for unexpected_k in ["wzeros"]:
|
||||
unexpected_k = no_qweight_k + f".{unexpected_k}"
|
||||
if unexpected_k in new_sd_keys:
|
||||
_ = new_state_dict.pop(unexpected_k)
|
||||
for k in model_state_dict:
|
||||
if k not in new_state_dict:
|
||||
# CPU device for now
|
||||
new_state_dict[k] = torch.ones_like(model_state_dict[k], device="cpu")
|
||||
|
||||
for k in new_state_dict:
|
||||
if "single_transformer_blocks.0" in k and k.endswith(".weight"):
|
||||
print(f"{k=}")
|
||||
|
||||
return new_state_dict
|
||||
|
||||
|
||||
def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
|
||||
converted_state_dict = {}
|
||||
keys = list(checkpoint.keys())
|
||||
|
||||
@@ -1,104 +0,0 @@
|
||||
import re
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
_QKV_ANCHORS_NUNCHAKU = ("to_qkv", "add_qkv_proj")
|
||||
_ALLOWED_SUFFIXES_NUNCHAKU = {
|
||||
"bias",
|
||||
"proj_down",
|
||||
"proj_up",
|
||||
"qweight",
|
||||
"smooth_factor",
|
||||
"smooth_factor_orig",
|
||||
"wscales",
|
||||
}
|
||||
|
||||
_QKV_NUNCHAKU_REGEX = re.compile(
|
||||
rf"^(?P<prefix>.*)\.(?:{'|'.join(map(re.escape, _QKV_ANCHORS_NUNCHAKU))})\.(?P<suffix>.+)$"
|
||||
)
|
||||
|
||||
|
||||
def _pick_split_dim(t: torch.Tensor, suffix: str) -> int:
|
||||
"""
|
||||
Choose which dimension to split by 3. Heuristics:
|
||||
- 1D -> dim 0
|
||||
- 2D -> prefer dim=1 for 'qweight' (common layout [*, 3*out_features]),
|
||||
otherwise prefer dim=0 (common layout [3*out_features, *]).
|
||||
- If preferred dim isn't divisible by 3, try the other; else error.
|
||||
"""
|
||||
shape = list(t.shape)
|
||||
if len(shape) == 0:
|
||||
raise ValueError("Cannot split a scalar into Q/K/V.")
|
||||
|
||||
if len(shape) == 1:
|
||||
dim = 0
|
||||
if shape[dim] % 3 == 0:
|
||||
return dim
|
||||
raise ValueError(f"1D tensor of length {shape[0]} not divisible by 3.")
|
||||
|
||||
# len(shape) >= 2
|
||||
preferred = 1 if suffix == "qweight" else 0
|
||||
other = 0 if preferred == 1 else 1
|
||||
|
||||
if shape[preferred] % 3 == 0:
|
||||
return preferred
|
||||
if shape[other] % 3 == 0:
|
||||
return other
|
||||
|
||||
# Fall back: any dim divisible by 3
|
||||
for d, s in enumerate(shape):
|
||||
if s % 3 == 0:
|
||||
return d
|
||||
|
||||
raise ValueError(f"None of the dims {shape} are divisible by 3 for suffix '{suffix}'.")
|
||||
|
||||
|
||||
def _split_qkv(t: torch.Tensor, dim: int):
|
||||
return torch.tensor_split(t, 3, dim=dim)
|
||||
|
||||
|
||||
def _unpack_qkv_state_dict(
|
||||
state_dict: dict, anchors=_QKV_ANCHORS_NUNCHAKU, allowed_suffixes=_ALLOWED_SUFFIXES_NUNCHAKU
|
||||
):
|
||||
"""
|
||||
Convert fused QKV entries (e.g., '...to_qkv.bias', '...qkv_proj.wscales') into separate Q/K/V entries:
|
||||
'...to_q.bias', '...to_k.bias', '...to_v.bias' '...to_q.wscales', '...to_k.wscales', '...to_v.wscales'
|
||||
Returns a NEW dict; original is not modified.
|
||||
|
||||
Only keys with suffix in `allowed_suffixes` are processed. Keys with non-divisible-by-3 tensors raise a ValueError.:
|
||||
"""
|
||||
anchors = tuple(anchors)
|
||||
allowed_suffixes = set(allowed_suffixes)
|
||||
|
||||
new_sd: dict = {}
|
||||
sd_keys = list(state_dict.keys())
|
||||
for k in sd_keys:
|
||||
m = _QKV_NUNCHAKU_REGEX.match(k)
|
||||
v = state_dict.pop(k)
|
||||
if m:
|
||||
suffix = m.group("suffix")
|
||||
if suffix not in allowed_suffixes:
|
||||
# keep as-is if it's not one of the targeted suffixes
|
||||
new_sd[k] = v
|
||||
continue
|
||||
|
||||
prefix = m.group("prefix") # everything before .to_qkv/.qkv_proj
|
||||
# Decide split axis
|
||||
split_dim = _pick_split_dim(v, suffix)
|
||||
q, k_, vv = _split_qkv(v, dim=split_dim)
|
||||
|
||||
# Build new keys
|
||||
base_q = f"{prefix}.to_q.{suffix}"
|
||||
base_k = f"{prefix}.to_k.{suffix}"
|
||||
base_v = f"{prefix}.to_v.{suffix}"
|
||||
|
||||
# Write into result dict
|
||||
new_sd[base_q] = q
|
||||
new_sd[base_k] = k_
|
||||
new_sd[base_v] = vv
|
||||
else:
|
||||
# not a fused qkv key
|
||||
new_sd[k] = v
|
||||
|
||||
return new_sd
|
||||
@@ -297,13 +297,6 @@ def load_model_dict_into_meta(
|
||||
offload_index = offload_weight(param, param_name, offload_folder, offload_index)
|
||||
elif param_device == "cpu" and state_dict_index is not None:
|
||||
state_dict_index = offload_weight(param, param_name, state_dict_folder, state_dict_index)
|
||||
# This check below might be a bit counter-intuitive in nature. This is because we're checking if the param
|
||||
# or its module is quantized and if so, we're proceeding with creating a quantized param. This is because
|
||||
# of the way pre-trained models are loaded. They're initialized under "meta" device, where
|
||||
# quantization layers are first injected. Hence, for a model that is either pre-quantized or supplemented
|
||||
# with a `quantization_config` during `from_pretrained`, we expect `check_if_quantized_param` to return True.
|
||||
# Then depending on the quantization backend being used, we run the actual quantization step under
|
||||
# `create_quantized_param`.
|
||||
elif is_quantized and (
|
||||
hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=param_device)
|
||||
):
|
||||
|
||||
@@ -1409,7 +1409,7 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
# YiYi TODO:
|
||||
# 1. look into the serialization of modular_model_index.json, make sure the items are properly ordered like model_index.json (currently a mess)
|
||||
# 2. do we need ConfigSpec? the are basically just key/val kwargs
|
||||
# 3. imnprove docstring and potentially add validator for methods where we accpet kwargs to be passed to from_pretrained/save_pretrained/load_default_components(), load_components()
|
||||
# 3. imnprove docstring and potentially add validator for methods where we accpet kwargs to be passed to from_pretrained/save_pretrained/load_components()
|
||||
class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
"""
|
||||
Base class for all Modular pipelines.
|
||||
@@ -1478,7 +1478,7 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
- Components with default_creation_method="from_config" are created immediately, its specs are not included
|
||||
in config dict and will not be saved in `modular_model_index.json`
|
||||
- Components with default_creation_method="from_pretrained" are set to None and can be loaded later with
|
||||
`load_default_components()`/`load_components()`
|
||||
`load_components()` (with or without specific component names)
|
||||
- The pipeline's config dict is populated with component specs (only for from_pretrained components) and
|
||||
config values, which will be saved as `modular_model_index.json` during `save_pretrained`
|
||||
- The pipeline's config dict is also used to store the pipeline blocks's class name, which will be saved as
|
||||
@@ -1541,20 +1541,6 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
params[input_param.name] = input_param.default
|
||||
return params
|
||||
|
||||
def load_default_components(self, **kwargs):
|
||||
"""
|
||||
Load from_pretrained components using the loading specs in the config dict.
|
||||
|
||||
Args:
|
||||
**kwargs: Additional arguments passed to `from_pretrained` method, e.g. torch_dtype, cache_dir, etc.
|
||||
"""
|
||||
names = [
|
||||
name
|
||||
for name in self._component_specs.keys()
|
||||
if self._component_specs[name].default_creation_method == "from_pretrained"
|
||||
]
|
||||
self.load_components(names=names, **kwargs)
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
def from_pretrained(
|
||||
@@ -1682,8 +1668,8 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
- non from_pretrained components are created during __init__ and registered as the object itself
|
||||
- Components are updated with the `update_components()` method: e.g. loader.update_components(unet=unet) or
|
||||
loader.update_components(guider=guider_spec)
|
||||
- (from_pretrained) Components are loaded with the `load_default_components()` method: e.g.
|
||||
loader.load_default_components(names=["unet"])
|
||||
- (from_pretrained) Components are loaded with the `load_components()` method: e.g.
|
||||
loader.load_components(names=["unet"]) or loader.load_components() to load all default components
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments where keys are component names and values are component objects.
|
||||
@@ -1995,13 +1981,14 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
self.register_to_config(**config_to_register)
|
||||
|
||||
# YiYi TODO: support map for additional from_pretrained kwargs
|
||||
# YiYi/Dhruv TODO: consolidate load_components and load_default_components?
|
||||
def load_components(self, names: Union[List[str], str], **kwargs):
|
||||
def load_components(self, names: Optional[Union[List[str], str]] = None, **kwargs):
|
||||
"""
|
||||
Load selected components from specs.
|
||||
|
||||
Args:
|
||||
names: List of component names to load; by default will not load any components
|
||||
names: List of component names to load. If None, will load all components with
|
||||
default_creation_method == "from_pretrained". If provided as a list or string, will load only the
|
||||
specified components.
|
||||
**kwargs: additional kwargs to be passed to `from_pretrained()`.Can be:
|
||||
- a single value to be applied to all components to be loaded, e.g. torch_dtype=torch.bfloat16
|
||||
- a dict, e.g. torch_dtype={"unet": torch.bfloat16, "default": torch.float32}
|
||||
@@ -2009,7 +1996,13 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
`variant`, `revision`, etc.
|
||||
"""
|
||||
|
||||
if isinstance(names, str):
|
||||
if names is None:
|
||||
names = [
|
||||
name
|
||||
for name in self._component_specs.keys()
|
||||
if self._component_specs[name].default_creation_method == "from_pretrained"
|
||||
]
|
||||
elif isinstance(names, str):
|
||||
names = [names]
|
||||
elif not isinstance(names, list):
|
||||
raise ValueError(f"Invalid type for names: {type(names)}")
|
||||
|
||||
@@ -21,11 +21,9 @@ from typing import Dict, Optional, Union
|
||||
|
||||
from .bitsandbytes import BnB4BitDiffusersQuantizer, BnB8BitDiffusersQuantizer
|
||||
from .gguf import GGUFQuantizer
|
||||
from .nunchaku import NunchakuQuantizer
|
||||
from .quantization_config import (
|
||||
BitsAndBytesConfig,
|
||||
GGUFQuantizationConfig,
|
||||
NunchakuConfig,
|
||||
QuantizationConfigMixin,
|
||||
QuantizationMethod,
|
||||
QuantoConfig,
|
||||
@@ -41,7 +39,6 @@ AUTO_QUANTIZER_MAPPING = {
|
||||
"gguf": GGUFQuantizer,
|
||||
"quanto": QuantoQuantizer,
|
||||
"torchao": TorchAoHfQuantizer,
|
||||
"nunchaku": NunchakuQuantizer,
|
||||
}
|
||||
|
||||
AUTO_QUANTIZATION_CONFIG_MAPPING = {
|
||||
@@ -50,13 +47,12 @@ AUTO_QUANTIZATION_CONFIG_MAPPING = {
|
||||
"gguf": GGUFQuantizationConfig,
|
||||
"quanto": QuantoConfig,
|
||||
"torchao": TorchAoConfig,
|
||||
"nunchaku": NunchakuConfig,
|
||||
}
|
||||
|
||||
|
||||
class DiffusersAutoQuantizer:
|
||||
"""
|
||||
The auto diffusers quantizer class that takes care of automatically instantiating to the correct
|
||||
The auto diffusers quantizer class that takes care of automatically instantiating to the correct
|
||||
`DiffusersQuantizer` given the `QuantizationConfig`.
|
||||
"""
|
||||
|
||||
|
||||
@@ -90,7 +90,7 @@ class GGUFQuantizer(DiffusersQuantizer):
|
||||
def check_if_quantized_param(
|
||||
self,
|
||||
model: "ModelMixin",
|
||||
param_value: Union["torch.Tensor"],
|
||||
param_value: Union["GGUFParameter", "torch.Tensor"],
|
||||
param_name: str,
|
||||
state_dict: Dict[str, Any],
|
||||
**kwargs,
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
from .nunchaku_quantizer import NunchakuQuantizer
|
||||
@@ -1,174 +0,0 @@
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Union
|
||||
|
||||
from diffusers.utils.import_utils import is_nunchaku_version
|
||||
|
||||
from ...utils import get_module_from_name, is_accelerate_available, is_nunchaku_available, is_torch_available, logging
|
||||
from ...utils.torch_utils import is_fp8_available
|
||||
from ..base import DiffusersQuantizer
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_nunchaku_available():
|
||||
from .utils import replace_with_nunchaku_linear
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
KEY_MAP = {
|
||||
"lora_down": "proj_down",
|
||||
"lora_up": "proj_up",
|
||||
"smooth_orig": "smooth_factor_orig",
|
||||
"smooth": "smooth_factor",
|
||||
}
|
||||
|
||||
|
||||
class NunchakuQuantizer(DiffusersQuantizer):
|
||||
r"""
|
||||
Diffusers Quantizer for Nunchaku (https://github.com/nunchaku-tech/nunchaku)
|
||||
"""
|
||||
|
||||
use_keep_in_fp32_modules = True
|
||||
requires_calibration = False
|
||||
required_packages = ["nunchaku", "accelerate"]
|
||||
|
||||
def __init__(self, quantization_config, **kwargs):
|
||||
super().__init__(quantization_config, **kwargs)
|
||||
|
||||
def validate_environment(self, *args, **kwargs):
|
||||
if not torch.cuda.is_available():
|
||||
raise RuntimeError("No GPU found. A GPU is needed for nunchaku quantization.")
|
||||
|
||||
if not is_nunchaku_available():
|
||||
raise ImportError(
|
||||
"Loading an nunchaku quantized model requires nunchaku library (follow https://nunchaku.tech/docs/nunchaku/installation/installation.html)"
|
||||
)
|
||||
if not is_nunchaku_version(">=", "0.3.1"):
|
||||
raise ImportError(
|
||||
"Loading an nunchaku quantized model requires `nunchaku>=1.0.0`. "
|
||||
"Please upgrade your installation by following https://nunchaku.tech/docs/nunchaku/installation/installation.html."
|
||||
)
|
||||
|
||||
if not is_accelerate_available():
|
||||
raise ImportError(
|
||||
"Loading an nunchaku quantized model requires accelerate library (`pip install accelerate`)"
|
||||
)
|
||||
|
||||
# TODO: check
|
||||
# device_map = kwargs.get("device_map", None)
|
||||
# if isinstance(device_map, dict) and len(device_map.keys()) > 1:
|
||||
# raise ValueError(
|
||||
# "`device_map` for multi-GPU inference or CPU/disk offload is currently not supported with Diffusers and the nunchaku backend"
|
||||
# )
|
||||
|
||||
def check_if_quantized_param(
|
||||
self,
|
||||
model: "ModelMixin",
|
||||
param_value: "torch.Tensor",
|
||||
param_name: str,
|
||||
state_dict: Dict[str, Any],
|
||||
**kwargs,
|
||||
):
|
||||
from nunchaku.models.linear import SVDQW4A4Linear
|
||||
|
||||
module, _ = get_module_from_name(model, param_name)
|
||||
if self.pre_quantized and isinstance(module, SVDQW4A4Linear):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def create_quantized_param(
|
||||
self,
|
||||
model: "ModelMixin",
|
||||
param_value: "torch.Tensor",
|
||||
param_name: str,
|
||||
target_device: "torch.device",
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Create a quantized parameter.
|
||||
"""
|
||||
from nunchaku.models.linear import SVDQW4A4Linear
|
||||
|
||||
module, tensor_name = get_module_from_name(model, param_name)
|
||||
if tensor_name not in module._parameters and tensor_name not in module._buffers:
|
||||
raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.")
|
||||
|
||||
if isinstance(module, SVDQW4A4Linear):
|
||||
module._parameters[tensor_name] = torch.nn.Parameter(param_value, requires_grad=False).to(target_device)
|
||||
|
||||
def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]:
|
||||
max_memory = {key: val * 0.90 for key, val in max_memory.items()}
|
||||
return max_memory
|
||||
|
||||
def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
|
||||
precision = self.quantization_config.precision
|
||||
expected_target_dtypes = [torch.int8]
|
||||
if is_fp8_available():
|
||||
expected_target_dtypes.append(torch.float8_e4m3fn)
|
||||
if target_dtype not in expected_target_dtypes:
|
||||
new_target_dtype = self.dtype_map[precision]
|
||||
|
||||
logger.info(f"target_dtype {target_dtype} is replaced by {new_target_dtype} for `nunchaku` quantization")
|
||||
return new_target_dtype
|
||||
else:
|
||||
raise ValueError(f"Wrong `target_dtype` ({target_dtype}) provided.")
|
||||
|
||||
def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
|
||||
if torch_dtype is None:
|
||||
# We force the `dtype` to be bfloat16, this is a requirement from `nunchaku`
|
||||
logger.info(
|
||||
"Overriding torch_dtype=%s with `torch_dtype=torch.bfloat16` due to "
|
||||
"requirements of `nunchaku` to enable model loading in 4-bit. "
|
||||
"Pass your own torch_dtype to specify the dtype of the remaining non-linear layers or pass"
|
||||
" torch_dtype=torch.bfloat16 to remove this warning.",
|
||||
torch_dtype,
|
||||
)
|
||||
torch_dtype = torch.bfloat16
|
||||
return torch_dtype
|
||||
|
||||
def _process_model_before_weight_loading(
|
||||
self,
|
||||
model: "ModelMixin",
|
||||
device_map,
|
||||
keep_in_fp32_modules: List[str] = [],
|
||||
**kwargs,
|
||||
):
|
||||
self.modules_to_not_convert = self.quantization_config.modules_to_not_convert
|
||||
if not isinstance(self.modules_to_not_convert, list):
|
||||
self.modules_to_not_convert = [self.modules_to_not_convert]
|
||||
self.modules_to_not_convert.extend(keep_in_fp32_modules)
|
||||
# Purge `None`.
|
||||
# Unlike `transformers`, we don't know if we should always keep certain modules in FP32
|
||||
# in case of diffusion transformer models. For language models and others alike, `lm_head`
|
||||
# and tied modules are usually kept in FP32.
|
||||
self.modules_to_not_convert = [module for module in self.modules_to_not_convert if module is not None]
|
||||
|
||||
# Extend `self.modules_to_not_convert` to keys that are supposed to be offloaded to `cpu` or `disk`
|
||||
if isinstance(device_map, dict) and len(device_map.keys()) > 1:
|
||||
keys_on_cpu = [key for key, value in device_map.items() if value in ["disk", "cpu"]]
|
||||
self.modules_to_not_convert.extend(keys_on_cpu)
|
||||
|
||||
model = replace_with_nunchaku_linear(
|
||||
model,
|
||||
modules_to_not_convert=self.modules_to_not_convert,
|
||||
quantization_config=self.quantization_config,
|
||||
)
|
||||
model.config.quantization_config = self.quantization_config
|
||||
|
||||
def _process_model_after_weight_loading(self, model, **kwargs):
|
||||
return model
|
||||
|
||||
@property
|
||||
def is_serializable(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def is_trainable(self):
|
||||
return False
|
||||
@@ -1,80 +0,0 @@
|
||||
import torch.nn as nn
|
||||
|
||||
from ...utils import is_accelerate_available, is_nunchaku_available, logging
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def _replace_with_nunchaku_linear(
|
||||
model,
|
||||
svdq_linear_cls,
|
||||
modules_to_not_convert=None,
|
||||
current_key_name=None,
|
||||
quantization_config=None,
|
||||
has_been_replaced=False,
|
||||
):
|
||||
for name, module in model.named_children():
|
||||
if current_key_name is None:
|
||||
current_key_name = []
|
||||
current_key_name.append(name)
|
||||
|
||||
if isinstance(module, nn.Linear) and name not in modules_to_not_convert:
|
||||
# Check if the current key is not in the `modules_to_not_convert`
|
||||
current_key_name_str = ".".join(current_key_name)
|
||||
if not any(
|
||||
(key + "." in current_key_name_str) or (key == current_key_name_str) for key in modules_to_not_convert
|
||||
):
|
||||
with init_empty_weights():
|
||||
in_features = module.in_features
|
||||
out_features = module.out_features
|
||||
|
||||
model._modules[name] = svdq_linear_cls(
|
||||
in_features,
|
||||
out_features,
|
||||
rank=quantization_config.rank,
|
||||
bias=module.bias is not None,
|
||||
torch_dtype=module.weight.dtype,
|
||||
)
|
||||
has_been_replaced = True
|
||||
# Store the module class in case we need to transpose the weight later
|
||||
model._modules[name].source_cls = type(module)
|
||||
# Force requires grad to False to avoid unexpected errors
|
||||
model._modules[name].requires_grad_(False)
|
||||
if len(list(module.children())) > 0:
|
||||
_, has_been_replaced = _replace_with_nunchaku_linear(
|
||||
module,
|
||||
svdq_linear_cls,
|
||||
modules_to_not_convert,
|
||||
current_key_name,
|
||||
quantization_config,
|
||||
has_been_replaced=has_been_replaced,
|
||||
)
|
||||
# Remove the last key for recursion
|
||||
current_key_name.pop(-1)
|
||||
return model, has_been_replaced
|
||||
|
||||
|
||||
def replace_with_nunchaku_linear(model, modules_to_not_convert=None, current_key_name=None, quantization_config=None):
|
||||
if is_nunchaku_available():
|
||||
from nunchaku.models.linear import SVDQW4A4Linear
|
||||
|
||||
model, _ = _replace_with_nunchaku_linear(
|
||||
model, SVDQW4A4Linear, modules_to_not_convert, current_key_name, quantization_config
|
||||
)
|
||||
|
||||
has_been_replaced = any(
|
||||
isinstance(replaced_module, SVDQW4A4Linear) for _, replaced_module in model.named_modules()
|
||||
)
|
||||
if not has_been_replaced:
|
||||
logger.warning(
|
||||
"You are loading your model in the SVDQuant method but no linear modules were found in your model."
|
||||
" Please double check your model architecture, or submit an issue on github if you think this is"
|
||||
" a bug."
|
||||
)
|
||||
|
||||
return model
|
||||
@@ -46,7 +46,6 @@ class QuantizationMethod(str, Enum):
|
||||
GGUF = "gguf"
|
||||
TORCHAO = "torchao"
|
||||
QUANTO = "quanto"
|
||||
NUNCHAKU = "nunchaku"
|
||||
|
||||
|
||||
if is_torchao_available():
|
||||
@@ -725,72 +724,3 @@ class QuantoConfig(QuantizationConfigMixin):
|
||||
accepted_weights = ["float8", "int8", "int4", "int2"]
|
||||
if self.weights_dtype not in accepted_weights:
|
||||
raise ValueError(f"Only support weights in {accepted_weights} but found {self.weights_dtype}")
|
||||
|
||||
|
||||
class NunchakuConfig(QuantizationConfigMixin):
|
||||
"""
|
||||
This is a wrapper class about all possible attributes and features that you can play with a model that has been
|
||||
loaded using `nunchaku`.
|
||||
|
||||
Args:
|
||||
TODO
|
||||
modules_to_not_convert (`list`, *optional*, default to `None`):
|
||||
The list of modules to not quantize, useful for quantizing models that explicitly require to have some
|
||||
modules left in their original precision (e.g. `norm` layers in Qwen-Image).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
method: str = "svdquant",
|
||||
weight_dtype: str = "int4",
|
||||
weight_scale_dtype: str = None,
|
||||
weight_group_size: int = 64,
|
||||
activation_dtype: str = "int4",
|
||||
activation_scale_dtype: str = None,
|
||||
activation_group_size: int = 64,
|
||||
rank: int = 32,
|
||||
modules_to_not_convert: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.quant_method = QuantizationMethod.NUNCHAKU
|
||||
self.method = method
|
||||
self.weight_dtype = weight_dtype
|
||||
self.weight_scale_dtype = weight_scale_dtype
|
||||
self.weight_group_size = weight_group_size
|
||||
self.activation_dtype = activation_dtype
|
||||
self.activation_scale_dtype = activation_scale_dtype
|
||||
self.activation_group_size = activation_group_size
|
||||
self.rank = rank
|
||||
self.modules_to_not_convert = modules_to_not_convert
|
||||
|
||||
self.post_init()
|
||||
|
||||
def post_init(self):
|
||||
r"""
|
||||
Safety checker that arguments are correct. Hardware checks were largely adapted from the official `nunchaku`
|
||||
codebase.
|
||||
"""
|
||||
from ..utils.torch_utils import get_device
|
||||
|
||||
device = get_device()
|
||||
if isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
capability = torch.cuda.get_device_capability(0 if device.index is None else device.index)
|
||||
sm = f"{capability[0]}{capability[1]}"
|
||||
if sm == "120": # you can only use the fp4 models
|
||||
if self.weight_dtype != "fp4_e2m1_all":
|
||||
raise ValueError('Please use "fp4" quantization for Blackwell GPUs.')
|
||||
elif sm in ["75", "80", "86", "89"]:
|
||||
if self.weight_dtype != "int4":
|
||||
raise ValueError('Please use "int4" quantization for Turing, Ampere and Ada GPUs.')
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported GPU architecture {sm} due to the lack of 4-bit tensorcores. "
|
||||
"Please use a Turing, Ampere, Ada or Blackwell GPU for this quantization configuration."
|
||||
)
|
||||
|
||||
# TODO: should there be a check for rank?
|
||||
|
||||
def __repr__(self):
|
||||
config_dict = self.to_dict()
|
||||
return f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True)}\n"
|
||||
|
||||
@@ -89,7 +89,6 @@ from .import_utils import (
|
||||
is_matplotlib_available,
|
||||
is_nltk_available,
|
||||
is_note_seq_available,
|
||||
is_nunchaku_available,
|
||||
is_onnx_available,
|
||||
is_opencv_available,
|
||||
is_optimum_quanto_available,
|
||||
|
||||
@@ -1,17 +0,0 @@
|
||||
# This file is autogenerated by the command `make fix-copies`, do not edit.
|
||||
from ..utils import DummyObject, requires_backends
|
||||
|
||||
|
||||
class NunchakuConfig(metaclass=DummyObject):
|
||||
_backends = ["nunchaku"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["nunchaku"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["nunchaku"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["nunchaku"])
|
||||
@@ -217,7 +217,6 @@ _gguf_available, _gguf_version = _is_package_available("gguf")
|
||||
_torchao_available, _torchao_version = _is_package_available("torchao")
|
||||
_bitsandbytes_available, _bitsandbytes_version = _is_package_available("bitsandbytes")
|
||||
_optimum_quanto_available, _optimum_quanto_version = _is_package_available("optimum", get_dist_name=True)
|
||||
_nunchaku_available, _nunchaku_version = _is_package_available("nunchaku", get_dist_name=True)
|
||||
_pytorch_retinaface_available, _pytorch_retinaface_version = _is_package_available("pytorch_retinaface")
|
||||
_better_profanity_available, _better_profanity_version = _is_package_available("better_profanity")
|
||||
_nltk_available, _nltk_version = _is_package_available("nltk")
|
||||
@@ -364,10 +363,6 @@ def is_optimum_quanto_available():
|
||||
return _optimum_quanto_available
|
||||
|
||||
|
||||
def is_nunchaku_available():
|
||||
return _nunchaku_available
|
||||
|
||||
|
||||
def is_timm_available():
|
||||
return _timm_available
|
||||
|
||||
@@ -821,7 +816,7 @@ def is_k_diffusion_version(operation: str, version: str):
|
||||
|
||||
def is_optimum_quanto_version(operation: str, version: str):
|
||||
"""
|
||||
Compares the current quanto version to a given reference with an operation.
|
||||
Compares the current Accelerate version to a given reference with an operation.
|
||||
|
||||
Args:
|
||||
operation (`str`):
|
||||
@@ -834,21 +829,6 @@ def is_optimum_quanto_version(operation: str, version: str):
|
||||
return compare_versions(parse(_optimum_quanto_version), operation, version)
|
||||
|
||||
|
||||
def is_nunchaku_version(operation: str, version: str):
|
||||
"""
|
||||
Compares the current nunchaku version to a given reference with an operation.
|
||||
|
||||
Args:
|
||||
operation (`str`):
|
||||
A string representation of an operator, such as `">"` or `"<="`
|
||||
version (`str`):
|
||||
A version string
|
||||
"""
|
||||
if not _nunchaku_available:
|
||||
return False
|
||||
return compare_versions(parse(_nunchaku_version), operation, version)
|
||||
|
||||
|
||||
def is_xformers_version(operation: str, version: str):
|
||||
"""
|
||||
Compares the current xformers version to a given reference with an operation.
|
||||
|
||||
@@ -197,7 +197,3 @@ def device_synchronize(device_type: Optional[str] = None):
|
||||
device_type = get_device()
|
||||
device_mod = getattr(torch, device_type, torch.cuda)
|
||||
device_mod.synchronize()
|
||||
|
||||
|
||||
def is_fp8_available():
|
||||
return getattr(torch, "float8_e4m3fn", None) is None
|
||||
|
||||
+2
-2
@@ -67,7 +67,7 @@ class SDXLModularTests:
|
||||
|
||||
def get_pipeline(self, components_manager=None, torch_dtype=torch.float32):
|
||||
pipeline = self.pipeline_blocks_class().init_pipeline(self.repo, components_manager=components_manager)
|
||||
pipeline.load_default_components(torch_dtype=torch_dtype)
|
||||
pipeline.load_components(torch_dtype=torch_dtype)
|
||||
return pipeline
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
@@ -158,7 +158,7 @@ class SDXLModularIPAdapterTests:
|
||||
blocks = self.pipeline_blocks_class()
|
||||
_ = blocks.sub_blocks.pop("ip_adapter")
|
||||
pipe = blocks.init_pipeline(self.repo)
|
||||
pipe.load_default_components(torch_dtype=torch.float32)
|
||||
pipe.load_components(torch_dtype=torch.float32)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
cross_attention_dim = pipe.unet.config.get("cross_attention_dim")
|
||||
|
||||
@@ -343,7 +343,7 @@ class ModularPipelineTesterMixin:
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
base_pipe.save_pretrained(tmpdirname)
|
||||
pipe = ModularPipeline.from_pretrained(tmpdirname).to(torch_device)
|
||||
pipe.load_default_components(torch_dtype=torch.float32)
|
||||
pipe.load_components(torch_dtype=torch.float32)
|
||||
pipe.to(torch_device)
|
||||
|
||||
pipes.append(pipe)
|
||||
|
||||
@@ -28,10 +28,10 @@ from diffusers import (
|
||||
)
|
||||
from diffusers.pipelines.bria import BriaPipeline
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
enable_full_determinism,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_accelerator,
|
||||
require_torch_gpu,
|
||||
require_torch_accelerator,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
@@ -149,7 +149,7 @@ class BriaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
assert (output_height, output_width) == (expected_height, expected_width)
|
||||
|
||||
@unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU")
|
||||
@require_accelerator
|
||||
@require_torch_accelerator
|
||||
def test_save_load_float16(self, expected_max_diff=1e-2):
|
||||
components = self.get_dummy_components()
|
||||
for name, module in components.items():
|
||||
@@ -237,7 +237,7 @@ class BriaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator
|
||||
class BriaPipelineSlowTests(unittest.TestCase):
|
||||
pipeline_class = BriaPipeline
|
||||
repo_id = "briaai/BRIA-3.2"
|
||||
@@ -245,12 +245,12 @@ class BriaPipelineSlowTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def get_inputs(self, device, seed=0):
|
||||
generator = torch.Generator(device="cpu").manual_seed(seed)
|
||||
|
||||
Reference in New Issue
Block a user