Compare commits

...

9 Commits

Author SHA1 Message Date
DN6 bffa3a9754 update 2025-11-14 15:48:19 +05:30
DN6 1c558712e8 Merge branch 'main' into model-test-refactor 2025-11-12 10:18:07 +05:30
DN6 1f026ad14e update 2025-11-12 10:17:54 +05:30
YiYi Xu 0c7589293b fix copies (#12637)
* fix

* remoce cocpies instead
2025-11-11 15:44:55 -10:00
Charchit Sharma ff263947ad Fix rotary positional embedding dimension mismatch in Wan and SkyReels V2 transformers (#12594)
* Fix rotary positional embedding dimension mismatch in Wan and SkyReels V2 transformers

- Store t_dim, h_dim, w_dim as instance variables in WanRotaryPosEmbed and SkyReelsV2RotaryPosEmbed __init__
- Use stored dimensions in forward() instead of recalculating with different formula
- Fixes inconsistency between init (using // 6) and forward (using // 3)
- Ensures split_sizes matches the dimensions used to create rotary embeddings

* quality fix

---------

Co-authored-by: Charchit Sharma <charchitsharma@A-267.local>
2025-11-11 11:45:36 -10:00
Dhruv Nair 66e6a0215f [CI] Remove unittest dependency from testing_utils.py (#12621)
* update

* Update tests/testing_utils.py

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Update tests/testing_utils.py

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Apply style fixes

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-11-11 16:40:39 +05:30
Cesaryuan 5a47442f92 Fix: update type hints for Tuple parameters across multiple files to support variable-length tuples (#12544)
* Fix: update type hints for Tuple parameters across multiple files to support variable-length tuples

* Apply style fixes

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-11-10 13:57:52 -08:00
Dhruv Nair 8f6328c4a4 [Modular] Clean up docs (#12604)
update

Co-authored-by: YiYi Xu <yixu310@gmail.com>
2025-11-10 23:37:29 +05:30
Dhruv Nair 8d45f219d0 Fix Context Parallel validation checks (#12446)
* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-11-10 23:37:07 +05:30
52 changed files with 3959 additions and 292 deletions
@@ -12,7 +12,7 @@ specific language governing permissions and limitations under the License.
# LoopSequentialPipelineBlocks
[`~modular_pipelines.LoopSequentialPipelineBlocks`] are a multi-block type that composes other [`~modular_pipelines.ModularPipelineBlocks`] together in a loop. Data flows circularly, using `intermediate_inputs` and `intermediate_outputs`, and each block is run iteratively. This is typically used to create a denoising loop which is iterative by default.
[`~modular_pipelines.LoopSequentialPipelineBlocks`] are a multi-block type that composes other [`~modular_pipelines.ModularPipelineBlocks`] together in a loop. Data flows circularly, using `inputs` and `intermediate_outputs`, and each block is run iteratively. This is typically used to create a denoising loop which is iterative by default.
This guide shows you how to create [`~modular_pipelines.LoopSequentialPipelineBlocks`].
@@ -21,7 +21,6 @@ This guide shows you how to create [`~modular_pipelines.LoopSequentialPipelineBl
[`~modular_pipelines.LoopSequentialPipelineBlocks`], is also known as the *loop wrapper* because it defines the loop structure, iteration variables, and configuration. Within the loop wrapper, you need the following variables.
- `loop_inputs` are user provided values and equivalent to [`~modular_pipelines.ModularPipelineBlocks.inputs`].
- `loop_intermediate_inputs` are intermediate variables from the [`~modular_pipelines.PipelineState`] and equivalent to [`~modular_pipelines.ModularPipelineBlocks.intermediate_inputs`].
- `loop_intermediate_outputs` are new intermediate variables created by the block and added to the [`~modular_pipelines.PipelineState`]. It is equivalent to [`~modular_pipelines.ModularPipelineBlocks.intermediate_outputs`].
- `__call__` method defines the loop structure and iteration logic.
@@ -90,4 +89,4 @@ Add more loop blocks to run within each iteration with [`~modular_pipelines.Loop
```py
loop = LoopWrapper.from_blocks_dict({"block1": LoopBlock(), "block2": LoopBlock})
```
```
@@ -37,17 +37,7 @@ A [`~modular_pipelines.ModularPipelineBlocks`] requires `inputs`, and `intermedi
]
```
- `intermediate_inputs` are values typically created from a previous block but it can also be directly provided if no preceding block generates them. Unlike `inputs`, `intermediate_inputs` can be modified.
Use `InputParam` to define `intermediate_inputs`.
```py
user_intermediate_inputs = [
InputParam(name="processed_image", type_hint="torch.Tensor", description="image that has been preprocessed and normalized"),
]
```
- `intermediate_outputs` are new values created by a block and added to the [`~modular_pipelines.PipelineState`]. The `intermediate_outputs` are available as `intermediate_inputs` for subsequent blocks or available as the final output from running the pipeline.
- `intermediate_outputs` are new values created by a block and added to the [`~modular_pipelines.PipelineState`]. The `intermediate_outputs` are available as `inputs` for subsequent blocks or available as the final output from running the pipeline.
Use `OutputParam` to define `intermediate_outputs`.
@@ -65,8 +55,8 @@ The intermediate inputs and outputs share data to connect blocks. They are acces
The computation a block performs is defined in the `__call__` method and it follows a specific structure.
1. Retrieve the [`~modular_pipelines.BlockState`] to get a local view of the `inputs` and `intermediate_inputs`.
2. Implement the computation logic on the `inputs` and `intermediate_inputs`.
1. Retrieve the [`~modular_pipelines.BlockState`] to get a local view of the `inputs`
2. Implement the computation logic on the `inputs`.
3. Update [`~modular_pipelines.PipelineState`] to push changes from the local [`~modular_pipelines.BlockState`] back to the global [`~modular_pipelines.PipelineState`].
4. Return the components and state which becomes available to the next block.
@@ -76,7 +66,7 @@ def __call__(self, components, state):
block_state = self.get_block_state(state)
# Your computation logic here
# block_state contains all your inputs and intermediate_inputs
# block_state contains all your inputs
# Access them like: block_state.image, block_state.processed_image
# Update the pipeline state with your updated block_states
@@ -112,4 +102,4 @@ def __call__(self, components, state):
unet = components.unet
vae = components.vae
scheduler = components.scheduler
```
```
@@ -183,7 +183,7 @@ from diffusers.modular_pipelines import ComponentsManager
components = ComponentManager()
dd_pipeline = dd_blocks.init_pipeline("YiYiXu/modular-demo-auto", components_manager=components, collection="diffdiff")
dd_pipeline.load_default_componenets(torch_dtype=torch.float16)
dd_pipeline.load_componenets(torch_dtype=torch.float16)
dd_pipeline.to("cuda")
```
@@ -12,11 +12,11 @@ specific language governing permissions and limitations under the License.
# SequentialPipelineBlocks
[`~modular_pipelines.SequentialPipelineBlocks`] are a multi-block type that composes other [`~modular_pipelines.ModularPipelineBlocks`] together in a sequence. Data flows linearly from one block to the next using `intermediate_inputs` and `intermediate_outputs`. Each block in [`~modular_pipelines.SequentialPipelineBlocks`] usually represents a step in the pipeline, and by combining them, you gradually build a pipeline.
[`~modular_pipelines.SequentialPipelineBlocks`] are a multi-block type that composes other [`~modular_pipelines.ModularPipelineBlocks`] together in a sequence. Data flows linearly from one block to the next using `inputs` and `intermediate_outputs`. Each block in [`~modular_pipelines.SequentialPipelineBlocks`] usually represents a step in the pipeline, and by combining them, you gradually build a pipeline.
This guide shows you how to connect two blocks into a [`~modular_pipelines.SequentialPipelineBlocks`].
Create two [`~modular_pipelines.ModularPipelineBlocks`]. The first block, `InputBlock`, outputs a `batch_size` value and the second block, `ImageEncoderBlock` uses `batch_size` as `intermediate_inputs`.
Create two [`~modular_pipelines.ModularPipelineBlocks`]. The first block, `InputBlock`, outputs a `batch_size` value and the second block, `ImageEncoderBlock` uses `batch_size` as `inputs`.
<hfoptions id="sequential">
<hfoption id="InputBlock">
@@ -110,4 +110,4 @@ Inspect the sub-blocks in [`~modular_pipelines.SequentialPipelineBlocks`] by cal
```py
print(blocks)
print(blocks.doc)
```
```
+1 -1
View File
@@ -45,7 +45,7 @@ def check_size(image, height, width):
raise ValueError(f"Image size should be {height}x{width}, but got {h}x{w}")
def overlay_inner_image(image, inner_image, paste_offset: Tuple[int] = (0, 0)):
def overlay_inner_image(image, inner_image, paste_offset: Tuple[int, ...] = (0, 0)):
inner_image = inner_image.convert("RGBA")
image = image.convert("RGB")
+11 -6
View File
@@ -1966,16 +1966,21 @@ class MatryoshkaUNet2DConditionModel(
center_input_sample: bool = False,
flip_sin_to_cos: bool = True,
freq_shift: int = 0,
down_block_types: Tuple[str] = (
down_block_types: Tuple[str, ...] = (
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"DownBlock2D",
),
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
up_block_types: Tuple[str, ...] = (
"UpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
),
only_cross_attention: Union[bool, Tuple[bool]] = False,
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
layers_per_block: Union[int, Tuple[int]] = 2,
downsample_padding: int = 1,
mid_block_scale_factor: float = 1,
@@ -2294,10 +2299,10 @@ class MatryoshkaUNet2DConditionModel(
def _check_config(
self,
down_block_types: Tuple[str],
up_block_types: Tuple[str],
down_block_types: Tuple[str, ...],
up_block_types: Tuple[str, ...],
only_cross_attention: Union[bool, Tuple[bool]],
block_out_channels: Tuple[int],
block_out_channels: Tuple[int, ...],
layers_per_block: Union[int, Tuple[int]],
cross_attention_dim: Union[int, Tuple[int]],
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple[int]]],
@@ -438,16 +438,21 @@ class UNet2DConditionModel(OriginalUNet2DConditionModel, ConfigMixin, UNet2DCond
center_input_sample: bool = False,
flip_sin_to_cos: bool = True,
freq_shift: int = 0,
down_block_types: Tuple[str] = (
down_block_types: Tuple[str, ...] = (
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"DownBlock2D",
),
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
up_block_types: Tuple[str, ...] = (
"UpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
),
only_cross_attention: Union[bool, Tuple[bool]] = False,
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
layers_per_block: Union[int, Tuple[int]] = 2,
downsample_padding: int = 1,
mid_block_scale_factor: float = 1,
+47 -25
View File
@@ -44,11 +44,16 @@ class ContextParallelConfig:
Args:
ring_degree (`int`, *optional*, defaults to `1`):
Number of devices to use for ring attention within a context parallel region. Must be a divisor of the
total number of devices in the context parallel mesh.
Number of devices to use for Ring Attention. Sequence is split across devices. Each device computes
attention between its local Q and KV chunks passed sequentially around ring. Lower memory (only holds 1/N
of KV at a time), overlaps compute with communication, but requires N iterations to see all tokens. Best
for long sequences with limited memory/bandwidth. Number of devices to use for ring attention within a
context parallel region. Must be a divisor of the total number of devices in the context parallel mesh.
ulysses_degree (`int`, *optional*, defaults to `1`):
Number of devices to use for ulysses attention within a context parallel region. Must be a divisor of the
total number of devices in the context parallel mesh.
Number of devices to use for Ulysses Attention. Sequence split is across devices. Each device computes
local QKV, then all-gathers all KV chunks to compute full attention in one pass. Higher memory (stores all
KV), requires high-bandwidth all-to-all communication, but lower latency. Best for moderate sequences with
good interconnect bandwidth.
convert_to_fp32 (`bool`, *optional*, defaults to `True`):
Whether to convert output and LSE to float32 for ring attention numerical stability.
rotate_method (`str`, *optional*, defaults to `"allgather"`):
@@ -79,29 +84,46 @@ class ContextParallelConfig:
if self.ulysses_degree is None:
self.ulysses_degree = 1
if self.ring_degree == 1 and self.ulysses_degree == 1:
raise ValueError(
"Either ring_degree or ulysses_degree must be greater than 1 in order to use context parallel inference"
)
if self.ring_degree < 1 or self.ulysses_degree < 1:
raise ValueError("`ring_degree` and `ulysses_degree` must be greater than or equal to 1.")
if self.ring_degree > 1 and self.ulysses_degree > 1:
raise ValueError(
"Unified Ulysses-Ring attention is not yet supported. Please set either `ring_degree` or `ulysses_degree` to 1."
)
if self.rotate_method != "allgather":
raise NotImplementedError(
f"Only rotate_method='allgather' is supported for now, but got {self.rotate_method}."
)
@property
def mesh_shape(self) -> Tuple[int, int]:
return (self.ring_degree, self.ulysses_degree)
@property
def mesh_dim_names(self) -> Tuple[str, str]:
"""Dimension names for the device mesh."""
return ("ring", "ulysses")
def setup(self, rank: int, world_size: int, device: torch.device, mesh: torch.distributed.device_mesh.DeviceMesh):
self._rank = rank
self._world_size = world_size
self._device = device
self._mesh = mesh
if self.ring_degree is None:
self.ring_degree = 1
if self.ulysses_degree is None:
self.ulysses_degree = 1
if self.rotate_method != "allgather":
raise NotImplementedError(
f"Only rotate_method='allgather' is supported for now, but got {self.rotate_method}."
if self.ulysses_degree * self.ring_degree > world_size:
raise ValueError(
f"The product of `ring_degree` ({self.ring_degree}) and `ulysses_degree` ({self.ulysses_degree}) must not exceed the world size ({world_size})."
)
if self._flattened_mesh is None:
self._flattened_mesh = self._mesh._flatten()
if self._ring_mesh is None:
self._ring_mesh = self._mesh["ring"]
if self._ulysses_mesh is None:
self._ulysses_mesh = self._mesh["ulysses"]
if self._ring_local_rank is None:
self._ring_local_rank = self._ring_mesh.get_local_rank()
if self._ulysses_local_rank is None:
self._ulysses_local_rank = self._ulysses_mesh.get_local_rank()
self._flattened_mesh = self._mesh._flatten()
self._ring_mesh = self._mesh["ring"]
self._ulysses_mesh = self._mesh["ulysses"]
self._ring_local_rank = self._ring_mesh.get_local_rank()
self._ulysses_local_rank = self._ulysses_mesh.get_local_rank()
@dataclass
@@ -119,7 +141,7 @@ class ParallelConfig:
_rank: int = None
_world_size: int = None
_device: torch.device = None
_cp_mesh: torch.distributed.device_mesh.DeviceMesh = None
_mesh: torch.distributed.device_mesh.DeviceMesh = None
def setup(
self,
@@ -127,14 +149,14 @@ class ParallelConfig:
world_size: int,
device: torch.device,
*,
cp_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
):
self._rank = rank
self._world_size = world_size
self._device = device
self._cp_mesh = cp_mesh
self._mesh = mesh
if self.context_parallel_config is not None:
self.context_parallel_config.setup(rank, world_size, device, cp_mesh)
self.context_parallel_config.setup(rank, world_size, device, mesh)
@dataclass(frozen=True)
+9 -18
View File
@@ -220,7 +220,7 @@ class _AttentionBackendRegistry:
_backends = {}
_constraints = {}
_supported_arg_names = {}
_supports_context_parallel = {}
_supports_context_parallel = set()
_active_backend = AttentionBackendName(DIFFUSERS_ATTN_BACKEND)
_checks_enabled = DIFFUSERS_ATTN_CHECKS
@@ -237,7 +237,9 @@ class _AttentionBackendRegistry:
cls._backends[backend] = func
cls._constraints[backend] = constraints or []
cls._supported_arg_names[backend] = set(inspect.signature(func).parameters.keys())
cls._supports_context_parallel[backend] = supports_context_parallel
if supports_context_parallel:
cls._supports_context_parallel.add(backend.value)
return func
return decorator
@@ -251,15 +253,12 @@ class _AttentionBackendRegistry:
return list(cls._backends.keys())
@classmethod
def _is_context_parallel_enabled(
cls, backend: AttentionBackendName, parallel_config: Optional["ParallelConfig"]
def _is_context_parallel_available(
cls,
backend: AttentionBackendName,
) -> bool:
supports_context_parallel = backend in cls._supports_context_parallel
is_degree_greater_than_1 = parallel_config is not None and (
parallel_config.context_parallel_config.ring_degree > 1
or parallel_config.context_parallel_config.ulysses_degree > 1
)
return supports_context_parallel and is_degree_greater_than_1
supports_context_parallel = backend.value in cls._supports_context_parallel
return supports_context_parallel
@contextlib.contextmanager
@@ -306,14 +305,6 @@ def dispatch_attention_fn(
backend_name = AttentionBackendName(backend)
backend_fn = _AttentionBackendRegistry._backends.get(backend_name)
if parallel_config is not None and not _AttentionBackendRegistry._is_context_parallel_enabled(
backend_name, parallel_config
):
raise ValueError(
f"Backend {backend_name} either does not support context parallelism or context parallelism "
f"was enabled with a world size of 1."
)
kwargs = {
"query": query,
"key": key,
@@ -102,7 +102,7 @@ def get_block(
attention_head_dim: int,
norm_type: str,
act_fn: str,
qkv_mutliscales: Tuple[int] = (),
qkv_mutliscales: Tuple[int, ...] = (),
):
if block_type == "ResBlock":
block = ResBlock(in_channels, out_channels, norm_type, act_fn)
@@ -206,8 +206,8 @@ class Encoder(nn.Module):
latent_channels: int,
attention_head_dim: int = 32,
block_type: Union[str, Tuple[str]] = "ResBlock",
block_out_channels: Tuple[int] = (128, 256, 512, 512, 1024, 1024),
layers_per_block: Tuple[int] = (2, 2, 2, 2, 2, 2),
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512, 1024, 1024),
layers_per_block: Tuple[int, ...] = (2, 2, 2, 2, 2, 2),
qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)),
downsample_block_type: str = "pixel_unshuffle",
out_shortcut: bool = True,
@@ -292,8 +292,8 @@ class Decoder(nn.Module):
latent_channels: int,
attention_head_dim: int = 32,
block_type: Union[str, Tuple[str]] = "ResBlock",
block_out_channels: Tuple[int] = (128, 256, 512, 512, 1024, 1024),
layers_per_block: Tuple[int] = (2, 2, 2, 2, 2, 2),
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512, 1024, 1024),
layers_per_block: Tuple[int, ...] = (2, 2, 2, 2, 2, 2),
qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)),
norm_type: Union[str, Tuple[str]] = "rms_norm",
act_fn: Union[str, Tuple[str]] = "silu",
@@ -440,8 +440,8 @@ class AutoencoderDC(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModel
decoder_block_types: Union[str, Tuple[str]] = "ResBlock",
encoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512, 1024, 1024),
decoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512, 1024, 1024),
encoder_layers_per_block: Tuple[int] = (2, 2, 2, 3, 3, 3),
decoder_layers_per_block: Tuple[int] = (3, 3, 3, 3, 3, 3),
encoder_layers_per_block: Tuple[int, ...] = (2, 2, 2, 3, 3, 3),
decoder_layers_per_block: Tuple[int, ...] = (3, 3, 3, 3, 3, 3),
encoder_qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)),
decoder_qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)),
upsample_block_type: str = "pixel_shuffle",
@@ -78,9 +78,9 @@ class AutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModel
self,
in_channels: int = 3,
out_channels: int = 3,
down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
block_out_channels: Tuple[int] = (64,),
down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",),
up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",),
block_out_channels: Tuple[int, ...] = (64,),
layers_per_block: int = 1,
act_fn: str = "silu",
latent_channels: int = 4,
@@ -995,19 +995,19 @@ class AutoencoderKLCogVideoX(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
self,
in_channels: int = 3,
out_channels: int = 3,
down_block_types: Tuple[str] = (
down_block_types: Tuple[str, ...] = (
"CogVideoXDownBlock3D",
"CogVideoXDownBlock3D",
"CogVideoXDownBlock3D",
"CogVideoXDownBlock3D",
),
up_block_types: Tuple[str] = (
up_block_types: Tuple[str, ...] = (
"CogVideoXUpBlock3D",
"CogVideoXUpBlock3D",
"CogVideoXUpBlock3D",
"CogVideoXUpBlock3D",
),
block_out_channels: Tuple[int] = (128, 256, 256, 512),
block_out_channels: Tuple[int, ...] = (128, 256, 256, 512),
latent_channels: int = 16,
layers_per_block: int = 3,
act_fn: str = "silu",
@@ -653,7 +653,7 @@ class AutoencoderKLHunyuanVideo(ModelMixin, AutoencoderMixin, ConfigMixin):
"HunyuanVideoUpBlock3D",
"HunyuanVideoUpBlock3D",
),
block_out_channels: Tuple[int] = (128, 256, 512, 512),
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
layers_per_block: int = 2,
act_fn: str = "silu",
norm_num_groups: int = 32,
@@ -601,7 +601,7 @@ class AutoencoderKLHunyuanImageRefiner(ModelMixin, ConfigMixin):
in_channels: int = 3,
out_channels: int = 3,
latent_channels: int = 32,
block_out_channels: Tuple[int] = (128, 256, 512, 1024, 1024),
block_out_channels: Tuple[int, ...] = (128, 256, 512, 1024, 1024),
layers_per_block: int = 2,
spatial_compression_ratio: int = 16,
temporal_compression_ratio: int = 4,
@@ -688,8 +688,8 @@ class AutoencoderKLMochi(ModelMixin, AutoencoderMixin, ConfigMixin):
self,
in_channels: int = 15,
out_channels: int = 3,
encoder_block_out_channels: Tuple[int] = (64, 128, 256, 384),
decoder_block_out_channels: Tuple[int] = (128, 256, 512, 768),
encoder_block_out_channels: Tuple[int, ...] = (64, 128, 256, 384),
decoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 768),
latent_channels: int = 12,
layers_per_block: Tuple[int, ...] = (3, 3, 4, 6, 3),
act_fn: str = "silu",
@@ -679,7 +679,7 @@ class AutoencoderKLQwenImage(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
self,
base_dim: int = 96,
z_dim: int = 16,
dim_mult: Tuple[int] = [1, 2, 4, 4],
dim_mult: Tuple[int, ...] = (1, 2, 4, 4),
num_res_blocks: int = 2,
attn_scales: List[float] = [],
temperal_downsample: List[bool] = [False, True, True],
@@ -31,7 +31,7 @@ class TemporalDecoder(nn.Module):
self,
in_channels: int = 4,
out_channels: int = 3,
block_out_channels: Tuple[int] = (128, 256, 512, 512),
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
layers_per_block: int = 2,
):
super().__init__()
@@ -172,8 +172,8 @@ class AutoencoderKLTemporalDecoder(ModelMixin, AutoencoderMixin, ConfigMixin):
self,
in_channels: int = 3,
out_channels: int = 3,
down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
block_out_channels: Tuple[int] = (64,),
down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",),
block_out_channels: Tuple[int, ...] = (64,),
layers_per_block: int = 1,
latent_channels: int = 4,
sample_size: int = 32,
@@ -971,7 +971,7 @@ class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMo
base_dim: int = 96,
decoder_base_dim: Optional[int] = None,
z_dim: int = 16,
dim_mult: Tuple[int] = [1, 2, 4, 4],
dim_mult: Tuple[int, ...] = (1, 2, 4, 4),
num_res_blocks: int = 2,
attn_scales: List[float] = [],
temperal_downsample: List[bool] = [False, True, True],
@@ -293,14 +293,14 @@ class ControlNetXSAdapter(ModelMixin, ConfigMixin):
self,
conditioning_channels: int = 3,
conditioning_channel_order: str = "rgb",
conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256),
conditioning_embedding_out_channels: Tuple[int, ...] = (16, 32, 96, 256),
time_embedding_mix: float = 1.0,
learn_time_embedding: bool = False,
num_attention_heads: Union[int, Tuple[int]] = 4,
block_out_channels: Tuple[int] = (4, 8, 16, 16),
base_block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
block_out_channels: Tuple[int, ...] = (4, 8, 16, 16),
base_block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
cross_attention_dim: int = 1024,
down_block_types: Tuple[str] = (
down_block_types: Tuple[str, ...] = (
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
@@ -436,7 +436,7 @@ class ControlNetXSAdapter(ModelMixin, ConfigMixin):
time_embedding_mix: int = 1.0,
conditioning_channels: int = 3,
conditioning_channel_order: str = "rgb",
conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256),
conditioning_embedding_out_channels: Tuple[int, ...] = (16, 32, 96, 256),
):
r"""
Instantiate a [`ControlNetXSAdapter`] from a [`UNet2DConditionModel`].
@@ -529,14 +529,19 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
self,
# unet configs
sample_size: Optional[int] = 96,
down_block_types: Tuple[str] = (
down_block_types: Tuple[str, ...] = (
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"DownBlock2D",
),
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
up_block_types: Tuple[str, ...] = (
"UpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
),
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
norm_num_groups: Optional[int] = 32,
cross_attention_dim: Union[int, Tuple[int]] = 1024,
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
@@ -550,10 +555,10 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
# additional controlnet configs
time_embedding_mix: float = 1.0,
ctrl_conditioning_channels: int = 3,
ctrl_conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256),
ctrl_conditioning_embedding_out_channels: Tuple[int, ...] = (16, 32, 96, 256),
ctrl_conditioning_channel_order: str = "rgb",
ctrl_learn_time_embedding: bool = False,
ctrl_block_out_channels: Tuple[int] = (4, 8, 16, 16),
ctrl_block_out_channels: Tuple[int, ...] = (4, 8, 16, 16),
ctrl_num_attention_heads: Union[int, Tuple[int]] = 4,
ctrl_max_norm_num_groups: int = 32,
):
+53 -33
View File
@@ -1484,59 +1484,71 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
config: Union[ParallelConfig, ContextParallelConfig],
cp_plan: Optional[Dict[str, ContextParallelModelPlan]] = None,
):
from ..hooks.context_parallel import apply_context_parallel
from .attention import AttentionModuleMixin
from .attention_processor import Attention, MochiAttention
logger.warning(
"`enable_parallelism` is an experimental feature. The API may change in the future and breaking changes may be introduced at any time without warning."
)
if not torch.distributed.is_available() and not torch.distributed.is_initialized():
raise RuntimeError(
"torch.distributed must be available and initialized before calling `enable_parallelism`."
)
from ..hooks.context_parallel import apply_context_parallel
from .attention import AttentionModuleMixin
from .attention_dispatch import AttentionBackendName, _AttentionBackendRegistry
from .attention_processor import Attention, MochiAttention
if isinstance(config, ContextParallelConfig):
config = ParallelConfig(context_parallel_config=config)
if not torch.distributed.is_initialized():
raise RuntimeError("torch.distributed must be initialized before calling `enable_parallelism`.")
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
device_type = torch._C._get_accelerator().type
device_module = torch.get_device_module(device_type)
device = torch.device(device_type, rank % device_module.device_count())
cp_mesh = None
attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
if config.context_parallel_config is not None:
for module in self.modules():
if not isinstance(module, attention_classes):
continue
processor = module.processor
if processor is None or not hasattr(processor, "_attention_backend"):
continue
attention_backend = processor._attention_backend
if attention_backend is None:
attention_backend, _ = _AttentionBackendRegistry.get_active_backend()
else:
attention_backend = AttentionBackendName(attention_backend)
if not _AttentionBackendRegistry._is_context_parallel_available(attention_backend):
compatible_backends = sorted(_AttentionBackendRegistry._supports_context_parallel)
raise ValueError(
f"Context parallelism is enabled but the attention processor '{processor.__class__.__name__}' "
f"is using backend '{attention_backend.value}' which does not support context parallelism. "
f"Please set a compatible attention backend: {compatible_backends} using `model.set_attention_backend()` before "
f"calling `enable_parallelism()`."
)
# All modules use the same attention processor and backend. We don't need to
# iterate over all modules after checking the first processor
break
mesh = None
if config.context_parallel_config is not None:
cp_config = config.context_parallel_config
if cp_config.ring_degree < 1 or cp_config.ulysses_degree < 1:
raise ValueError("`ring_degree` and `ulysses_degree` must be greater than or equal to 1.")
if cp_config.ring_degree > 1 and cp_config.ulysses_degree > 1:
raise ValueError(
"Unified Ulysses-Ring attention is not yet supported. Please set either `ring_degree` or `ulysses_degree` to 1."
)
if cp_config.ring_degree * cp_config.ulysses_degree > world_size:
raise ValueError(
f"The product of `ring_degree` ({cp_config.ring_degree}) and `ulysses_degree` ({cp_config.ulysses_degree}) must not exceed the world size ({world_size})."
)
cp_mesh = torch.distributed.device_mesh.init_device_mesh(
mesh = torch.distributed.device_mesh.init_device_mesh(
device_type=device_type,
mesh_shape=(cp_config.ring_degree, cp_config.ulysses_degree),
mesh_dim_names=("ring", "ulysses"),
mesh_shape=cp_config.mesh_shape,
mesh_dim_names=cp_config.mesh_dim_names,
)
config.setup(rank, world_size, device, cp_mesh=cp_mesh)
if cp_plan is None and self._cp_plan is None:
raise ValueError(
"`cp_plan` must be provided either as an argument or set in the model's `_cp_plan` attribute."
)
cp_plan = cp_plan if cp_plan is not None else self._cp_plan
if config.context_parallel_config is not None:
apply_context_parallel(self, config.context_parallel_config, cp_plan)
config.setup(rank, world_size, device, mesh=mesh)
self._parallel_config = config
attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
for module in self.modules():
if not isinstance(module, attention_classes):
continue
@@ -1545,6 +1557,14 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
continue
processor._parallel_config = config
if config.context_parallel_config is not None:
if cp_plan is None and self._cp_plan is None:
raise ValueError(
"`cp_plan` must be provided either as an argument or set in the model's `_cp_plan` attribute."
)
cp_plan = cp_plan if cp_plan is not None else self._cp_plan
apply_context_parallel(self, config.context_parallel_config, cp_plan)
@classmethod
def _load_pretrained_model(
cls,
@@ -914,7 +914,7 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
text_embed_dim: int = 4096,
pooled_projection_dim: int = 768,
rope_theta: float = 256.0,
rope_axes_dim: Tuple[int] = (16, 56, 56),
rope_axes_dim: Tuple[int, ...] = (16, 56, 56),
image_condition_type: Optional[str] = None,
) -> None:
super().__init__()
@@ -139,7 +139,7 @@ class HunyuanVideoFramepackTransformer3DModel(
text_embed_dim: int = 4096,
pooled_projection_dim: int = 768,
rope_theta: float = 256.0,
rope_axes_dim: Tuple[int] = (16, 56, 56),
rope_axes_dim: Tuple[int, ...] = (16, 56, 56),
image_condition_type: Optional[str] = None,
has_image_proj: int = False,
image_proj_dim: int = 1152,
@@ -689,7 +689,7 @@ class HunyuanImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
text_embed_dim: int = 3584,
text_embed_2_dim: Optional[int] = None,
rope_theta: float = 256.0,
rope_axes_dim: Tuple[int] = (64, 64),
rope_axes_dim: Tuple[int, ...] = (64, 64),
use_meanflow: bool = False,
) -> None:
super().__init__()
@@ -172,7 +172,6 @@ class SanaLinearAttnProcessor3_0:
return hidden_states
# Copied from diffusers.models.transformers.transformer_wan.WanRotaryPosEmbed
class WanRotaryPosEmbed(nn.Module):
def __init__(
self,
@@ -389,6 +389,10 @@ class SkyReelsV2RotaryPosEmbed(nn.Module):
t_dim = attention_head_dim - h_dim - w_dim
freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
self.t_dim = t_dim
self.h_dim = h_dim
self.w_dim = w_dim
freqs_cos = []
freqs_sin = []
@@ -412,11 +416,7 @@ class SkyReelsV2RotaryPosEmbed(nn.Module):
p_t, p_h, p_w = self.patch_size
ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
split_sizes = [
self.attention_head_dim - 2 * (self.attention_head_dim // 3),
self.attention_head_dim // 3,
self.attention_head_dim // 3,
]
split_sizes = [self.t_dim, self.h_dim, self.w_dim]
freqs_cos = self.freqs_cos.split(split_sizes, dim=1)
freqs_sin = self.freqs_sin.split(split_sizes, dim=1)
@@ -570,7 +570,7 @@ class SkyReelsV2Transformer3DModel(
@register_to_config
def __init__(
self,
patch_size: Tuple[int] = (1, 2, 2),
patch_size: Tuple[int, ...] = (1, 2, 2),
num_attention_heads: int = 16,
attention_head_dim: int = 128,
in_channels: int = 16,
@@ -362,6 +362,11 @@ class WanRotaryPosEmbed(nn.Module):
h_dim = w_dim = 2 * (attention_head_dim // 6)
t_dim = attention_head_dim - h_dim - w_dim
self.t_dim = t_dim
self.h_dim = h_dim
self.w_dim = w_dim
freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
freqs_cos = []
@@ -387,11 +392,7 @@ class WanRotaryPosEmbed(nn.Module):
p_t, p_h, p_w = self.patch_size
ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
split_sizes = [
self.attention_head_dim - 2 * (self.attention_head_dim // 3),
self.attention_head_dim // 3,
self.attention_head_dim // 3,
]
split_sizes = [self.t_dim, self.h_dim, self.w_dim]
freqs_cos = self.freqs_cos.split(split_sizes, dim=1)
freqs_sin = self.freqs_sin.split(split_sizes, dim=1)
@@ -563,7 +564,7 @@ class WanTransformer3DModel(
@register_to_config
def __init__(
self,
patch_size: Tuple[int] = (1, 2, 2),
patch_size: Tuple[int, ...] = (1, 2, 2),
num_attention_heads: int = 40,
attention_head_dim: int = 128,
in_channels: int = 16,
@@ -182,7 +182,7 @@ class WanVACETransformer3DModel(
@register_to_config
def __init__(
self,
patch_size: Tuple[int] = (1, 2, 2),
patch_size: Tuple[int, ...] = (1, 2, 2),
num_attention_heads: int = 40,
attention_head_dim: int = 128,
in_channels: int = 16,
+4 -4
View File
@@ -86,11 +86,11 @@ class UNet1DModel(ModelMixin, ConfigMixin):
flip_sin_to_cos: bool = True,
use_timestep_embedding: bool = False,
freq_shift: float = 0.0,
down_block_types: Tuple[str] = ("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D"),
up_block_types: Tuple[str] = ("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip"),
mid_block_type: Tuple[str] = "UNetMidBlock1D",
down_block_types: Tuple[str, ...] = ("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D"),
up_block_types: Tuple[str, ...] = ("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip"),
mid_block_type: str = "UNetMidBlock1D",
out_block_type: str = None,
block_out_channels: Tuple[int] = (32, 32, 64),
block_out_channels: Tuple[int, ...] = (32, 32, 64),
act_fn: str = None,
norm_num_groups: int = 8,
layers_per_block: int = 1,
@@ -177,16 +177,21 @@ class UNet2DConditionModel(
center_input_sample: bool = False,
flip_sin_to_cos: bool = True,
freq_shift: int = 0,
down_block_types: Tuple[str] = (
down_block_types: Tuple[str, ...] = (
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"DownBlock2D",
),
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
up_block_types: Tuple[str, ...] = (
"UpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
),
only_cross_attention: Union[bool, Tuple[bool]] = False,
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
layers_per_block: Union[int, Tuple[int]] = 2,
downsample_padding: int = 1,
mid_block_scale_factor: float = 1,
@@ -486,10 +491,10 @@ class UNet2DConditionModel(
def _check_config(
self,
down_block_types: Tuple[str],
up_block_types: Tuple[str],
down_block_types: Tuple[str, ...],
up_block_types: Tuple[str, ...],
only_cross_attention: Union[bool, Tuple[bool]],
block_out_channels: Tuple[int],
block_out_channels: Tuple[int, ...],
layers_per_block: Union[int, Tuple[int]],
cross_attention_dim: Union[int, Tuple[int]],
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple[int]]],
@@ -54,7 +54,7 @@ class Kandinsky3UNet(ModelMixin, ConfigMixin):
groups: int = 32,
attention_head_dim: int = 64,
layers_per_block: Union[int, Tuple[int]] = 3,
block_out_channels: Tuple[int] = (384, 768, 1536, 3072),
block_out_channels: Tuple[int, ...] = (384, 768, 1536, 3072),
cross_attention_dim: Union[int, Tuple[int]] = 4096,
encoder_hid_dim: int = 4096,
):
@@ -73,25 +73,25 @@ class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionL
sample_size: Optional[int] = None,
in_channels: int = 8,
out_channels: int = 4,
down_block_types: Tuple[str] = (
down_block_types: Tuple[str, ...] = (
"CrossAttnDownBlockSpatioTemporal",
"CrossAttnDownBlockSpatioTemporal",
"CrossAttnDownBlockSpatioTemporal",
"DownBlockSpatioTemporal",
),
up_block_types: Tuple[str] = (
up_block_types: Tuple[str, ...] = (
"UpBlockSpatioTemporal",
"CrossAttnUpBlockSpatioTemporal",
"CrossAttnUpBlockSpatioTemporal",
"CrossAttnUpBlockSpatioTemporal",
),
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
addition_time_embed_dim: int = 256,
projection_class_embeddings_input_dim: int = 768,
layers_per_block: Union[int, Tuple[int]] = 2,
cross_attention_dim: Union[int, Tuple[int]] = 1024,
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
num_attention_heads: Union[int, Tuple[int]] = (5, 10, 20, 20),
num_attention_heads: Union[int, Tuple[int, ...]] = (5, 10, 20, 20),
num_frames: int = 25,
):
super().__init__()
@@ -145,10 +145,10 @@ class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalModelMixin):
timestep_ratio_embedding_dim: int = 64,
patch_size: int = 1,
conditioning_dim: int = 2048,
block_out_channels: Tuple[int] = (2048, 2048),
num_attention_heads: Tuple[int] = (32, 32),
down_num_layers_per_block: Tuple[int] = (8, 24),
up_num_layers_per_block: Tuple[int] = (24, 8),
block_out_channels: Tuple[int, ...] = (2048, 2048),
num_attention_heads: Tuple[int, ...] = (32, 32),
down_num_layers_per_block: Tuple[int, ...] = (8, 24),
up_num_layers_per_block: Tuple[int, ...] = (24, 8),
down_blocks_repeat_mappers: Optional[Tuple[int]] = (
1,
1,
@@ -167,7 +167,7 @@ class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalModelMixin):
kernel_size=3,
dropout: Union[float, Tuple[float]] = (0.1, 0.1),
self_attn: Union[bool, Tuple[bool]] = True,
timestep_conditioning_type: Tuple[str] = ("sca", "crp"),
timestep_conditioning_type: Tuple[str, ...] = ("sca", "crp"),
switch_level: Optional[Tuple[bool]] = None,
):
"""
+7 -7
View File
@@ -532,8 +532,8 @@ class FlaxEncoder(nn.Module):
in_channels: int = 3
out_channels: int = 3
down_block_types: Tuple[str] = ("DownEncoderBlock2D",)
block_out_channels: Tuple[int] = (64,)
down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",)
block_out_channels: Tuple[int, ...] = (64,)
layers_per_block: int = 2
norm_num_groups: int = 32
act_fn: str = "silu"
@@ -650,8 +650,8 @@ class FlaxDecoder(nn.Module):
in_channels: int = 3
out_channels: int = 3
up_block_types: Tuple[str] = ("UpDecoderBlock2D",)
block_out_channels: int = (64,)
up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",)
block_out_channels: Tuple[int, ...] = (64,)
layers_per_block: int = 2
norm_num_groups: int = 32
act_fn: str = "silu"
@@ -823,9 +823,9 @@ class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin):
in_channels: int = 3
out_channels: int = 3
down_block_types: Tuple[str] = ("DownEncoderBlock2D",)
up_block_types: Tuple[str] = ("UpDecoderBlock2D",)
block_out_channels: Tuple[int] = (64,)
down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",)
up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",)
block_out_channels: Tuple[int, ...] = (64,)
layers_per_block: int = 1
act_fn: str = "silu"
latent_channels: int = 4
@@ -245,16 +245,21 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad
out_channels: int = 4,
flip_sin_to_cos: bool = True,
freq_shift: int = 0,
down_block_types: Tuple[str] = (
down_block_types: Tuple[str, ...] = (
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"DownBlock2D",
),
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
up_block_types: Tuple[str, ...] = (
"UpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
),
only_cross_attention: Union[bool, Tuple[bool]] = False,
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
layers_per_block: Union[int, Tuple[int]] = 2,
downsample_padding: int = 1,
mid_block_scale_factor: float = 1,
@@ -374,21 +374,21 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
center_input_sample: bool = False,
flip_sin_to_cos: bool = True,
freq_shift: int = 0,
down_block_types: Tuple[str] = (
down_block_types: Tuple[str, ...] = (
"CrossAttnDownBlockFlat",
"CrossAttnDownBlockFlat",
"CrossAttnDownBlockFlat",
"DownBlockFlat",
),
mid_block_type: Optional[str] = "UNetMidBlockFlatCrossAttn",
up_block_types: Tuple[str] = (
up_block_types: Tuple[str, ...] = (
"UpBlockFlat",
"CrossAttnUpBlockFlat",
"CrossAttnUpBlockFlat",
"CrossAttnUpBlockFlat",
),
only_cross_attention: Union[bool, Tuple[bool]] = False,
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
layers_per_block: Union[int, Tuple[int]] = 2,
downsample_padding: int = 1,
mid_block_scale_factor: float = 1,
+4 -4
View File
@@ -742,7 +742,7 @@ class ShapEParamsProjModel(ModelMixin, ConfigMixin):
def __init__(
self,
*,
param_names: Tuple[str] = (
param_names: Tuple[str, ...] = (
"nerstf.mlp.0.weight",
"nerstf.mlp.1.weight",
"nerstf.mlp.2.weight",
@@ -786,13 +786,13 @@ class ShapERenderer(ModelMixin, ConfigMixin):
def __init__(
self,
*,
param_names: Tuple[str] = (
param_names: Tuple[str, ...] = (
"nerstf.mlp.0.weight",
"nerstf.mlp.1.weight",
"nerstf.mlp.2.weight",
"nerstf.mlp.3.weight",
),
param_shapes: Tuple[Tuple[int]] = (
param_shapes: Tuple[Tuple[int, int], ...] = (
(256, 93),
(256, 256),
(256, 256),
@@ -804,7 +804,7 @@ class ShapERenderer(ModelMixin, ConfigMixin):
n_hidden_layers: int = 6,
act_fn: str = "swish",
insert_direction_at: int = 4,
background: Tuple[float] = (
background: Tuple[float, ...] = (
255.0,
255.0,
255.0,
+14
View File
@@ -32,6 +32,20 @@ warnings.simplefilter(action="ignore", category=FutureWarning)
def pytest_configure(config):
config.addinivalue_line("markers", "big_accelerator: marks tests as requiring big accelerator resources")
config.addinivalue_line("markers", "lora: marks tests for LoRA/PEFT functionality")
config.addinivalue_line("markers", "ip_adapter: marks tests for IP Adapter functionality")
config.addinivalue_line("markers", "training: marks tests for training functionality")
config.addinivalue_line("markers", "attention: marks tests for attention processor functionality")
config.addinivalue_line("markers", "memory: marks tests for memory optimization functionality")
config.addinivalue_line("markers", "cpu_offload: marks tests for CPU offloading functionality")
config.addinivalue_line("markers", "group_offload: marks tests for group offloading functionality")
config.addinivalue_line("markers", "compile: marks tests for torch.compile functionality")
config.addinivalue_line("markers", "single_file: marks tests for single file checkpoint loading")
config.addinivalue_line("markers", "bitsandbytes: marks tests for BitsAndBytes quantization functionality")
config.addinivalue_line("markers", "quanto: marks tests for Quanto quantization functionality")
config.addinivalue_line("markers", "torchao: marks tests for TorchAO quantization functionality")
config.addinivalue_line("markers", "gguf: marks tests for GGUF quantization functionality")
config.addinivalue_line("markers", "modelopt: marks tests for NVIDIA ModelOpt quantization functionality")
def pytest_addoption(parser):
+12 -12
View File
@@ -317,9 +317,9 @@ class ModelUtilsTest(unittest.TestCase):
repo_id, subfolder="transformer", cache_dir=tmpdir, local_files_only=True
)
assert all(torch.equal(p1, p2) for p1, p2 in zip(model.parameters(), local_model.parameters())), (
"Model parameters don't match!"
)
assert all(
torch.equal(p1, p2) for p1, p2 in zip(model.parameters(), local_model.parameters())
), "Model parameters don't match!"
# Remove a shard file
cached_shard_file = try_to_load_from_cache(
@@ -335,9 +335,9 @@ class ModelUtilsTest(unittest.TestCase):
# Verify error mentions the missing shard
error_msg = str(context.exception)
assert cached_shard_file in error_msg or "required according to the checkpoint index" in error_msg, (
f"Expected error about missing shard, got: {error_msg}"
)
assert (
cached_shard_file in error_msg or "required according to the checkpoint index" in error_msg
), f"Expected error about missing shard, got: {error_msg}"
@unittest.skip("Flaky behaviour on CI. Re-enable after migrating to new runners")
@unittest.skipIf(torch_device == "mps", reason="Test not supported for MPS.")
@@ -354,9 +354,9 @@ class ModelUtilsTest(unittest.TestCase):
)
download_requests = [r.method for r in m.request_history]
assert download_requests.count("HEAD") == 3, (
"3 HEAD requests one for config, one for model, and one for shard index file."
)
assert (
download_requests.count("HEAD") == 3
), "3 HEAD requests one for config, one for model, and one for shard index file."
assert download_requests.count("GET") == 2, "2 GET requests one for config, one for model"
with requests_mock.mock(real_http=True) as m:
@@ -368,9 +368,9 @@ class ModelUtilsTest(unittest.TestCase):
)
cache_requests = [r.method for r in m.request_history]
assert "HEAD" == cache_requests[0] and len(cache_requests) == 2, (
"We should call only `model_info` to check for commit hash and knowing if shard index is present."
)
assert (
"HEAD" == cache_requests[0] and len(cache_requests) == 2
), "We should call only `model_info` to check for commit hash and knowing if shard index is present."
def test_weight_overwrite(self):
with tempfile.TemporaryDirectory() as tmpdirname, self.assertRaises(ValueError) as error_context:
+37
View File
@@ -0,0 +1,37 @@
from .attention import AttentionTesterMixin
from .common import ModelTesterMixin
from .compile import TorchCompileTesterMixin
from .ip_adapter import IPAdapterTesterMixin
from .lora import LoraTesterMixin
from .memory import CPUOffloadTesterMixin, GroupOffloadTesterMixin, LayerwiseCastingTesterMixin, MemoryTesterMixin
from .quantization import (
BitsAndBytesTesterMixin,
GGUFTesterMixin,
ModelOptTesterMixin,
QuantizationTesterMixin,
QuantoTesterMixin,
TorchAoTesterMixin,
)
from .single_file import SingleFileTesterMixin
from .training import TrainingTesterMixin
__all__ = [
"AttentionTesterMixin",
"BitsAndBytesTesterMixin",
"CPUOffloadTesterMixin",
"GGUFTesterMixin",
"GroupOffloadTesterMixin",
"IPAdapterTesterMixin",
"LayerwiseCastingTesterMixin",
"LoraTesterMixin",
"MemoryTesterMixin",
"ModelOptTesterMixin",
"ModelTesterMixin",
"QuantizationTesterMixin",
"QuantoTesterMixin",
"SingleFileTesterMixin",
"TorchAoTesterMixin",
"TorchCompileTesterMixin",
"TrainingTesterMixin",
]
+180
View File
@@ -0,0 +1,180 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import torch
from diffusers.models.attention import AttentionModuleMixin
from diffusers.models.attention_processor import (
AttnProcessor,
)
from ...testing_utils import is_attention, require_accelerator, torch_device
@is_attention
@require_accelerator
class AttentionTesterMixin:
"""
Mixin class for testing attention processor and module functionality on models.
Tests functionality from AttentionModuleMixin including:
- Attention processor management (set/get)
- QKV projection fusion/unfusion
- Attention backends (XFormers, NPU, etc.)
Expected class attributes to be set by subclasses:
- model_class: The model class to test
- base_precision: Tolerance for floating point comparisons (default: 1e-3)
- uses_custom_attn_processor: Whether model uses custom attention processors (default: False)
Expected methods to be implemented by subclasses:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Pytest mark: attention
Use `pytest -m "not attention"` to skip these tests
"""
base_precision = 1e-3
def test_fuse_unfuse_qkv_projections(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
if not hasattr(model, "fuse_qkv_projections"):
pytest.skip("Model does not support QKV projection fusion.")
# Get output before fusion
with torch.no_grad():
output_before_fusion = model(**inputs_dict)
if isinstance(output_before_fusion, dict):
output_before_fusion = output_before_fusion.to_tuple()[0]
# Fuse projections
model.fuse_qkv_projections()
# Verify fusion occurred by checking for fused attributes
has_fused_projections = False
for module in model.modules():
if isinstance(module, AttentionModuleMixin):
if hasattr(module, "to_qkv") or hasattr(module, "to_kv"):
has_fused_projections = True
assert module.fused_projections, "fused_projections flag should be True"
break
if has_fused_projections:
# Get output after fusion
with torch.no_grad():
output_after_fusion = model(**inputs_dict)
if isinstance(output_after_fusion, dict):
output_after_fusion = output_after_fusion.to_tuple()[0]
# Verify outputs match
assert torch.allclose(
output_before_fusion, output_after_fusion, atol=self.base_precision
), "Output should not change after fusing projections"
# Unfuse projections
model.unfuse_qkv_projections()
# Verify unfusion occurred
for module in model.modules():
if isinstance(module, AttentionModuleMixin):
assert not hasattr(module, "to_qkv"), "to_qkv should be removed after unfusing"
assert not hasattr(module, "to_kv"), "to_kv should be removed after unfusing"
assert not module.fused_projections, "fused_projections flag should be False"
# Get output after unfusion
with torch.no_grad():
output_after_unfusion = model(**inputs_dict)
if isinstance(output_after_unfusion, dict):
output_after_unfusion = output_after_unfusion.to_tuple()[0]
# Verify outputs still match
assert torch.allclose(
output_before_fusion, output_after_unfusion, atol=self.base_precision
), "Output should match original after unfusing projections"
def test_get_set_processor(self):
init_dict = self.get_init_dict()
model = self.model_class(**init_dict)
model.to(torch_device)
# Check if model has attention processors
if not hasattr(model, "attn_processors"):
pytest.skip("Model does not have attention processors.")
# Test getting processors
processors = model.attn_processors
assert isinstance(processors, dict), "attn_processors should return a dict"
assert len(processors) > 0, "Model should have at least one attention processor"
# Test that all processors can be retrieved via get_processor
for module in model.modules():
if isinstance(module, AttentionModuleMixin):
processor = module.get_processor()
assert processor is not None, "get_processor should return a processor"
# Test setting a new processor
new_processor = AttnProcessor()
module.set_processor(new_processor)
retrieved_processor = module.get_processor()
assert retrieved_processor is new_processor, "Retrieved processor should be the same as the one set"
def test_attention_processor_dict(self):
init_dict = self.get_init_dict()
model = self.model_class(**init_dict)
model.to(torch_device)
if not hasattr(model, "set_attn_processor"):
pytest.skip("Model does not support setting attention processors.")
# Get current processors
current_processors = model.attn_processors
# Create a dict of new processors
new_processors = {key: AttnProcessor() for key in current_processors.keys()}
# Set processors using dict
model.set_attn_processor(new_processors)
# Verify all processors were set
updated_processors = model.attn_processors
for key in current_processors.keys():
assert type(updated_processors[key]) == AttnProcessor, f"Processor {key} should be AttnProcessor"
def test_attention_processor_count_mismatch_raises_error(self):
init_dict = self.get_init_dict()
model = self.model_class(**init_dict)
model.to(torch_device)
if not hasattr(model, "set_attn_processor"):
pytest.skip("Model does not support setting attention processors.")
# Get current processors
current_processors = model.attn_processors
# Create a dict with wrong number of processors
wrong_processors = {list(current_processors.keys())[0]: AttnProcessor()}
# Verify error is raised
with pytest.raises(ValueError) as exc_info:
model.set_attn_processor(wrong_processors)
assert "number of processors" in str(exc_info.value).lower(), "Error should mention processor count mismatch"
+514
View File
@@ -0,0 +1,514 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import tempfile
from typing import Dict, List, Tuple
import pytest
import torch
from accelerate.utils.modeling import _get_proper_dtype, compute_module_sizes, dtype_byte_size
from diffusers.utils import SAFE_WEIGHTS_INDEX_NAME, _add_variant
from diffusers.utils.testing_utils import require_accelerator, require_torch_multi_accelerator
from ...testing_utils import torch_device
def compute_module_persistent_sizes(
model: nn.Module,
dtype: Optional[Union[str, torch.device]] = None,
special_dtypes: Optional[Dict[str, Union[str, torch.device]]] = None,
):
"""
Compute the size of each submodule of a given model (parameters + persistent buffers).
"""
if dtype is not None:
dtype = _get_proper_dtype(dtype)
dtype_size = dtype_byte_size(dtype)
if special_dtypes is not None:
special_dtypes = {key: _get_proper_dtype(dtyp) for key, dtyp in special_dtypes.items()}
special_dtypes_size = {key: dtype_byte_size(dtyp) for key, dtyp in special_dtypes.items()}
module_sizes = defaultdict(int)
module_list = []
module_list = named_persistent_module_tensors(model, recurse=True)
for name, tensor in module_list:
if special_dtypes is not None and name in special_dtypes:
size = tensor.numel() * special_dtypes_size[name]
elif dtype is None:
size = tensor.numel() * dtype_byte_size(tensor.dtype)
elif str(tensor.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
# According to the code in set_module_tensor_to_device, these types won't be converted
# so use their original size here
size = tensor.numel() * dtype_byte_size(tensor.dtype)
else:
size = tensor.numel() * min(dtype_size, dtype_byte_size(tensor.dtype))
name_parts = name.split(".")
for idx in range(len(name_parts) + 1):
module_sizes[".".join(name_parts[:idx])] += size
return module_sizes
def calculate_expected_num_shards(index_map_path):
"""
Calculate expected number of shards from index file.
Args:
index_map_path: Path to the sharded checkpoint index file
Returns:
int: Expected number of shards
"""
with open(index_map_path) as f:
weight_map_dict = json.load(f)["weight_map"]
first_key = list(weight_map_dict.keys())[0]
weight_loc = weight_map_dict[first_key] # e.g., diffusion_pytorch_model-00001-of-00002.safetensors
expected_num_shards = int(weight_loc.split("-")[-1].split(".")[0])
return expected_num_shards
def check_device_map_is_respected(model, device_map):
for param_name, param in model.named_parameters():
# Find device in device_map
while len(param_name) > 0 and param_name not in device_map:
param_name = ".".join(param_name.split(".")[:-1])
if param_name not in device_map:
raise ValueError("device map is incomplete, it does not contain any device for `param_name`.")
param_device = device_map[param_name]
if param_device in ["cpu", "disk"]:
assert param.device == torch.device("meta"), f"Expected device 'meta' for {param_name}, got {param.device}"
else:
assert param.device == torch.device(
param_device
), f"Expected device {param_device} for {param_name}, got {param.device}"
class ModelTesterMixin:
"""
Base mixin class for model testing with common test methods.
Expected class attributes to be set by subclasses:
- model_class: The model class to test
- main_input_name: Name of the main input tensor (e.g., "sample", "hidden_states")
- base_precision: Default tolerance for floating point comparisons (default: 1e-3)
Expected methods to be implemented by subclasses:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
"""
model_class = None
base_precision = 1e-3
model_split_percents = [0.5, 0.7]
def get_init_dict(self):
raise NotImplementedError("get_init_dict must be implemented by subclasses. ")
def get_dummy_inputs(self):
raise NotImplementedError(
"get_dummy_inputs must be implemented by subclasses. " "It should return inputs_dict."
)
def test_from_save_pretrained(self, expected_max_diff=5e-5):
model = self.model_class(**self.get_init_dict())
model.to(torch_device)
model.eval()
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
new_model = self.model_class.from_pretrained(tmpdirname)
new_model.to(torch_device)
# check if all parameters shape are the same
for param_name in model.state_dict().keys():
param_1 = model.state_dict()[param_name]
param_2 = new_model.state_dict()[param_name]
assert (
param_1.shape == param_2.shape
), f"Parameter shape mismatch for {param_name}. Original: {param_1.shape}, loaded: {param_2.shape}"
with torch.no_grad():
image = model(**self.get_dummy_inputs())
if isinstance(image, dict):
image = image.to_tuple()[0]
new_image = new_model(**self.get_dummy_inputs())
if isinstance(new_image, dict):
new_image = new_image.to_tuple()[0]
max_diff = (image - new_image).abs().max().item()
assert (
max_diff <= expected_max_diff
), f"Models give different forward passes. Max diff: {max_diff}, expected: {expected_max_diff}"
def test_from_save_pretrained_variant(self, expected_max_diff=5e-5):
model = self.model_class(**self.get_init_dict())
model.to(torch_device)
model.eval()
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname, variant="fp16")
new_model = self.model_class.from_pretrained(tmpdirname, variant="fp16")
# non-variant cannot be loaded
with pytest.raises(OSError) as exc_info:
self.model_class.from_pretrained(tmpdirname)
# make sure that error message states what keys are missing
assert "Error no file named diffusion_pytorch_model.bin found in directory" in str(exc_info.value)
new_model.to(torch_device)
with torch.no_grad():
image = model(**self.get_dummy_inputs())
if isinstance(image, dict):
image = image.to_tuple()[0]
new_image = new_model(**self.get_dummy_inputs())
if isinstance(new_image, dict):
new_image = new_image.to_tuple()[0]
max_diff = (image - new_image).abs().max().item()
assert (
max_diff <= expected_max_diff
), f"Models give different forward passes. Max diff: {max_diff}, expected: {expected_max_diff}"
def test_from_save_pretrained_dtype(self):
model = self.model_class(**self.get_init_dict())
model.to(torch_device)
model.eval()
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
if torch_device == "mps" and dtype == torch.bfloat16:
continue
with tempfile.TemporaryDirectory() as tmpdirname:
model.to(dtype)
model.save_pretrained(tmpdirname)
new_model = self.model_class.from_pretrained(tmpdirname, low_cpu_mem_usage=True, torch_dtype=dtype)
assert new_model.dtype == dtype
if (
hasattr(self.model_class, "_keep_in_fp32_modules")
and self.model_class._keep_in_fp32_modules is None
):
# When loading without accelerate dtype == torch.float32 if _keep_in_fp32_modules is not None
new_model = self.model_class.from_pretrained(
tmpdirname, low_cpu_mem_usage=False, torch_dtype=dtype
)
assert new_model.dtype == dtype
def test_determinism(self, expected_max_diff=1e-5):
model = self.model_class(**self.get_init_dict())
model.to(torch_device)
model.eval()
with torch.no_grad():
first = model(**self.get_dummy_inputs())
if isinstance(first, dict):
first = first.to_tuple()[0]
second = model(**self.get_dummy_inputs())
if isinstance(second, dict):
second = second.to_tuple()[0]
# Remove NaN values and compute max difference
first_flat = first.flatten()
second_flat = second.flatten()
# Filter out NaN values
mask = ~(torch.isnan(first_flat) | torch.isnan(second_flat))
first_filtered = first_flat[mask]
second_filtered = second_flat[mask]
max_diff = torch.abs(first_filtered - second_filtered).max().item()
assert (
max_diff <= expected_max_diff
), f"Model outputs are not deterministic. Max diff: {max_diff}, expected: {expected_max_diff}"
def test_output(self, expected_output_shape=None):
model = self.model_class(**self.get_init_dict())
model.to(torch_device)
model.eval()
inputs_dict = self.get_dummy_inputs()
with torch.no_grad():
output = model(**inputs_dict)
if isinstance(output, dict):
output = output.to_tuple()[0]
assert output is not None, "Model output is None"
assert (
output.shape == expected_output_shape
), f"Output shape does not match expected. Expected {expected_output_shape}, got {output.shape}"
def test_outputs_equivalence(self):
def set_nan_tensor_to_zero(t):
# Temporary fallback until `aten::_index_put_impl_` is implemented in mps
# Track progress in https://github.com/pytorch/pytorch/issues/77764
device = t.device
if device.type == "mps":
t = t.to("cpu")
t[t != t] = 0
return t.to(device)
def recursive_check(tuple_object, dict_object):
if isinstance(tuple_object, (List, Tuple)):
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()):
recursive_check(tuple_iterable_value, dict_iterable_value)
elif isinstance(tuple_object, Dict):
for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()):
recursive_check(tuple_iterable_value, dict_iterable_value)
elif tuple_object is None:
return
else:
assert torch.allclose(
set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5
), (
"Tuple and dict output are not equal. Difference:"
f" {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`:"
f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has"
f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}."
)
model = self.model_class(**self.get_init_dict())
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs_dict = model(**self.get_dummy_inputs())
outputs_tuple = model(**self.get_dummy_inputs(), return_dict=False)
recursive_check(outputs_tuple, outputs_dict)
def test_model_config_to_json_string(self):
model = self.model_class(**self.get_init_dict())
json_string = model.config.to_json_string()
assert isinstance(json_string, str), "Config to_json_string should return a string"
assert len(json_string) > 0, "JSON string should not be empty"
@require_accelerator
@pytest.mark.skipif(torch_device not in ["cuda", "xpu"])
def test_from_save_pretrained_float16_bfloat16(self):
model = self.model_class(**self.get_init_dict())
model.to(torch_device)
fp32_modules = model._keep_in_fp32_modules
with tempfile.TemporaryDirectory() as tmp_dir:
for torch_dtype in [torch.bfloat16, torch.float16]:
model.to(torch_dtype).save_pretrained(tmp_dir)
model_loaded = self.model_class.from_pretrained(tmp_dir, torch_dtype=torch_dtype).to(torch_device)
for name, param in model_loaded.named_parameters():
if any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in fp32_modules):
assert param.data.dtype == torch.float32
else:
assert param.data.dtype == torch_dtype
with torch.no_grad():
output = model(**get_dummy_inputs())
output_loaded = model_loaded(**get_dummy_inputs())
assert torch.allclose(
output, output_loaded, atol=1e-4
), f"Loaded model output differs for {torch_dtype}"
@require_accelerator
def test_sharded_checkpoints(self):
torch.manual_seed(0)
config = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**config).eval()
model = model.to(torch_device)
base_output = model(**inputs_dict)
model_size = compute_module_persistent_sizes(model)[""]
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small
with tempfile.TemporaryDirectory() as tmp_dir:
model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB")
assert os.path.exists(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)), "Index file should exist"
# Check if the right number of shards exists
expected_num_shards = calculate_expected_num_shards(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME))
actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(".safetensors")])
assert (
actual_num_shards == expected_num_shards
), f"Expected {expected_num_shards} shards, got {actual_num_shards}"
new_model = self.model_class.from_pretrained(tmp_dir).eval()
new_model = new_model.to(torch_device)
torch.manual_seed(0)
inputs_dict_new = self.get_dummy_inputs()
new_output = new_model(**inputs_dict_new)
assert torch.allclose(
base_output[0], new_output[0], atol=1e-5
), "Output should match after sharded save/load"
@require_accelerator
def test_sharded_checkpoints_with_variant(self):
torch.manual_seed(0)
config = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**config).eval()
model = model.to(torch_device)
base_output = model(**inputs_dict)
model_size = compute_module_persistent_sizes(model)[""]
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small
variant = "fp16"
with tempfile.TemporaryDirectory() as tmp_dir:
model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB", variant=variant)
index_filename = _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)
assert os.path.exists(
os.path.join(tmp_dir, index_filename)
), f"Variant index file {index_filename} should exist"
# Check if the right number of shards exists
expected_num_shards = calculate_expected_num_shards(os.path.join(tmp_dir, index_filename))
actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(".safetensors")])
assert (
actual_num_shards == expected_num_shards
), f"Expected {expected_num_shards} shards, got {actual_num_shards}"
new_model = self.model_class.from_pretrained(tmp_dir, variant=variant).eval()
new_model = new_model.to(torch_device)
torch.manual_seed(0)
inputs_dict_new = self.get_dummy_inputs()
new_output = new_model(**inputs_dict_new)
assert torch.allclose(
base_output[0], new_output[0], atol=1e-5
), "Output should match after variant sharded save/load"
@require_accelerator
def test_sharded_checkpoints_with_parallel_loading(self):
import time
from diffusers.utils import constants
torch.manual_seed(0)
config = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**config).eval()
model = model.to(torch_device)
base_output = model(**inputs_dict)
model_size = compute_module_persistent_sizes(model)[""]
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small
# Save original values to restore after test
original_parallel_loading = constants.HF_ENABLE_PARALLEL_LOADING
original_parallel_workers = getattr(constants, "HF_PARALLEL_WORKERS", None)
try:
with tempfile.TemporaryDirectory() as tmp_dir:
model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB")
assert os.path.exists(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)), "Index file should exist"
# Check if the right number of shards exists
expected_num_shards = calculate_expected_num_shards(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME))
actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(".safetensors")])
assert (
actual_num_shards == expected_num_shards
), f"Expected {expected_num_shards} shards, got {actual_num_shards}"
# Load without parallel loading
constants.HF_ENABLE_PARALLEL_LOADING = False
start_time = time.time()
model_sequential = self.model_class.from_pretrained(tmp_dir).eval()
sequential_load_time = time.time() - start_time
model_sequential = model_sequential.to(torch_device)
torch.manual_seed(0)
# Load with parallel loading
constants.HF_ENABLE_PARALLEL_LOADING = True
constants.DEFAULT_HF_PARALLEL_LOADING_WORKERS = 2
start_time = time.time()
model_parallel = self.model_class.from_pretrained(tmp_dir).eval()
parallel_load_time = time.time() - start_time
model_parallel = model_parallel.to(torch_device)
torch.manual_seed(0)
inputs_dict_parallel = self.get_dummy_inputs()
output_parallel = model_parallel(**inputs_dict_parallel)
assert torch.allclose(
base_output[0], output_parallel[0], atol=1e-5
), "Output should match with parallel loading"
# Verify parallel loading is faster or at least not significantly slower
# For small test models, the difference might be negligible or even slightly slower due to overhead
# so we just check that parallel loading completed successfully and outputs match
assert (
parallel_load_time < sequential_load_time
), f"Parallel loading took {parallel_load_time:.4f}s, sequential took {sequential_load_time:.4f}s"
finally:
# Restore original values
constants.HF_ENABLE_PARALLEL_LOADING = original_parallel_loading
if original_parallel_workers is not None:
constants.HF_PARALLEL_WORKERS = original_parallel_workers
@require_torch_multi_accelerator
def test_model_parallelism(self):
if self.model_class._no_split_modules is None:
pytest.skip("Test not supported for this model as `_no_split_modules` is not set.")
config = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**config).eval()
model = model.to(torch_device)
torch.manual_seed(0)
base_output = model(**inputs_dict)
model_size = compute_module_sizes(model)[""]
max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents]
with tempfile.TemporaryDirectory() as tmp_dir:
model.cpu().save_pretrained(tmp_dir)
for max_size in max_gpu_sizes:
max_memory = {0: max_size, 1: model_size * 2, "cpu": model_size * 2}
new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
# Making sure part of the model will be on GPU 0 and GPU 1
assert set(new_model.hf_device_map.values()) == {0, 1}, "Model should be split across GPUs"
check_device_map_is_respected(new_model, new_model.hf_device_map)
torch.manual_seed(0)
new_output = new_model(**inputs_dict)
assert torch.allclose(
base_output[0], new_output[0], atol=1e-5
), "Output should match with model parallelism"
+162
View File
@@ -0,0 +1,162 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import os
import tempfile
import pytest
import torch
from ...testing_utils import (
backend_empty_cache,
is_torch_compile,
require_accelerator,
require_torch_version_greater,
torch_device,
)
@is_torch_compile
@require_accelerator
@require_torch_version_greater("2.7.1")
class TorchCompileTesterMixin:
"""
Mixin class for testing torch.compile functionality on models.
Expected class attributes to be set by subclasses:
- model_class: The model class to test
- different_shapes_for_compilation: Optional list of (height, width) tuples for dynamic shape testing
Expected methods to be implemented by subclasses:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Pytest mark: compile
Use `pytest -m "not compile"` to skip these tests
"""
different_shapes_for_compilation = None
def setup_method(self):
torch.compiler.reset()
gc.collect()
backend_empty_cache(torch_device)
def teardown_method(self):
torch.compiler.reset()
gc.collect()
backend_empty_cache(torch_device)
def test_torch_compile_recompilation_and_graph_break(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
model.eval()
model = torch.compile(model, fullgraph=True)
with (
torch._inductor.utils.fresh_inductor_cache(),
torch._dynamo.config.patch(error_on_recompile=True),
torch.no_grad(),
):
_ = model(**inputs_dict)
_ = model(**inputs_dict)
def test_torch_compile_repeated_blocks(self):
if self.model_class._repeated_blocks is None:
pytest.skip("Skipping test as the model class doesn't have `_repeated_blocks` set.")
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
model.eval()
model.compile_repeated_blocks(fullgraph=True)
recompile_limit = 1
if self.model_class.__name__ == "UNet2DConditionModel":
recompile_limit = 2
with (
torch._inductor.utils.fresh_inductor_cache(),
torch._dynamo.config.patch(recompile_limit=recompile_limit),
torch.no_grad(),
):
_ = model(**inputs_dict)
_ = model(**inputs_dict)
def test_compile_with_group_offloading(self):
if not self.model_class._supports_group_offloading:
pytest.skip("Model does not support group offloading.")
torch._dynamo.config.cache_size_limit = 10000
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict)
model.eval()
group_offload_kwargs = {
"onload_device": torch_device,
"offload_device": "cpu",
"offload_type": "block_level",
"num_blocks_per_group": 1,
"use_stream": True,
"non_blocking": True,
}
model.enable_group_offload(**group_offload_kwargs)
model.compile()
with torch.no_grad():
_ = model(**inputs_dict)
_ = model(**inputs_dict)
def test_compile_on_different_shapes(self):
if self.different_shapes_for_compilation is None:
pytest.skip(f"Skipping as `different_shapes_for_compilation` is not set for {self.__class__.__name__}.")
torch.fx.experimental._config.use_duck_shape = False
init_dict = self.get_init_dict()
model = self.model_class(**init_dict).to(torch_device)
model.eval()
model = torch.compile(model, fullgraph=True, dynamic=True)
for height, width in self.different_shapes_for_compilation:
with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad():
inputs_dict = self.get_dummy_inputs(height=height, width=width)
_ = model(**inputs_dict)
def test_compile_works_with_aot(self):
from torch._inductor.package import load_package
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
exported_model = torch.export.export(model, args=(), kwargs=inputs_dict)
with tempfile.TemporaryDirectory() as tmpdir:
package_path = os.path.join(tmpdir, f"{self.model_class.__name__}.pt2")
_ = torch._inductor.aoti_compile_and_package(exported_model, package_path=package_path)
assert os.path.exists(package_path), f"Package file not created at {package_path}"
loaded_binary = load_package(package_path, run_single_threaded=True)
model.forward = loaded_binary
with torch.no_grad():
_ = model(**inputs_dict)
_ = model(**inputs_dict)
+109
View File
@@ -0,0 +1,109 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile
import uuid
import pytest
import torch
from huggingface_hub.utils import is_jinja_available
from ...others.test_utils import TOKEN, USER, is_staging_test
@is_staging_test
class ModelPushToHubTesterMixin:
"""
Mixin class for testing push_to_hub functionality on models.
Expected class attributes to be set by subclasses:
- model_class: The model class to test
Expected methods to be implemented by subclasses:
- get_init_dict(): Returns dict of arguments to initialize the model
"""
identifier = uuid.uuid4()
repo_id = f"test-model-{identifier}"
org_repo_id = f"valid_org/{repo_id}-org"
def test_push_to_hub(self):
"""Test pushing model to hub and loading it back."""
init_dict = self.get_init_dict()
model = self.model_class(**init_dict)
model.push_to_hub(self.repo_id, token=TOKEN)
new_model = self.model_class.from_pretrained(f"{USER}/{self.repo_id}")
for p1, p2 in zip(model.parameters(), new_model.parameters()):
assert torch.equal(p1, p2), "Parameters don't match after push_to_hub and from_pretrained"
# Reset repo
delete_repo(token=TOKEN, repo_id=self.repo_id)
# Push to hub via save_pretrained
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir, repo_id=self.repo_id, push_to_hub=True, token=TOKEN)
new_model = self.model_class.from_pretrained(f"{USER}/{self.repo_id}")
for p1, p2 in zip(model.parameters(), new_model.parameters()):
assert torch.equal(
p1, p2
), "Parameters don't match after save_pretrained with push_to_hub and from_pretrained"
# Reset repo
delete_repo(self.repo_id, token=TOKEN)
def test_push_to_hub_in_organization(self):
"""Test pushing model to hub in organization namespace."""
init_dict = self.get_init_dict()
model = self.model_class(**init_dict)
model.push_to_hub(self.org_repo_id, token=TOKEN)
new_model = self.model_class.from_pretrained(self.org_repo_id)
for p1, p2 in zip(model.parameters(), new_model.parameters()):
assert torch.equal(p1, p2), "Parameters don't match after push_to_hub to org and from_pretrained"
# Reset repo
delete_repo(token=TOKEN, repo_id=self.org_repo_id)
# Push to hub via save_pretrained
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir, push_to_hub=True, token=TOKEN, repo_id=self.org_repo_id)
new_model = self.model_class.from_pretrained(self.org_repo_id)
for p1, p2 in zip(model.parameters(), new_model.parameters()):
assert torch.equal(
p1, p2
), "Parameters don't match after save_pretrained with push_to_hub to org and from_pretrained"
# Reset repo
delete_repo(self.org_repo_id, token=TOKEN)
def test_push_to_hub_library_name(self):
"""Test that library_name in model card is set to 'diffusers'."""
if not is_jinja_available():
pytest.skip("Model card tests cannot be performed without Jinja installed.")
init_dict = self.get_init_dict()
model = self.model_class(**init_dict)
model.push_to_hub(self.repo_id, token=TOKEN)
model_card = ModelCard.load(f"{USER}/{self.repo_id}", token=TOKEN).data
assert (
model_card.library_name == "diffusers"
), f"Expected library_name 'diffusers', got {model_card.library_name}"
# Reset repo
delete_repo(self.repo_id, token=TOKEN)
+205
View File
@@ -0,0 +1,205 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import tempfile
import torch
from diffusers.models.attention_processor import IPAdapterAttnProcessor
from ...testing_utils import is_ip_adapter, torch_device
def create_ip_adapter_state_dict(model):
"""
Create a dummy IP Adapter state dict for testing.
Args:
model: The model to create IP adapter weights for
Returns:
dict: IP adapter state dict with to_k_ip and to_v_ip weights
"""
ip_state_dict = {}
key_id = 1
for name in model.attn_processors.keys():
# Skip self-attention processors
cross_attention_dim = getattr(model.config, "cross_attention_dim", None)
if cross_attention_dim is None:
continue
# Get hidden size based on model architecture
hidden_size = getattr(model.config, "hidden_size", cross_attention_dim)
# Create IP adapter processor to get state dict structure
sd = IPAdapterAttnProcessor(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0
).state_dict()
ip_state_dict.update(
{
f"{key_id}.to_k_ip.weight": sd["to_k_ip.0.weight"],
f"{key_id}.to_v_ip.weight": sd["to_v_ip.0.weight"],
}
)
key_id += 2
return {"ip_adapter": ip_state_dict}
def check_if_ip_adapter_correctly_set(model) -> bool:
"""
Check if IP Adapter processors are correctly set in the model.
Args:
model: The model to check
Returns:
bool: True if IP Adapter is correctly set, False otherwise
"""
for module in model.attn_processors.values():
if isinstance(module, IPAdapterAttnProcessor):
return True
return False
@is_ip_adapter
class IPAdapterTesterMixin:
"""
Mixin class for testing IP Adapter functionality on models.
Expected class attributes to be set by subclasses:
- model_class: The model class to test
Expected methods to be implemented by subclasses:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Pytest mark: ip_adapter
Use `pytest -m "not ip_adapter"` to skip these tests
"""
def create_ip_adapter_state_dict(self, model):
raise NotImplementedError("child class must implement method to create IPAdapter State Dict")
def test_load_ip_adapter(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
torch.manual_seed(0)
output_no_adapter = model(**inputs_dict, return_dict=False)[0]
# Create dummy IP adapter state dict
ip_adapter_state_dict = self.create_ip_adapter_state_dict(model)
# Load IP adapter
model._load_ip_adapter_weights([ip_adapter_state_dict])
assert check_if_ip_adapter_correctly_set(model), "IP Adapter processors not set correctly"
torch.manual_seed(0)
# Create dummy image embeds for IP adapter
cross_attention_dim = getattr(model.config, "cross_attention_dim", 32)
image_embeds = torch.randn(1, 1, cross_attention_dim).to(torch_device)
inputs_dict_with_adapter = inputs_dict.copy()
inputs_dict_with_adapter["image_embeds"] = image_embeds
outputs_with_adapter = model(**inputs_dict_with_adapter, return_dict=False)[0]
assert not torch.allclose(
output_no_adapter, outputs_with_adapter, atol=1e-4, rtol=1e-4
), "Output should differ with IP Adapter enabled"
def test_ip_adapter_scale(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
# Create and load dummy IP adapter state dict
ip_adapter_state_dict = create_ip_adapter_state_dict(model)
model._load_ip_adapter_weights([ip_adapter_state_dict])
# Test scale = 0.0 (no effect)
model.set_ip_adapter_scale(0.0)
torch.manual_seed(0)
output_scale_zero = model(**inputs_dict_with_adapter, return_dict=False)[0]
# Test scale = 1.0 (full effect)
model.set_ip_adapter_scale(1.0)
torch.manual_seed(0)
output_scale_one = model(**inputs_dict_with_adapter, return_dict=False)[0]
# Outputs should differ with different scales
assert not torch.allclose(
output_scale_zero, output_scale_one, atol=1e-4, rtol=1e-4
), "Output should differ with different IP Adapter scales"
def test_unload_ip_adapter(self):
init_dict = self.get_init_dict()
model = self.model_class(**init_dict).to(torch_device)
# Save original processors
original_processors = {k: type(v).__name__ for k, v in model.attn_processors.items()}
# Create and load IP adapter
ip_adapter_state_dict = create_ip_adapter_state_dict(model)
model._load_ip_adapter_weights([ip_adapter_state_dict])
assert check_if_ip_adapter_correctly_set(model), "IP Adapter should be set"
# Unload IP adapter
model.unload_ip_adapter()
assert not check_if_ip_adapter_correctly_set(model), "IP Adapter should be unloaded"
# Verify processors are restored
current_processors = {k: type(v).__name__ for k, v in model.attn_processors.items()}
assert original_processors == current_processors, "Processors should be restored after unload"
def test_ip_adapter_save_load(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
# Create and load IP adapter
ip_adapter_state_dict = self.create_ip_adapter_state_dict()
model._load_ip_adapter_weights([ip_adapter_state_dict])
torch.manual_seed(0)
output_before_save = model(**inputs_dict, return_dict=False)[0]
with tempfile.TemporaryDirectory() as tmpdir:
# Save the IP adapter weights
save_path = os.path.join(tmpdir, "ip_adapter.safetensors")
import safetensors.torch
safetensors.torch.save_file(ip_adapter_state_dict["ip_adapter"], save_path)
# Unload and reload
model.unload_ip_adapter()
assert not check_if_ip_adapter_correctly_set(model), "IP Adapter should be unloaded"
# Reload from saved file
loaded_state_dict = {"ip_adapter": safetensors.torch.load_file(save_path)}
model._load_ip_adapter_weights([loaded_state_dict])
assert check_if_ip_adapter_correctly_set(model), "IP Adapter should be loaded"
torch.manual_seed(0)
output_after_load = model(**inputs_dict_with_adapter, return_dict=False)[0]
# Outputs should match before and after save/load
assert torch.allclose(
output_before_save, output_after_load, atol=1e-4, rtol=1e-4
), "Output should match before and after save/load"
+220
View File
@@ -0,0 +1,220 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import tempfile
import pytest
import safetensors.torch
import torch
from diffusers.utils.testing_utils import check_if_dicts_are_equal
from ...testing_utils import is_lora, require_peft_backend, torch_device
def check_if_lora_correctly_set(model) -> bool:
"""
Check if LoRA layers are correctly set in the model.
Args:
model: The model to check
Returns:
bool: True if LoRA is correctly set, False otherwise
"""
from peft.tuners.tuners_utils import BaseTunerLayer
for module in model.modules():
if isinstance(module, BaseTunerLayer):
return True
return False
@is_lora
@require_peft_backend
class LoraTesterMixin:
"""
Mixin class for testing LoRA/PEFT functionality on models.
Expected class attributes to be set by subclasses:
- model_class: The model class to test
Expected methods to be implemented by subclasses:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Pytest mark: lora
Use `pytest -m "not lora"` to skip these tests
"""
def setup_method(self):
from diffusers.loaders.peft import PeftAdapterMixin
if not issubclass(self.model_class, PeftAdapterMixin):
pytest.skip(f"PEFT is not supported for this model ({self.model_class.__name__}).")
def test_save_load_lora_adapter(self, rank=4, lora_alpha=4, use_dora=False):
from peft import LoraConfig
from peft.utils import get_peft_model_state_dict
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
torch.manual_seed(0)
output_no_lora = model(**inputs_dict, return_dict=False)[0]
denoiser_lora_config = LoraConfig(
r=rank,
lora_alpha=lora_alpha,
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
init_lora_weights=False,
use_dora=use_dora,
)
model.add_adapter(denoiser_lora_config)
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly"
torch.manual_seed(0)
outputs_with_lora = model(**inputs_dict, return_dict=False)[0]
assert not torch.allclose(
output_no_lora, outputs_with_lora, atol=1e-4, rtol=1e-4
), "Output should differ with LoRA enabled"
with tempfile.TemporaryDirectory() as tmpdir:
model.save_lora_adapter(tmpdir)
assert os.path.isfile(
os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
), "LoRA weights file not created"
state_dict_loaded = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
model.unload_lora()
assert not check_if_lora_correctly_set(model), "LoRA should be unloaded"
model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True)
state_dict_retrieved = get_peft_model_state_dict(model, adapter_name="default_0")
for k in state_dict_loaded:
loaded_v = state_dict_loaded[k]
retrieved_v = state_dict_retrieved[k].to(loaded_v.device)
assert torch.allclose(loaded_v, retrieved_v), f"Mismatch in LoRA weight {k}"
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly after reload"
torch.manual_seed(0)
outputs_with_lora_2 = model(**inputs_dict, return_dict=False)[0]
assert not torch.allclose(
output_no_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4
), "Output should differ with LoRA enabled"
assert torch.allclose(
outputs_with_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4
), "Outputs should match before and after save/load"
def test_lora_wrong_adapter_name_raises_error(self):
from peft import LoraConfig
init_dict = self.get_init_dict()
model = self.model_class(**init_dict).to(torch_device)
denoiser_lora_config = LoraConfig(
r=4,
lora_alpha=4,
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
init_lora_weights=False,
use_dora=False,
)
model.add_adapter(denoiser_lora_config)
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly"
with tempfile.TemporaryDirectory() as tmpdir:
wrong_name = "foo"
with pytest.raises(ValueError) as exc_info:
model.save_lora_adapter(tmpdir, adapter_name=wrong_name)
assert f"Adapter name {wrong_name} not found in the model." in str(exc_info.value)
def test_lora_adapter_metadata_is_loaded_correctly(self, rank=4, lora_alpha=4, use_dora=False):
from peft import LoraConfig
init_dict = self.get_init_dict()
model = self.model_class(**init_dict).to(torch_device)
denoiser_lora_config = LoraConfig(
r=rank,
lora_alpha=lora_alpha,
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
init_lora_weights=False,
use_dora=use_dora,
)
model.add_adapter(denoiser_lora_config)
metadata = model.peft_config["default"].to_dict()
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly"
with tempfile.TemporaryDirectory() as tmpdir:
model.save_lora_adapter(tmpdir)
model_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
assert os.path.isfile(model_file), "LoRA weights file not created"
model.unload_lora()
assert not check_if_lora_correctly_set(model), "LoRA should be unloaded"
model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True)
parsed_metadata = model.peft_config["default_0"].to_dict()
check_if_dicts_are_equal(metadata, parsed_metadata)
def test_lora_adapter_wrong_metadata_raises_error(self):
from peft import LoraConfig
from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY
init_dict = self.get_init_dict()
model = self.model_class(**init_dict).to(torch_device)
denoiser_lora_config = LoraConfig(
r=4,
lora_alpha=4,
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
init_lora_weights=False,
use_dora=False,
)
model.add_adapter(denoiser_lora_config)
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly"
with tempfile.TemporaryDirectory() as tmpdir:
model.save_lora_adapter(tmpdir)
model_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
assert os.path.isfile(model_file), "LoRA weights file not created"
# Perturb the metadata in the state dict
loaded_state_dict = safetensors.torch.load_file(model_file)
metadata = {"format": "pt"}
lora_adapter_metadata = denoiser_lora_config.to_dict()
lora_adapter_metadata.update({"foo": 1, "bar": 2})
for key, value in lora_adapter_metadata.items():
if isinstance(value, set):
lora_adapter_metadata[key] = list(value)
metadata[LORA_ADAPTER_METADATA_KEY] = json.dumps(lora_adapter_metadata, indent=2, sort_keys=True)
safetensors.torch.save_file(loaded_state_dict, model_file, metadata=metadata)
model.unload_lora()
assert not check_if_lora_correctly_set(model), "LoRA should be unloaded"
with pytest.raises(TypeError) as exc_info:
model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True)
assert "`LoraConfig` class could not be instantiated" in str(exc_info.value)
+443
View File
@@ -0,0 +1,443 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import glob
import inspect
import tempfile
from functools import wraps
import pytest
import torch
from accelerate.utils.modeling import compute_module_sizes
from diffusers.utils.testing_utils import _check_safetensors_serialization
from diffusers.utils.torch_utils import get_torch_cuda_device_capability
from ...testing_utils import (
backend_empty_cache,
backend_max_memory_allocated,
backend_reset_peak_memory_stats,
backend_synchronize,
is_cpu_offload,
is_group_offload,
is_memory,
require_accelerator,
torch_device,
)
from .common import check_device_map_is_respected
def cast_maybe_tensor_dtype(inputs_dict, from_dtype, to_dtype):
"""Helper to cast tensor inputs from one dtype to another."""
for key, value in inputs_dict.items():
if isinstance(value, torch.Tensor) and value.dtype == from_dtype:
inputs_dict[key] = value.to(to_dtype)
return inputs_dict
def require_offload_support(func):
"""
Decorator to skip tests if model doesn't support offloading (requires _no_split_modules).
"""
@wraps(func)
def wrapper(self, *args, **kwargs):
if self.model_class._no_split_modules is None:
pytest.skip("Test not supported for this model as `_no_split_modules` is not set.")
return func(self, *args, **kwargs)
return wrapper
def require_group_offload_support(func):
"""
Decorator to skip tests if model doesn't support group offloading.
"""
@wraps(func)
def wrapper(self, *args, **kwargs):
if not self.model_class._supports_group_offloading:
pytest.skip("Model does not support group offloading.")
return func(self, *args, **kwargs)
return wrapper
@is_cpu_offload
class CPUOffloadTesterMixin:
"""
Mixin class for testing CPU offloading functionality.
Expected class attributes to be set by subclasses:
- model_class: The model class to test
- model_split_percents: List of percentages for splitting model across devices
Expected methods to be implemented by subclasses:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Pytest mark: cpu_offload
Use `pytest -m "not cpu_offload"` to skip these tests
"""
model_split_percents = [0.5, 0.7]
@require_offload_support
def test_cpu_offload(self):
config = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**config).eval()
model = model.to(torch_device)
torch.manual_seed(0)
base_output = model(**inputs_dict)
model_size = compute_module_sizes(model)[""]
# We test several splits of sizes to make sure it works
max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents]
with tempfile.TemporaryDirectory() as tmp_dir:
model.cpu().save_pretrained(tmp_dir)
for max_size in max_gpu_sizes:
max_memory = {0: max_size, "cpu": model_size * 2}
new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
# Making sure part of the model will actually end up offloaded
assert set(new_model.hf_device_map.values()) == {0, "cpu"}, "Model should be split between GPU and CPU"
check_device_map_is_respected(new_model, new_model.hf_device_map)
torch.manual_seed(0)
new_output = new_model(**inputs_dict)
assert torch.allclose(
base_output[0], new_output[0], atol=1e-5
), "Output should match with CPU offloading"
@require_offload_support
def test_disk_offload_without_safetensors(self):
config = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**config).eval()
model = model.to(torch_device)
torch.manual_seed(0)
base_output = model(**inputs_dict)
model_size = compute_module_sizes(model)[""]
max_size = int(self.model_split_percents[0] * model_size)
# Force disk offload by setting very small CPU memory
max_memory = {0: max_size, "cpu": int(0.1 * max_size)}
with tempfile.TemporaryDirectory() as tmp_dir:
model.cpu().save_pretrained(tmp_dir, safe_serialization=False)
# This errors out because it's missing an offload folder
with pytest.raises(ValueError):
new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
new_model = self.model_class.from_pretrained(
tmp_dir, device_map="auto", max_memory=max_memory, offload_folder=tmp_dir
)
check_device_map_is_respected(new_model, new_model.hf_device_map)
torch.manual_seed(0)
new_output = new_model(**inputs_dict)
assert torch.allclose(base_output[0], new_output[0], atol=1e-5), "Output should match with disk offloading"
@require_offload_support
def test_disk_offload_with_safetensors(self):
config = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**config).eval()
model = model.to(torch_device)
torch.manual_seed(0)
base_output = model(**inputs_dict)
model_size = compute_module_sizes(model)[""]
with tempfile.TemporaryDirectory() as tmp_dir:
model.cpu().save_pretrained(tmp_dir)
max_size = int(self.model_split_percents[0] * model_size)
max_memory = {0: max_size, "cpu": max_size}
new_model = self.model_class.from_pretrained(
tmp_dir, device_map="auto", offload_folder=tmp_dir, max_memory=max_memory
)
check_device_map_is_respected(new_model, new_model.hf_device_map)
torch.manual_seed(0)
new_output = new_model(**inputs_dict)
assert torch.allclose(
base_output[0], new_output[0], atol=1e-5
), "Output should match with disk offloading (safetensors)"
@is_group_offload
class GroupOffloadTesterMixin:
"""
Mixin class for testing group offloading functionality.
Expected class attributes to be set by subclasses:
- model_class: The model class to test
Expected methods to be implemented by subclasses:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Pytest mark: group_offload
Use `pytest -m "not group_offload"` to skip these tests
"""
@require_group_offload_support
def test_group_offloading(self, record_stream=False):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
torch.manual_seed(0)
@torch.no_grad()
def run_forward(model):
assert all(
module._diffusers_hook.get_hook("group_offloading") is not None
for module in model.modules()
if hasattr(module, "_diffusers_hook")
), "Group offloading hook should be set"
model.eval()
return model(**inputs_dict)[0]
model = self.model_class(**init_dict)
model.to(torch_device)
output_without_group_offloading = run_forward(model)
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1)
output_with_group_offloading1 = run_forward(model)
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, non_blocking=True)
output_with_group_offloading2 = run_forward(model)
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.enable_group_offload(torch_device, offload_type="leaf_level")
output_with_group_offloading3 = run_forward(model)
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.enable_group_offload(
torch_device, offload_type="leaf_level", use_stream=True, record_stream=record_stream
)
output_with_group_offloading4 = run_forward(model)
assert torch.allclose(
output_without_group_offloading, output_with_group_offloading1, atol=1e-5
), "Output should match with block-level offloading"
assert torch.allclose(
output_without_group_offloading, output_with_group_offloading2, atol=1e-5
), "Output should match with non-blocking block-level offloading"
assert torch.allclose(
output_without_group_offloading, output_with_group_offloading3, atol=1e-5
), "Output should match with leaf-level offloading"
assert torch.allclose(
output_without_group_offloading, output_with_group_offloading4, atol=1e-5
), "Output should match with leaf-level offloading with stream"
@require_group_offload_support
@torch.no_grad()
def test_group_offloading_with_layerwise_casting(self, record_stream=False, offload_type="block_level"):
torch.manual_seed(0)
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
_ = model(**inputs_dict)[0]
torch.manual_seed(0)
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
storage_dtype, compute_dtype = torch.float16, torch.float32
inputs_dict = cast_maybe_tensor_dtype(inputs_dict, torch.float32, compute_dtype)
model = self.model_class(**init_dict)
model.eval()
additional_kwargs = {} if offload_type == "leaf_level" else {"num_blocks_per_group": 1}
model.enable_group_offload(
torch_device, offload_type=offload_type, use_stream=True, record_stream=record_stream, **additional_kwargs
)
model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
_ = model(**inputs_dict)[0]
@require_group_offload_support
@torch.no_grad()
@torch.inference_mode()
def test_group_offloading_with_disk(self, offload_type="block_level", record_stream=False, atol=1e-5):
def _has_generator_arg(model):
sig = inspect.signature(model.forward)
params = sig.parameters
return "generator" in params
def _run_forward(model, inputs_dict):
accepts_generator = _has_generator_arg(model)
if accepts_generator:
inputs_dict["generator"] = torch.manual_seed(0)
torch.manual_seed(0)
return model(**inputs_dict)[0]
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.eval()
model.to(torch_device)
output_without_group_offloading = _run_forward(model, inputs_dict)
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.eval()
num_blocks_per_group = None if offload_type == "leaf_level" else 1
additional_kwargs = {} if offload_type == "leaf_level" else {"num_blocks_per_group": num_blocks_per_group}
with tempfile.TemporaryDirectory() as tmpdir:
model.enable_group_offload(
torch_device,
offload_type=offload_type,
offload_to_disk_path=tmpdir,
use_stream=True,
record_stream=record_stream,
**additional_kwargs,
)
has_safetensors = glob.glob(f"{tmpdir}/*.safetensors")
assert has_safetensors, "No safetensors found in the directory."
# For "leaf-level", there is a prefetching hook which makes this check a bit non-deterministic
# in nature. So, skip it.
if offload_type != "leaf_level":
is_correct, extra_files, missing_files = _check_safetensors_serialization(
module=model,
offload_to_disk_path=tmpdir,
offload_type=offload_type,
num_blocks_per_group=num_blocks_per_group,
)
if not is_correct:
if extra_files:
raise ValueError(f"Found extra files: {', '.join(extra_files)}")
elif missing_files:
raise ValueError(f"Following files are missing: {', '.join(missing_files)}")
output_with_group_offloading = _run_forward(model, inputs_dict)
assert torch.allclose(
output_without_group_offloading, output_with_group_offloading, atol=atol
), "Output should match with disk-based group offloading"
class LayerwiseCastingTesterMixin:
"""
Mixin class for testing layerwise dtype casting for memory optimization.
Expected class attributes to be set by subclasses:
- model_class: The model class to test
Expected methods to be implemented by subclasses:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
"""
@torch.no_grad()
def test_layerwise_casting_memory(self):
MB_TOLERANCE = 0.2
LEAST_COMPUTE_CAPABILITY = 8.0
def reset_memory_stats():
gc.collect()
backend_synchronize(torch_device)
backend_empty_cache(torch_device)
backend_reset_peak_memory_stats(torch_device)
def get_memory_usage(storage_dtype, compute_dtype):
torch.manual_seed(0)
config = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
inputs_dict = cast_maybe_tensor_dtype(inputs_dict, torch.float32, compute_dtype)
model = self.model_class(**config).eval()
model = model.to(torch_device, dtype=compute_dtype)
model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
reset_memory_stats()
model(**inputs_dict)
model_memory_footprint = model.get_memory_footprint()
peak_inference_memory_allocated_mb = backend_max_memory_allocated(torch_device) / 1024**2
return model_memory_footprint, peak_inference_memory_allocated_mb
fp32_memory_footprint, fp32_max_memory = get_memory_usage(torch.float32, torch.float32)
fp8_e4m3_fp32_memory_footprint, fp8_e4m3_fp32_max_memory = get_memory_usage(torch.float8_e4m3fn, torch.float32)
fp8_e4m3_bf16_memory_footprint, fp8_e4m3_bf16_max_memory = get_memory_usage(
torch.float8_e4m3fn, torch.bfloat16
)
compute_capability = get_torch_cuda_device_capability() if torch_device == "cuda" else None
assert (
fp8_e4m3_bf16_memory_footprint < fp8_e4m3_fp32_memory_footprint < fp32_memory_footprint
), "Memory footprint should decrease with lower precision storage"
# NOTE: the following assertion would fail on our CI (running Tesla T4) due to bf16 using more memory than fp32.
# On other devices, such as DGX (Ampere) and Audace (Ada), the test passes. So, we conditionally check it.
if compute_capability and compute_capability >= LEAST_COMPUTE_CAPABILITY:
assert (
fp8_e4m3_bf16_max_memory < fp8_e4m3_fp32_max_memory
), "Peak memory should be lower with bf16 compute on newer GPUs"
# On this dummy test case with a small model, sometimes fp8_e4m3_fp32 max memory usage is higher than fp32 by a few
# bytes. This only happens for some models, so we allow a small tolerance.
# For any real model being tested, the order would be fp8_e4m3_bf16 < fp8_e4m3_fp32 < fp32.
assert (
fp8_e4m3_fp32_max_memory < fp32_max_memory
or abs(fp8_e4m3_fp32_max_memory - fp32_max_memory) < MB_TOLERANCE
), "Peak memory should be lower or within tolerance with fp8 storage"
@is_memory
@require_accelerator
class MemoryTesterMixin(CPUOffloadTesterMixin, GroupOffloadTesterMixin, LayerwiseCastingTesterMixin):
"""
Combined mixin class for all memory optimization tests including CPU/disk offloading,
group offloading, and layerwise dtype casting.
This mixin inherits from:
- CPUOffloadTesterMixin: CPU and disk offloading tests
- GroupOffloadTesterMixin: Group offloading tests (block-level and leaf-level)
- LayerwiseCastingTesterMixin: Layerwise dtype casting tests
Expected class attributes to be set by subclasses:
- model_class: The model class to test
- model_split_percents: List of percentages for splitting model across devices (default: [0.5, 0.7])
Expected methods to be implemented by subclasses:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Pytest mark: memory
Use `pytest -m "not memory"` to skip these tests
"""
pass
+833
View File
@@ -0,0 +1,833 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import tempfile
import pytest
import torch
from diffusers import BitsAndBytesConfig, GGUFQuantizationConfig, NVIDIAModelOptConfig, QuantoConfig, TorchAoConfig
from diffusers.utils.import_utils import (
is_bitsandbytes_available,
is_gguf_available,
is_nvidia_modelopt_available,
is_optimum_quanto_available,
)
from ...testing_utils import (
backend_empty_cache,
is_bitsandbytes,
is_gguf,
is_modelopt,
is_quanto,
is_torchao,
nightly,
require_accelerate,
require_accelerator,
require_bitsandbytes_version_greater,
require_gguf_version_greater_or_equal,
require_quanto,
require_torchao_version_greater_or_equal,
torch_device,
)
if is_nvidia_modelopt_available():
import modelopt.torch.quantization as mtq
if is_bitsandbytes_available():
import bitsandbytes as bnb
if is_optimum_quanto_available():
from optimum.quanto import QLinear
if is_gguf_available():
pass
if is_torchao_available():
if is_torchao_version(">=", "0.9.0"):
pass
@require_accelerator
class QuantizationTesterMixin:
"""
Base mixin class providing common test implementations for quantization testing.
Backend-specific mixins should:
1. Implement _create_quantized_model(config_kwargs)
2. Implement _verify_if_layer_quantized(name, module, config_kwargs)
3. Define their config dict (e.g., BNB_CONFIGS, QUANTO_WEIGHT_TYPES, etc.)
4. Use @pytest.mark.parametrize to create tests that call the common test methods below
Expected class attributes:
- model_class: The model class to test
- pretrained_model_name_or_path: Hub repository ID for the pretrained model
- pretrained_model_kwargs: (Optional) Dict of kwargs to pass to from_pretrained (e.g., {"subfolder": "transformer"})
Expected methods in test classes:
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
"""
def setup_method(self):
gc.collect()
backend_empty_cache(torch_device)
def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)
def _create_quantized_model(self, config_kwargs, **extra_kwargs):
"""
Create a quantized model with the given config kwargs.
Args:
config_kwargs: Quantization config parameters
**extra_kwargs: Additional kwargs to pass to from_pretrained (e.g., device_map, offload_folder)
"""
raise NotImplementedError("Subclass must implement _create_quantized_model")
def _verify_if_layer_quantized(self, name, module, config_kwargs):
raise NotImplementedError("Subclass must implement _verify_if_layer_quantized")
def _is_module_quantized(self, module):
"""
Check if a module is quantized. Returns True if quantized, False otherwise.
Default implementation tries _verify_if_layer_quantized and catches exceptions.
Subclasses can override for more efficient checking.
"""
try:
self._verify_if_layer_quantized("", module, {})
return True
except (AssertionError, AttributeError):
return False
def _load_unquantized_model(self):
kwargs = getattr(self, "pretrained_model_kwargs", {})
return self.model_class.from_pretrained(self.pretrained_model_name_or_path, **kwargs)
def _test_quantization_num_parameters(self, config_kwargs):
model = self._load_unquantized_model()
num_params = model.num_parameters()
model_quantized = self._create_quantized_model(config_kwargs)
num_params_quantized = model_quantized.num_parameters()
assert (
num_params == num_params_quantized
), f"Parameter count mismatch: unquantized={num_params}, quantized={num_params_quantized}"
def _test_quantization_memory_footprint(self, config_kwargs, expected_memory_reduction=1.2):
model = self._load_unquantized_model()
mem = model.get_memory_footprint()
model_quantized = self._create_quantized_model(config_kwargs)
mem_quantized = model_quantized.get_memory_footprint()
ratio = mem / mem_quantized
assert (
ratio >= expected_memory_reduction
), f"Memory ratio {ratio:.2f} is less than expected ({expected_memory_reduction}x). unquantized={mem}, quantized={mem_quantized}"
def _test_quantization_inference(self, config_kwargs):
model_quantized = self._create_quantized_model(config_kwargs)
with torch.no_grad():
inputs = self.get_dummy_inputs()
output = model_quantized(**inputs)
if isinstance(output, tuple):
output = output[0]
assert output is not None, "Model output is None"
assert not torch.isnan(output).any(), "Model output contains NaN"
def _test_quantization_dtype_assignment(self, config_kwargs):
model = self._create_quantized_model(config_kwargs)
with pytest.raises(ValueError):
model.to(torch.float16)
with pytest.raises(ValueError):
device_0 = f"{torch_device}:0"
model.to(device=device_0, dtype=torch.float16)
with pytest.raises(ValueError):
model.float()
with pytest.raises(ValueError):
model.half()
model.to(torch_device)
def _test_quantization_lora_inference(self, config_kwargs):
try:
from peft import LoraConfig
except ImportError:
pytest.skip("peft is not available")
from diffusers.loaders.peft import PeftAdapterMixin
if not issubclass(self.model_class, PeftAdapterMixin):
pytest.skip(f"PEFT is not supported for this model ({self.model_class.__name__})")
model = self._create_quantized_model(config_kwargs)
lora_config = LoraConfig(
r=4,
lora_alpha=4,
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
init_lora_weights=False,
)
model.add_adapter(lora_config)
with torch.no_grad():
inputs = self.get_dummy_inputs()
output = model(**inputs)
if isinstance(output, tuple):
output = output[0]
assert output is not None, "Model output is None with LoRA"
assert not torch.isnan(output).any(), "Model output contains NaN with LoRA"
def _test_quantization_serialization(self, config_kwargs):
model = self._create_quantized_model(config_kwargs)
with tempfile.TemporaryDirectory() as tmpdir:
model.save_pretrained(tmpdir, safe_serialization=True)
model_loaded = self.model_class.from_pretrained(tmpdir)
with torch.no_grad():
inputs = self.get_dummy_inputs()
output = model_loaded(**inputs)
if isinstance(output, tuple):
output = output[0]
assert not torch.isnan(output).any(), "Loaded model output contains NaN"
def _test_quantized_layers(self, config_kwargs):
model_fp = self._load_unquantized_model()
num_linear_layers = sum(1 for module in model_fp.modules() if isinstance(module, torch.nn.Linear))
model_quantized = self._create_quantized_model(config_kwargs)
num_fp32_modules = 0
if hasattr(model_quantized, "_keep_in_fp32_modules") and model_quantized._keep_in_fp32_modules:
for name, module in model_quantized.named_modules():
if isinstance(module, torch.nn.Linear):
if any(fp32_name in name for fp32_name in model_quantized._keep_in_fp32_modules):
num_fp32_modules += 1
expected_quantized_layers = num_linear_layers - num_fp32_modules
num_quantized_layers = 0
for name, module in model_quantized.named_modules():
if isinstance(module, torch.nn.Linear):
if hasattr(model_quantized, "_keep_in_fp32_modules") and model_quantized._keep_in_fp32_modules:
if any(fp32_name in name for fp32_name in model_quantized._keep_in_fp32_modules):
continue
self._verify_if_layer_quantized(name, module, config_kwargs)
num_quantized_layers += 1
assert (
num_quantized_layers > 0
), f"No quantized layers found in model (expected {expected_quantized_layers} linear layers, {num_fp32_modules} kept in FP32)"
assert (
num_quantized_layers == expected_quantized_layers
), f"Quantized layer count mismatch: expected {expected_quantized_layers}, got {num_quantized_layers} (total linear layers: {num_linear_layers}, FP32 modules: {num_fp32_modules})"
def _test_quantization_modules_to_not_convert(self, config_kwargs, modules_to_not_convert):
"""
Test that modules specified in modules_to_not_convert are not quantized.
Args:
config_kwargs: Base quantization config kwargs
modules_to_not_convert: List of module names to exclude from quantization
"""
# Create config with modules_to_not_convert
config_kwargs_with_exclusion = config_kwargs.copy()
config_kwargs_with_exclusion["modules_to_not_convert"] = modules_to_not_convert
model_with_exclusion = self._create_quantized_model(config_kwargs_with_exclusion)
# Find a module that should NOT be quantized
found_excluded = False
for name, module in model_with_exclusion.named_modules():
if isinstance(module, torch.nn.Linear):
# Check if this module is in the exclusion list
if any(excluded in name for excluded in modules_to_not_convert):
found_excluded = True
# This module should NOT be quantized
assert not self._is_module_quantized(
module
), f"Module {name} should not be quantized but was found to be quantized"
assert found_excluded, f"No linear layers found in excluded modules: {modules_to_not_convert}"
# Find a module that SHOULD be quantized (not in exclusion list)
found_quantized = False
for name, module in model_with_exclusion.named_modules():
if isinstance(module, torch.nn.Linear):
# Check if this module is NOT in the exclusion list
if not any(excluded in name for excluded in modules_to_not_convert):
if self._is_module_quantized(module):
found_quantized = True
break
assert found_quantized, "No quantized layers found outside of excluded modules"
# Compare memory footprint with fully quantized model
model_fully_quantized = self._create_quantized_model(config_kwargs)
mem_with_exclusion = model_with_exclusion.get_memory_footprint()
mem_fully_quantized = model_fully_quantized.get_memory_footprint()
assert (
mem_with_exclusion > mem_fully_quantized
), f"Model with exclusions should be larger. With exclusion: {mem_with_exclusion}, fully quantized: {mem_fully_quantized}"
def _test_quantization_device_map(self, config_kwargs):
"""
Test that quantized models work correctly with device_map="auto".
Args:
config_kwargs: Base quantization config kwargs
"""
model = self._create_quantized_model(config_kwargs, device_map="auto")
# Verify device map is set
assert hasattr(model, "hf_device_map"), "Model should have hf_device_map attribute"
assert model.hf_device_map is not None, "hf_device_map should not be None"
# Verify inference works
with torch.no_grad():
inputs = self.get_dummy_inputs()
output = model(**inputs)
if isinstance(output, tuple):
output = output[0]
assert output is not None, "Model output is None"
assert not torch.isnan(output).any(), "Model output contains NaN"
@is_bitsandbytes
@nightly
@require_accelerator
@require_bitsandbytes_version_greater("0.43.2")
@require_accelerate
class BitsAndBytesTesterMixin(QuantizationTesterMixin):
"""
Mixin class for testing BitsAndBytes quantization on models.
Expected class attributes:
- model_class: The model class to test
- pretrained_model_name_or_path: Hub repository ID for the pretrained model
- pretrained_model_kwargs: (Optional) Dict of kwargs to pass to from_pretrained (e.g., {"subfolder": "transformer"})
Expected methods to be implemented by subclasses:
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Optional class attributes:
- BNB_CONFIGS: Dict of config name -> BitsAndBytesConfig kwargs to test
Pytest mark: bitsandbytes
Use `pytest -m "not bitsandbytes"` to skip these tests
"""
# Standard BnB configs tested for all models
# Subclasses can override to add or modify configs
BNB_CONFIGS = {
"4bit_nf4": {
"load_in_4bit": True,
"bnb_4bit_quant_type": "nf4",
"bnb_4bit_compute_dtype": torch.float16,
},
"4bit_fp4": {
"load_in_4bit": True,
"bnb_4bit_quant_type": "fp4",
"bnb_4bit_compute_dtype": torch.float16,
},
"8bit": {
"load_in_8bit": True,
},
}
BNB_EXPECTED_MEMORY_REDUCTIONS = {
"4bit_nf4": 3.0,
"4bit_fp4": 3.0,
"8bit": 1.5,
}
def _create_quantized_model(self, config_kwargs, **extra_kwargs):
config = BitsAndBytesConfig(**config_kwargs)
kwargs = getattr(self, "pretrained_model_kwargs", {}).copy()
kwargs["quantization_config"] = config
kwargs.update(extra_kwargs)
return self.model_class.from_pretrained(self.pretrained_model_name_or_path, **kwargs)
def _verify_if_layer_quantized(self, name, module, config_kwargs):
expected_weight_class = bnb.nn.Params4bit if config_kwargs.get("load_in_4bit") else bnb.nn.Int8Params
assert (
module.weight.__class__ == expected_weight_class
), f"Layer {name} has weight type {module.weight.__class__}, expected {expected_weight_class}"
@pytest.mark.parametrize("config_name", list(BNB_CONFIGS.keys()))
def test_bnb_quantization_num_parameters(self, config_name):
self._test_quantization_num_parameters(self.BNB_CONFIGS[config_name])
@pytest.mark.parametrize("config_name", list(BNB_CONFIGS.keys()))
def test_bnb_quantization_memory_footprint(self, config_name):
expected = self.BNB_EXPECTED_MEMORY_REDUCTIONS.get(config_name, 1.2)
self._test_quantization_memory_footprint(self.BNB_CONFIGS[config_name], expected_memory_reduction=expected)
@pytest.mark.parametrize("config_name", list(BNB_CONFIGS.keys()))
def test_bnb_quantization_inference(self, config_name):
self._test_quantization_inference(self.BNB_CONFIGS[config_name])
@pytest.mark.parametrize("config_name", ["4bit_nf4"])
def test_bnb_quantization_dtype_assignment(self, config_name):
self._test_quantization_dtype_assignment(self.BNB_CONFIGS[config_name])
@pytest.mark.parametrize("config_name", ["4bit_nf4"])
def test_bnb_quantization_lora_inference(self, config_name):
self._test_quantization_lora_inference(self.BNB_CONFIGS[config_name])
@pytest.mark.parametrize("config_name", ["4bit_nf4"])
def test_bnb_quantization_serialization(self, config_name):
self._test_quantization_serialization(self.BNB_CONFIGS[config_name])
@pytest.mark.parametrize("config_name", list(BNB_CONFIGS.keys()))
def test_bnb_quantized_layers(self, config_name):
self._test_quantized_layers(self.BNB_CONFIGS[config_name])
@pytest.mark.parametrize("config_name", list(BNB_CONFIGS.keys()))
def test_bnb_quantization_config_serialization(self, config_name):
model = self._create_quantized_model(self.BNB_CONFIGS[config_name])
assert "quantization_config" in model.config, "Missing quantization_config"
_ = model.config["quantization_config"].to_dict()
_ = model.config["quantization_config"].to_diff_dict()
_ = model.config["quantization_config"].to_json_string()
def test_bnb_original_dtype(self):
config_name = list(self.BNB_CONFIGS.keys())[0]
config_kwargs = self.BNB_CONFIGS[config_name]
model = self._create_quantized_model(config_kwargs)
assert "_pre_quantization_dtype" in model.config, "Missing _pre_quantization_dtype"
assert model.config["_pre_quantization_dtype"] in [
torch.float16,
torch.float32,
torch.bfloat16,
], f"Unexpected dtype: {model.config['_pre_quantization_dtype']}"
def test_bnb_keep_modules_in_fp32(self):
if not hasattr(self.model_class, "_keep_in_fp32_modules"):
pytest.skip(f"{self.model_class.__name__} does not have _keep_in_fp32_modules")
config_kwargs = self.BNB_CONFIGS["4bit_nf4"]
original_fp32_modules = getattr(self.model_class, "_keep_in_fp32_modules", None)
self.model_class._keep_in_fp32_modules = ["proj_out"]
try:
model = self._create_quantized_model(config_kwargs)
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear):
if any(fp32_name in name for fp32_name in model._keep_in_fp32_modules):
assert (
module.weight.dtype == torch.float32
), f"Module {name} should be FP32 but is {module.weight.dtype}"
else:
assert (
module.weight.dtype == torch.uint8
), f"Module {name} should be uint8 but is {module.weight.dtype}"
with torch.no_grad():
inputs = self.get_dummy_inputs()
_ = model(**inputs)
finally:
if original_fp32_modules is not None:
self.model_class._keep_in_fp32_modules = original_fp32_modules
def test_bnb_modules_to_not_convert(self):
"""Test that modules_to_not_convert parameter works correctly."""
modules_to_exclude = getattr(self, "modules_to_not_convert_for_test", None)
if modules_to_exclude is None:
pytest.skip("modules_to_not_convert_for_test not defined for this model")
self._test_quantization_modules_to_not_convert(self.BNB_CONFIGS["4bit_nf4"], modules_to_exclude)
def test_bnb_device_map(self):
"""Test that device_map='auto' works correctly with quantization."""
self._test_quantization_device_map(self.BNB_CONFIGS["4bit_nf4"])
@is_quanto
@nightly
@require_quanto
@require_accelerate
@require_accelerator
class QuantoTesterMixin(QuantizationTesterMixin):
"""
Mixin class for testing Quanto quantization on models.
Expected class attributes:
- model_class: The model class to test
- pretrained_model_name_or_path: Hub repository ID for the pretrained model
- pretrained_model_kwargs: (Optional) Dict of kwargs to pass to from_pretrained (e.g., {"subfolder": "transformer"})
Expected methods to be implemented by subclasses:
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Optional class attributes:
- QUANTO_WEIGHT_TYPES: Dict of weight_type_name -> qtype
Pytest mark: quanto
Use `pytest -m "not quanto"` to skip these tests
"""
QUANTO_WEIGHT_TYPES = {
"float8": {"weights_dtype": "float8"},
"int8": {"weights_dtype": "int8"},
"int4": {"weights_dtype": "int4"},
"int2": {"weights_dtype": "int2"},
}
QUANTO_EXPECTED_MEMORY_REDUCTIONS = {
"float8": 1.5,
"int8": 1.5,
"int4": 3.0,
"int2": 7.0,
}
def _create_quantized_model(self, config_kwargs, **extra_kwargs):
config = QuantoConfig(**config_kwargs)
kwargs = getattr(self, "pretrained_model_kwargs", {}).copy()
kwargs["quantization_config"] = config
kwargs.update(extra_kwargs)
return self.model_class.from_pretrained(self.pretrained_model_name_or_path, **kwargs)
def _verify_if_layer_quantized(self, name, module, config_kwargs):
assert isinstance(module, QLinear), f"Layer {name} is not QLinear, got {type(module)}"
@pytest.mark.parametrize("weight_type_name", list(QUANTO_WEIGHT_TYPES.keys()))
def test_quanto_quantization_num_parameters(self, weight_type_name):
self._test_quantization_num_parameters(self.QUANTO_WEIGHT_TYPES[weight_type_name])
@pytest.mark.parametrize("weight_type_name", list(QUANTO_WEIGHT_TYPES.keys()))
def test_quanto_quantization_memory_footprint(self, weight_type_name):
expected = self.QUANTO_EXPECTED_MEMORY_REDUCTIONS.get(weight_type_name, 1.2)
self._test_quantization_memory_footprint(
self.QUANTO_WEIGHT_TYPES[weight_type_name], expected_memory_reduction=expected
)
@pytest.mark.parametrize("weight_type_name", list(QUANTO_WEIGHT_TYPES.keys()))
def test_quanto_quantization_inference(self, weight_type_name):
self._test_quantization_inference(self.QUANTO_WEIGHT_TYPES[weight_type_name])
@pytest.mark.parametrize("weight_type_name", ["int8"])
def test_quanto_quantized_layers(self, weight_type_name):
self._test_quantized_layers(self.QUANTO_WEIGHT_TYPES[weight_type_name])
@pytest.mark.parametrize("weight_type_name", ["int8"])
def test_quanto_quantization_lora_inference(self, weight_type_name):
self._test_quantization_lora_inference(self.QUANTO_WEIGHT_TYPES[weight_type_name])
@pytest.mark.parametrize("weight_type_name", ["int8"])
def test_quanto_quantization_serialization(self, weight_type_name):
self._test_quantization_serialization(self.QUANTO_WEIGHT_TYPES[weight_type_name])
def test_quanto_modules_to_not_convert(self):
"""Test that modules_to_not_convert parameter works correctly."""
modules_to_exclude = getattr(self, "modules_to_not_convert_for_test", None)
if modules_to_exclude is None:
pytest.skip("modules_to_not_convert_for_test not defined for this model")
self._test_quantization_modules_to_not_convert(self.QUANTO_WEIGHT_TYPES["int8"], modules_to_exclude)
def test_quanto_device_map(self):
"""Test that device_map='auto' works correctly with quantization."""
self._test_quantization_device_map(self.QUANTO_WEIGHT_TYPES["int8"])
@is_torchao
@require_accelerator
@require_torchao_version_greater_or_equal("0.7.0")
class TorchAoTesterMixin(QuantizationTesterMixin):
"""
Mixin class for testing TorchAO quantization on models.
Expected class attributes:
- model_class: The model class to test
- pretrained_model_name_or_path: Hub repository ID for the pretrained model
- pretrained_model_kwargs: (Optional) Dict of kwargs to pass to from_pretrained (e.g., {"subfolder": "transformer"})
Expected methods to be implemented by subclasses:
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Optional class attributes:
- TORCHAO_QUANT_TYPES: Dict of quantization type strings to test
Pytest mark: torchao
Use `pytest -m "not torchao"` to skip these tests
"""
TORCHAO_QUANT_TYPES = {
"int4wo": {"quant_type": "int4_weight_only"},
"int8wo": {"quant_type": "int8_weight_only"},
"int8dq": {"quant_type": "int8_dynamic_activation_int8_weight"},
}
TORCHAO_EXPECTED_MEMORY_REDUCTIONS = {
"int4wo": 3.0,
"int8wo": 1.5,
"int8dq": 1.5,
}
def _create_quantized_model(self, config_kwargs, **extra_kwargs):
config = TorchAoConfig(**config_kwargs)
kwargs = getattr(self, "pretrained_model_kwargs", {}).copy()
kwargs["quantization_config"] = config
kwargs.update(extra_kwargs)
return self.model_class.from_pretrained(self.pretrained_model_name_or_path, **kwargs)
def _verify_if_layer_quantized(self, name, module, config_kwargs):
assert isinstance(module, torch.nn.Linear), f"Layer {name} is not Linear, got {type(module)}"
@pytest.mark.parametrize("quant_type", list(TORCHAO_QUANT_TYPES.keys()))
def test_torchao_quantization_num_parameters(self, quant_type):
self._test_quantization_num_parameters(self.TORCHAO_QUANT_TYPES[quant_type])
@pytest.mark.parametrize("quant_type", list(TORCHAO_QUANT_TYPES.keys()))
def test_torchao_quantization_memory_footprint(self, quant_type):
expected = self.TORCHAO_EXPECTED_MEMORY_REDUCTIONS.get(quant_type, 1.2)
self._test_quantization_memory_footprint(
self.TORCHAO_QUANT_TYPES[quant_type], expected_memory_reduction=expected
)
@pytest.mark.parametrize("quant_type", list(TORCHAO_QUANT_TYPES.keys()))
def test_torchao_quantization_inference(self, quant_type):
self._test_quantization_inference(self.TORCHAO_QUANT_TYPES[quant_type])
@pytest.mark.parametrize("quant_type", ["int8wo"])
def test_torchao_quantized_layers(self, quant_type):
self._test_quantized_layers(self.TORCHAO_QUANT_TYPES[quant_type])
@pytest.mark.parametrize("quant_type", ["int8wo"])
def test_torchao_quantization_lora_inference(self, quant_type):
self._test_quantization_lora_inference(self.TORCHAO_QUANT_TYPES[quant_type])
@pytest.mark.parametrize("quant_type", ["int8wo"])
def test_torchao_quantization_serialization(self, quant_type):
self._test_quantization_serialization(self.TORCHAO_QUANT_TYPES[quant_type])
def test_torchao_modules_to_not_convert(self):
"""Test that modules_to_not_convert parameter works correctly."""
# Get a module name that exists in the model - this needs to be set by test classes
# For now, use a generic pattern that should work with transformer models
modules_to_exclude = getattr(self, "modules_to_not_convert_for_test", None)
if modules_to_exclude is None:
pytest.skip("modules_to_not_convert_for_test not defined for this model")
self._test_quantization_modules_to_not_convert(
self.TORCHAO_QUANT_TYPES["int8wo"], modules_to_exclude
)
def test_torchao_device_map(self):
"""Test that device_map='auto' works correctly with quantization."""
self._test_quantization_device_map(self.TORCHAO_QUANT_TYPES["int8wo"])
@is_gguf
@nightly
@require_accelerate
@require_accelerator
@require_gguf_version_greater_or_equal("0.10.0")
class GGUFTesterMixin(QuantizationTesterMixin):
"""
Mixin class for testing GGUF quantization on models.
Expected class attributes:
- model_class: The model class to test
- gguf_filename: URL or path to the GGUF file
Expected methods to be implemented by subclasses:
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Pytest mark: gguf
Use `pytest -m "not gguf"` to skip these tests
"""
gguf_filename = None
def _create_quantized_model(self, config_kwargs=None, **extra_kwargs):
if config_kwargs is None:
config_kwargs = {"compute_dtype": torch.bfloat16}
config = GGUFQuantizationConfig(**config_kwargs)
kwargs = {
"quantization_config": config,
"torch_dtype": config_kwargs.get("compute_dtype", torch.bfloat16),
}
kwargs.update(extra_kwargs)
return self.model_class.from_single_file(self.gguf_filename, **kwargs)
def _verify_if_layer_quantized(self, name, module, config_kwargs=None):
from diffusers.quantizers.gguf.utils import GGUFParameter
assert isinstance(module.weight, GGUFParameter), f"{name} weight is not GGUFParameter"
assert hasattr(module.weight, "quant_type"), f"{name} weight missing quant_type"
assert module.weight.dtype == torch.uint8, f"{name} weight dtype should be uint8"
def test_gguf_quantization_inference(self):
self._test_quantization_inference({"compute_dtype": torch.bfloat16})
def test_gguf_keep_modules_in_fp32(self):
if not hasattr(self.model_class, "_keep_in_fp32_modules"):
pytest.skip(f"{self.model_class.__name__} does not have _keep_in_fp32_modules")
_keep_in_fp32_modules = self.model_class._keep_in_fp32_modules
self.model_class._keep_in_fp32_modules = ["proj_out"]
try:
model = self._create_quantized_model()
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear):
if any(fp32_name in name for fp32_name in model._keep_in_fp32_modules):
assert module.weight.dtype == torch.float32, f"Module {name} should be FP32"
finally:
self.model_class._keep_in_fp32_modules = _keep_in_fp32_modules
def test_gguf_quantization_dtype_assignment(self):
self._test_quantization_dtype_assignment({"compute_dtype": torch.bfloat16})
def test_gguf_quantization_lora_inference(self):
self._test_quantization_lora_inference({"compute_dtype": torch.bfloat16})
def test_gguf_dequantize_model(self):
from diffusers.quantizers.gguf.utils import GGUFLinear, GGUFParameter
model = self._create_quantized_model()
model.dequantize()
def _check_for_gguf_linear(model):
has_children = list(model.children())
if not has_children:
return
for name, module in model.named_children():
if isinstance(module, torch.nn.Linear):
assert not isinstance(module, GGUFLinear), f"{name} is still GGUFLinear"
assert not isinstance(module.weight, GGUFParameter), f"{name} weight is still GGUFParameter"
for name, module in model.named_children():
_check_for_gguf_linear(module)
def test_gguf_quantized_layers(self):
self._test_quantized_layers({"compute_dtype": torch.bfloat16})
@is_modelopt
@nightly
@require_accelerator
@require_accelerate
@require_modelopt_version_greater_or_equal("0.33.1")
class ModelOptTesterMixin(QuantizationTesterMixin):
"""
Mixin class for testing NVIDIA ModelOpt quantization on models.
Expected class attributes:
- model_class: The model class to test
- pretrained_model_name_or_path: Hub repository ID for the pretrained model
- pretrained_model_kwargs: (Optional) Dict of kwargs to pass to from_pretrained (e.g., {"subfolder": "transformer"})
Expected methods to be implemented by subclasses:
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Optional class attributes:
- MODELOPT_CONFIGS: Dict of config name -> NVIDIAModelOptConfig kwargs to test
Pytest mark: modelopt
Use `pytest -m "not modelopt"` to skip these tests
"""
MODELOPT_CONFIGS = {
"fp8": {"quant_type": "FP8"},
"int8": {"quant_type": "INT8"},
"int4": {"quant_type": "INT4"},
}
MODELOPT_EXPECTED_MEMORY_REDUCTIONS = {
"fp8": 1.5,
"int8": 1.5,
"int4": 3.0,
}
def _create_quantized_model(self, config_kwargs, **extra_kwargs):
config = NVIDIAModelOptConfig(**config_kwargs)
kwargs = getattr(self, "pretrained_model_kwargs", {}).copy()
kwargs["quantization_config"] = config
kwargs.update(extra_kwargs)
return self.model_class.from_pretrained(self.pretrained_model_name_or_path, **kwargs)
def _verify_if_layer_quantized(self, name, module, config_kwargs):
assert mtq.utils.is_quantized(module), f"Layer {name} does not have weight_quantizer attribute (not quantized)"
@pytest.mark.parametrize("config_name", ["fp8"])
def test_modelopt_quantization_num_parameters(self, config_name):
self._test_quantization_num_parameters(self.MODELOPT_CONFIGS[config_name])
@pytest.mark.parametrize("config_name", list(MODELOPT_CONFIGS.keys()))
def test_modelopt_quantization_memory_footprint(self, config_name):
expected = self.MODELOPT_EXPECTED_MEMORY_REDUCTIONS.get(config_name, 1.2)
self._test_quantization_memory_footprint(
self.MODELOPT_CONFIGS[config_name], expected_memory_reduction=expected
)
@pytest.mark.parametrize("config_name", list(MODELOPT_CONFIGS.keys()))
def test_modelopt_quantization_inference(self, config_name):
self._test_quantization_inference(self.MODELOPT_CONFIGS[config_name])
@pytest.mark.parametrize("config_name", ["fp8"])
def test_modelopt_quantization_dtype_assignment(self, config_name):
self._test_quantization_dtype_assignment(self.MODELOPT_CONFIGS[config_name])
@pytest.mark.parametrize("config_name", ["fp8"])
def test_modelopt_quantization_lora_inference(self, config_name):
self._test_quantization_lora_inference(self.MODELOPT_CONFIGS[config_name])
@pytest.mark.parametrize("config_name", ["fp8"])
def test_modelopt_quantization_serialization(self, config_name):
self._test_quantization_serialization(self.MODELOPT_CONFIGS[config_name])
@pytest.mark.parametrize("config_name", ["fp8"])
def test_modelopt_quantized_layers(self, config_name):
self._test_quantized_layers(self.MODELOPT_CONFIGS[config_name])
def test_modelopt_modules_to_not_convert(self):
"""Test that modules_to_not_convert parameter works correctly."""
modules_to_exclude = getattr(self, "modules_to_not_convert_for_test", None)
if modules_to_exclude is None:
pytest.skip("modules_to_not_convert_for_test not defined for this model")
self._test_quantization_modules_to_not_convert(self.MODELOPT_CONFIGS["fp8"], modules_to_exclude)
def test_modelopt_device_map(self):
"""Test that device_map='auto' works correctly with quantization."""
self._test_quantization_device_map(self.MODELOPT_CONFIGS["fp8"])
+247
View File
@@ -0,0 +1,247 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import tempfile
import torch
from huggingface_hub import hf_hub_download, snapshot_download
from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name
from ...testing_utils import (
backend_empty_cache,
is_single_file,
nightly,
require_torch_accelerator,
torch_device,
)
def download_single_file_checkpoint(pretrained_model_name_or_path, filename, tmpdir):
"""Download a single file checkpoint from the Hub to a temporary directory."""
path = hf_hub_download(pretrained_model_name_or_path, filename=filename, local_dir=tmpdir)
return path
def download_diffusers_config(pretrained_model_name_or_path, tmpdir):
"""Download diffusers config files (excluding weights) from a repository."""
path = snapshot_download(
pretrained_model_name_or_path,
ignore_patterns=[
"**/*.ckpt",
"*.ckpt",
"**/*.bin",
"*.bin",
"**/*.pt",
"*.pt",
"**/*.safetensors",
"*.safetensors",
],
allow_patterns=["**/*.json", "*.json", "*.txt", "**/*.txt"],
local_dir=tmpdir,
)
return path
@nightly
@require_torch_accelerator
@is_single_file
class SingleFileTesterMixin:
"""
Mixin class for testing single file loading for models.
Expected class attributes:
- model_class: The model class to test
- pretrained_model_name_or_path: Hub repository ID for the pretrained model
- ckpt_path: Path or Hub path to the single file checkpoint
- subfolder: (Optional) Subfolder within the repo
- torch_dtype: (Optional) torch dtype to use for testing
Pytest mark: single_file
Use `pytest -m "not single_file"` to skip these tests
"""
pretrained_model_name_or_path = None
ckpt_path = None
def setup_method(self):
gc.collect()
backend_empty_cache(torch_device)
def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)
def test_single_file_model_config(self):
pretrained_kwargs = {}
single_file_kwargs = {}
pretrained_kwargs["device"] = torch_device
single_file_kwargs["device"] = torch_device
if hasattr(self, "subfolder") and self.subfolder:
pretrained_kwargs["subfolder"] = self.subfolder
if hasattr(self, "torch_dtype") and self.torch_dtype:
pretrained_kwargs["torch_dtype"] = self.torch_dtype
single_file_kwargs["torch_dtype"] = self.torch_dtype
model = self.model_class.from_pretrained(self.pretrained_model_name_or_path, **pretrained_kwargs)
model_single_file = self.model_class.from_single_file(self.ckpt_path, **single_file_kwargs)
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
assert model.config[param_name] == param_value, (
f"{param_name} differs between pretrained loading and single file loading: "
f"pretrained={model.config[param_name]}, single_file={param_value}"
)
def test_single_file_model_parameters(self):
pretrained_kwargs = {}
single_file_kwargs = {}
pretrained_kwargs["device"] = torch_device
single_file_kwargs["device"] = torch_device
if hasattr(self, "subfolder") and self.subfolder:
pretrained_kwargs["subfolder"] = self.subfolder
if hasattr(self, "torch_dtype") and self.torch_dtype:
pretrained_kwargs["torch_dtype"] = self.torch_dtype
single_file_kwargs["torch_dtype"] = self.torch_dtype
model = self.model_class.from_pretrained(self.pretrained_model_name_or_path, **pretrained_kwargs)
model_single_file = self.model_class.from_single_file(self.ckpt_path, **single_file_kwargs)
state_dict = model.state_dict()
state_dict_single_file = model_single_file.state_dict()
assert set(state_dict.keys()) == set(state_dict_single_file.keys()), (
"Model parameters keys differ between pretrained and single file loading. "
f"Missing in single file: {set(state_dict.keys()) - set(state_dict_single_file.keys())}. "
f"Extra in single file: {set(state_dict_single_file.keys()) - set(state_dict.keys())}"
)
for key in state_dict.keys():
param = state_dict[key]
param_single_file = state_dict_single_file[key]
assert param.shape == param_single_file.shape, (
f"Parameter shape mismatch for {key}: "
f"pretrained {param.shape} vs single file {param_single_file.shape}"
)
assert torch.allclose(param, param_single_file, rtol=1e-5, atol=1e-5), (
f"Parameter values differ for {key}: "
f"max difference {torch.max(torch.abs(param - param_single_file)).item()}"
)
def test_single_file_loading_local_files_only(self):
single_file_kwargs = {}
if hasattr(self, "torch_dtype") and self.torch_dtype:
single_file_kwargs["torch_dtype"] = self.torch_dtype
with tempfile.TemporaryDirectory() as tmpdir:
pretrained_model_name_or_path, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path)
local_ckpt_path = download_single_file_checkpoint(pretrained_model_name_or_path, weight_name, tmpdir)
model_single_file = self.model_class.from_single_file(
local_ckpt_path, local_files_only=True, **single_file_kwargs
)
assert model_single_file is not None, "Failed to load model with local_files_only=True"
def test_single_file_loading_with_diffusers_config(self):
single_file_kwargs = {}
if hasattr(self, "torch_dtype") and self.torch_dtype:
single_file_kwargs["torch_dtype"] = self.torch_dtype
# Load with config parameter
model_single_file = self.model_class.from_single_file(
self.ckpt_path, config=self.pretrained_model_name_or_path, **single_file_kwargs
)
# Load pretrained for comparison
pretrained_kwargs = {}
if hasattr(self, "subfolder") and self.subfolder:
pretrained_kwargs["subfolder"] = self.subfolder
if hasattr(self, "torch_dtype") and self.torch_dtype:
pretrained_kwargs["torch_dtype"] = self.torch_dtype
model = self.model_class.from_pretrained(self.pretrained_model_name_or_path, **pretrained_kwargs)
# Compare configs
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
assert (
model.config[param_name] == param_value
), f"{param_name} differs: pretrained={model.config[param_name]}, single_file={param_value}"
def test_single_file_loading_with_diffusers_config_local_files_only(self):
single_file_kwargs = {}
if hasattr(self, "torch_dtype") and self.torch_dtype:
single_file_kwargs["torch_dtype"] = self.torch_dtype
with tempfile.TemporaryDirectory() as tmpdir:
pretrained_model_name_or_path, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path)
local_ckpt_path = download_single_file_checkpoint(pretrained_model_name_or_path, weight_name, tmpdir)
local_diffusers_config = download_diffusers_config(self.pretrained_model_name_or_path, tmpdir)
model_single_file = self.model_class.from_single_file(
local_ckpt_path, config=local_diffusers_config, local_files_only=True, **single_file_kwargs
)
assert model_single_file is not None, "Failed to load model with config and local_files_only=True"
def test_single_file_loading_dtype(self):
for dtype in [torch.float32, torch.float16]:
if torch_device == "mps" and dtype == torch.bfloat16:
continue
model_single_file = self.model_class.from_single_file(self.ckpt_path, torch_dtype=dtype)
assert model_single_file.dtype == dtype, f"Expected dtype {dtype}, got {model_single_file.dtype}"
# Cleanup
del model_single_file
gc.collect()
backend_empty_cache(torch_device)
def test_checkpoint_variant_loading(self):
if not hasattr(self, "alternate_ckpt_paths") or not self.alternate_ckpt_paths:
return
for ckpt_path in self.alternate_ckpt_paths:
backend_empty_cache(torch_device)
single_file_kwargs = {}
if hasattr(self, "torch_dtype") and self.torch_dtype:
single_file_kwargs["torch_dtype"] = self.torch_dtype
model = self.model_class.from_single_file(ckpt_path, **single_file_kwargs)
assert model is not None, f"Failed to load checkpoint from {ckpt_path}"
del model
gc.collect()
backend_empty_cache(torch_device)
+224
View File
@@ -0,0 +1,224 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import pytest
import torch
from diffusers.training_utils import EMAModel
from ...testing_utils import is_training, require_torch_accelerator_with_training, torch_all_close, torch_device
@is_training
@require_torch_accelerator_with_training
class TrainingTesterMixin:
"""
Mixin class for testing training functionality on models.
Expected class attributes to be set by subclasses:
- model_class: The model class to test
Expected methods to be implemented by subclasses:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Expected properties to be implemented by subclasses:
- output_shape: Tuple defining the expected output shape
Pytest mark: training
Use `pytest -m "not training"` to skip these tests
"""
def test_training(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict)
model.to(torch_device)
model.train()
output = model(**inputs_dict)
if isinstance(output, dict):
output = output.to_tuple()[0]
noise = torch.randn((output.shape[0],) + self.output_shape).to(torch_device)
loss = torch.nn.functional.mse_loss(output, noise)
loss.backward()
def test_training_with_ema(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict)
model.to(torch_device)
model.train()
ema_model = EMAModel(model.parameters())
output = model(**inputs_dict)
if isinstance(output, dict):
output = output.to_tuple()[0]
noise = torch.randn((output.shape[0],) + self.output_shape).to(torch_device)
loss = torch.nn.functional.mse_loss(output, noise)
loss.backward()
ema_model.step(model.parameters())
def test_gradient_checkpointing(self):
if not self.model_class._supports_gradient_checkpointing:
pytest.skip("Gradient checkpointing is not supported.")
init_dict = self.get_init_dict()
# at init model should have gradient checkpointing disabled
model = self.model_class(**init_dict)
assert not model.is_gradient_checkpointing, "Gradient checkpointing should be disabled at init"
# check enable works
model.enable_gradient_checkpointing()
assert model.is_gradient_checkpointing, "Gradient checkpointing should be enabled"
# check disable works
model.disable_gradient_checkpointing()
assert not model.is_gradient_checkpointing, "Gradient checkpointing should be disabled"
def test_gradient_checkpointing_is_applied(self, expected_set=None):
if not self.model_class._supports_gradient_checkpointing:
pytest.skip("Gradient checkpointing is not supported.")
if expected_set is None:
pytest.skip("expected_set must be provided to verify gradient checkpointing is applied.")
init_dict = self.get_init_dict()
model_class_copy = copy.copy(self.model_class)
model = model_class_copy(**init_dict)
model.enable_gradient_checkpointing()
modules_with_gc_enabled = {}
for submodule in model.modules():
if hasattr(submodule, "gradient_checkpointing"):
assert submodule.gradient_checkpointing, f"{submodule.__class__.__name__} should have GC enabled"
modules_with_gc_enabled[submodule.__class__.__name__] = True
assert set(modules_with_gc_enabled.keys()) == expected_set, (
f"Modules with GC enabled {set(modules_with_gc_enabled.keys())} "
f"do not match expected set {expected_set}"
)
assert all(modules_with_gc_enabled.values()), "All modules should have GC enabled"
def test_gradient_checkpointing_equivalence(self, loss_tolerance=1e-5, param_grad_tol=5e-5, skip=None):
if not self.model_class._supports_gradient_checkpointing:
pytest.skip("Gradient checkpointing is not supported.")
if skip is None:
skip = set()
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
inputs_dict_copy = copy.deepcopy(inputs_dict)
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.to(torch_device)
assert not model.is_gradient_checkpointing and model.training
out = model(**inputs_dict)
if isinstance(out, dict):
out = out.sample if hasattr(out, "sample") else out.to_tuple()[0]
# run the backwards pass on the model
model.zero_grad()
labels = torch.randn_like(out)
loss = (out - labels).mean()
loss.backward()
# re-instantiate the model now enabling gradient checkpointing
torch.manual_seed(0)
model_2 = self.model_class(**init_dict)
# clone model
model_2.load_state_dict(model.state_dict())
model_2.to(torch_device)
model_2.enable_gradient_checkpointing()
assert model_2.is_gradient_checkpointing and model_2.training
out_2 = model_2(**inputs_dict_copy)
if isinstance(out_2, dict):
out_2 = out_2.sample if hasattr(out_2, "sample") else out_2.to_tuple()[0]
# run the backwards pass on the model
model_2.zero_grad()
loss_2 = (out_2 - labels).mean()
loss_2.backward()
# compare the output and parameters gradients
assert (
loss - loss_2
).abs() < loss_tolerance, f"Loss difference {(loss - loss_2).abs()} exceeds tolerance {loss_tolerance}"
named_params = dict(model.named_parameters())
named_params_2 = dict(model_2.named_parameters())
for name, param in named_params.items():
if "post_quant_conv" in name:
continue
if name in skip:
continue
if param.grad is None:
continue
assert torch_all_close(
param.grad.data, named_params_2[name].grad.data, atol=param_grad_tol
), f"Gradient mismatch for {name}"
def test_mixed_precision_training(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict)
model.to(torch_device)
model.train()
# Test with float16
if torch.device(torch_device).type != "cpu":
with torch.amp.autocast(device_type=torch.device(torch_device).type, dtype=torch.float16):
output = model(**inputs_dict)
if isinstance(output, dict):
output = output.to_tuple()[0]
noise = torch.randn((output.shape[0],) + self.output_shape).to(torch_device)
loss = torch.nn.functional.mse_loss(output, noise)
loss.backward()
# Test with bfloat16
if torch.device(torch_device).type != "cpu":
model.zero_grad()
with torch.amp.autocast(device_type=torch.device(torch_device).type, dtype=torch.bfloat16):
output = model(**inputs_dict)
if isinstance(output, dict):
output = output.to_tuple()[0]
noise = torch.randn((output.shape[0],) + self.output_shape).to(torch_device)
loss = torch.nn.functional.mse_loss(output, noise)
loss.backward()
@@ -0,0 +1,316 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from diffusers import FluxTransformer2DModel
from diffusers.models.embeddings import ImageProjection
from diffusers.utils.torch_utils import randn_tensor
from ...testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import LoraHotSwappingForModelTesterMixin
from ..testing_utils import (
AttentionTesterMixin,
BitsAndBytesTesterMixin,
GGUFTesterMixin,
IPAdapterTesterMixin,
LoraTesterMixin,
MemoryTesterMixin,
ModelOptTesterMixin,
ModelTesterMixin,
QuantoTesterMixin,
SingleFileTesterMixin,
TorchAoTesterMixin,
TorchCompileTesterMixin,
TrainingTesterMixin,
)
enable_full_determinism()
class FluxTransformerTesterConfig:
model_class = FluxTransformer2DModel
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-pipe"
pretrained_model_kwargs = {"subfolder": "transformer"}
def get_init_dict(self):
"""Return Flux model initialization arguments."""
return {
"patch_size": 1,
"in_channels": 4,
"num_layers": 1,
"num_single_layers": 1,
"attention_head_dim": 16,
"num_attention_heads": 2,
"joint_attention_dim": 32,
"pooled_projection_dim": 32,
"axes_dims_rope": [4, 4, 8],
}
def get_dummy_inputs(self):
batch_size = 1
height = width = 4
num_latent_channels = 4
num_image_channels = 3
sequence_length = 24
embedding_dim = 8
return {
"hidden_states": randn_tensor((batch_size, height * width, num_latent_channels)),
"encoder_hidden_states": randn_tensor((batch_size, sequence_length, embedding_dim)),
"pooled_projections": randn_tensor((batch_size, embedding_dim)),
"img_ids": randn_tensor((height * width, num_image_channels)),
"txt_ids": randn_tensor((sequence_length, num_image_channels)),
"timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size),
}
@property
def input_shape(self):
return (16, 4)
@property
def output_shape(self):
return (16, 4)
class TestFluxTransformer(FluxTransformerTesterConfig, ModelTesterMixin):
def test_deprecated_inputs_img_txt_ids_3d(self):
"""Test that deprecated 3D img_ids and txt_ids still work."""
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
with torch.no_grad():
output_1 = model(**inputs_dict).to_tuple()[0]
# update inputs_dict with txt_ids and img_ids as 3d tensors (deprecated)
text_ids_3d = inputs_dict["txt_ids"].unsqueeze(0)
image_ids_3d = inputs_dict["img_ids"].unsqueeze(0)
assert text_ids_3d.ndim == 3, "text_ids_3d should be a 3d tensor"
assert image_ids_3d.ndim == 3, "img_ids_3d should be a 3d tensor"
inputs_dict["txt_ids"] = text_ids_3d
inputs_dict["img_ids"] = image_ids_3d
with torch.no_grad():
output_2 = model(**inputs_dict).to_tuple()[0]
assert output_1.shape == output_2.shape
assert torch.allclose(output_1, output_2, atol=1e-5), (
"output with deprecated inputs (img_ids and txt_ids as 3d torch tensors) "
"are not equal as them as 2d inputs"
)
class TestFluxTransformerMemory(FluxTransformerTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for Flux Transformer."""
pass
class TestFluxTransformerTraining(FluxTransformerTesterConfig, TrainingTesterMixin):
"""Training tests for Flux Transformer."""
pass
class TestFluxTransformerAttention(FluxTransformerTesterConfig, AttentionTesterMixin):
"""Attention processor tests for Flux Transformer."""
pass
class TestFluxTransformerIPAdapter(FluxTransformerTesterConfig, IPAdapterTesterMixin):
"""IP Adapter tests for Flux Transformer."""
def create_ip_adapter_state_dict(self, model):
from diffusers.models.transformers.transformer_flux import FluxIPAdapterAttnProcessor
ip_cross_attn_state_dict = {}
key_id = 0
for name in model.attn_processors.keys():
if name.startswith("single_transformer_blocks"):
continue
joint_attention_dim = model.config["joint_attention_dim"]
hidden_size = model.config["num_attention_heads"] * model.config["attention_head_dim"]
sd = FluxIPAdapterAttnProcessor(
hidden_size=hidden_size, cross_attention_dim=joint_attention_dim, scale=1.0
).state_dict()
ip_cross_attn_state_dict.update(
{
f"{key_id}.to_k_ip.weight": sd["to_k_ip.0.weight"],
f"{key_id}.to_v_ip.weight": sd["to_v_ip.0.weight"],
f"{key_id}.to_k_ip.bias": sd["to_k_ip.0.bias"],
f"{key_id}.to_v_ip.bias": sd["to_v_ip.0.bias"],
}
)
key_id += 1
image_projection = ImageProjection(
cross_attention_dim=model.config["joint_attention_dim"],
image_embed_dim=(
model.config["pooled_projection_dim"] if "pooled_projection_dim" in model.config.keys() else 768
),
num_image_text_embeds=4,
)
ip_image_projection_state_dict = {}
sd = image_projection.state_dict()
ip_image_projection_state_dict.update(
{
"proj.weight": sd["image_embeds.weight"],
"proj.bias": sd["image_embeds.bias"],
"norm.weight": sd["norm.weight"],
"norm.bias": sd["norm.bias"],
}
)
del sd
ip_state_dict = {}
ip_state_dict.update({"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict})
return ip_state_dict
class TestFluxTransformerLoRA(FluxTransformerTesterConfig, LoraTesterMixin):
"""LoRA adapter tests for Flux Transformer."""
pass
class TestFluxTransformerLoRAHotSwap(FluxTransformerTesterConfig, LoraHotSwappingForModelTesterMixin):
"""LoRA hot-swapping tests for Flux Transformer."""
different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
def get_dummy_inputs(self, height=4, width=4):
"""Override to support dynamic height/width for LoRA hotswap tests."""
batch_size = 1
num_latent_channels = 4
num_image_channels = 3
sequence_length = 24
embedding_dim = 8
return {
"hidden_states": randn_tensor((batch_size, height * width, num_latent_channels)),
"encoder_hidden_states": randn_tensor((batch_size, sequence_length, embedding_dim)),
"pooled_projections": randn_tensor((batch_size, embedding_dim)),
"img_ids": randn_tensor((height * width, num_image_channels)),
"txt_ids": randn_tensor((sequence_length, num_image_channels)),
"timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size),
}
class TestFluxTransformerCompile(FluxTransformerTesterConfig, TorchCompileTesterMixin):
different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
def get_dummy_inputs(self, height=4, width=4):
"""Override to support dynamic height/width for compilation tests."""
batch_size = 1
num_latent_channels = 4
num_image_channels = 3
sequence_length = 24
embedding_dim = 8
return {
"hidden_states": randn_tensor((batch_size, height * width, num_latent_channels)),
"encoder_hidden_states": randn_tensor((batch_size, sequence_length, embedding_dim)),
"pooled_projections": randn_tensor((batch_size, embedding_dim)),
"img_ids": randn_tensor((height * width, num_image_channels)),
"txt_ids": randn_tensor((sequence_length, num_image_channels)),
"timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size),
}
class TestFluxSingleFile(FluxTransformerTesterConfig, SingleFileTesterMixin):
ckpt_path = "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/flux1-dev.safetensors"
alternate_keys_ckpt_paths = ["https://huggingface.co/Comfy-Org/flux1-dev/blob/main/flux1-dev-fp8.safetensors"]
pretrained_model_name_or_path = "black-forest-labs/FLUX.1-dev"
subfolder = "transformer"
pass
class TestFluxTransformerBitsAndBytes(FluxTransformerTesterConfig, BitsAndBytesTesterMixin):
def get_dummy_inputs(self):
return {
"hidden_states": randn_tensor((1, 4096, 64)),
"encoder_hidden_states": randn_tensor((1, 512, 4096)),
"pooled_projections": randn_tensor((1, 768)),
"timestep": torch.tensor([1.0]).to(torch_device),
"img_ids": randn_tensor((4096, 3)),
"txt_ids": randn_tensor((512, 3)),
"guidance": torch.tensor([3.5]).to(torch_device),
}
class TestFluxTransformerQuanto(FluxTransformerTesterConfig, QuantoTesterMixin):
def get_dummy_inputs(self):
return {
"hidden_states": randn_tensor((1, 4096, 64)),
"encoder_hidden_states": randn_tensor((1, 512, 4096)),
"pooled_projections": randn_tensor((1, 768)),
"timestep": torch.tensor([1.0]).to(torch_device),
"img_ids": randn_tensor((4096, 3)),
"txt_ids": randn_tensor((512, 3)),
"guidance": torch.tensor([3.5]).to(torch_device),
}
class TestFluxTransformerTorchAo(FluxTransformerTesterConfig, TorchAoTesterMixin):
def get_dummy_inputs(self):
return {
"hidden_states": randn_tensor((1, 4096, 64)),
"encoder_hidden_states": randn_tensor((1, 512, 4096)),
"pooled_projections": randn_tensor((1, 768)),
"timestep": torch.tensor([1.0]).to(torch_device),
"img_ids": randn_tensor((4096, 3)),
"txt_ids": randn_tensor((512, 3)),
"guidance": torch.tensor([3.5]).to(torch_device),
}
class TestFluxTransformerGGUF(FluxTransformerTesterConfig, GGUFTesterMixin):
gguf_filename = "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q8_0.gguf"
def get_dummy_inputs(self):
return {
"hidden_states": randn_tensor((1, 4096, 64)),
"encoder_hidden_states": randn_tensor((1, 512, 4096)),
"pooled_projections": randn_tensor((1, 768)),
"timestep": torch.tensor([1.0]).to(torch_device),
"img_ids": randn_tensor((4096, 3)),
"txt_ids": randn_tensor((512, 3)),
"guidance": torch.tensor([3.5]).to(torch_device),
}
class TestFluxTransformerModelOpt(FluxTransformerTesterConfig, ModelOptTesterMixin):
def get_dummy_inputs(self):
return {
"hidden_states": randn_tensor((1, 4096, 64)),
"encoder_hidden_states": randn_tensor((1, 512, 4096)),
"pooled_projections": randn_tensor((1, 768)),
"timestep": torch.tensor([1.0]).to(torch_device),
"img_ids": randn_tensor((4096, 3)),
"txt_ids": randn_tensor((512, 3)),
"guidance": torch.tensor([3.5]).to(torch_device),
}
+3 -3
View File
@@ -98,9 +98,9 @@ class GGUFCudaKernelsTests(unittest.TestCase):
output_native = linear.forward_native(x)
output_cuda = linear.forward_cuda(x)
assert torch.allclose(output_native, output_cuda, 1e-2), (
f"GGUF CUDA Kernel Output is different from Native Output for {quant_type}"
)
assert torch.allclose(
output_native, output_cuda, 1e-2
), f"GGUF CUDA Kernel Output is different from Native Output for {quant_type}"
@nightly
+199 -83
View File
@@ -13,7 +13,6 @@ import struct
import sys
import tempfile
import time
import unittest
import urllib.parse
from collections import UserDict
from contextlib import contextmanager
@@ -24,6 +23,7 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Tupl
import numpy as np
import PIL.Image
import PIL.ImageOps
import pytest
import requests
from numpy.linalg import norm
from packaging import version
@@ -241,7 +241,6 @@ def parse_flag_from_env(key, default=False):
_run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False)
_run_nightly_tests = parse_flag_from_env("RUN_NIGHTLY", default=False)
_run_compile_tests = parse_flag_from_env("RUN_COMPILE", default=False)
def floats_tensor(shape, scale=1.0, rng=None, name=None):
@@ -267,7 +266,7 @@ def slow(test_case):
Slow tests are skipped by default. Set the RUN_SLOW environment variable to a truthy value to run them.
"""
return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case)
return pytest.mark.skipif(not _run_slow_tests, reason="test is slow")(test_case)
def nightly(test_case):
@@ -277,33 +276,149 @@ def nightly(test_case):
Slow tests are skipped by default. Set the RUN_NIGHTLY environment variable to a truthy value to run them.
"""
return unittest.skipUnless(_run_nightly_tests, "test is nightly")(test_case)
return pytest.mark.skipif(not _run_nightly_tests, reason="test is nightly")(test_case)
def is_torch_compile(test_case):
"""
Decorator marking a test that runs compile tests in the diffusers CI.
Compile tests are skipped by default. Set the RUN_COMPILE environment variable to a truthy value to run them.
Decorator marking a test as a torch.compile test. These tests can be filtered using:
pytest -m "not compile" to skip
pytest -m compile to run only these tests
"""
return unittest.skipUnless(_run_compile_tests, "test is torch compile")(test_case)
return pytest.mark.compile(test_case)
def is_single_file(test_case):
"""
Decorator marking a test as a single file loading test. These tests can be filtered using:
pytest -m "not single_file" to skip
pytest -m single_file to run only these tests
"""
return pytest.mark.single_file(test_case)
def is_lora(test_case):
"""
Decorator marking a test as a LoRA test. These tests can be filtered using:
pytest -m "not lora" to skip
pytest -m lora to run only these tests
"""
return pytest.mark.lora(test_case)
def is_ip_adapter(test_case):
"""
Decorator marking a test as an IP Adapter test. These tests can be filtered using:
pytest -m "not ip_adapter" to skip
pytest -m ip_adapter to run only these tests
"""
return pytest.mark.ip_adapter(test_case)
def is_training(test_case):
"""
Decorator marking a test as a training test. These tests can be filtered using:
pytest -m "not training" to skip
pytest -m training to run only these tests
"""
return pytest.mark.training(test_case)
def is_attention(test_case):
"""
Decorator marking a test as an attention test. These tests can be filtered using:
pytest -m "not attention" to skip
pytest -m attention to run only these tests
"""
return pytest.mark.attention(test_case)
def is_memory(test_case):
"""
Decorator marking a test as a memory optimization test. These tests can be filtered using:
pytest -m "not memory" to skip
pytest -m memory to run only these tests
"""
return pytest.mark.memory(test_case)
def is_cpu_offload(test_case):
"""
Decorator marking a test as a CPU offload test. These tests can be filtered using:
pytest -m "not cpu_offload" to skip
pytest -m cpu_offload to run only these tests
"""
return pytest.mark.cpu_offload(test_case)
def is_group_offload(test_case):
"""
Decorator marking a test as a group offload test. These tests can be filtered using:
pytest -m "not group_offload" to skip
pytest -m group_offload to run only these tests
"""
return pytest.mark.group_offload(test_case)
def is_bitsandbytes(test_case):
"""
Decorator marking a test as a BitsAndBytes quantization test. These tests can be filtered using:
pytest -m "not bitsandbytes" to skip
pytest -m bitsandbytes to run only these tests
"""
return pytest.mark.bitsandbytes(test_case)
def is_quanto(test_case):
"""
Decorator marking a test as a Quanto quantization test. These tests can be filtered using:
pytest -m "not quanto" to skip
pytest -m quanto to run only these tests
"""
return pytest.mark.quanto(test_case)
def is_torchao(test_case):
"""
Decorator marking a test as a TorchAO quantization test. These tests can be filtered using:
pytest -m "not torchao" to skip
pytest -m torchao to run only these tests
"""
return pytest.mark.torchao(test_case)
def is_gguf(test_case):
"""
Decorator marking a test as a GGUF quantization test. These tests can be filtered using:
pytest -m "not gguf" to skip
pytest -m gguf to run only these tests
"""
return pytest.mark.gguf(test_case)
def is_modelopt(test_case):
"""
Decorator marking a test as a NVIDIA ModelOpt quantization test. These tests can be filtered using:
pytest -m "not modelopt" to skip
pytest -m modelopt to run only these tests
"""
return pytest.mark.modelopt(test_case)
def require_torch(test_case):
"""
Decorator marking a test that requires PyTorch. These tests are skipped when PyTorch isn't installed.
"""
return unittest.skipUnless(is_torch_available(), "test requires PyTorch")(test_case)
return pytest.mark.skipif(not is_torch_available(), reason="test requires PyTorch")(test_case)
def require_torch_2(test_case):
"""
Decorator marking a test that requires PyTorch 2. These tests are skipped when it isn't installed.
"""
return unittest.skipUnless(is_torch_available() and is_torch_version(">=", "2.0.0"), "test requires PyTorch 2")(
test_case
)
return pytest.mark.skipif(
not (is_torch_available() and is_torch_version(">=", "2.0.0")), reason="test requires PyTorch 2"
)(test_case)
def require_torch_version_greater_equal(torch_version):
@@ -311,8 +426,9 @@ def require_torch_version_greater_equal(torch_version):
def decorator(test_case):
correct_torch_version = is_torch_available() and is_torch_version(">=", torch_version)
return unittest.skipUnless(
correct_torch_version, f"test requires torch with the version greater than or equal to {torch_version}"
return pytest.mark.skipif(
not correct_torch_version,
reason=f"test requires torch with the version greater than or equal to {torch_version}",
)(test_case)
return decorator
@@ -323,8 +439,8 @@ def require_torch_version_greater(torch_version):
def decorator(test_case):
correct_torch_version = is_torch_available() and is_torch_version(">", torch_version)
return unittest.skipUnless(
correct_torch_version, f"test requires torch with the version greater than {torch_version}"
return pytest.mark.skipif(
not correct_torch_version, reason=f"test requires torch with the version greater than {torch_version}"
)(test_case)
return decorator
@@ -332,19 +448,18 @@ def require_torch_version_greater(torch_version):
def require_torch_gpu(test_case):
"""Decorator marking a test that requires CUDA and PyTorch."""
return unittest.skipUnless(is_torch_available() and torch_device == "cuda", "test requires PyTorch+CUDA")(
test_case
)
return pytest.mark.skipif(torch_device != "cuda", reason="test requires PyTorch+CUDA")(test_case)
def require_torch_cuda_compatibility(expected_compute_capability):
def decorator(test_case):
if torch.cuda.is_available():
current_compute_capability = get_torch_cuda_device_capability()
return unittest.skipUnless(
float(current_compute_capability) == float(expected_compute_capability),
"Test not supported for this compute capability.",
)
return pytest.mark.skipif(
float(current_compute_capability) != float(expected_compute_capability),
reason="Test not supported for this compute capability.",
)(test_case)
return test_case
return decorator
@@ -352,9 +467,7 @@ def require_torch_cuda_compatibility(expected_compute_capability):
# These decorators are for accelerator-specific behaviours that are not GPU-specific
def require_torch_accelerator(test_case):
"""Decorator marking a test that requires an accelerator backend and PyTorch."""
return unittest.skipUnless(is_torch_available() and torch_device != "cpu", "test requires accelerator+PyTorch")(
test_case
)
return pytest.mark.skipif(torch_device == "cpu", reason="test requires accelerator+PyTorch")(test_case)
def require_torch_multi_gpu(test_case):
@@ -364,11 +477,11 @@ def require_torch_multi_gpu(test_case):
-k "multi_gpu"
"""
if not is_torch_available():
return unittest.skip("test requires PyTorch")(test_case)
return pytest.mark.skip(reason="test requires PyTorch")(test_case)
import torch
return unittest.skipUnless(torch.cuda.device_count() > 1, "test requires multiple GPUs")(test_case)
return pytest.mark.skipif(torch.cuda.device_count() <= 1, reason="test requires multiple GPUs")(test_case)
def require_torch_multi_accelerator(test_case):
@@ -377,27 +490,28 @@ def require_torch_multi_accelerator(test_case):
without multiple hardware accelerators.
"""
if not is_torch_available():
return unittest.skip("test requires PyTorch")(test_case)
return pytest.mark.skip(reason="test requires PyTorch")(test_case)
import torch
return unittest.skipUnless(
torch.cuda.device_count() > 1 or torch.xpu.device_count() > 1, "test requires multiple hardware accelerators"
return pytest.mark.skipif(
not (torch.cuda.device_count() > 1 or torch.xpu.device_count() > 1),
reason="test requires multiple hardware accelerators",
)(test_case)
def require_torch_accelerator_with_fp16(test_case):
"""Decorator marking a test that requires an accelerator with support for the FP16 data type."""
return unittest.skipUnless(_is_torch_fp16_available(torch_device), "test requires accelerator with fp16 support")(
test_case
)
return pytest.mark.skipif(
not _is_torch_fp16_available(torch_device), reason="test requires accelerator with fp16 support"
)(test_case)
def require_torch_accelerator_with_fp64(test_case):
"""Decorator marking a test that requires an accelerator with support for the FP64 data type."""
return unittest.skipUnless(_is_torch_fp64_available(torch_device), "test requires accelerator with fp64 support")(
test_case
)
return pytest.mark.skipif(
not _is_torch_fp64_available(torch_device), reason="test requires accelerator with fp64 support"
)(test_case)
def require_big_gpu_with_torch_cuda(test_case):
@@ -406,17 +520,17 @@ def require_big_gpu_with_torch_cuda(test_case):
etc.
"""
if not is_torch_available():
return unittest.skip("test requires PyTorch")(test_case)
return pytest.mark.skip(reason="test requires PyTorch")(test_case)
import torch
if not torch.cuda.is_available():
return unittest.skip("test requires PyTorch CUDA")(test_case)
return pytest.mark.skip(reason="test requires PyTorch CUDA")(test_case)
device_properties = torch.cuda.get_device_properties(0)
total_memory = device_properties.total_memory / (1024**3)
return unittest.skipUnless(
total_memory >= BIG_GPU_MEMORY, f"test requires a GPU with at least {BIG_GPU_MEMORY} GB memory"
return pytest.mark.skipif(
total_memory < BIG_GPU_MEMORY, reason=f"test requires a GPU with at least {BIG_GPU_MEMORY} GB memory"
)(test_case)
@@ -430,12 +544,12 @@ def require_big_accelerator(test_case):
test_case = pytest.mark.big_accelerator(test_case)
if not is_torch_available():
return unittest.skip("test requires PyTorch")(test_case)
return pytest.mark.skip(reason="test requires PyTorch")(test_case)
import torch
if not (torch.cuda.is_available() or torch.xpu.is_available()):
return unittest.skip("test requires PyTorch CUDA")(test_case)
return pytest.mark.skip(reason="test requires PyTorch CUDA")(test_case)
if torch.xpu.is_available():
device_properties = torch.xpu.get_device_properties(0)
@@ -443,30 +557,30 @@ def require_big_accelerator(test_case):
device_properties = torch.cuda.get_device_properties(0)
total_memory = device_properties.total_memory / (1024**3)
return unittest.skipUnless(
total_memory >= BIG_GPU_MEMORY,
f"test requires a hardware accelerator with at least {BIG_GPU_MEMORY} GB memory",
return pytest.mark.skipif(
total_memory < BIG_GPU_MEMORY,
reason=f"test requires a hardware accelerator with at least {BIG_GPU_MEMORY} GB memory",
)(test_case)
def require_torch_accelerator_with_training(test_case):
"""Decorator marking a test that requires an accelerator with support for training."""
return unittest.skipUnless(
is_torch_available() and backend_supports_training(torch_device),
"test requires accelerator with training support",
return pytest.mark.skipif(
not (is_torch_available() and backend_supports_training(torch_device)),
reason="test requires accelerator with training support",
)(test_case)
def skip_mps(test_case):
"""Decorator marking a test to skip if torch_device is 'mps'"""
return unittest.skipUnless(torch_device != "mps", "test requires non 'mps' device")(test_case)
return pytest.mark.skipif(torch_device == "mps", reason="test requires non 'mps' device")(test_case)
def require_flax(test_case):
"""
Decorator marking a test that requires JAX & Flax. These tests are skipped when one / both are not installed
"""
return unittest.skipUnless(is_flax_available(), "test requires JAX & Flax")(test_case)
return pytest.mark.skipif(not is_flax_available(), reason="test requires JAX & Flax")(test_case)
def require_compel(test_case):
@@ -474,21 +588,21 @@ def require_compel(test_case):
Decorator marking a test that requires compel: https://github.com/damian0815/compel. These tests are skipped when
the library is not installed.
"""
return unittest.skipUnless(is_compel_available(), "test requires compel")(test_case)
return pytest.mark.skipif(not is_compel_available(), reason="test requires compel")(test_case)
def require_onnxruntime(test_case):
"""
Decorator marking a test that requires onnxruntime. These tests are skipped when onnxruntime isn't installed.
"""
return unittest.skipUnless(is_onnx_available(), "test requires onnxruntime")(test_case)
return pytest.mark.skipif(not is_onnx_available(), reason="test requires onnxruntime")(test_case)
def require_note_seq(test_case):
"""
Decorator marking a test that requires note_seq. These tests are skipped when note_seq isn't installed.
"""
return unittest.skipUnless(is_note_seq_available(), "test requires note_seq")(test_case)
return pytest.mark.skipif(not is_note_seq_available(), reason="test requires note_seq")(test_case)
def require_accelerator(test_case):
@@ -496,14 +610,14 @@ def require_accelerator(test_case):
Decorator marking a test that requires a hardware accelerator backend. These tests are skipped when there are no
hardware accelerator available.
"""
return unittest.skipUnless(torch_device != "cpu", "test requires a hardware accelerator")(test_case)
return pytest.mark.skipif(torch_device == "cpu", reason="test requires a hardware accelerator")(test_case)
def require_torchsde(test_case):
"""
Decorator marking a test that requires torchsde. These tests are skipped when torchsde isn't installed.
"""
return unittest.skipUnless(is_torchsde_available(), "test requires torchsde")(test_case)
return pytest.mark.skipif(not is_torchsde_available(), reason="test requires torchsde")(test_case)
def require_peft_backend(test_case):
@@ -511,35 +625,35 @@ def require_peft_backend(test_case):
Decorator marking a test that requires PEFT backend, this would require some specific versions of PEFT and
transformers.
"""
return unittest.skipUnless(USE_PEFT_BACKEND, "test requires PEFT backend")(test_case)
return pytest.mark.skipif(not USE_PEFT_BACKEND, reason="test requires PEFT backend")(test_case)
def require_timm(test_case):
"""
Decorator marking a test that requires timm. These tests are skipped when timm isn't installed.
"""
return unittest.skipUnless(is_timm_available(), "test requires timm")(test_case)
return pytest.mark.skipif(not is_timm_available(), reason="test requires timm")(test_case)
def require_bitsandbytes(test_case):
"""
Decorator marking a test that requires bitsandbytes. These tests are skipped when bitsandbytes isn't installed.
"""
return unittest.skipUnless(is_bitsandbytes_available(), "test requires bitsandbytes")(test_case)
return pytest.mark.skipif(not is_bitsandbytes_available(), reason="test requires bitsandbytes")(test_case)
def require_quanto(test_case):
"""
Decorator marking a test that requires quanto. These tests are skipped when quanto isn't installed.
"""
return unittest.skipUnless(is_optimum_quanto_available(), "test requires quanto")(test_case)
return pytest.mark.skipif(not is_optimum_quanto_available(), reason="test requires quanto")(test_case)
def require_accelerate(test_case):
"""
Decorator marking a test that requires accelerate. These tests are skipped when accelerate isn't installed.
"""
return unittest.skipUnless(is_accelerate_available(), "test requires accelerate")(test_case)
return pytest.mark.skipif(not is_accelerate_available(), reason="test requires accelerate")(test_case)
def require_peft_version_greater(peft_version):
@@ -552,8 +666,8 @@ def require_peft_version_greater(peft_version):
correct_peft_version = is_peft_available() and version.parse(
version.parse(importlib.metadata.version("peft")).base_version
) > version.parse(peft_version)
return unittest.skipUnless(
correct_peft_version, f"test requires PEFT backend with the version greater than {peft_version}"
return pytest.mark.skipif(
not correct_peft_version, reason=f"test requires PEFT backend with the version greater than {peft_version}"
)(test_case)
return decorator
@@ -569,9 +683,9 @@ def require_transformers_version_greater(transformers_version):
correct_transformers_version = is_transformers_available() and version.parse(
version.parse(importlib.metadata.version("transformers")).base_version
) > version.parse(transformers_version)
return unittest.skipUnless(
correct_transformers_version,
f"test requires transformers with the version greater than {transformers_version}",
return pytest.mark.skipif(
not correct_transformers_version,
reason=f"test requires transformers with the version greater than {transformers_version}",
)(test_case)
return decorator
@@ -582,8 +696,9 @@ def require_accelerate_version_greater(accelerate_version):
correct_accelerate_version = is_accelerate_available() and version.parse(
version.parse(importlib.metadata.version("accelerate")).base_version
) > version.parse(accelerate_version)
return unittest.skipUnless(
correct_accelerate_version, f"Test requires accelerate with the version greater than {accelerate_version}."
return pytest.mark.skipif(
not correct_accelerate_version,
reason=f"Test requires accelerate with the version greater than {accelerate_version}.",
)(test_case)
return decorator
@@ -594,8 +709,8 @@ def require_bitsandbytes_version_greater(bnb_version):
correct_bnb_version = is_bitsandbytes_available() and version.parse(
version.parse(importlib.metadata.version("bitsandbytes")).base_version
) > version.parse(bnb_version)
return unittest.skipUnless(
correct_bnb_version, f"Test requires bitsandbytes with the version greater than {bnb_version}."
return pytest.mark.skipif(
not correct_bnb_version, reason=f"Test requires bitsandbytes with the version greater than {bnb_version}."
)(test_case)
return decorator
@@ -606,8 +721,9 @@ def require_hf_hub_version_greater(hf_hub_version):
correct_hf_hub_version = version.parse(
version.parse(importlib.metadata.version("huggingface_hub")).base_version
) > version.parse(hf_hub_version)
return unittest.skipUnless(
correct_hf_hub_version, f"Test requires huggingface_hub with the version greater than {hf_hub_version}."
return pytest.mark.skipif(
not correct_hf_hub_version,
reason=f"Test requires huggingface_hub with the version greater than {hf_hub_version}.",
)(test_case)
return decorator
@@ -618,8 +734,8 @@ def require_gguf_version_greater_or_equal(gguf_version):
correct_gguf_version = is_gguf_available() and version.parse(
version.parse(importlib.metadata.version("gguf")).base_version
) >= version.parse(gguf_version)
return unittest.skipUnless(
correct_gguf_version, f"Test requires gguf with the version greater than {gguf_version}."
return pytest.mark.skipif(
not correct_gguf_version, reason=f"Test requires gguf with the version greater than {gguf_version}."
)(test_case)
return decorator
@@ -630,8 +746,8 @@ def require_torchao_version_greater_or_equal(torchao_version):
correct_torchao_version = is_torchao_available() and version.parse(
version.parse(importlib.metadata.version("torchao")).base_version
) >= version.parse(torchao_version)
return unittest.skipUnless(
correct_torchao_version, f"Test requires torchao with version greater than {torchao_version}."
return pytest.mark.skipif(
not correct_torchao_version, reason=f"Test requires torchao with version greater than {torchao_version}."
)(test_case)
return decorator
@@ -642,8 +758,8 @@ def require_kernels_version_greater_or_equal(kernels_version):
correct_kernels_version = is_kernels_available() and version.parse(
version.parse(importlib.metadata.version("kernels")).base_version
) >= version.parse(kernels_version)
return unittest.skipUnless(
correct_kernels_version, f"Test requires kernels with version greater than {kernels_version}."
return pytest.mark.skipif(
not correct_kernels_version, reason=f"Test requires kernels with version greater than {kernels_version}."
)(test_case)
return decorator
@@ -653,7 +769,7 @@ def deprecate_after_peft_backend(test_case):
"""
Decorator marking a test that will be skipped after PEFT backend
"""
return unittest.skipUnless(not USE_PEFT_BACKEND, "test skipped in favor of PEFT backend")(test_case)
return pytest.mark.skipif(USE_PEFT_BACKEND, reason="test skipped in favor of PEFT backend")(test_case)
def get_python_version():
@@ -1064,8 +1180,8 @@ def run_test_in_subprocess(test_case, target_func, inputs=None, timeout=None):
To run a test in a subprocess. In particular, this can avoid (GPU) memory issue.
Args:
test_case (`unittest.TestCase`):
The test that will run `target_func`.
test_case:
The test case object that will run `target_func`.
target_func (`Callable`):
The function implementing the actual testing logic.
inputs (`dict`, *optional*, defaults to `None`):
@@ -1083,7 +1199,7 @@ def run_test_in_subprocess(test_case, target_func, inputs=None, timeout=None):
input_queue = ctx.Queue(1)
output_queue = ctx.JoinableQueue(1)
# We can't send `unittest.TestCase` to the child, otherwise we get issues regarding pickle.
# We can't send test case objects to the child, otherwise we get issues regarding pickle.
input_queue.put(inputs, timeout=timeout)
process = ctx.Process(target=target_func, args=(input_queue, output_queue, timeout))