Compare commits

..

22 Commits

Author SHA1 Message Date
sayakpaul 1c91475008 up 2025-11-11 17:54:01 +05:30
sayakpaul 6375c02130 resolve conflicts., 2025-11-11 17:52:53 +05:30
Dhruv Nair 66e6a0215f [CI] Remove unittest dependency from testing_utils.py (#12621)
* update

* Update tests/testing_utils.py

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Update tests/testing_utils.py

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Apply style fixes

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-11-11 16:40:39 +05:30
Sayak Paul e0b1383868 Merge branch 'main' into custom-modular-tests 2025-11-11 09:39:22 +05:30
Cesaryuan 5a47442f92 Fix: update type hints for Tuple parameters across multiple files to support variable-length tuples (#12544)
* Fix: update type hints for Tuple parameters across multiple files to support variable-length tuples

* Apply style fixes

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-11-10 13:57:52 -08:00
Dhruv Nair 8f6328c4a4 [Modular] Clean up docs (#12604)
update

Co-authored-by: YiYi Xu <yixu310@gmail.com>
2025-11-10 23:37:29 +05:30
Dhruv Nair 8d45f219d0 Fix Context Parallel validation checks (#12446)
* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-11-10 23:37:07 +05:30
Sayak Paul 54ddce87fd Merge branch 'main' into custom-modular-tests 2025-11-10 09:56:58 +05:30
Sayak Paul c0ce538afc Apply suggestions from code review 2025-11-03 08:31:06 +05:30
Sayak Paul fd88f3d3fc Merge branch 'main' into custom-modular-tests 2025-11-03 08:28:52 +05:30
Sayak Paul ea4f29f0e8 Merge branch 'main' into custom-modular-tests 2025-10-31 15:53:03 +05:30
sayakpaul b8809f76d5 up 2025-10-31 15:52:19 +05:30
Sayak Paul 728655ca01 Merge branch 'main' into custom-modular-tests 2025-10-30 08:47:18 +05:30
sayakpaul 9f113f8138 up 2025-10-29 21:25:21 +05:30
sayakpaul b5f13d9b59 up 2025-10-29 18:28:06 +05:30
sayakpaul ddb5ba734d up 2025-10-29 18:27:31 +05:30
sayakpaul 5f1afc11ac up 2025-10-29 18:19:07 +05:30
sayakpaul ecdd843044 up 2025-10-29 17:10:10 +05:30
sayakpaul 316b71ff2b style. 2025-10-29 17:03:34 +05:30
sayakpaul 1be88f036f up 2025-10-29 17:03:02 +05:30
sayakpaul 77e50155e6 simplify modular workflow ci. 2025-10-29 16:43:39 +05:30
sayakpaul 760a9149a7 start custom block testing. 2025-10-29 16:40:53 +05:30
36 changed files with 504 additions and 231 deletions
+30 -46
View File
@@ -77,62 +77,46 @@ jobs:
run_fast_tests:
needs: [check_code_quality, check_repository_consistency]
strategy:
fail-fast: false
matrix:
config:
- name: Fast PyTorch Modular Pipeline CPU tests
framework: pytorch_pipelines
runner: aws-highmemory-32-plus
image: diffusers/diffusers-pytorch-cpu
report: torch_cpu_modular_pipelines
name: ${{ matrix.config.name }}
name: Fast PyTorch Modular Pipeline CPU tests
runs-on:
group: ${{ matrix.config.runner }}
group: aws-highmemory-32-plus
container:
image: ${{ matrix.config.image }}
image: diffusers/diffusers-pytorch-cpu
options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/
defaults:
run:
shell: bash
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
with:
fetch-depth: 2
- name: Checkout diffusers
uses: actions/checkout@v3
with:
fetch-depth: 2
- name: Install dependencies
run: |
uv pip install -e ".[quality]"
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
- name: Install dependencies
run: |
uv pip install -e ".[quality]"
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
- name: Environment
run: |
python utils/print_env.py
- name: Environment
run: |
python utils/print_env.py
- name: Run fast PyTorch Pipeline CPU tests
if: ${{ matrix.config.framework == 'pytorch_pipelines' }}
run: |
pytest -n 8 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx" \
--make-reports=tests_${{ matrix.config.report }} \
tests/modular_pipelines
- name: Failure short reports
if: ${{ failure() }}
run: cat reports/tests_${{ matrix.config.report }}_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v4
with:
name: pr_${{ matrix.config.framework }}_${{ matrix.config.report }}_test_reports
path: reports
- name: Run fast PyTorch Pipeline CPU tests
run: |
pytest -n 8 --max-worker-restart=0 --dist=loadfile \
-s -v \
--make-reports=tests_torch_cpu_modular_pipelines \
tests/modular_pipelines
- name: Failure short reports
if: ${{ failure() }}
run: cat reports/tests_torch_cpu_modular_pipelines_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v4
with:
name: pr_pytorch_pipelines_torch_cpu_modular_pipelines_test_reports
path: reports
@@ -12,7 +12,7 @@ specific language governing permissions and limitations under the License.
# LoopSequentialPipelineBlocks
[`~modular_pipelines.LoopSequentialPipelineBlocks`] are a multi-block type that composes other [`~modular_pipelines.ModularPipelineBlocks`] together in a loop. Data flows circularly, using `intermediate_inputs` and `intermediate_outputs`, and each block is run iteratively. This is typically used to create a denoising loop which is iterative by default.
[`~modular_pipelines.LoopSequentialPipelineBlocks`] are a multi-block type that composes other [`~modular_pipelines.ModularPipelineBlocks`] together in a loop. Data flows circularly, using `inputs` and `intermediate_outputs`, and each block is run iteratively. This is typically used to create a denoising loop which is iterative by default.
This guide shows you how to create [`~modular_pipelines.LoopSequentialPipelineBlocks`].
@@ -21,7 +21,6 @@ This guide shows you how to create [`~modular_pipelines.LoopSequentialPipelineBl
[`~modular_pipelines.LoopSequentialPipelineBlocks`], is also known as the *loop wrapper* because it defines the loop structure, iteration variables, and configuration. Within the loop wrapper, you need the following variables.
- `loop_inputs` are user provided values and equivalent to [`~modular_pipelines.ModularPipelineBlocks.inputs`].
- `loop_intermediate_inputs` are intermediate variables from the [`~modular_pipelines.PipelineState`] and equivalent to [`~modular_pipelines.ModularPipelineBlocks.intermediate_inputs`].
- `loop_intermediate_outputs` are new intermediate variables created by the block and added to the [`~modular_pipelines.PipelineState`]. It is equivalent to [`~modular_pipelines.ModularPipelineBlocks.intermediate_outputs`].
- `__call__` method defines the loop structure and iteration logic.
@@ -90,4 +89,4 @@ Add more loop blocks to run within each iteration with [`~modular_pipelines.Loop
```py
loop = LoopWrapper.from_blocks_dict({"block1": LoopBlock(), "block2": LoopBlock})
```
```
@@ -37,17 +37,7 @@ A [`~modular_pipelines.ModularPipelineBlocks`] requires `inputs`, and `intermedi
]
```
- `intermediate_inputs` are values typically created from a previous block but it can also be directly provided if no preceding block generates them. Unlike `inputs`, `intermediate_inputs` can be modified.
Use `InputParam` to define `intermediate_inputs`.
```py
user_intermediate_inputs = [
InputParam(name="processed_image", type_hint="torch.Tensor", description="image that has been preprocessed and normalized"),
]
```
- `intermediate_outputs` are new values created by a block and added to the [`~modular_pipelines.PipelineState`]. The `intermediate_outputs` are available as `intermediate_inputs` for subsequent blocks or available as the final output from running the pipeline.
- `intermediate_outputs` are new values created by a block and added to the [`~modular_pipelines.PipelineState`]. The `intermediate_outputs` are available as `inputs` for subsequent blocks or available as the final output from running the pipeline.
Use `OutputParam` to define `intermediate_outputs`.
@@ -65,8 +55,8 @@ The intermediate inputs and outputs share data to connect blocks. They are acces
The computation a block performs is defined in the `__call__` method and it follows a specific structure.
1. Retrieve the [`~modular_pipelines.BlockState`] to get a local view of the `inputs` and `intermediate_inputs`.
2. Implement the computation logic on the `inputs` and `intermediate_inputs`.
1. Retrieve the [`~modular_pipelines.BlockState`] to get a local view of the `inputs`
2. Implement the computation logic on the `inputs`.
3. Update [`~modular_pipelines.PipelineState`] to push changes from the local [`~modular_pipelines.BlockState`] back to the global [`~modular_pipelines.PipelineState`].
4. Return the components and state which becomes available to the next block.
@@ -76,7 +66,7 @@ def __call__(self, components, state):
block_state = self.get_block_state(state)
# Your computation logic here
# block_state contains all your inputs and intermediate_inputs
# block_state contains all your inputs
# Access them like: block_state.image, block_state.processed_image
# Update the pipeline state with your updated block_states
@@ -112,4 +102,4 @@ def __call__(self, components, state):
unet = components.unet
vae = components.vae
scheduler = components.scheduler
```
```
@@ -183,7 +183,7 @@ from diffusers.modular_pipelines import ComponentsManager
components = ComponentManager()
dd_pipeline = dd_blocks.init_pipeline("YiYiXu/modular-demo-auto", components_manager=components, collection="diffdiff")
dd_pipeline.load_default_componenets(torch_dtype=torch.float16)
dd_pipeline.load_componenets(torch_dtype=torch.float16)
dd_pipeline.to("cuda")
```
@@ -12,11 +12,11 @@ specific language governing permissions and limitations under the License.
# SequentialPipelineBlocks
[`~modular_pipelines.SequentialPipelineBlocks`] are a multi-block type that composes other [`~modular_pipelines.ModularPipelineBlocks`] together in a sequence. Data flows linearly from one block to the next using `intermediate_inputs` and `intermediate_outputs`. Each block in [`~modular_pipelines.SequentialPipelineBlocks`] usually represents a step in the pipeline, and by combining them, you gradually build a pipeline.
[`~modular_pipelines.SequentialPipelineBlocks`] are a multi-block type that composes other [`~modular_pipelines.ModularPipelineBlocks`] together in a sequence. Data flows linearly from one block to the next using `inputs` and `intermediate_outputs`. Each block in [`~modular_pipelines.SequentialPipelineBlocks`] usually represents a step in the pipeline, and by combining them, you gradually build a pipeline.
This guide shows you how to connect two blocks into a [`~modular_pipelines.SequentialPipelineBlocks`].
Create two [`~modular_pipelines.ModularPipelineBlocks`]. The first block, `InputBlock`, outputs a `batch_size` value and the second block, `ImageEncoderBlock` uses `batch_size` as `intermediate_inputs`.
Create two [`~modular_pipelines.ModularPipelineBlocks`]. The first block, `InputBlock`, outputs a `batch_size` value and the second block, `ImageEncoderBlock` uses `batch_size` as `inputs`.
<hfoptions id="sequential">
<hfoption id="InputBlock">
@@ -110,4 +110,4 @@ Inspect the sub-blocks in [`~modular_pipelines.SequentialPipelineBlocks`] by cal
```py
print(blocks)
print(blocks.doc)
```
```
+1 -1
View File
@@ -45,7 +45,7 @@ def check_size(image, height, width):
raise ValueError(f"Image size should be {height}x{width}, but got {h}x{w}")
def overlay_inner_image(image, inner_image, paste_offset: Tuple[int] = (0, 0)):
def overlay_inner_image(image, inner_image, paste_offset: Tuple[int, ...] = (0, 0)):
inner_image = inner_image.convert("RGBA")
image = image.convert("RGB")
+11 -6
View File
@@ -1966,16 +1966,21 @@ class MatryoshkaUNet2DConditionModel(
center_input_sample: bool = False,
flip_sin_to_cos: bool = True,
freq_shift: int = 0,
down_block_types: Tuple[str] = (
down_block_types: Tuple[str, ...] = (
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"DownBlock2D",
),
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
up_block_types: Tuple[str, ...] = (
"UpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
),
only_cross_attention: Union[bool, Tuple[bool]] = False,
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
layers_per_block: Union[int, Tuple[int]] = 2,
downsample_padding: int = 1,
mid_block_scale_factor: float = 1,
@@ -2294,10 +2299,10 @@ class MatryoshkaUNet2DConditionModel(
def _check_config(
self,
down_block_types: Tuple[str],
up_block_types: Tuple[str],
down_block_types: Tuple[str, ...],
up_block_types: Tuple[str, ...],
only_cross_attention: Union[bool, Tuple[bool]],
block_out_channels: Tuple[int],
block_out_channels: Tuple[int, ...],
layers_per_block: Union[int, Tuple[int]],
cross_attention_dim: Union[int, Tuple[int]],
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple[int]]],
@@ -438,16 +438,21 @@ class UNet2DConditionModel(OriginalUNet2DConditionModel, ConfigMixin, UNet2DCond
center_input_sample: bool = False,
flip_sin_to_cos: bool = True,
freq_shift: int = 0,
down_block_types: Tuple[str] = (
down_block_types: Tuple[str, ...] = (
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"DownBlock2D",
),
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
up_block_types: Tuple[str, ...] = (
"UpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
),
only_cross_attention: Union[bool, Tuple[bool]] = False,
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
layers_per_block: Union[int, Tuple[int]] = 2,
downsample_padding: int = 1,
mid_block_scale_factor: float = 1,
@@ -102,7 +102,7 @@ def get_block(
attention_head_dim: int,
norm_type: str,
act_fn: str,
qkv_mutliscales: Tuple[int] = (),
qkv_mutliscales: Tuple[int, ...] = (),
):
if block_type == "ResBlock":
block = ResBlock(in_channels, out_channels, norm_type, act_fn)
@@ -206,8 +206,8 @@ class Encoder(nn.Module):
latent_channels: int,
attention_head_dim: int = 32,
block_type: Union[str, Tuple[str]] = "ResBlock",
block_out_channels: Tuple[int] = (128, 256, 512, 512, 1024, 1024),
layers_per_block: Tuple[int] = (2, 2, 2, 2, 2, 2),
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512, 1024, 1024),
layers_per_block: Tuple[int, ...] = (2, 2, 2, 2, 2, 2),
qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)),
downsample_block_type: str = "pixel_unshuffle",
out_shortcut: bool = True,
@@ -292,8 +292,8 @@ class Decoder(nn.Module):
latent_channels: int,
attention_head_dim: int = 32,
block_type: Union[str, Tuple[str]] = "ResBlock",
block_out_channels: Tuple[int] = (128, 256, 512, 512, 1024, 1024),
layers_per_block: Tuple[int] = (2, 2, 2, 2, 2, 2),
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512, 1024, 1024),
layers_per_block: Tuple[int, ...] = (2, 2, 2, 2, 2, 2),
qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)),
norm_type: Union[str, Tuple[str]] = "rms_norm",
act_fn: Union[str, Tuple[str]] = "silu",
@@ -440,8 +440,8 @@ class AutoencoderDC(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModel
decoder_block_types: Union[str, Tuple[str]] = "ResBlock",
encoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512, 1024, 1024),
decoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512, 1024, 1024),
encoder_layers_per_block: Tuple[int] = (2, 2, 2, 3, 3, 3),
decoder_layers_per_block: Tuple[int] = (3, 3, 3, 3, 3, 3),
encoder_layers_per_block: Tuple[int, ...] = (2, 2, 2, 3, 3, 3),
decoder_layers_per_block: Tuple[int, ...] = (3, 3, 3, 3, 3, 3),
encoder_qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)),
decoder_qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)),
upsample_block_type: str = "pixel_shuffle",
@@ -78,9 +78,9 @@ class AutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModel
self,
in_channels: int = 3,
out_channels: int = 3,
down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
block_out_channels: Tuple[int] = (64,),
down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",),
up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",),
block_out_channels: Tuple[int, ...] = (64,),
layers_per_block: int = 1,
act_fn: str = "silu",
latent_channels: int = 4,
@@ -995,19 +995,19 @@ class AutoencoderKLCogVideoX(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
self,
in_channels: int = 3,
out_channels: int = 3,
down_block_types: Tuple[str] = (
down_block_types: Tuple[str, ...] = (
"CogVideoXDownBlock3D",
"CogVideoXDownBlock3D",
"CogVideoXDownBlock3D",
"CogVideoXDownBlock3D",
),
up_block_types: Tuple[str] = (
up_block_types: Tuple[str, ...] = (
"CogVideoXUpBlock3D",
"CogVideoXUpBlock3D",
"CogVideoXUpBlock3D",
"CogVideoXUpBlock3D",
),
block_out_channels: Tuple[int] = (128, 256, 256, 512),
block_out_channels: Tuple[int, ...] = (128, 256, 256, 512),
latent_channels: int = 16,
layers_per_block: int = 3,
act_fn: str = "silu",
@@ -653,7 +653,7 @@ class AutoencoderKLHunyuanVideo(ModelMixin, AutoencoderMixin, ConfigMixin):
"HunyuanVideoUpBlock3D",
"HunyuanVideoUpBlock3D",
),
block_out_channels: Tuple[int] = (128, 256, 512, 512),
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
layers_per_block: int = 2,
act_fn: str = "silu",
norm_num_groups: int = 32,
@@ -601,7 +601,7 @@ class AutoencoderKLHunyuanImageRefiner(ModelMixin, ConfigMixin):
in_channels: int = 3,
out_channels: int = 3,
latent_channels: int = 32,
block_out_channels: Tuple[int] = (128, 256, 512, 1024, 1024),
block_out_channels: Tuple[int, ...] = (128, 256, 512, 1024, 1024),
layers_per_block: int = 2,
spatial_compression_ratio: int = 16,
temporal_compression_ratio: int = 4,
@@ -688,8 +688,8 @@ class AutoencoderKLMochi(ModelMixin, AutoencoderMixin, ConfigMixin):
self,
in_channels: int = 15,
out_channels: int = 3,
encoder_block_out_channels: Tuple[int] = (64, 128, 256, 384),
decoder_block_out_channels: Tuple[int] = (128, 256, 512, 768),
encoder_block_out_channels: Tuple[int, ...] = (64, 128, 256, 384),
decoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 768),
latent_channels: int = 12,
layers_per_block: Tuple[int, ...] = (3, 3, 4, 6, 3),
act_fn: str = "silu",
@@ -679,7 +679,7 @@ class AutoencoderKLQwenImage(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
self,
base_dim: int = 96,
z_dim: int = 16,
dim_mult: Tuple[int] = [1, 2, 4, 4],
dim_mult: Tuple[int, ...] = (1, 2, 4, 4),
num_res_blocks: int = 2,
attn_scales: List[float] = [],
temperal_downsample: List[bool] = [False, True, True],
@@ -31,7 +31,7 @@ class TemporalDecoder(nn.Module):
self,
in_channels: int = 4,
out_channels: int = 3,
block_out_channels: Tuple[int] = (128, 256, 512, 512),
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
layers_per_block: int = 2,
):
super().__init__()
@@ -172,8 +172,8 @@ class AutoencoderKLTemporalDecoder(ModelMixin, AutoencoderMixin, ConfigMixin):
self,
in_channels: int = 3,
out_channels: int = 3,
down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
block_out_channels: Tuple[int] = (64,),
down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",),
block_out_channels: Tuple[int, ...] = (64,),
layers_per_block: int = 1,
latent_channels: int = 4,
sample_size: int = 32,
@@ -971,7 +971,7 @@ class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMo
base_dim: int = 96,
decoder_base_dim: Optional[int] = None,
z_dim: int = 16,
dim_mult: Tuple[int] = [1, 2, 4, 4],
dim_mult: Tuple[int, ...] = (1, 2, 4, 4),
num_res_blocks: int = 2,
attn_scales: List[float] = [],
temperal_downsample: List[bool] = [False, True, True],
@@ -293,14 +293,14 @@ class ControlNetXSAdapter(ModelMixin, ConfigMixin):
self,
conditioning_channels: int = 3,
conditioning_channel_order: str = "rgb",
conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256),
conditioning_embedding_out_channels: Tuple[int, ...] = (16, 32, 96, 256),
time_embedding_mix: float = 1.0,
learn_time_embedding: bool = False,
num_attention_heads: Union[int, Tuple[int]] = 4,
block_out_channels: Tuple[int] = (4, 8, 16, 16),
base_block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
block_out_channels: Tuple[int, ...] = (4, 8, 16, 16),
base_block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
cross_attention_dim: int = 1024,
down_block_types: Tuple[str] = (
down_block_types: Tuple[str, ...] = (
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
@@ -436,7 +436,7 @@ class ControlNetXSAdapter(ModelMixin, ConfigMixin):
time_embedding_mix: int = 1.0,
conditioning_channels: int = 3,
conditioning_channel_order: str = "rgb",
conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256),
conditioning_embedding_out_channels: Tuple[int, ...] = (16, 32, 96, 256),
):
r"""
Instantiate a [`ControlNetXSAdapter`] from a [`UNet2DConditionModel`].
@@ -529,14 +529,19 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
self,
# unet configs
sample_size: Optional[int] = 96,
down_block_types: Tuple[str] = (
down_block_types: Tuple[str, ...] = (
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"DownBlock2D",
),
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
up_block_types: Tuple[str, ...] = (
"UpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
),
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
norm_num_groups: Optional[int] = 32,
cross_attention_dim: Union[int, Tuple[int]] = 1024,
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
@@ -550,10 +555,10 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
# additional controlnet configs
time_embedding_mix: float = 1.0,
ctrl_conditioning_channels: int = 3,
ctrl_conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256),
ctrl_conditioning_embedding_out_channels: Tuple[int, ...] = (16, 32, 96, 256),
ctrl_conditioning_channel_order: str = "rgb",
ctrl_learn_time_embedding: bool = False,
ctrl_block_out_channels: Tuple[int] = (4, 8, 16, 16),
ctrl_block_out_channels: Tuple[int, ...] = (4, 8, 16, 16),
ctrl_num_attention_heads: Union[int, Tuple[int]] = 4,
ctrl_max_norm_num_groups: int = 32,
):
@@ -914,7 +914,7 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
text_embed_dim: int = 4096,
pooled_projection_dim: int = 768,
rope_theta: float = 256.0,
rope_axes_dim: Tuple[int] = (16, 56, 56),
rope_axes_dim: Tuple[int, ...] = (16, 56, 56),
image_condition_type: Optional[str] = None,
) -> None:
super().__init__()
@@ -139,7 +139,7 @@ class HunyuanVideoFramepackTransformer3DModel(
text_embed_dim: int = 4096,
pooled_projection_dim: int = 768,
rope_theta: float = 256.0,
rope_axes_dim: Tuple[int] = (16, 56, 56),
rope_axes_dim: Tuple[int, ...] = (16, 56, 56),
image_condition_type: Optional[str] = None,
has_image_proj: int = False,
image_proj_dim: int = 1152,
@@ -689,7 +689,7 @@ class HunyuanImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
text_embed_dim: int = 3584,
text_embed_2_dim: Optional[int] = None,
rope_theta: float = 256.0,
rope_axes_dim: Tuple[int] = (64, 64),
rope_axes_dim: Tuple[int, ...] = (64, 64),
use_meanflow: bool = False,
) -> None:
super().__init__()
@@ -570,7 +570,7 @@ class SkyReelsV2Transformer3DModel(
@register_to_config
def __init__(
self,
patch_size: Tuple[int] = (1, 2, 2),
patch_size: Tuple[int, ...] = (1, 2, 2),
num_attention_heads: int = 16,
attention_head_dim: int = 128,
in_channels: int = 16,
@@ -563,7 +563,7 @@ class WanTransformer3DModel(
@register_to_config
def __init__(
self,
patch_size: Tuple[int] = (1, 2, 2),
patch_size: Tuple[int, ...] = (1, 2, 2),
num_attention_heads: int = 40,
attention_head_dim: int = 128,
in_channels: int = 16,
@@ -182,7 +182,7 @@ class WanVACETransformer3DModel(
@register_to_config
def __init__(
self,
patch_size: Tuple[int] = (1, 2, 2),
patch_size: Tuple[int, ...] = (1, 2, 2),
num_attention_heads: int = 40,
attention_head_dim: int = 128,
in_channels: int = 16,
+4 -4
View File
@@ -86,11 +86,11 @@ class UNet1DModel(ModelMixin, ConfigMixin):
flip_sin_to_cos: bool = True,
use_timestep_embedding: bool = False,
freq_shift: float = 0.0,
down_block_types: Tuple[str] = ("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D"),
up_block_types: Tuple[str] = ("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip"),
mid_block_type: Tuple[str] = "UNetMidBlock1D",
down_block_types: Tuple[str, ...] = ("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D"),
up_block_types: Tuple[str, ...] = ("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip"),
mid_block_type: str = "UNetMidBlock1D",
out_block_type: str = None,
block_out_channels: Tuple[int] = (32, 32, 64),
block_out_channels: Tuple[int, ...] = (32, 32, 64),
act_fn: str = None,
norm_num_groups: int = 8,
layers_per_block: int = 1,
@@ -177,16 +177,21 @@ class UNet2DConditionModel(
center_input_sample: bool = False,
flip_sin_to_cos: bool = True,
freq_shift: int = 0,
down_block_types: Tuple[str] = (
down_block_types: Tuple[str, ...] = (
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"DownBlock2D",
),
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
up_block_types: Tuple[str, ...] = (
"UpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
),
only_cross_attention: Union[bool, Tuple[bool]] = False,
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
layers_per_block: Union[int, Tuple[int]] = 2,
downsample_padding: int = 1,
mid_block_scale_factor: float = 1,
@@ -486,10 +491,10 @@ class UNet2DConditionModel(
def _check_config(
self,
down_block_types: Tuple[str],
up_block_types: Tuple[str],
down_block_types: Tuple[str, ...],
up_block_types: Tuple[str, ...],
only_cross_attention: Union[bool, Tuple[bool]],
block_out_channels: Tuple[int],
block_out_channels: Tuple[int, ...],
layers_per_block: Union[int, Tuple[int]],
cross_attention_dim: Union[int, Tuple[int]],
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple[int]]],
@@ -54,7 +54,7 @@ class Kandinsky3UNet(ModelMixin, ConfigMixin):
groups: int = 32,
attention_head_dim: int = 64,
layers_per_block: Union[int, Tuple[int]] = 3,
block_out_channels: Tuple[int] = (384, 768, 1536, 3072),
block_out_channels: Tuple[int, ...] = (384, 768, 1536, 3072),
cross_attention_dim: Union[int, Tuple[int]] = 4096,
encoder_hid_dim: int = 4096,
):
@@ -73,25 +73,25 @@ class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionL
sample_size: Optional[int] = None,
in_channels: int = 8,
out_channels: int = 4,
down_block_types: Tuple[str] = (
down_block_types: Tuple[str, ...] = (
"CrossAttnDownBlockSpatioTemporal",
"CrossAttnDownBlockSpatioTemporal",
"CrossAttnDownBlockSpatioTemporal",
"DownBlockSpatioTemporal",
),
up_block_types: Tuple[str] = (
up_block_types: Tuple[str, ...] = (
"UpBlockSpatioTemporal",
"CrossAttnUpBlockSpatioTemporal",
"CrossAttnUpBlockSpatioTemporal",
"CrossAttnUpBlockSpatioTemporal",
),
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
addition_time_embed_dim: int = 256,
projection_class_embeddings_input_dim: int = 768,
layers_per_block: Union[int, Tuple[int]] = 2,
cross_attention_dim: Union[int, Tuple[int]] = 1024,
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
num_attention_heads: Union[int, Tuple[int]] = (5, 10, 20, 20),
num_attention_heads: Union[int, Tuple[int, ...]] = (5, 10, 20, 20),
num_frames: int = 25,
):
super().__init__()
@@ -145,10 +145,10 @@ class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalModelMixin):
timestep_ratio_embedding_dim: int = 64,
patch_size: int = 1,
conditioning_dim: int = 2048,
block_out_channels: Tuple[int] = (2048, 2048),
num_attention_heads: Tuple[int] = (32, 32),
down_num_layers_per_block: Tuple[int] = (8, 24),
up_num_layers_per_block: Tuple[int] = (24, 8),
block_out_channels: Tuple[int, ...] = (2048, 2048),
num_attention_heads: Tuple[int, ...] = (32, 32),
down_num_layers_per_block: Tuple[int, ...] = (8, 24),
up_num_layers_per_block: Tuple[int, ...] = (24, 8),
down_blocks_repeat_mappers: Optional[Tuple[int]] = (
1,
1,
@@ -167,7 +167,7 @@ class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalModelMixin):
kernel_size=3,
dropout: Union[float, Tuple[float]] = (0.1, 0.1),
self_attn: Union[bool, Tuple[bool]] = True,
timestep_conditioning_type: Tuple[str] = ("sca", "crp"),
timestep_conditioning_type: Tuple[str, ...] = ("sca", "crp"),
switch_level: Optional[Tuple[bool]] = None,
):
"""
+7 -7
View File
@@ -532,8 +532,8 @@ class FlaxEncoder(nn.Module):
in_channels: int = 3
out_channels: int = 3
down_block_types: Tuple[str] = ("DownEncoderBlock2D",)
block_out_channels: Tuple[int] = (64,)
down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",)
block_out_channels: Tuple[int, ...] = (64,)
layers_per_block: int = 2
norm_num_groups: int = 32
act_fn: str = "silu"
@@ -650,8 +650,8 @@ class FlaxDecoder(nn.Module):
in_channels: int = 3
out_channels: int = 3
up_block_types: Tuple[str] = ("UpDecoderBlock2D",)
block_out_channels: int = (64,)
up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",)
block_out_channels: Tuple[int, ...] = (64,)
layers_per_block: int = 2
norm_num_groups: int = 32
act_fn: str = "silu"
@@ -823,9 +823,9 @@ class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin):
in_channels: int = 3
out_channels: int = 3
down_block_types: Tuple[str] = ("DownEncoderBlock2D",)
up_block_types: Tuple[str] = ("UpDecoderBlock2D",)
block_out_channels: Tuple[int] = (64,)
down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",)
up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",)
block_out_channels: Tuple[int, ...] = (64,)
layers_per_block: int = 1
act_fn: str = "silu"
latent_channels: int = 4
@@ -245,16 +245,21 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad
out_channels: int = 4,
flip_sin_to_cos: bool = True,
freq_shift: int = 0,
down_block_types: Tuple[str] = (
down_block_types: Tuple[str, ...] = (
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"DownBlock2D",
),
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
up_block_types: Tuple[str, ...] = (
"UpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
),
only_cross_attention: Union[bool, Tuple[bool]] = False,
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
layers_per_block: Union[int, Tuple[int]] = 2,
downsample_padding: int = 1,
mid_block_scale_factor: float = 1,
@@ -374,21 +374,21 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
center_input_sample: bool = False,
flip_sin_to_cos: bool = True,
freq_shift: int = 0,
down_block_types: Tuple[str] = (
down_block_types: Tuple[str, ...] = (
"CrossAttnDownBlockFlat",
"CrossAttnDownBlockFlat",
"CrossAttnDownBlockFlat",
"DownBlockFlat",
),
mid_block_type: Optional[str] = "UNetMidBlockFlatCrossAttn",
up_block_types: Tuple[str] = (
up_block_types: Tuple[str, ...] = (
"UpBlockFlat",
"CrossAttnUpBlockFlat",
"CrossAttnUpBlockFlat",
"CrossAttnUpBlockFlat",
),
only_cross_attention: Union[bool, Tuple[bool]] = False,
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
layers_per_block: Union[int, Tuple[int]] = 2,
downsample_padding: int = 1,
mid_block_scale_factor: float = 1,
+4 -4
View File
@@ -742,7 +742,7 @@ class ShapEParamsProjModel(ModelMixin, ConfigMixin):
def __init__(
self,
*,
param_names: Tuple[str] = (
param_names: Tuple[str, ...] = (
"nerstf.mlp.0.weight",
"nerstf.mlp.1.weight",
"nerstf.mlp.2.weight",
@@ -786,13 +786,13 @@ class ShapERenderer(ModelMixin, ConfigMixin):
def __init__(
self,
*,
param_names: Tuple[str] = (
param_names: Tuple[str, ...] = (
"nerstf.mlp.0.weight",
"nerstf.mlp.1.weight",
"nerstf.mlp.2.weight",
"nerstf.mlp.3.weight",
),
param_shapes: Tuple[Tuple[int]] = (
param_shapes: Tuple[Tuple[int, int], ...] = (
(256, 93),
(256, 256),
(256, 256),
@@ -804,7 +804,7 @@ class ShapERenderer(ModelMixin, ConfigMixin):
n_hidden_layers: int = 6,
act_fn: str = "swish",
insert_direction_at: int = 4,
background: Tuple[float] = (
background: Tuple[float, ...] = (
255.0,
255.0,
255.0,
+2
View File
@@ -32,6 +32,8 @@ warnings.simplefilter(action="ignore", category=FutureWarning)
def pytest_configure(config):
config.addinivalue_line("markers", "big_accelerator: marks tests as requiring big accelerator resources")
config.addinivalue_line("markers", "slow: mark test as slow")
config.addinivalue_line("markers", "nightly: mark test as nightly")
def pytest_addoption(parser):
@@ -0,0 +1,272 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import tempfile
from collections import deque
from typing import List
import numpy as np
import torch
from diffusers import FluxTransformer2DModel
from diffusers.modular_pipelines import (
ComponentSpec,
InputParam,
ModularPipelineBlocks,
OutputParam,
PipelineState,
WanModularPipeline,
)
from ..testing_utils import nightly, require_torch, slow
class DummyCustomBlockSimple(ModularPipelineBlocks):
def __init__(self, use_dummy_model_component=False):
self.use_dummy_model_component = use_dummy_model_component
super().__init__()
@property
def expected_components(self):
if self.use_dummy_model_component:
return [ComponentSpec("transformer", FluxTransformer2DModel)]
else:
return []
@property
def inputs(self) -> List[InputParam]:
return [InputParam("prompt", type_hint=str, required=True, description="Prompt to use")]
@property
def intermediate_inputs(self) -> List[InputParam]:
return []
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"output_prompt",
type_hint=str,
description="Modified prompt",
)
]
def __call__(self, components, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
old_prompt = block_state.prompt
block_state.output_prompt = "Modular diffusers + " + old_prompt
self.set_block_state(state, block_state)
return components, state
CODE_STR = """
from diffusers.modular_pipelines import (
ComponentSpec,
InputParam,
ModularPipelineBlocks,
OutputParam,
PipelineState,
WanModularPipeline,
)
from typing import List
class DummyCustomBlockSimple(ModularPipelineBlocks):
def __init__(self, use_dummy_model_component=False):
self.use_dummy_model_component = use_dummy_model_component
super().__init__()
@property
def expected_components(self):
if self.use_dummy_model_component:
return [ComponentSpec("transformer", FluxTransformer2DModel)]
else:
return []
@property
def inputs(self) -> List[InputParam]:
return [InputParam("prompt", type_hint=str, required=True, description="Prompt to use")]
@property
def intermediate_inputs(self) -> List[InputParam]:
return []
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"output_prompt",
type_hint=str,
description="Modified prompt",
)
]
def __call__(self, components, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
old_prompt = block_state.prompt
block_state.output_prompt = "Modular diffusers + " + old_prompt
self.set_block_state(state, block_state)
return components, state
"""
class TestModularCustomBlocks:
def _test_block_properties(self, block):
assert not block.expected_components
assert not block.intermediate_inputs
actual_inputs = [inp.name for inp in block.inputs]
actual_intermediate_outputs = [out.name for out in block.intermediate_outputs]
assert actual_inputs == ["prompt"]
assert actual_intermediate_outputs == ["output_prompt"]
def test_custom_block_properties(self):
custom_block = DummyCustomBlockSimple()
self._test_block_properties(custom_block)
def test_custom_block_output(self):
custom_block = DummyCustomBlockSimple()
pipe = custom_block.init_pipeline()
prompt = "Diffusers is nice"
output = pipe(prompt=prompt)
actual_inputs = [inp.name for inp in custom_block.inputs]
actual_intermediate_outputs = [out.name for out in custom_block.intermediate_outputs]
assert sorted(output.values) == sorted(actual_inputs + actual_intermediate_outputs)
output_prompt = output.values["output_prompt"]
assert output_prompt.startswith("Modular diffusers + ")
def test_custom_block_saving_loading(self):
custom_block = DummyCustomBlockSimple()
with tempfile.TemporaryDirectory() as tmpdir:
custom_block.save_pretrained(tmpdir)
assert any("modular_config.json" in k for k in os.listdir(tmpdir))
with open(os.path.join(tmpdir, "modular_config.json"), "r") as f:
config = json.load(f)
auto_map = config["auto_map"]
assert auto_map == {"ModularPipelineBlocks": "test_modular_pipelines_custom_blocks.DummyCustomBlockSimple"}
# For now, the Python script that implements the custom block has to be manually pushed to the Hub.
# This is why, we have to separately save the Python script here.
code_path = os.path.join(tmpdir, "test_modular_pipelines_custom_blocks.py")
with open(code_path, "w") as f:
f.write(CODE_STR)
loaded_custom_block = ModularPipelineBlocks.from_pretrained(tmpdir, trust_remote_code=True)
pipe = loaded_custom_block.init_pipeline()
prompt = "Diffusers is nice"
output = pipe(prompt=prompt)
actual_inputs = [inp.name for inp in loaded_custom_block.inputs]
actual_intermediate_outputs = [out.name for out in loaded_custom_block.intermediate_outputs]
assert sorted(output.values) == sorted(actual_inputs + actual_intermediate_outputs)
output_prompt = output.values["output_prompt"]
assert output_prompt.startswith("Modular diffusers + ")
def test_custom_block_supported_components(self):
custom_block = DummyCustomBlockSimple(use_dummy_model_component=True)
pipe = custom_block.init_pipeline("hf-internal-testing/tiny-flux-kontext-pipe")
pipe.load_components()
assert len(pipe.components) == 1
assert pipe.component_names[0] == "transformer"
def test_custom_block_loads_from_hub(self):
repo_id = "hf-internal-testing/tiny-modular-diffusers-block"
block = ModularPipelineBlocks.from_pretrained(repo_id, trust_remote_code=True)
self._test_block_properties(block)
pipe = block.init_pipeline()
prompt = "Diffusers is nice"
output = pipe(prompt=prompt)
output_prompt = output.values["output_prompt"]
assert output_prompt.startswith("Modular diffusers + ")
@slow
@nightly
@require_torch
class TestKreaCustomBlocksIntegration:
repo_id = "krea/krea-realtime-video"
def test_loading_from_hub(self):
blocks = ModularPipelineBlocks.from_pretrained(self.repo_id, trust_remote_code=True)
block_names = sorted(blocks.sub_blocks)
assert block_names == sorted(["text_encoder", "before_denoise", "denoise", "decode"])
pipe = WanModularPipeline(blocks, self.repo_id)
pipe.load_components(
trust_remote_code=True,
device_map="cuda",
torch_dtype={"default": torch.bfloat16, "vae": torch.float16},
)
assert len(pipe.components) == 7
assert sorted(pipe.components) == sorted(
["text_encoder", "tokenizer", "guider", "scheduler", "vae", "transformer", "video_processor"]
)
def test_forward(self):
blocks = ModularPipelineBlocks.from_pretrained(self.repo_id, trust_remote_code=True)
pipe = WanModularPipeline(blocks, self.repo_id)
pipe.load_components(
trust_remote_code=True,
device_map="cuda",
torch_dtype={"default": torch.bfloat16, "vae": torch.float16},
)
num_frames_per_block = 2
num_blocks = 2
state = PipelineState()
state.set("frame_cache_context", deque(maxlen=pipe.config.frame_cache_len))
prompt = ["a cat sitting on a boat"]
for block in pipe.transformer.blocks:
block.self_attn.fuse_projections()
for block_idx in range(num_blocks):
state = pipe(
state,
prompt=prompt,
num_inference_steps=2,
num_blocks=num_blocks,
num_frames_per_block=num_frames_per_block,
block_idx=block_idx,
generator=torch.manual_seed(42),
)
current_frames = np.array(state.values["videos"][0])
current_frames_flat = current_frames.flatten()
actual_slices = np.concatenate([current_frames_flat[:4], current_frames_flat[-4:]]).tolist()
if block_idx == 0:
assert current_frames.shape == (5, 480, 832, 3)
expected_slices = np.array([211, 229, 238, 208, 195, 180, 188, 193])
else:
assert current_frames.shape == (8, 480, 832, 3)
expected_slices = np.array([179, 203, 214, 176, 194, 181, 187, 191])
assert np.allclose(actual_slices, expected_slices)
+79 -78
View File
@@ -13,7 +13,6 @@ import struct
import sys
import tempfile
import time
import unittest
import urllib.parse
from collections import UserDict
from contextlib import contextmanager
@@ -24,6 +23,7 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Tupl
import numpy as np
import PIL.Image
import PIL.ImageOps
import pytest
import requests
from numpy.linalg import norm
from packaging import version
@@ -267,7 +267,7 @@ def slow(test_case):
Slow tests are skipped by default. Set the RUN_SLOW environment variable to a truthy value to run them.
"""
return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case)
return pytest.mark.skipif(not _run_slow_tests, reason="test is slow")(test_case)
def nightly(test_case):
@@ -277,7 +277,7 @@ def nightly(test_case):
Slow tests are skipped by default. Set the RUN_NIGHTLY environment variable to a truthy value to run them.
"""
return unittest.skipUnless(_run_nightly_tests, "test is nightly")(test_case)
return pytest.mark.skipif(not _run_nightly_tests, reason="test is nightly")(test_case)
def is_torch_compile(test_case):
@@ -287,23 +287,23 @@ def is_torch_compile(test_case):
Compile tests are skipped by default. Set the RUN_COMPILE environment variable to a truthy value to run them.
"""
return unittest.skipUnless(_run_compile_tests, "test is torch compile")(test_case)
return pytest.mark.skipif(not _run_compile_tests, reason="test is torch compile")(test_case)
def require_torch(test_case):
"""
Decorator marking a test that requires PyTorch. These tests are skipped when PyTorch isn't installed.
"""
return unittest.skipUnless(is_torch_available(), "test requires PyTorch")(test_case)
return pytest.mark.skipif(not is_torch_available(), reason="test requires PyTorch")(test_case)
def require_torch_2(test_case):
"""
Decorator marking a test that requires PyTorch 2. These tests are skipped when it isn't installed.
"""
return unittest.skipUnless(is_torch_available() and is_torch_version(">=", "2.0.0"), "test requires PyTorch 2")(
test_case
)
return pytest.mark.skipif(
not (is_torch_available() and is_torch_version(">=", "2.0.0")), reason="test requires PyTorch 2"
)(test_case)
def require_torch_version_greater_equal(torch_version):
@@ -311,8 +311,9 @@ def require_torch_version_greater_equal(torch_version):
def decorator(test_case):
correct_torch_version = is_torch_available() and is_torch_version(">=", torch_version)
return unittest.skipUnless(
correct_torch_version, f"test requires torch with the version greater than or equal to {torch_version}"
return pytest.mark.skipif(
not correct_torch_version,
reason=f"test requires torch with the version greater than or equal to {torch_version}",
)(test_case)
return decorator
@@ -323,8 +324,8 @@ def require_torch_version_greater(torch_version):
def decorator(test_case):
correct_torch_version = is_torch_available() and is_torch_version(">", torch_version)
return unittest.skipUnless(
correct_torch_version, f"test requires torch with the version greater than {torch_version}"
return pytest.mark.skipif(
not correct_torch_version, reason=f"test requires torch with the version greater than {torch_version}"
)(test_case)
return decorator
@@ -332,19 +333,18 @@ def require_torch_version_greater(torch_version):
def require_torch_gpu(test_case):
"""Decorator marking a test that requires CUDA and PyTorch."""
return unittest.skipUnless(is_torch_available() and torch_device == "cuda", "test requires PyTorch+CUDA")(
test_case
)
return pytest.mark.skipif(torch_device != "cuda", reason="test requires PyTorch+CUDA")(test_case)
def require_torch_cuda_compatibility(expected_compute_capability):
def decorator(test_case):
if torch.cuda.is_available():
current_compute_capability = get_torch_cuda_device_capability()
return unittest.skipUnless(
float(current_compute_capability) == float(expected_compute_capability),
"Test not supported for this compute capability.",
)
return pytest.mark.skipif(
float(current_compute_capability) != float(expected_compute_capability),
reason="Test not supported for this compute capability.",
)(test_case)
return test_case
return decorator
@@ -352,9 +352,7 @@ def require_torch_cuda_compatibility(expected_compute_capability):
# These decorators are for accelerator-specific behaviours that are not GPU-specific
def require_torch_accelerator(test_case):
"""Decorator marking a test that requires an accelerator backend and PyTorch."""
return unittest.skipUnless(is_torch_available() and torch_device != "cpu", "test requires accelerator+PyTorch")(
test_case
)
return pytest.mark.skipif(torch_device == "cpu", reason="test requires accelerator+PyTorch")(test_case)
def require_torch_multi_gpu(test_case):
@@ -364,11 +362,11 @@ def require_torch_multi_gpu(test_case):
-k "multi_gpu"
"""
if not is_torch_available():
return unittest.skip("test requires PyTorch")(test_case)
return pytest.mark.skip(reason="test requires PyTorch")(test_case)
import torch
return unittest.skipUnless(torch.cuda.device_count() > 1, "test requires multiple GPUs")(test_case)
return pytest.mark.skipif(torch.cuda.device_count() <= 1, reason="test requires multiple GPUs")(test_case)
def require_torch_multi_accelerator(test_case):
@@ -377,27 +375,28 @@ def require_torch_multi_accelerator(test_case):
without multiple hardware accelerators.
"""
if not is_torch_available():
return unittest.skip("test requires PyTorch")(test_case)
return pytest.mark.skip(reason="test requires PyTorch")(test_case)
import torch
return unittest.skipUnless(
torch.cuda.device_count() > 1 or torch.xpu.device_count() > 1, "test requires multiple hardware accelerators"
return pytest.mark.skipif(
not (torch.cuda.device_count() > 1 or torch.xpu.device_count() > 1),
reason="test requires multiple hardware accelerators",
)(test_case)
def require_torch_accelerator_with_fp16(test_case):
"""Decorator marking a test that requires an accelerator with support for the FP16 data type."""
return unittest.skipUnless(_is_torch_fp16_available(torch_device), "test requires accelerator with fp16 support")(
test_case
)
return pytest.mark.skipif(
not _is_torch_fp16_available(torch_device), reason="test requires accelerator with fp16 support"
)(test_case)
def require_torch_accelerator_with_fp64(test_case):
"""Decorator marking a test that requires an accelerator with support for the FP64 data type."""
return unittest.skipUnless(_is_torch_fp64_available(torch_device), "test requires accelerator with fp64 support")(
test_case
)
return pytest.mark.skipif(
not _is_torch_fp64_available(torch_device), reason="test requires accelerator with fp64 support"
)(test_case)
def require_big_gpu_with_torch_cuda(test_case):
@@ -406,17 +405,17 @@ def require_big_gpu_with_torch_cuda(test_case):
etc.
"""
if not is_torch_available():
return unittest.skip("test requires PyTorch")(test_case)
return pytest.mark.skip(reason="test requires PyTorch")(test_case)
import torch
if not torch.cuda.is_available():
return unittest.skip("test requires PyTorch CUDA")(test_case)
return pytest.mark.skip(reason="test requires PyTorch CUDA")(test_case)
device_properties = torch.cuda.get_device_properties(0)
total_memory = device_properties.total_memory / (1024**3)
return unittest.skipUnless(
total_memory >= BIG_GPU_MEMORY, f"test requires a GPU with at least {BIG_GPU_MEMORY} GB memory"
return pytest.mark.skipif(
total_memory < BIG_GPU_MEMORY, reason=f"test requires a GPU with at least {BIG_GPU_MEMORY} GB memory"
)(test_case)
@@ -430,12 +429,12 @@ def require_big_accelerator(test_case):
test_case = pytest.mark.big_accelerator(test_case)
if not is_torch_available():
return unittest.skip("test requires PyTorch")(test_case)
return pytest.mark.skip(reason="test requires PyTorch")(test_case)
import torch
if not (torch.cuda.is_available() or torch.xpu.is_available()):
return unittest.skip("test requires PyTorch CUDA")(test_case)
return pytest.mark.skip(reason="test requires PyTorch CUDA")(test_case)
if torch.xpu.is_available():
device_properties = torch.xpu.get_device_properties(0)
@@ -443,30 +442,30 @@ def require_big_accelerator(test_case):
device_properties = torch.cuda.get_device_properties(0)
total_memory = device_properties.total_memory / (1024**3)
return unittest.skipUnless(
total_memory >= BIG_GPU_MEMORY,
f"test requires a hardware accelerator with at least {BIG_GPU_MEMORY} GB memory",
return pytest.mark.skipif(
total_memory < BIG_GPU_MEMORY,
reason=f"test requires a hardware accelerator with at least {BIG_GPU_MEMORY} GB memory",
)(test_case)
def require_torch_accelerator_with_training(test_case):
"""Decorator marking a test that requires an accelerator with support for training."""
return unittest.skipUnless(
is_torch_available() and backend_supports_training(torch_device),
"test requires accelerator with training support",
return pytest.mark.skipif(
not (is_torch_available() and backend_supports_training(torch_device)),
reason="test requires accelerator with training support",
)(test_case)
def skip_mps(test_case):
"""Decorator marking a test to skip if torch_device is 'mps'"""
return unittest.skipUnless(torch_device != "mps", "test requires non 'mps' device")(test_case)
return pytest.mark.skipif(torch_device == "mps", reason="test requires non 'mps' device")(test_case)
def require_flax(test_case):
"""
Decorator marking a test that requires JAX & Flax. These tests are skipped when one / both are not installed
"""
return unittest.skipUnless(is_flax_available(), "test requires JAX & Flax")(test_case)
return pytest.mark.skipif(not is_flax_available(), reason="test requires JAX & Flax")(test_case)
def require_compel(test_case):
@@ -474,21 +473,21 @@ def require_compel(test_case):
Decorator marking a test that requires compel: https://github.com/damian0815/compel. These tests are skipped when
the library is not installed.
"""
return unittest.skipUnless(is_compel_available(), "test requires compel")(test_case)
return pytest.mark.skipif(not is_compel_available(), reason="test requires compel")(test_case)
def require_onnxruntime(test_case):
"""
Decorator marking a test that requires onnxruntime. These tests are skipped when onnxruntime isn't installed.
"""
return unittest.skipUnless(is_onnx_available(), "test requires onnxruntime")(test_case)
return pytest.mark.skipif(not is_onnx_available(), reason="test requires onnxruntime")(test_case)
def require_note_seq(test_case):
"""
Decorator marking a test that requires note_seq. These tests are skipped when note_seq isn't installed.
"""
return unittest.skipUnless(is_note_seq_available(), "test requires note_seq")(test_case)
return pytest.mark.skipif(not is_note_seq_available(), reason="test requires note_seq")(test_case)
def require_accelerator(test_case):
@@ -496,14 +495,14 @@ def require_accelerator(test_case):
Decorator marking a test that requires a hardware accelerator backend. These tests are skipped when there are no
hardware accelerator available.
"""
return unittest.skipUnless(torch_device != "cpu", "test requires a hardware accelerator")(test_case)
return pytest.mark.skipif(torch_device == "cpu", reason="test requires a hardware accelerator")(test_case)
def require_torchsde(test_case):
"""
Decorator marking a test that requires torchsde. These tests are skipped when torchsde isn't installed.
"""
return unittest.skipUnless(is_torchsde_available(), "test requires torchsde")(test_case)
return pytest.mark.skipif(not is_torchsde_available(), reason="test requires torchsde")(test_case)
def require_peft_backend(test_case):
@@ -511,35 +510,35 @@ def require_peft_backend(test_case):
Decorator marking a test that requires PEFT backend, this would require some specific versions of PEFT and
transformers.
"""
return unittest.skipUnless(USE_PEFT_BACKEND, "test requires PEFT backend")(test_case)
return pytest.mark.skipif(not USE_PEFT_BACKEND, reason="test requires PEFT backend")(test_case)
def require_timm(test_case):
"""
Decorator marking a test that requires timm. These tests are skipped when timm isn't installed.
"""
return unittest.skipUnless(is_timm_available(), "test requires timm")(test_case)
return pytest.mark.skipif(not is_timm_available(), reason="test requires timm")(test_case)
def require_bitsandbytes(test_case):
"""
Decorator marking a test that requires bitsandbytes. These tests are skipped when bitsandbytes isn't installed.
"""
return unittest.skipUnless(is_bitsandbytes_available(), "test requires bitsandbytes")(test_case)
return pytest.mark.skipif(not is_bitsandbytes_available(), reason="test requires bitsandbytes")(test_case)
def require_quanto(test_case):
"""
Decorator marking a test that requires quanto. These tests are skipped when quanto isn't installed.
"""
return unittest.skipUnless(is_optimum_quanto_available(), "test requires quanto")(test_case)
return pytest.mark.skipif(not is_optimum_quanto_available(), reason="test requires quanto")(test_case)
def require_accelerate(test_case):
"""
Decorator marking a test that requires accelerate. These tests are skipped when accelerate isn't installed.
"""
return unittest.skipUnless(is_accelerate_available(), "test requires accelerate")(test_case)
return pytest.mark.skipif(not is_accelerate_available(), reason="test requires accelerate")(test_case)
def require_peft_version_greater(peft_version):
@@ -552,8 +551,8 @@ def require_peft_version_greater(peft_version):
correct_peft_version = is_peft_available() and version.parse(
version.parse(importlib.metadata.version("peft")).base_version
) > version.parse(peft_version)
return unittest.skipUnless(
correct_peft_version, f"test requires PEFT backend with the version greater than {peft_version}"
return pytest.mark.skipif(
not correct_peft_version, reason=f"test requires PEFT backend with the version greater than {peft_version}"
)(test_case)
return decorator
@@ -569,9 +568,9 @@ def require_transformers_version_greater(transformers_version):
correct_transformers_version = is_transformers_available() and version.parse(
version.parse(importlib.metadata.version("transformers")).base_version
) > version.parse(transformers_version)
return unittest.skipUnless(
correct_transformers_version,
f"test requires transformers with the version greater than {transformers_version}",
return pytest.mark.skipif(
not correct_transformers_version,
reason=f"test requires transformers with the version greater than {transformers_version}",
)(test_case)
return decorator
@@ -582,8 +581,9 @@ def require_accelerate_version_greater(accelerate_version):
correct_accelerate_version = is_accelerate_available() and version.parse(
version.parse(importlib.metadata.version("accelerate")).base_version
) > version.parse(accelerate_version)
return unittest.skipUnless(
correct_accelerate_version, f"Test requires accelerate with the version greater than {accelerate_version}."
return pytest.mark.skipif(
not correct_accelerate_version,
reason=f"Test requires accelerate with the version greater than {accelerate_version}.",
)(test_case)
return decorator
@@ -594,8 +594,8 @@ def require_bitsandbytes_version_greater(bnb_version):
correct_bnb_version = is_bitsandbytes_available() and version.parse(
version.parse(importlib.metadata.version("bitsandbytes")).base_version
) > version.parse(bnb_version)
return unittest.skipUnless(
correct_bnb_version, f"Test requires bitsandbytes with the version greater than {bnb_version}."
return pytest.mark.skipif(
not correct_bnb_version, reason=f"Test requires bitsandbytes with the version greater than {bnb_version}."
)(test_case)
return decorator
@@ -606,8 +606,9 @@ def require_hf_hub_version_greater(hf_hub_version):
correct_hf_hub_version = version.parse(
version.parse(importlib.metadata.version("huggingface_hub")).base_version
) > version.parse(hf_hub_version)
return unittest.skipUnless(
correct_hf_hub_version, f"Test requires huggingface_hub with the version greater than {hf_hub_version}."
return pytest.mark.skipif(
not correct_hf_hub_version,
reason=f"Test requires huggingface_hub with the version greater than {hf_hub_version}.",
)(test_case)
return decorator
@@ -618,8 +619,8 @@ def require_gguf_version_greater_or_equal(gguf_version):
correct_gguf_version = is_gguf_available() and version.parse(
version.parse(importlib.metadata.version("gguf")).base_version
) >= version.parse(gguf_version)
return unittest.skipUnless(
correct_gguf_version, f"Test requires gguf with the version greater than {gguf_version}."
return pytest.mark.skipif(
not correct_gguf_version, reason=f"Test requires gguf with the version greater than {gguf_version}."
)(test_case)
return decorator
@@ -630,8 +631,8 @@ def require_torchao_version_greater_or_equal(torchao_version):
correct_torchao_version = is_torchao_available() and version.parse(
version.parse(importlib.metadata.version("torchao")).base_version
) >= version.parse(torchao_version)
return unittest.skipUnless(
correct_torchao_version, f"Test requires torchao with version greater than {torchao_version}."
return pytest.mark.skipif(
not correct_torchao_version, reason=f"Test requires torchao with version greater than {torchao_version}."
)(test_case)
return decorator
@@ -642,8 +643,8 @@ def require_kernels_version_greater_or_equal(kernels_version):
correct_kernels_version = is_kernels_available() and version.parse(
version.parse(importlib.metadata.version("kernels")).base_version
) >= version.parse(kernels_version)
return unittest.skipUnless(
correct_kernels_version, f"Test requires kernels with version greater than {kernels_version}."
return pytest.mark.skipif(
not correct_kernels_version, reason=f"Test requires kernels with version greater than {kernels_version}."
)(test_case)
return decorator
@@ -653,7 +654,7 @@ def deprecate_after_peft_backend(test_case):
"""
Decorator marking a test that will be skipped after PEFT backend
"""
return unittest.skipUnless(not USE_PEFT_BACKEND, "test skipped in favor of PEFT backend")(test_case)
return pytest.mark.skipif(USE_PEFT_BACKEND, reason="test skipped in favor of PEFT backend")(test_case)
def get_python_version():
@@ -1064,8 +1065,8 @@ def run_test_in_subprocess(test_case, target_func, inputs=None, timeout=None):
To run a test in a subprocess. In particular, this can avoid (GPU) memory issue.
Args:
test_case (`unittest.TestCase`):
The test that will run `target_func`.
test_case:
The test case object that will run `target_func`.
target_func (`Callable`):
The function implementing the actual testing logic.
inputs (`dict`, *optional*, defaults to `None`):
@@ -1083,7 +1084,7 @@ def run_test_in_subprocess(test_case, target_func, inputs=None, timeout=None):
input_queue = ctx.Queue(1)
output_queue = ctx.JoinableQueue(1)
# We can't send `unittest.TestCase` to the child, otherwise we get issues regarding pickle.
# We can't send test case objects to the child, otherwise we get issues regarding pickle.
input_queue.put(inputs, timeout=timeout)
process = ctx.Process(target=target_func, args=(input_queue, output_queue, timeout))