Compare commits
19 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 1c91475008 | |||
| 6375c02130 | |||
| 66e6a0215f | |||
| e0b1383868 | |||
| 54ddce87fd | |||
| c0ce538afc | |||
| fd88f3d3fc | |||
| ea4f29f0e8 | |||
| b8809f76d5 | |||
| 728655ca01 | |||
| 9f113f8138 | |||
| b5f13d9b59 | |||
| ddb5ba734d | |||
| 5f1afc11ac | |||
| ecdd843044 | |||
| 316b71ff2b | |||
| 1be88f036f | |||
| 77e50155e6 | |||
| 760a9149a7 |
@@ -77,62 +77,46 @@ jobs:
|
||||
|
||||
run_fast_tests:
|
||||
needs: [check_code_quality, check_repository_consistency]
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
config:
|
||||
- name: Fast PyTorch Modular Pipeline CPU tests
|
||||
framework: pytorch_pipelines
|
||||
runner: aws-highmemory-32-plus
|
||||
image: diffusers/diffusers-pytorch-cpu
|
||||
report: torch_cpu_modular_pipelines
|
||||
|
||||
name: ${{ matrix.config.name }}
|
||||
|
||||
name: Fast PyTorch Modular Pipeline CPU tests
|
||||
runs-on:
|
||||
group: ${{ matrix.config.runner }}
|
||||
|
||||
group: aws-highmemory-32-plus
|
||||
container:
|
||||
image: ${{ matrix.config.image }}
|
||||
image: diffusers/diffusers-pytorch-cpu
|
||||
options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/
|
||||
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
uv pip install -e ".[quality]"
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
uv pip install -e ".[quality]"
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
|
||||
- name: Run fast PyTorch Pipeline CPU tests
|
||||
if: ${{ matrix.config.framework == 'pytorch_pipelines' }}
|
||||
run: |
|
||||
pytest -n 8 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "not Flax and not Onnx" \
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
tests/modular_pipelines
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
run: cat reports/tests_${{ matrix.config.report }}_failures_short.txt
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: pr_${{ matrix.config.framework }}_${{ matrix.config.report }}_test_reports
|
||||
path: reports
|
||||
- name: Run fast PyTorch Pipeline CPU tests
|
||||
run: |
|
||||
pytest -n 8 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v \
|
||||
--make-reports=tests_torch_cpu_modular_pipelines \
|
||||
tests/modular_pipelines
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
run: cat reports/tests_torch_cpu_modular_pipelines_failures_short.txt
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: pr_pytorch_pipelines_torch_cpu_modular_pipelines_test_reports
|
||||
path: reports
|
||||
|
||||
@@ -32,6 +32,8 @@ warnings.simplefilter(action="ignore", category=FutureWarning)
|
||||
|
||||
def pytest_configure(config):
|
||||
config.addinivalue_line("markers", "big_accelerator: marks tests as requiring big accelerator resources")
|
||||
config.addinivalue_line("markers", "slow: mark test as slow")
|
||||
config.addinivalue_line("markers", "nightly: mark test as nightly")
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
|
||||
@@ -0,0 +1,272 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
from collections import deque
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from diffusers import FluxTransformer2DModel
|
||||
from diffusers.modular_pipelines import (
|
||||
ComponentSpec,
|
||||
InputParam,
|
||||
ModularPipelineBlocks,
|
||||
OutputParam,
|
||||
PipelineState,
|
||||
WanModularPipeline,
|
||||
)
|
||||
|
||||
from ..testing_utils import nightly, require_torch, slow
|
||||
|
||||
|
||||
class DummyCustomBlockSimple(ModularPipelineBlocks):
|
||||
def __init__(self, use_dummy_model_component=False):
|
||||
self.use_dummy_model_component = use_dummy_model_component
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def expected_components(self):
|
||||
if self.use_dummy_model_component:
|
||||
return [ComponentSpec("transformer", FluxTransformer2DModel)]
|
||||
else:
|
||||
return []
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [InputParam("prompt", type_hint=str, required=True, description="Prompt to use")]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[InputParam]:
|
||||
return []
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
"output_prompt",
|
||||
type_hint=str,
|
||||
description="Modified prompt",
|
||||
)
|
||||
]
|
||||
|
||||
def __call__(self, components, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
old_prompt = block_state.prompt
|
||||
block_state.output_prompt = "Modular diffusers + " + old_prompt
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
|
||||
|
||||
CODE_STR = """
|
||||
from diffusers.modular_pipelines import (
|
||||
ComponentSpec,
|
||||
InputParam,
|
||||
ModularPipelineBlocks,
|
||||
OutputParam,
|
||||
PipelineState,
|
||||
WanModularPipeline,
|
||||
)
|
||||
from typing import List
|
||||
|
||||
class DummyCustomBlockSimple(ModularPipelineBlocks):
|
||||
def __init__(self, use_dummy_model_component=False):
|
||||
self.use_dummy_model_component = use_dummy_model_component
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def expected_components(self):
|
||||
if self.use_dummy_model_component:
|
||||
return [ComponentSpec("transformer", FluxTransformer2DModel)]
|
||||
else:
|
||||
return []
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [InputParam("prompt", type_hint=str, required=True, description="Prompt to use")]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[InputParam]:
|
||||
return []
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
"output_prompt",
|
||||
type_hint=str,
|
||||
description="Modified prompt",
|
||||
)
|
||||
]
|
||||
|
||||
def __call__(self, components, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
old_prompt = block_state.prompt
|
||||
block_state.output_prompt = "Modular diffusers + " + old_prompt
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
"""
|
||||
|
||||
|
||||
class TestModularCustomBlocks:
|
||||
def _test_block_properties(self, block):
|
||||
assert not block.expected_components
|
||||
assert not block.intermediate_inputs
|
||||
|
||||
actual_inputs = [inp.name for inp in block.inputs]
|
||||
actual_intermediate_outputs = [out.name for out in block.intermediate_outputs]
|
||||
assert actual_inputs == ["prompt"]
|
||||
assert actual_intermediate_outputs == ["output_prompt"]
|
||||
|
||||
def test_custom_block_properties(self):
|
||||
custom_block = DummyCustomBlockSimple()
|
||||
self._test_block_properties(custom_block)
|
||||
|
||||
def test_custom_block_output(self):
|
||||
custom_block = DummyCustomBlockSimple()
|
||||
pipe = custom_block.init_pipeline()
|
||||
prompt = "Diffusers is nice"
|
||||
output = pipe(prompt=prompt)
|
||||
|
||||
actual_inputs = [inp.name for inp in custom_block.inputs]
|
||||
actual_intermediate_outputs = [out.name for out in custom_block.intermediate_outputs]
|
||||
assert sorted(output.values) == sorted(actual_inputs + actual_intermediate_outputs)
|
||||
|
||||
output_prompt = output.values["output_prompt"]
|
||||
assert output_prompt.startswith("Modular diffusers + ")
|
||||
|
||||
def test_custom_block_saving_loading(self):
|
||||
custom_block = DummyCustomBlockSimple()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
custom_block.save_pretrained(tmpdir)
|
||||
assert any("modular_config.json" in k for k in os.listdir(tmpdir))
|
||||
|
||||
with open(os.path.join(tmpdir, "modular_config.json"), "r") as f:
|
||||
config = json.load(f)
|
||||
auto_map = config["auto_map"]
|
||||
assert auto_map == {"ModularPipelineBlocks": "test_modular_pipelines_custom_blocks.DummyCustomBlockSimple"}
|
||||
|
||||
# For now, the Python script that implements the custom block has to be manually pushed to the Hub.
|
||||
# This is why, we have to separately save the Python script here.
|
||||
code_path = os.path.join(tmpdir, "test_modular_pipelines_custom_blocks.py")
|
||||
with open(code_path, "w") as f:
|
||||
f.write(CODE_STR)
|
||||
|
||||
loaded_custom_block = ModularPipelineBlocks.from_pretrained(tmpdir, trust_remote_code=True)
|
||||
|
||||
pipe = loaded_custom_block.init_pipeline()
|
||||
prompt = "Diffusers is nice"
|
||||
output = pipe(prompt=prompt)
|
||||
|
||||
actual_inputs = [inp.name for inp in loaded_custom_block.inputs]
|
||||
actual_intermediate_outputs = [out.name for out in loaded_custom_block.intermediate_outputs]
|
||||
assert sorted(output.values) == sorted(actual_inputs + actual_intermediate_outputs)
|
||||
|
||||
output_prompt = output.values["output_prompt"]
|
||||
assert output_prompt.startswith("Modular diffusers + ")
|
||||
|
||||
def test_custom_block_supported_components(self):
|
||||
custom_block = DummyCustomBlockSimple(use_dummy_model_component=True)
|
||||
pipe = custom_block.init_pipeline("hf-internal-testing/tiny-flux-kontext-pipe")
|
||||
pipe.load_components()
|
||||
|
||||
assert len(pipe.components) == 1
|
||||
assert pipe.component_names[0] == "transformer"
|
||||
|
||||
def test_custom_block_loads_from_hub(self):
|
||||
repo_id = "hf-internal-testing/tiny-modular-diffusers-block"
|
||||
block = ModularPipelineBlocks.from_pretrained(repo_id, trust_remote_code=True)
|
||||
self._test_block_properties(block)
|
||||
|
||||
pipe = block.init_pipeline()
|
||||
|
||||
prompt = "Diffusers is nice"
|
||||
output = pipe(prompt=prompt)
|
||||
output_prompt = output.values["output_prompt"]
|
||||
assert output_prompt.startswith("Modular diffusers + ")
|
||||
|
||||
|
||||
@slow
|
||||
@nightly
|
||||
@require_torch
|
||||
class TestKreaCustomBlocksIntegration:
|
||||
repo_id = "krea/krea-realtime-video"
|
||||
|
||||
def test_loading_from_hub(self):
|
||||
blocks = ModularPipelineBlocks.from_pretrained(self.repo_id, trust_remote_code=True)
|
||||
block_names = sorted(blocks.sub_blocks)
|
||||
|
||||
assert block_names == sorted(["text_encoder", "before_denoise", "denoise", "decode"])
|
||||
|
||||
pipe = WanModularPipeline(blocks, self.repo_id)
|
||||
pipe.load_components(
|
||||
trust_remote_code=True,
|
||||
device_map="cuda",
|
||||
torch_dtype={"default": torch.bfloat16, "vae": torch.float16},
|
||||
)
|
||||
assert len(pipe.components) == 7
|
||||
assert sorted(pipe.components) == sorted(
|
||||
["text_encoder", "tokenizer", "guider", "scheduler", "vae", "transformer", "video_processor"]
|
||||
)
|
||||
|
||||
def test_forward(self):
|
||||
blocks = ModularPipelineBlocks.from_pretrained(self.repo_id, trust_remote_code=True)
|
||||
pipe = WanModularPipeline(blocks, self.repo_id)
|
||||
pipe.load_components(
|
||||
trust_remote_code=True,
|
||||
device_map="cuda",
|
||||
torch_dtype={"default": torch.bfloat16, "vae": torch.float16},
|
||||
)
|
||||
|
||||
num_frames_per_block = 2
|
||||
num_blocks = 2
|
||||
|
||||
state = PipelineState()
|
||||
state.set("frame_cache_context", deque(maxlen=pipe.config.frame_cache_len))
|
||||
|
||||
prompt = ["a cat sitting on a boat"]
|
||||
|
||||
for block in pipe.transformer.blocks:
|
||||
block.self_attn.fuse_projections()
|
||||
|
||||
for block_idx in range(num_blocks):
|
||||
state = pipe(
|
||||
state,
|
||||
prompt=prompt,
|
||||
num_inference_steps=2,
|
||||
num_blocks=num_blocks,
|
||||
num_frames_per_block=num_frames_per_block,
|
||||
block_idx=block_idx,
|
||||
generator=torch.manual_seed(42),
|
||||
)
|
||||
current_frames = np.array(state.values["videos"][0])
|
||||
current_frames_flat = current_frames.flatten()
|
||||
actual_slices = np.concatenate([current_frames_flat[:4], current_frames_flat[-4:]]).tolist()
|
||||
|
||||
if block_idx == 0:
|
||||
assert current_frames.shape == (5, 480, 832, 3)
|
||||
expected_slices = np.array([211, 229, 238, 208, 195, 180, 188, 193])
|
||||
else:
|
||||
assert current_frames.shape == (8, 480, 832, 3)
|
||||
expected_slices = np.array([179, 203, 214, 176, 194, 181, 187, 191])
|
||||
|
||||
assert np.allclose(actual_slices, expected_slices)
|
||||
+79
-78
@@ -13,7 +13,6 @@ import struct
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
import unittest
|
||||
import urllib.parse
|
||||
from collections import UserDict
|
||||
from contextlib import contextmanager
|
||||
@@ -24,6 +23,7 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Tupl
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
import PIL.ImageOps
|
||||
import pytest
|
||||
import requests
|
||||
from numpy.linalg import norm
|
||||
from packaging import version
|
||||
@@ -267,7 +267,7 @@ def slow(test_case):
|
||||
Slow tests are skipped by default. Set the RUN_SLOW environment variable to a truthy value to run them.
|
||||
|
||||
"""
|
||||
return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case)
|
||||
return pytest.mark.skipif(not _run_slow_tests, reason="test is slow")(test_case)
|
||||
|
||||
|
||||
def nightly(test_case):
|
||||
@@ -277,7 +277,7 @@ def nightly(test_case):
|
||||
Slow tests are skipped by default. Set the RUN_NIGHTLY environment variable to a truthy value to run them.
|
||||
|
||||
"""
|
||||
return unittest.skipUnless(_run_nightly_tests, "test is nightly")(test_case)
|
||||
return pytest.mark.skipif(not _run_nightly_tests, reason="test is nightly")(test_case)
|
||||
|
||||
|
||||
def is_torch_compile(test_case):
|
||||
@@ -287,23 +287,23 @@ def is_torch_compile(test_case):
|
||||
Compile tests are skipped by default. Set the RUN_COMPILE environment variable to a truthy value to run them.
|
||||
|
||||
"""
|
||||
return unittest.skipUnless(_run_compile_tests, "test is torch compile")(test_case)
|
||||
return pytest.mark.skipif(not _run_compile_tests, reason="test is torch compile")(test_case)
|
||||
|
||||
|
||||
def require_torch(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires PyTorch. These tests are skipped when PyTorch isn't installed.
|
||||
"""
|
||||
return unittest.skipUnless(is_torch_available(), "test requires PyTorch")(test_case)
|
||||
return pytest.mark.skipif(not is_torch_available(), reason="test requires PyTorch")(test_case)
|
||||
|
||||
|
||||
def require_torch_2(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires PyTorch 2. These tests are skipped when it isn't installed.
|
||||
"""
|
||||
return unittest.skipUnless(is_torch_available() and is_torch_version(">=", "2.0.0"), "test requires PyTorch 2")(
|
||||
test_case
|
||||
)
|
||||
return pytest.mark.skipif(
|
||||
not (is_torch_available() and is_torch_version(">=", "2.0.0")), reason="test requires PyTorch 2"
|
||||
)(test_case)
|
||||
|
||||
|
||||
def require_torch_version_greater_equal(torch_version):
|
||||
@@ -311,8 +311,9 @@ def require_torch_version_greater_equal(torch_version):
|
||||
|
||||
def decorator(test_case):
|
||||
correct_torch_version = is_torch_available() and is_torch_version(">=", torch_version)
|
||||
return unittest.skipUnless(
|
||||
correct_torch_version, f"test requires torch with the version greater than or equal to {torch_version}"
|
||||
return pytest.mark.skipif(
|
||||
not correct_torch_version,
|
||||
reason=f"test requires torch with the version greater than or equal to {torch_version}",
|
||||
)(test_case)
|
||||
|
||||
return decorator
|
||||
@@ -323,8 +324,8 @@ def require_torch_version_greater(torch_version):
|
||||
|
||||
def decorator(test_case):
|
||||
correct_torch_version = is_torch_available() and is_torch_version(">", torch_version)
|
||||
return unittest.skipUnless(
|
||||
correct_torch_version, f"test requires torch with the version greater than {torch_version}"
|
||||
return pytest.mark.skipif(
|
||||
not correct_torch_version, reason=f"test requires torch with the version greater than {torch_version}"
|
||||
)(test_case)
|
||||
|
||||
return decorator
|
||||
@@ -332,19 +333,18 @@ def require_torch_version_greater(torch_version):
|
||||
|
||||
def require_torch_gpu(test_case):
|
||||
"""Decorator marking a test that requires CUDA and PyTorch."""
|
||||
return unittest.skipUnless(is_torch_available() and torch_device == "cuda", "test requires PyTorch+CUDA")(
|
||||
test_case
|
||||
)
|
||||
return pytest.mark.skipif(torch_device != "cuda", reason="test requires PyTorch+CUDA")(test_case)
|
||||
|
||||
|
||||
def require_torch_cuda_compatibility(expected_compute_capability):
|
||||
def decorator(test_case):
|
||||
if torch.cuda.is_available():
|
||||
current_compute_capability = get_torch_cuda_device_capability()
|
||||
return unittest.skipUnless(
|
||||
float(current_compute_capability) == float(expected_compute_capability),
|
||||
"Test not supported for this compute capability.",
|
||||
)
|
||||
return pytest.mark.skipif(
|
||||
float(current_compute_capability) != float(expected_compute_capability),
|
||||
reason="Test not supported for this compute capability.",
|
||||
)(test_case)
|
||||
return test_case
|
||||
|
||||
return decorator
|
||||
|
||||
@@ -352,9 +352,7 @@ def require_torch_cuda_compatibility(expected_compute_capability):
|
||||
# These decorators are for accelerator-specific behaviours that are not GPU-specific
|
||||
def require_torch_accelerator(test_case):
|
||||
"""Decorator marking a test that requires an accelerator backend and PyTorch."""
|
||||
return unittest.skipUnless(is_torch_available() and torch_device != "cpu", "test requires accelerator+PyTorch")(
|
||||
test_case
|
||||
)
|
||||
return pytest.mark.skipif(torch_device == "cpu", reason="test requires accelerator+PyTorch")(test_case)
|
||||
|
||||
|
||||
def require_torch_multi_gpu(test_case):
|
||||
@@ -364,11 +362,11 @@ def require_torch_multi_gpu(test_case):
|
||||
-k "multi_gpu"
|
||||
"""
|
||||
if not is_torch_available():
|
||||
return unittest.skip("test requires PyTorch")(test_case)
|
||||
return pytest.mark.skip(reason="test requires PyTorch")(test_case)
|
||||
|
||||
import torch
|
||||
|
||||
return unittest.skipUnless(torch.cuda.device_count() > 1, "test requires multiple GPUs")(test_case)
|
||||
return pytest.mark.skipif(torch.cuda.device_count() <= 1, reason="test requires multiple GPUs")(test_case)
|
||||
|
||||
|
||||
def require_torch_multi_accelerator(test_case):
|
||||
@@ -377,27 +375,28 @@ def require_torch_multi_accelerator(test_case):
|
||||
without multiple hardware accelerators.
|
||||
"""
|
||||
if not is_torch_available():
|
||||
return unittest.skip("test requires PyTorch")(test_case)
|
||||
return pytest.mark.skip(reason="test requires PyTorch")(test_case)
|
||||
|
||||
import torch
|
||||
|
||||
return unittest.skipUnless(
|
||||
torch.cuda.device_count() > 1 or torch.xpu.device_count() > 1, "test requires multiple hardware accelerators"
|
||||
return pytest.mark.skipif(
|
||||
not (torch.cuda.device_count() > 1 or torch.xpu.device_count() > 1),
|
||||
reason="test requires multiple hardware accelerators",
|
||||
)(test_case)
|
||||
|
||||
|
||||
def require_torch_accelerator_with_fp16(test_case):
|
||||
"""Decorator marking a test that requires an accelerator with support for the FP16 data type."""
|
||||
return unittest.skipUnless(_is_torch_fp16_available(torch_device), "test requires accelerator with fp16 support")(
|
||||
test_case
|
||||
)
|
||||
return pytest.mark.skipif(
|
||||
not _is_torch_fp16_available(torch_device), reason="test requires accelerator with fp16 support"
|
||||
)(test_case)
|
||||
|
||||
|
||||
def require_torch_accelerator_with_fp64(test_case):
|
||||
"""Decorator marking a test that requires an accelerator with support for the FP64 data type."""
|
||||
return unittest.skipUnless(_is_torch_fp64_available(torch_device), "test requires accelerator with fp64 support")(
|
||||
test_case
|
||||
)
|
||||
return pytest.mark.skipif(
|
||||
not _is_torch_fp64_available(torch_device), reason="test requires accelerator with fp64 support"
|
||||
)(test_case)
|
||||
|
||||
|
||||
def require_big_gpu_with_torch_cuda(test_case):
|
||||
@@ -406,17 +405,17 @@ def require_big_gpu_with_torch_cuda(test_case):
|
||||
etc.
|
||||
"""
|
||||
if not is_torch_available():
|
||||
return unittest.skip("test requires PyTorch")(test_case)
|
||||
return pytest.mark.skip(reason="test requires PyTorch")(test_case)
|
||||
|
||||
import torch
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
return unittest.skip("test requires PyTorch CUDA")(test_case)
|
||||
return pytest.mark.skip(reason="test requires PyTorch CUDA")(test_case)
|
||||
|
||||
device_properties = torch.cuda.get_device_properties(0)
|
||||
total_memory = device_properties.total_memory / (1024**3)
|
||||
return unittest.skipUnless(
|
||||
total_memory >= BIG_GPU_MEMORY, f"test requires a GPU with at least {BIG_GPU_MEMORY} GB memory"
|
||||
return pytest.mark.skipif(
|
||||
total_memory < BIG_GPU_MEMORY, reason=f"test requires a GPU with at least {BIG_GPU_MEMORY} GB memory"
|
||||
)(test_case)
|
||||
|
||||
|
||||
@@ -430,12 +429,12 @@ def require_big_accelerator(test_case):
|
||||
test_case = pytest.mark.big_accelerator(test_case)
|
||||
|
||||
if not is_torch_available():
|
||||
return unittest.skip("test requires PyTorch")(test_case)
|
||||
return pytest.mark.skip(reason="test requires PyTorch")(test_case)
|
||||
|
||||
import torch
|
||||
|
||||
if not (torch.cuda.is_available() or torch.xpu.is_available()):
|
||||
return unittest.skip("test requires PyTorch CUDA")(test_case)
|
||||
return pytest.mark.skip(reason="test requires PyTorch CUDA")(test_case)
|
||||
|
||||
if torch.xpu.is_available():
|
||||
device_properties = torch.xpu.get_device_properties(0)
|
||||
@@ -443,30 +442,30 @@ def require_big_accelerator(test_case):
|
||||
device_properties = torch.cuda.get_device_properties(0)
|
||||
|
||||
total_memory = device_properties.total_memory / (1024**3)
|
||||
return unittest.skipUnless(
|
||||
total_memory >= BIG_GPU_MEMORY,
|
||||
f"test requires a hardware accelerator with at least {BIG_GPU_MEMORY} GB memory",
|
||||
return pytest.mark.skipif(
|
||||
total_memory < BIG_GPU_MEMORY,
|
||||
reason=f"test requires a hardware accelerator with at least {BIG_GPU_MEMORY} GB memory",
|
||||
)(test_case)
|
||||
|
||||
|
||||
def require_torch_accelerator_with_training(test_case):
|
||||
"""Decorator marking a test that requires an accelerator with support for training."""
|
||||
return unittest.skipUnless(
|
||||
is_torch_available() and backend_supports_training(torch_device),
|
||||
"test requires accelerator with training support",
|
||||
return pytest.mark.skipif(
|
||||
not (is_torch_available() and backend_supports_training(torch_device)),
|
||||
reason="test requires accelerator with training support",
|
||||
)(test_case)
|
||||
|
||||
|
||||
def skip_mps(test_case):
|
||||
"""Decorator marking a test to skip if torch_device is 'mps'"""
|
||||
return unittest.skipUnless(torch_device != "mps", "test requires non 'mps' device")(test_case)
|
||||
return pytest.mark.skipif(torch_device == "mps", reason="test requires non 'mps' device")(test_case)
|
||||
|
||||
|
||||
def require_flax(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires JAX & Flax. These tests are skipped when one / both are not installed
|
||||
"""
|
||||
return unittest.skipUnless(is_flax_available(), "test requires JAX & Flax")(test_case)
|
||||
return pytest.mark.skipif(not is_flax_available(), reason="test requires JAX & Flax")(test_case)
|
||||
|
||||
|
||||
def require_compel(test_case):
|
||||
@@ -474,21 +473,21 @@ def require_compel(test_case):
|
||||
Decorator marking a test that requires compel: https://github.com/damian0815/compel. These tests are skipped when
|
||||
the library is not installed.
|
||||
"""
|
||||
return unittest.skipUnless(is_compel_available(), "test requires compel")(test_case)
|
||||
return pytest.mark.skipif(not is_compel_available(), reason="test requires compel")(test_case)
|
||||
|
||||
|
||||
def require_onnxruntime(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires onnxruntime. These tests are skipped when onnxruntime isn't installed.
|
||||
"""
|
||||
return unittest.skipUnless(is_onnx_available(), "test requires onnxruntime")(test_case)
|
||||
return pytest.mark.skipif(not is_onnx_available(), reason="test requires onnxruntime")(test_case)
|
||||
|
||||
|
||||
def require_note_seq(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires note_seq. These tests are skipped when note_seq isn't installed.
|
||||
"""
|
||||
return unittest.skipUnless(is_note_seq_available(), "test requires note_seq")(test_case)
|
||||
return pytest.mark.skipif(not is_note_seq_available(), reason="test requires note_seq")(test_case)
|
||||
|
||||
|
||||
def require_accelerator(test_case):
|
||||
@@ -496,14 +495,14 @@ def require_accelerator(test_case):
|
||||
Decorator marking a test that requires a hardware accelerator backend. These tests are skipped when there are no
|
||||
hardware accelerator available.
|
||||
"""
|
||||
return unittest.skipUnless(torch_device != "cpu", "test requires a hardware accelerator")(test_case)
|
||||
return pytest.mark.skipif(torch_device == "cpu", reason="test requires a hardware accelerator")(test_case)
|
||||
|
||||
|
||||
def require_torchsde(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires torchsde. These tests are skipped when torchsde isn't installed.
|
||||
"""
|
||||
return unittest.skipUnless(is_torchsde_available(), "test requires torchsde")(test_case)
|
||||
return pytest.mark.skipif(not is_torchsde_available(), reason="test requires torchsde")(test_case)
|
||||
|
||||
|
||||
def require_peft_backend(test_case):
|
||||
@@ -511,35 +510,35 @@ def require_peft_backend(test_case):
|
||||
Decorator marking a test that requires PEFT backend, this would require some specific versions of PEFT and
|
||||
transformers.
|
||||
"""
|
||||
return unittest.skipUnless(USE_PEFT_BACKEND, "test requires PEFT backend")(test_case)
|
||||
return pytest.mark.skipif(not USE_PEFT_BACKEND, reason="test requires PEFT backend")(test_case)
|
||||
|
||||
|
||||
def require_timm(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires timm. These tests are skipped when timm isn't installed.
|
||||
"""
|
||||
return unittest.skipUnless(is_timm_available(), "test requires timm")(test_case)
|
||||
return pytest.mark.skipif(not is_timm_available(), reason="test requires timm")(test_case)
|
||||
|
||||
|
||||
def require_bitsandbytes(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires bitsandbytes. These tests are skipped when bitsandbytes isn't installed.
|
||||
"""
|
||||
return unittest.skipUnless(is_bitsandbytes_available(), "test requires bitsandbytes")(test_case)
|
||||
return pytest.mark.skipif(not is_bitsandbytes_available(), reason="test requires bitsandbytes")(test_case)
|
||||
|
||||
|
||||
def require_quanto(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires quanto. These tests are skipped when quanto isn't installed.
|
||||
"""
|
||||
return unittest.skipUnless(is_optimum_quanto_available(), "test requires quanto")(test_case)
|
||||
return pytest.mark.skipif(not is_optimum_quanto_available(), reason="test requires quanto")(test_case)
|
||||
|
||||
|
||||
def require_accelerate(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires accelerate. These tests are skipped when accelerate isn't installed.
|
||||
"""
|
||||
return unittest.skipUnless(is_accelerate_available(), "test requires accelerate")(test_case)
|
||||
return pytest.mark.skipif(not is_accelerate_available(), reason="test requires accelerate")(test_case)
|
||||
|
||||
|
||||
def require_peft_version_greater(peft_version):
|
||||
@@ -552,8 +551,8 @@ def require_peft_version_greater(peft_version):
|
||||
correct_peft_version = is_peft_available() and version.parse(
|
||||
version.parse(importlib.metadata.version("peft")).base_version
|
||||
) > version.parse(peft_version)
|
||||
return unittest.skipUnless(
|
||||
correct_peft_version, f"test requires PEFT backend with the version greater than {peft_version}"
|
||||
return pytest.mark.skipif(
|
||||
not correct_peft_version, reason=f"test requires PEFT backend with the version greater than {peft_version}"
|
||||
)(test_case)
|
||||
|
||||
return decorator
|
||||
@@ -569,9 +568,9 @@ def require_transformers_version_greater(transformers_version):
|
||||
correct_transformers_version = is_transformers_available() and version.parse(
|
||||
version.parse(importlib.metadata.version("transformers")).base_version
|
||||
) > version.parse(transformers_version)
|
||||
return unittest.skipUnless(
|
||||
correct_transformers_version,
|
||||
f"test requires transformers with the version greater than {transformers_version}",
|
||||
return pytest.mark.skipif(
|
||||
not correct_transformers_version,
|
||||
reason=f"test requires transformers with the version greater than {transformers_version}",
|
||||
)(test_case)
|
||||
|
||||
return decorator
|
||||
@@ -582,8 +581,9 @@ def require_accelerate_version_greater(accelerate_version):
|
||||
correct_accelerate_version = is_accelerate_available() and version.parse(
|
||||
version.parse(importlib.metadata.version("accelerate")).base_version
|
||||
) > version.parse(accelerate_version)
|
||||
return unittest.skipUnless(
|
||||
correct_accelerate_version, f"Test requires accelerate with the version greater than {accelerate_version}."
|
||||
return pytest.mark.skipif(
|
||||
not correct_accelerate_version,
|
||||
reason=f"Test requires accelerate with the version greater than {accelerate_version}.",
|
||||
)(test_case)
|
||||
|
||||
return decorator
|
||||
@@ -594,8 +594,8 @@ def require_bitsandbytes_version_greater(bnb_version):
|
||||
correct_bnb_version = is_bitsandbytes_available() and version.parse(
|
||||
version.parse(importlib.metadata.version("bitsandbytes")).base_version
|
||||
) > version.parse(bnb_version)
|
||||
return unittest.skipUnless(
|
||||
correct_bnb_version, f"Test requires bitsandbytes with the version greater than {bnb_version}."
|
||||
return pytest.mark.skipif(
|
||||
not correct_bnb_version, reason=f"Test requires bitsandbytes with the version greater than {bnb_version}."
|
||||
)(test_case)
|
||||
|
||||
return decorator
|
||||
@@ -606,8 +606,9 @@ def require_hf_hub_version_greater(hf_hub_version):
|
||||
correct_hf_hub_version = version.parse(
|
||||
version.parse(importlib.metadata.version("huggingface_hub")).base_version
|
||||
) > version.parse(hf_hub_version)
|
||||
return unittest.skipUnless(
|
||||
correct_hf_hub_version, f"Test requires huggingface_hub with the version greater than {hf_hub_version}."
|
||||
return pytest.mark.skipif(
|
||||
not correct_hf_hub_version,
|
||||
reason=f"Test requires huggingface_hub with the version greater than {hf_hub_version}.",
|
||||
)(test_case)
|
||||
|
||||
return decorator
|
||||
@@ -618,8 +619,8 @@ def require_gguf_version_greater_or_equal(gguf_version):
|
||||
correct_gguf_version = is_gguf_available() and version.parse(
|
||||
version.parse(importlib.metadata.version("gguf")).base_version
|
||||
) >= version.parse(gguf_version)
|
||||
return unittest.skipUnless(
|
||||
correct_gguf_version, f"Test requires gguf with the version greater than {gguf_version}."
|
||||
return pytest.mark.skipif(
|
||||
not correct_gguf_version, reason=f"Test requires gguf with the version greater than {gguf_version}."
|
||||
)(test_case)
|
||||
|
||||
return decorator
|
||||
@@ -630,8 +631,8 @@ def require_torchao_version_greater_or_equal(torchao_version):
|
||||
correct_torchao_version = is_torchao_available() and version.parse(
|
||||
version.parse(importlib.metadata.version("torchao")).base_version
|
||||
) >= version.parse(torchao_version)
|
||||
return unittest.skipUnless(
|
||||
correct_torchao_version, f"Test requires torchao with version greater than {torchao_version}."
|
||||
return pytest.mark.skipif(
|
||||
not correct_torchao_version, reason=f"Test requires torchao with version greater than {torchao_version}."
|
||||
)(test_case)
|
||||
|
||||
return decorator
|
||||
@@ -642,8 +643,8 @@ def require_kernels_version_greater_or_equal(kernels_version):
|
||||
correct_kernels_version = is_kernels_available() and version.parse(
|
||||
version.parse(importlib.metadata.version("kernels")).base_version
|
||||
) >= version.parse(kernels_version)
|
||||
return unittest.skipUnless(
|
||||
correct_kernels_version, f"Test requires kernels with version greater than {kernels_version}."
|
||||
return pytest.mark.skipif(
|
||||
not correct_kernels_version, reason=f"Test requires kernels with version greater than {kernels_version}."
|
||||
)(test_case)
|
||||
|
||||
return decorator
|
||||
@@ -653,7 +654,7 @@ def deprecate_after_peft_backend(test_case):
|
||||
"""
|
||||
Decorator marking a test that will be skipped after PEFT backend
|
||||
"""
|
||||
return unittest.skipUnless(not USE_PEFT_BACKEND, "test skipped in favor of PEFT backend")(test_case)
|
||||
return pytest.mark.skipif(USE_PEFT_BACKEND, reason="test skipped in favor of PEFT backend")(test_case)
|
||||
|
||||
|
||||
def get_python_version():
|
||||
@@ -1064,8 +1065,8 @@ def run_test_in_subprocess(test_case, target_func, inputs=None, timeout=None):
|
||||
To run a test in a subprocess. In particular, this can avoid (GPU) memory issue.
|
||||
|
||||
Args:
|
||||
test_case (`unittest.TestCase`):
|
||||
The test that will run `target_func`.
|
||||
test_case:
|
||||
The test case object that will run `target_func`.
|
||||
target_func (`Callable`):
|
||||
The function implementing the actual testing logic.
|
||||
inputs (`dict`, *optional*, defaults to `None`):
|
||||
@@ -1083,7 +1084,7 @@ def run_test_in_subprocess(test_case, target_func, inputs=None, timeout=None):
|
||||
input_queue = ctx.Queue(1)
|
||||
output_queue = ctx.JoinableQueue(1)
|
||||
|
||||
# We can't send `unittest.TestCase` to the child, otherwise we get issues regarding pickle.
|
||||
# We can't send test case objects to the child, otherwise we get issues regarding pickle.
|
||||
input_queue.put(inputs, timeout=timeout)
|
||||
|
||||
process = ctx.Process(target=target_func, args=(input_queue, output_queue, timeout))
|
||||
|
||||
Reference in New Issue
Block a user