Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| d23c775e86 |
@@ -12,7 +12,7 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
# LoopSequentialPipelineBlocks
|
||||
|
||||
[`~modular_pipelines.LoopSequentialPipelineBlocks`] are a multi-block type that composes other [`~modular_pipelines.ModularPipelineBlocks`] together in a loop. Data flows circularly, using `inputs` and `intermediate_outputs`, and each block is run iteratively. This is typically used to create a denoising loop which is iterative by default.
|
||||
[`~modular_pipelines.LoopSequentialPipelineBlocks`] are a multi-block type that composes other [`~modular_pipelines.ModularPipelineBlocks`] together in a loop. Data flows circularly, using `intermediate_inputs` and `intermediate_outputs`, and each block is run iteratively. This is typically used to create a denoising loop which is iterative by default.
|
||||
|
||||
This guide shows you how to create [`~modular_pipelines.LoopSequentialPipelineBlocks`].
|
||||
|
||||
@@ -21,6 +21,7 @@ This guide shows you how to create [`~modular_pipelines.LoopSequentialPipelineBl
|
||||
[`~modular_pipelines.LoopSequentialPipelineBlocks`], is also known as the *loop wrapper* because it defines the loop structure, iteration variables, and configuration. Within the loop wrapper, you need the following variables.
|
||||
|
||||
- `loop_inputs` are user provided values and equivalent to [`~modular_pipelines.ModularPipelineBlocks.inputs`].
|
||||
- `loop_intermediate_inputs` are intermediate variables from the [`~modular_pipelines.PipelineState`] and equivalent to [`~modular_pipelines.ModularPipelineBlocks.intermediate_inputs`].
|
||||
- `loop_intermediate_outputs` are new intermediate variables created by the block and added to the [`~modular_pipelines.PipelineState`]. It is equivalent to [`~modular_pipelines.ModularPipelineBlocks.intermediate_outputs`].
|
||||
- `__call__` method defines the loop structure and iteration logic.
|
||||
|
||||
@@ -89,4 +90,4 @@ Add more loop blocks to run within each iteration with [`~modular_pipelines.Loop
|
||||
|
||||
```py
|
||||
loop = LoopWrapper.from_blocks_dict({"block1": LoopBlock(), "block2": LoopBlock})
|
||||
```
|
||||
```
|
||||
@@ -37,7 +37,17 @@ A [`~modular_pipelines.ModularPipelineBlocks`] requires `inputs`, and `intermedi
|
||||
]
|
||||
```
|
||||
|
||||
- `intermediate_outputs` are new values created by a block and added to the [`~modular_pipelines.PipelineState`]. The `intermediate_outputs` are available as `inputs` for subsequent blocks or available as the final output from running the pipeline.
|
||||
- `intermediate_inputs` are values typically created from a previous block but it can also be directly provided if no preceding block generates them. Unlike `inputs`, `intermediate_inputs` can be modified.
|
||||
|
||||
Use `InputParam` to define `intermediate_inputs`.
|
||||
|
||||
```py
|
||||
user_intermediate_inputs = [
|
||||
InputParam(name="processed_image", type_hint="torch.Tensor", description="image that has been preprocessed and normalized"),
|
||||
]
|
||||
```
|
||||
|
||||
- `intermediate_outputs` are new values created by a block and added to the [`~modular_pipelines.PipelineState`]. The `intermediate_outputs` are available as `intermediate_inputs` for subsequent blocks or available as the final output from running the pipeline.
|
||||
|
||||
Use `OutputParam` to define `intermediate_outputs`.
|
||||
|
||||
@@ -55,8 +65,8 @@ The intermediate inputs and outputs share data to connect blocks. They are acces
|
||||
|
||||
The computation a block performs is defined in the `__call__` method and it follows a specific structure.
|
||||
|
||||
1. Retrieve the [`~modular_pipelines.BlockState`] to get a local view of the `inputs`
|
||||
2. Implement the computation logic on the `inputs`.
|
||||
1. Retrieve the [`~modular_pipelines.BlockState`] to get a local view of the `inputs` and `intermediate_inputs`.
|
||||
2. Implement the computation logic on the `inputs` and `intermediate_inputs`.
|
||||
3. Update [`~modular_pipelines.PipelineState`] to push changes from the local [`~modular_pipelines.BlockState`] back to the global [`~modular_pipelines.PipelineState`].
|
||||
4. Return the components and state which becomes available to the next block.
|
||||
|
||||
@@ -66,7 +76,7 @@ def __call__(self, components, state):
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
# Your computation logic here
|
||||
# block_state contains all your inputs
|
||||
# block_state contains all your inputs and intermediate_inputs
|
||||
# Access them like: block_state.image, block_state.processed_image
|
||||
|
||||
# Update the pipeline state with your updated block_states
|
||||
@@ -102,4 +112,4 @@ def __call__(self, components, state):
|
||||
unet = components.unet
|
||||
vae = components.vae
|
||||
scheduler = components.scheduler
|
||||
```
|
||||
```
|
||||
@@ -183,7 +183,7 @@ from diffusers.modular_pipelines import ComponentsManager
|
||||
components = ComponentManager()
|
||||
|
||||
dd_pipeline = dd_blocks.init_pipeline("YiYiXu/modular-demo-auto", components_manager=components, collection="diffdiff")
|
||||
dd_pipeline.load_componenets(torch_dtype=torch.float16)
|
||||
dd_pipeline.load_default_componenets(torch_dtype=torch.float16)
|
||||
dd_pipeline.to("cuda")
|
||||
```
|
||||
|
||||
|
||||
@@ -12,11 +12,11 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
# SequentialPipelineBlocks
|
||||
|
||||
[`~modular_pipelines.SequentialPipelineBlocks`] are a multi-block type that composes other [`~modular_pipelines.ModularPipelineBlocks`] together in a sequence. Data flows linearly from one block to the next using `inputs` and `intermediate_outputs`. Each block in [`~modular_pipelines.SequentialPipelineBlocks`] usually represents a step in the pipeline, and by combining them, you gradually build a pipeline.
|
||||
[`~modular_pipelines.SequentialPipelineBlocks`] are a multi-block type that composes other [`~modular_pipelines.ModularPipelineBlocks`] together in a sequence. Data flows linearly from one block to the next using `intermediate_inputs` and `intermediate_outputs`. Each block in [`~modular_pipelines.SequentialPipelineBlocks`] usually represents a step in the pipeline, and by combining them, you gradually build a pipeline.
|
||||
|
||||
This guide shows you how to connect two blocks into a [`~modular_pipelines.SequentialPipelineBlocks`].
|
||||
|
||||
Create two [`~modular_pipelines.ModularPipelineBlocks`]. The first block, `InputBlock`, outputs a `batch_size` value and the second block, `ImageEncoderBlock` uses `batch_size` as `inputs`.
|
||||
Create two [`~modular_pipelines.ModularPipelineBlocks`]. The first block, `InputBlock`, outputs a `batch_size` value and the second block, `ImageEncoderBlock` uses `batch_size` as `intermediate_inputs`.
|
||||
|
||||
<hfoptions id="sequential">
|
||||
<hfoption id="InputBlock">
|
||||
@@ -110,4 +110,4 @@ Inspect the sub-blocks in [`~modular_pipelines.SequentialPipelineBlocks`] by cal
|
||||
```py
|
||||
print(blocks)
|
||||
print(blocks.doc)
|
||||
```
|
||||
```
|
||||
@@ -45,7 +45,7 @@ def check_size(image, height, width):
|
||||
raise ValueError(f"Image size should be {height}x{width}, but got {h}x{w}")
|
||||
|
||||
|
||||
def overlay_inner_image(image, inner_image, paste_offset: Tuple[int, ...] = (0, 0)):
|
||||
def overlay_inner_image(image, inner_image, paste_offset: Tuple[int] = (0, 0)):
|
||||
inner_image = inner_image.convert("RGBA")
|
||||
image = image.convert("RGB")
|
||||
|
||||
|
||||
@@ -1966,21 +1966,16 @@ class MatryoshkaUNet2DConditionModel(
|
||||
center_input_sample: bool = False,
|
||||
flip_sin_to_cos: bool = True,
|
||||
freq_shift: int = 0,
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
down_block_types: Tuple[str] = (
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"DownBlock2D",
|
||||
),
|
||||
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
|
||||
up_block_types: Tuple[str, ...] = (
|
||||
"UpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
),
|
||||
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
|
||||
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
||||
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
||||
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
||||
layers_per_block: Union[int, Tuple[int]] = 2,
|
||||
downsample_padding: int = 1,
|
||||
mid_block_scale_factor: float = 1,
|
||||
@@ -2299,10 +2294,10 @@ class MatryoshkaUNet2DConditionModel(
|
||||
|
||||
def _check_config(
|
||||
self,
|
||||
down_block_types: Tuple[str, ...],
|
||||
up_block_types: Tuple[str, ...],
|
||||
down_block_types: Tuple[str],
|
||||
up_block_types: Tuple[str],
|
||||
only_cross_attention: Union[bool, Tuple[bool]],
|
||||
block_out_channels: Tuple[int, ...],
|
||||
block_out_channels: Tuple[int],
|
||||
layers_per_block: Union[int, Tuple[int]],
|
||||
cross_attention_dim: Union[int, Tuple[int]],
|
||||
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple[int]]],
|
||||
|
||||
@@ -438,21 +438,16 @@ class UNet2DConditionModel(OriginalUNet2DConditionModel, ConfigMixin, UNet2DCond
|
||||
center_input_sample: bool = False,
|
||||
flip_sin_to_cos: bool = True,
|
||||
freq_shift: int = 0,
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
down_block_types: Tuple[str] = (
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"DownBlock2D",
|
||||
),
|
||||
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
|
||||
up_block_types: Tuple[str, ...] = (
|
||||
"UpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
),
|
||||
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
|
||||
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
||||
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
||||
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
||||
layers_per_block: Union[int, Tuple[int]] = 2,
|
||||
downsample_padding: int = 1,
|
||||
mid_block_scale_factor: float = 1,
|
||||
|
||||
@@ -10,7 +10,7 @@ from accelerate import init_empty_weights
|
||||
from diffusers import (
|
||||
SanaControlNetModel,
|
||||
)
|
||||
from diffusers.models.model_loading_utils import load_model_dict_into_meta
|
||||
from diffusers.models.modeling_utils import load_model_dict_into_meta
|
||||
from diffusers.utils.import_utils import is_accelerate_available
|
||||
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ from diffusers import (
|
||||
SanaTransformer2DModel,
|
||||
SCMScheduler,
|
||||
)
|
||||
from diffusers.models.model_loading_utils import load_model_dict_into_meta
|
||||
from diffusers.models.modeling_utils import load_model_dict_into_meta
|
||||
from diffusers.utils.import_utils import is_accelerate_available
|
||||
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ from accelerate import init_empty_weights
|
||||
|
||||
from diffusers import AutoencoderKL, SD3Transformer2DModel
|
||||
from diffusers.loaders.single_file_utils import convert_ldm_vae_checkpoint
|
||||
from diffusers.models.model_loading_utils import load_model_dict_into_meta
|
||||
from diffusers.models.modeling_utils import load_model_dict_into_meta
|
||||
from diffusers.utils.import_utils import is_accelerate_available
|
||||
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ from diffusers import (
|
||||
StableAudioPipeline,
|
||||
StableAudioProjectionModel,
|
||||
)
|
||||
from diffusers.models.model_loading_utils import load_model_dict_into_meta
|
||||
from diffusers.models.modeling_utils import load_model_dict_into_meta
|
||||
from diffusers.utils import is_accelerate_available
|
||||
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ from diffusers import (
|
||||
)
|
||||
from diffusers.loaders.single_file_utils import convert_stable_cascade_unet_single_file_to_diffusers
|
||||
from diffusers.models import StableCascadeUNet
|
||||
from diffusers.models.model_loading_utils import load_model_dict_into_meta
|
||||
from diffusers.models.modeling_utils import load_model_dict_into_meta
|
||||
from diffusers.pipelines.wuerstchen import PaellaVQModel
|
||||
from diffusers.utils import is_accelerate_available
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ from diffusers import (
|
||||
)
|
||||
from diffusers.loaders.single_file_utils import convert_stable_cascade_unet_single_file_to_diffusers
|
||||
from diffusers.models import StableCascadeUNet
|
||||
from diffusers.models.model_loading_utils import load_model_dict_into_meta
|
||||
from diffusers.models.modeling_utils import load_model_dict_into_meta
|
||||
from diffusers.pipelines.wuerstchen import PaellaVQModel
|
||||
from diffusers.utils import is_accelerate_available
|
||||
|
||||
|
||||
@@ -407,7 +407,6 @@ else:
|
||||
"QwenImageModularPipeline",
|
||||
"StableDiffusionXLAutoBlocks",
|
||||
"StableDiffusionXLModularPipeline",
|
||||
"Wan22AutoBlocks",
|
||||
"WanAutoBlocks",
|
||||
"WanModularPipeline",
|
||||
]
|
||||
@@ -1091,7 +1090,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
QwenImageModularPipeline,
|
||||
StableDiffusionXLAutoBlocks,
|
||||
StableDiffusionXLModularPipeline,
|
||||
Wan22AutoBlocks,
|
||||
WanAutoBlocks,
|
||||
WanModularPipeline,
|
||||
)
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@@ -88,19 +88,6 @@ class AdaptiveProjectedGuidance(BaseGuidance):
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def prepare_inputs_from_block_state(
|
||||
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
|
||||
) -> List["BlockState"]:
|
||||
if self._step == 0:
|
||||
if self.adaptive_projected_guidance_momentum is not None:
|
||||
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
|
||||
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
|
||||
pred = None
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@@ -99,19 +99,6 @@ class AdaptiveProjectedMixGuidance(BaseGuidance):
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def prepare_inputs_from_block_state(
|
||||
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
|
||||
) -> List["BlockState"]:
|
||||
if self._step == 0:
|
||||
if self.adaptive_projected_guidance_momentum is not None:
|
||||
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
|
||||
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
|
||||
pred = None
|
||||
|
||||
|
||||
@@ -141,16 +141,6 @@ class AutoGuidance(BaseGuidance):
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def prepare_inputs_from_block_state(
|
||||
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
|
||||
) -> List["BlockState"]:
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
|
||||
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
|
||||
pred = None
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@@ -99,16 +99,6 @@ class ClassifierFreeGuidance(BaseGuidance):
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def prepare_inputs_from_block_state(
|
||||
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
|
||||
) -> List["BlockState"]:
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
|
||||
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
|
||||
pred = None
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@@ -85,16 +85,6 @@ class ClassifierFreeZeroStarGuidance(BaseGuidance):
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def prepare_inputs_from_block_state(
|
||||
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
|
||||
) -> List["BlockState"]:
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
|
||||
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
|
||||
pred = None
|
||||
|
||||
|
||||
@@ -226,16 +226,6 @@ class FrequencyDecoupledGuidance(BaseGuidance):
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def prepare_inputs_from_block_state(
|
||||
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
|
||||
) -> List["BlockState"]:
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
|
||||
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
|
||||
pred = None
|
||||
|
||||
|
||||
@@ -166,11 +166,6 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
|
||||
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
|
||||
raise NotImplementedError("BaseGuidance::prepare_inputs must be implemented in subclasses.")
|
||||
|
||||
def prepare_inputs_from_block_state(
|
||||
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
|
||||
) -> List["BlockState"]:
|
||||
raise NotImplementedError("BaseGuidance::prepare_inputs_from_block_state must be implemented in subclasses.")
|
||||
|
||||
def __call__(self, data: List["BlockState"]) -> Any:
|
||||
if not all(hasattr(d, "noise_pred") for d in data):
|
||||
raise ValueError("Expected all data to have `noise_pred` attribute.")
|
||||
@@ -239,51 +234,6 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
|
||||
data_batch[cls._identifier_key] = identifier
|
||||
return BlockState(**data_batch)
|
||||
|
||||
@classmethod
|
||||
def _prepare_batch_from_block_state(
|
||||
cls,
|
||||
input_fields: Dict[str, Union[str, Tuple[str, str]]],
|
||||
data: "BlockState",
|
||||
tuple_index: int,
|
||||
identifier: str,
|
||||
) -> "BlockState":
|
||||
"""
|
||||
Prepares a batch of data for the guidance technique. This method is used in the `prepare_inputs` method of the
|
||||
`BaseGuidance` class. It prepares the batch based on the provided tuple index.
|
||||
|
||||
Args:
|
||||
input_fields (`Dict[str, Union[str, Tuple[str, str]]]`):
|
||||
A dictionary where the keys are the names of the fields that will be used to store the data once it is
|
||||
prepared with `prepare_inputs`. The values can be either a string or a tuple of length 2, which is used
|
||||
to look up the required data provided for preparation. If a string is provided, it will be used as the
|
||||
conditional data (or unconditional if used with a guidance method that requires it). If a tuple of
|
||||
length 2 is provided, the first element must be the conditional data identifier and the second element
|
||||
must be the unconditional data identifier or None.
|
||||
data (`BlockState`):
|
||||
The input data to be prepared.
|
||||
tuple_index (`int`):
|
||||
The index to use when accessing input fields that are tuples.
|
||||
|
||||
Returns:
|
||||
`BlockState`: The prepared batch of data.
|
||||
"""
|
||||
from ..modular_pipelines.modular_pipeline import BlockState
|
||||
|
||||
data_batch = {}
|
||||
for key, value in input_fields.items():
|
||||
try:
|
||||
if isinstance(value, str):
|
||||
data_batch[key] = getattr(data, value)
|
||||
elif isinstance(value, tuple):
|
||||
data_batch[key] = getattr(data, value[tuple_index])
|
||||
else:
|
||||
# We've already checked that value is a string or a tuple of strings with length 2
|
||||
pass
|
||||
except AttributeError:
|
||||
logger.debug(f"`data` does not have attribute(s) {value}, skipping.")
|
||||
data_batch[cls._identifier_key] = identifier
|
||||
return BlockState(**data_batch)
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
def from_pretrained(
|
||||
|
||||
@@ -187,26 +187,6 @@ class PerturbedAttentionGuidance(BaseGuidance):
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def prepare_inputs_from_block_state(
|
||||
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
|
||||
) -> List["BlockState"]:
|
||||
if self.num_conditions == 1:
|
||||
tuple_indices = [0]
|
||||
input_predictions = ["pred_cond"]
|
||||
elif self.num_conditions == 2:
|
||||
tuple_indices = [0, 1]
|
||||
input_predictions = (
|
||||
["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_skip"]
|
||||
)
|
||||
else:
|
||||
tuple_indices = [0, 1, 0]
|
||||
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
|
||||
data_batches = []
|
||||
for tuple_idx, input_prediction in zip(tuple_indices, input_predictions):
|
||||
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.forward
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@@ -183,26 +183,6 @@ class SkipLayerGuidance(BaseGuidance):
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def prepare_inputs_from_block_state(
|
||||
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
|
||||
) -> List["BlockState"]:
|
||||
if self.num_conditions == 1:
|
||||
tuple_indices = [0]
|
||||
input_predictions = ["pred_cond"]
|
||||
elif self.num_conditions == 2:
|
||||
tuple_indices = [0, 1]
|
||||
input_predictions = (
|
||||
["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_skip"]
|
||||
)
|
||||
else:
|
||||
tuple_indices = [0, 1, 0]
|
||||
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
|
||||
data_batches = []
|
||||
for tuple_idx, input_prediction in zip(tuple_indices, input_predictions):
|
||||
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pred_cond: torch.Tensor,
|
||||
|
||||
@@ -172,26 +172,6 @@ class SmoothedEnergyGuidance(BaseGuidance):
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def prepare_inputs_from_block_state(
|
||||
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
|
||||
) -> List["BlockState"]:
|
||||
if self.num_conditions == 1:
|
||||
tuple_indices = [0]
|
||||
input_predictions = ["pred_cond"]
|
||||
elif self.num_conditions == 2:
|
||||
tuple_indices = [0, 1]
|
||||
input_predictions = (
|
||||
["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_seg"]
|
||||
)
|
||||
else:
|
||||
tuple_indices = [0, 1, 0]
|
||||
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_seg"]
|
||||
data_batches = []
|
||||
for tuple_idx, input_prediction in zip(tuple_indices, input_predictions):
|
||||
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pred_cond: torch.Tensor,
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@@ -74,16 +74,6 @@ class TangentialClassifierFreeGuidance(BaseGuidance):
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def prepare_inputs_from_block_state(
|
||||
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
|
||||
) -> List["BlockState"]:
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
|
||||
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
|
||||
pred = None
|
||||
|
||||
|
||||
@@ -44,16 +44,11 @@ class ContextParallelConfig:
|
||||
|
||||
Args:
|
||||
ring_degree (`int`, *optional*, defaults to `1`):
|
||||
Number of devices to use for Ring Attention. Sequence is split across devices. Each device computes
|
||||
attention between its local Q and KV chunks passed sequentially around ring. Lower memory (only holds 1/N
|
||||
of KV at a time), overlaps compute with communication, but requires N iterations to see all tokens. Best
|
||||
for long sequences with limited memory/bandwidth. Number of devices to use for ring attention within a
|
||||
context parallel region. Must be a divisor of the total number of devices in the context parallel mesh.
|
||||
Number of devices to use for ring attention within a context parallel region. Must be a divisor of the
|
||||
total number of devices in the context parallel mesh.
|
||||
ulysses_degree (`int`, *optional*, defaults to `1`):
|
||||
Number of devices to use for Ulysses Attention. Sequence split is across devices. Each device computes
|
||||
local QKV, then all-gathers all KV chunks to compute full attention in one pass. Higher memory (stores all
|
||||
KV), requires high-bandwidth all-to-all communication, but lower latency. Best for moderate sequences with
|
||||
good interconnect bandwidth.
|
||||
Number of devices to use for ulysses attention within a context parallel region. Must be a divisor of the
|
||||
total number of devices in the context parallel mesh.
|
||||
convert_to_fp32 (`bool`, *optional*, defaults to `True`):
|
||||
Whether to convert output and LSE to float32 for ring attention numerical stability.
|
||||
rotate_method (`str`, *optional*, defaults to `"allgather"`):
|
||||
@@ -84,46 +79,29 @@ class ContextParallelConfig:
|
||||
if self.ulysses_degree is None:
|
||||
self.ulysses_degree = 1
|
||||
|
||||
if self.ring_degree == 1 and self.ulysses_degree == 1:
|
||||
raise ValueError(
|
||||
"Either ring_degree or ulysses_degree must be greater than 1 in order to use context parallel inference"
|
||||
)
|
||||
if self.ring_degree < 1 or self.ulysses_degree < 1:
|
||||
raise ValueError("`ring_degree` and `ulysses_degree` must be greater than or equal to 1.")
|
||||
if self.ring_degree > 1 and self.ulysses_degree > 1:
|
||||
raise ValueError(
|
||||
"Unified Ulysses-Ring attention is not yet supported. Please set either `ring_degree` or `ulysses_degree` to 1."
|
||||
)
|
||||
if self.rotate_method != "allgather":
|
||||
raise NotImplementedError(
|
||||
f"Only rotate_method='allgather' is supported for now, but got {self.rotate_method}."
|
||||
)
|
||||
|
||||
@property
|
||||
def mesh_shape(self) -> Tuple[int, int]:
|
||||
return (self.ring_degree, self.ulysses_degree)
|
||||
|
||||
@property
|
||||
def mesh_dim_names(self) -> Tuple[str, str]:
|
||||
"""Dimension names for the device mesh."""
|
||||
return ("ring", "ulysses")
|
||||
|
||||
def setup(self, rank: int, world_size: int, device: torch.device, mesh: torch.distributed.device_mesh.DeviceMesh):
|
||||
self._rank = rank
|
||||
self._world_size = world_size
|
||||
self._device = device
|
||||
self._mesh = mesh
|
||||
|
||||
if self.ulysses_degree * self.ring_degree > world_size:
|
||||
raise ValueError(
|
||||
f"The product of `ring_degree` ({self.ring_degree}) and `ulysses_degree` ({self.ulysses_degree}) must not exceed the world size ({world_size})."
|
||||
if self.ring_degree is None:
|
||||
self.ring_degree = 1
|
||||
if self.ulysses_degree is None:
|
||||
self.ulysses_degree = 1
|
||||
if self.rotate_method != "allgather":
|
||||
raise NotImplementedError(
|
||||
f"Only rotate_method='allgather' is supported for now, but got {self.rotate_method}."
|
||||
)
|
||||
|
||||
self._flattened_mesh = self._mesh._flatten()
|
||||
self._ring_mesh = self._mesh["ring"]
|
||||
self._ulysses_mesh = self._mesh["ulysses"]
|
||||
self._ring_local_rank = self._ring_mesh.get_local_rank()
|
||||
self._ulysses_local_rank = self._ulysses_mesh.get_local_rank()
|
||||
if self._flattened_mesh is None:
|
||||
self._flattened_mesh = self._mesh._flatten()
|
||||
if self._ring_mesh is None:
|
||||
self._ring_mesh = self._mesh["ring"]
|
||||
if self._ulysses_mesh is None:
|
||||
self._ulysses_mesh = self._mesh["ulysses"]
|
||||
if self._ring_local_rank is None:
|
||||
self._ring_local_rank = self._ring_mesh.get_local_rank()
|
||||
if self._ulysses_local_rank is None:
|
||||
self._ulysses_local_rank = self._ulysses_mesh.get_local_rank()
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -141,7 +119,7 @@ class ParallelConfig:
|
||||
_rank: int = None
|
||||
_world_size: int = None
|
||||
_device: torch.device = None
|
||||
_mesh: torch.distributed.device_mesh.DeviceMesh = None
|
||||
_cp_mesh: torch.distributed.device_mesh.DeviceMesh = None
|
||||
|
||||
def setup(
|
||||
self,
|
||||
@@ -149,14 +127,14 @@ class ParallelConfig:
|
||||
world_size: int,
|
||||
device: torch.device,
|
||||
*,
|
||||
mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
|
||||
cp_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
|
||||
):
|
||||
self._rank = rank
|
||||
self._world_size = world_size
|
||||
self._device = device
|
||||
self._mesh = mesh
|
||||
self._cp_mesh = cp_mesh
|
||||
if self.context_parallel_config is not None:
|
||||
self.context_parallel_config.setup(rank, world_size, device, mesh)
|
||||
self.context_parallel_config.setup(rank, world_size, device, cp_mesh)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
|
||||
@@ -220,7 +220,7 @@ class _AttentionBackendRegistry:
|
||||
_backends = {}
|
||||
_constraints = {}
|
||||
_supported_arg_names = {}
|
||||
_supports_context_parallel = set()
|
||||
_supports_context_parallel = {}
|
||||
_active_backend = AttentionBackendName(DIFFUSERS_ATTN_BACKEND)
|
||||
_checks_enabled = DIFFUSERS_ATTN_CHECKS
|
||||
|
||||
@@ -237,9 +237,7 @@ class _AttentionBackendRegistry:
|
||||
cls._backends[backend] = func
|
||||
cls._constraints[backend] = constraints or []
|
||||
cls._supported_arg_names[backend] = set(inspect.signature(func).parameters.keys())
|
||||
if supports_context_parallel:
|
||||
cls._supports_context_parallel.add(backend.value)
|
||||
|
||||
cls._supports_context_parallel[backend] = supports_context_parallel
|
||||
return func
|
||||
|
||||
return decorator
|
||||
@@ -253,12 +251,15 @@ class _AttentionBackendRegistry:
|
||||
return list(cls._backends.keys())
|
||||
|
||||
@classmethod
|
||||
def _is_context_parallel_available(
|
||||
cls,
|
||||
backend: AttentionBackendName,
|
||||
def _is_context_parallel_enabled(
|
||||
cls, backend: AttentionBackendName, parallel_config: Optional["ParallelConfig"]
|
||||
) -> bool:
|
||||
supports_context_parallel = backend.value in cls._supports_context_parallel
|
||||
return supports_context_parallel
|
||||
supports_context_parallel = backend in cls._supports_context_parallel
|
||||
is_degree_greater_than_1 = parallel_config is not None and (
|
||||
parallel_config.context_parallel_config.ring_degree > 1
|
||||
or parallel_config.context_parallel_config.ulysses_degree > 1
|
||||
)
|
||||
return supports_context_parallel and is_degree_greater_than_1
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
@@ -305,6 +306,14 @@ def dispatch_attention_fn(
|
||||
backend_name = AttentionBackendName(backend)
|
||||
backend_fn = _AttentionBackendRegistry._backends.get(backend_name)
|
||||
|
||||
if parallel_config is not None and not _AttentionBackendRegistry._is_context_parallel_enabled(
|
||||
backend_name, parallel_config
|
||||
):
|
||||
raise ValueError(
|
||||
f"Backend {backend_name} either does not support context parallelism or context parallelism "
|
||||
f"was enabled with a world size of 1."
|
||||
)
|
||||
|
||||
kwargs = {
|
||||
"query": query,
|
||||
"key": key,
|
||||
|
||||
@@ -102,7 +102,7 @@ def get_block(
|
||||
attention_head_dim: int,
|
||||
norm_type: str,
|
||||
act_fn: str,
|
||||
qkv_mutliscales: Tuple[int, ...] = (),
|
||||
qkv_mutliscales: Tuple[int] = (),
|
||||
):
|
||||
if block_type == "ResBlock":
|
||||
block = ResBlock(in_channels, out_channels, norm_type, act_fn)
|
||||
@@ -206,8 +206,8 @@ class Encoder(nn.Module):
|
||||
latent_channels: int,
|
||||
attention_head_dim: int = 32,
|
||||
block_type: Union[str, Tuple[str]] = "ResBlock",
|
||||
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512, 1024, 1024),
|
||||
layers_per_block: Tuple[int, ...] = (2, 2, 2, 2, 2, 2),
|
||||
block_out_channels: Tuple[int] = (128, 256, 512, 512, 1024, 1024),
|
||||
layers_per_block: Tuple[int] = (2, 2, 2, 2, 2, 2),
|
||||
qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)),
|
||||
downsample_block_type: str = "pixel_unshuffle",
|
||||
out_shortcut: bool = True,
|
||||
@@ -292,8 +292,8 @@ class Decoder(nn.Module):
|
||||
latent_channels: int,
|
||||
attention_head_dim: int = 32,
|
||||
block_type: Union[str, Tuple[str]] = "ResBlock",
|
||||
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512, 1024, 1024),
|
||||
layers_per_block: Tuple[int, ...] = (2, 2, 2, 2, 2, 2),
|
||||
block_out_channels: Tuple[int] = (128, 256, 512, 512, 1024, 1024),
|
||||
layers_per_block: Tuple[int] = (2, 2, 2, 2, 2, 2),
|
||||
qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)),
|
||||
norm_type: Union[str, Tuple[str]] = "rms_norm",
|
||||
act_fn: Union[str, Tuple[str]] = "silu",
|
||||
@@ -440,8 +440,8 @@ class AutoencoderDC(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModel
|
||||
decoder_block_types: Union[str, Tuple[str]] = "ResBlock",
|
||||
encoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512, 1024, 1024),
|
||||
decoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512, 1024, 1024),
|
||||
encoder_layers_per_block: Tuple[int, ...] = (2, 2, 2, 3, 3, 3),
|
||||
decoder_layers_per_block: Tuple[int, ...] = (3, 3, 3, 3, 3, 3),
|
||||
encoder_layers_per_block: Tuple[int] = (2, 2, 2, 3, 3, 3),
|
||||
decoder_layers_per_block: Tuple[int] = (3, 3, 3, 3, 3, 3),
|
||||
encoder_qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)),
|
||||
decoder_qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)),
|
||||
upsample_block_type: str = "pixel_shuffle",
|
||||
|
||||
@@ -78,9 +78,9 @@ class AutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModel
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",),
|
||||
up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",),
|
||||
block_out_channels: Tuple[int, ...] = (64,),
|
||||
down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
|
||||
up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
|
||||
block_out_channels: Tuple[int] = (64,),
|
||||
layers_per_block: int = 1,
|
||||
act_fn: str = "silu",
|
||||
latent_channels: int = 4,
|
||||
|
||||
@@ -995,19 +995,19 @@ class AutoencoderKLCogVideoX(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
down_block_types: Tuple[str] = (
|
||||
"CogVideoXDownBlock3D",
|
||||
"CogVideoXDownBlock3D",
|
||||
"CogVideoXDownBlock3D",
|
||||
"CogVideoXDownBlock3D",
|
||||
),
|
||||
up_block_types: Tuple[str, ...] = (
|
||||
up_block_types: Tuple[str] = (
|
||||
"CogVideoXUpBlock3D",
|
||||
"CogVideoXUpBlock3D",
|
||||
"CogVideoXUpBlock3D",
|
||||
"CogVideoXUpBlock3D",
|
||||
),
|
||||
block_out_channels: Tuple[int, ...] = (128, 256, 256, 512),
|
||||
block_out_channels: Tuple[int] = (128, 256, 256, 512),
|
||||
latent_channels: int = 16,
|
||||
layers_per_block: int = 3,
|
||||
act_fn: str = "silu",
|
||||
|
||||
@@ -653,7 +653,7 @@ class AutoencoderKLHunyuanVideo(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
"HunyuanVideoUpBlock3D",
|
||||
"HunyuanVideoUpBlock3D",
|
||||
),
|
||||
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
|
||||
block_out_channels: Tuple[int] = (128, 256, 512, 512),
|
||||
layers_per_block: int = 2,
|
||||
act_fn: str = "silu",
|
||||
norm_num_groups: int = 32,
|
||||
|
||||
@@ -601,7 +601,7 @@ class AutoencoderKLHunyuanImageRefiner(ModelMixin, ConfigMixin):
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
latent_channels: int = 32,
|
||||
block_out_channels: Tuple[int, ...] = (128, 256, 512, 1024, 1024),
|
||||
block_out_channels: Tuple[int] = (128, 256, 512, 1024, 1024),
|
||||
layers_per_block: int = 2,
|
||||
spatial_compression_ratio: int = 16,
|
||||
temporal_compression_ratio: int = 4,
|
||||
|
||||
@@ -688,8 +688,8 @@ class AutoencoderKLMochi(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
self,
|
||||
in_channels: int = 15,
|
||||
out_channels: int = 3,
|
||||
encoder_block_out_channels: Tuple[int, ...] = (64, 128, 256, 384),
|
||||
decoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 768),
|
||||
encoder_block_out_channels: Tuple[int] = (64, 128, 256, 384),
|
||||
decoder_block_out_channels: Tuple[int] = (128, 256, 512, 768),
|
||||
latent_channels: int = 12,
|
||||
layers_per_block: Tuple[int, ...] = (3, 3, 4, 6, 3),
|
||||
act_fn: str = "silu",
|
||||
|
||||
@@ -679,7 +679,7 @@ class AutoencoderKLQwenImage(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
|
||||
self,
|
||||
base_dim: int = 96,
|
||||
z_dim: int = 16,
|
||||
dim_mult: Tuple[int, ...] = (1, 2, 4, 4),
|
||||
dim_mult: Tuple[int] = [1, 2, 4, 4],
|
||||
num_res_blocks: int = 2,
|
||||
attn_scales: List[float] = [],
|
||||
temperal_downsample: List[bool] = [False, True, True],
|
||||
|
||||
@@ -31,7 +31,7 @@ class TemporalDecoder(nn.Module):
|
||||
self,
|
||||
in_channels: int = 4,
|
||||
out_channels: int = 3,
|
||||
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
|
||||
block_out_channels: Tuple[int] = (128, 256, 512, 512),
|
||||
layers_per_block: int = 2,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -172,8 +172,8 @@ class AutoencoderKLTemporalDecoder(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",),
|
||||
block_out_channels: Tuple[int, ...] = (64,),
|
||||
down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
|
||||
block_out_channels: Tuple[int] = (64,),
|
||||
layers_per_block: int = 1,
|
||||
latent_channels: int = 4,
|
||||
sample_size: int = 32,
|
||||
|
||||
@@ -971,7 +971,7 @@ class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMo
|
||||
base_dim: int = 96,
|
||||
decoder_base_dim: Optional[int] = None,
|
||||
z_dim: int = 16,
|
||||
dim_mult: Tuple[int, ...] = (1, 2, 4, 4),
|
||||
dim_mult: Tuple[int] = [1, 2, 4, 4],
|
||||
num_res_blocks: int = 2,
|
||||
attn_scales: List[float] = [],
|
||||
temperal_downsample: List[bool] = [False, True, True],
|
||||
|
||||
@@ -293,14 +293,14 @@ class ControlNetXSAdapter(ModelMixin, ConfigMixin):
|
||||
self,
|
||||
conditioning_channels: int = 3,
|
||||
conditioning_channel_order: str = "rgb",
|
||||
conditioning_embedding_out_channels: Tuple[int, ...] = (16, 32, 96, 256),
|
||||
conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256),
|
||||
time_embedding_mix: float = 1.0,
|
||||
learn_time_embedding: bool = False,
|
||||
num_attention_heads: Union[int, Tuple[int]] = 4,
|
||||
block_out_channels: Tuple[int, ...] = (4, 8, 16, 16),
|
||||
base_block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
||||
block_out_channels: Tuple[int] = (4, 8, 16, 16),
|
||||
base_block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
||||
cross_attention_dim: int = 1024,
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
down_block_types: Tuple[str] = (
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
@@ -436,7 +436,7 @@ class ControlNetXSAdapter(ModelMixin, ConfigMixin):
|
||||
time_embedding_mix: int = 1.0,
|
||||
conditioning_channels: int = 3,
|
||||
conditioning_channel_order: str = "rgb",
|
||||
conditioning_embedding_out_channels: Tuple[int, ...] = (16, 32, 96, 256),
|
||||
conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256),
|
||||
):
|
||||
r"""
|
||||
Instantiate a [`ControlNetXSAdapter`] from a [`UNet2DConditionModel`].
|
||||
@@ -529,19 +529,14 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
|
||||
self,
|
||||
# unet configs
|
||||
sample_size: Optional[int] = 96,
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
down_block_types: Tuple[str] = (
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"DownBlock2D",
|
||||
),
|
||||
up_block_types: Tuple[str, ...] = (
|
||||
"UpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
),
|
||||
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
||||
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
|
||||
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
||||
norm_num_groups: Optional[int] = 32,
|
||||
cross_attention_dim: Union[int, Tuple[int]] = 1024,
|
||||
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
|
||||
@@ -555,10 +550,10 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
|
||||
# additional controlnet configs
|
||||
time_embedding_mix: float = 1.0,
|
||||
ctrl_conditioning_channels: int = 3,
|
||||
ctrl_conditioning_embedding_out_channels: Tuple[int, ...] = (16, 32, 96, 256),
|
||||
ctrl_conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256),
|
||||
ctrl_conditioning_channel_order: str = "rgb",
|
||||
ctrl_learn_time_embedding: bool = False,
|
||||
ctrl_block_out_channels: Tuple[int, ...] = (4, 8, 16, 16),
|
||||
ctrl_block_out_channels: Tuple[int] = (4, 8, 16, 16),
|
||||
ctrl_num_attention_heads: Union[int, Tuple[int]] = 4,
|
||||
ctrl_max_norm_num_groups: int = 32,
|
||||
):
|
||||
|
||||
@@ -1484,71 +1484,59 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
config: Union[ParallelConfig, ContextParallelConfig],
|
||||
cp_plan: Optional[Dict[str, ContextParallelModelPlan]] = None,
|
||||
):
|
||||
from ..hooks.context_parallel import apply_context_parallel
|
||||
from .attention import AttentionModuleMixin
|
||||
from .attention_processor import Attention, MochiAttention
|
||||
|
||||
logger.warning(
|
||||
"`enable_parallelism` is an experimental feature. The API may change in the future and breaking changes may be introduced at any time without warning."
|
||||
)
|
||||
|
||||
if not torch.distributed.is_available() and not torch.distributed.is_initialized():
|
||||
raise RuntimeError(
|
||||
"torch.distributed must be available and initialized before calling `enable_parallelism`."
|
||||
)
|
||||
|
||||
from ..hooks.context_parallel import apply_context_parallel
|
||||
from .attention import AttentionModuleMixin
|
||||
from .attention_dispatch import AttentionBackendName, _AttentionBackendRegistry
|
||||
from .attention_processor import Attention, MochiAttention
|
||||
|
||||
if isinstance(config, ContextParallelConfig):
|
||||
config = ParallelConfig(context_parallel_config=config)
|
||||
|
||||
if not torch.distributed.is_initialized():
|
||||
raise RuntimeError("torch.distributed must be initialized before calling `enable_parallelism`.")
|
||||
|
||||
rank = torch.distributed.get_rank()
|
||||
world_size = torch.distributed.get_world_size()
|
||||
device_type = torch._C._get_accelerator().type
|
||||
device_module = torch.get_device_module(device_type)
|
||||
device = torch.device(device_type, rank % device_module.device_count())
|
||||
|
||||
attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
|
||||
|
||||
if config.context_parallel_config is not None:
|
||||
for module in self.modules():
|
||||
if not isinstance(module, attention_classes):
|
||||
continue
|
||||
|
||||
processor = module.processor
|
||||
if processor is None or not hasattr(processor, "_attention_backend"):
|
||||
continue
|
||||
|
||||
attention_backend = processor._attention_backend
|
||||
if attention_backend is None:
|
||||
attention_backend, _ = _AttentionBackendRegistry.get_active_backend()
|
||||
else:
|
||||
attention_backend = AttentionBackendName(attention_backend)
|
||||
|
||||
if not _AttentionBackendRegistry._is_context_parallel_available(attention_backend):
|
||||
compatible_backends = sorted(_AttentionBackendRegistry._supports_context_parallel)
|
||||
raise ValueError(
|
||||
f"Context parallelism is enabled but the attention processor '{processor.__class__.__name__}' "
|
||||
f"is using backend '{attention_backend.value}' which does not support context parallelism. "
|
||||
f"Please set a compatible attention backend: {compatible_backends} using `model.set_attention_backend()` before "
|
||||
f"calling `enable_parallelism()`."
|
||||
)
|
||||
|
||||
# All modules use the same attention processor and backend. We don't need to
|
||||
# iterate over all modules after checking the first processor
|
||||
break
|
||||
|
||||
mesh = None
|
||||
cp_mesh = None
|
||||
if config.context_parallel_config is not None:
|
||||
cp_config = config.context_parallel_config
|
||||
mesh = torch.distributed.device_mesh.init_device_mesh(
|
||||
if cp_config.ring_degree < 1 or cp_config.ulysses_degree < 1:
|
||||
raise ValueError("`ring_degree` and `ulysses_degree` must be greater than or equal to 1.")
|
||||
if cp_config.ring_degree > 1 and cp_config.ulysses_degree > 1:
|
||||
raise ValueError(
|
||||
"Unified Ulysses-Ring attention is not yet supported. Please set either `ring_degree` or `ulysses_degree` to 1."
|
||||
)
|
||||
if cp_config.ring_degree * cp_config.ulysses_degree > world_size:
|
||||
raise ValueError(
|
||||
f"The product of `ring_degree` ({cp_config.ring_degree}) and `ulysses_degree` ({cp_config.ulysses_degree}) must not exceed the world size ({world_size})."
|
||||
)
|
||||
cp_mesh = torch.distributed.device_mesh.init_device_mesh(
|
||||
device_type=device_type,
|
||||
mesh_shape=cp_config.mesh_shape,
|
||||
mesh_dim_names=cp_config.mesh_dim_names,
|
||||
mesh_shape=(cp_config.ring_degree, cp_config.ulysses_degree),
|
||||
mesh_dim_names=("ring", "ulysses"),
|
||||
)
|
||||
|
||||
config.setup(rank, world_size, device, mesh=mesh)
|
||||
config.setup(rank, world_size, device, cp_mesh=cp_mesh)
|
||||
|
||||
if cp_plan is None and self._cp_plan is None:
|
||||
raise ValueError(
|
||||
"`cp_plan` must be provided either as an argument or set in the model's `_cp_plan` attribute."
|
||||
)
|
||||
cp_plan = cp_plan if cp_plan is not None else self._cp_plan
|
||||
|
||||
if config.context_parallel_config is not None:
|
||||
apply_context_parallel(self, config.context_parallel_config, cp_plan)
|
||||
|
||||
self._parallel_config = config
|
||||
|
||||
attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
|
||||
for module in self.modules():
|
||||
if not isinstance(module, attention_classes):
|
||||
continue
|
||||
@@ -1557,14 +1545,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
continue
|
||||
processor._parallel_config = config
|
||||
|
||||
if config.context_parallel_config is not None:
|
||||
if cp_plan is None and self._cp_plan is None:
|
||||
raise ValueError(
|
||||
"`cp_plan` must be provided either as an argument or set in the model's `_cp_plan` attribute."
|
||||
)
|
||||
cp_plan = cp_plan if cp_plan is not None else self._cp_plan
|
||||
apply_context_parallel(self, config.context_parallel_config, cp_plan)
|
||||
|
||||
@classmethod
|
||||
def _load_pretrained_model(
|
||||
cls,
|
||||
|
||||
@@ -914,7 +914,7 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
|
||||
text_embed_dim: int = 4096,
|
||||
pooled_projection_dim: int = 768,
|
||||
rope_theta: float = 256.0,
|
||||
rope_axes_dim: Tuple[int, ...] = (16, 56, 56),
|
||||
rope_axes_dim: Tuple[int] = (16, 56, 56),
|
||||
image_condition_type: Optional[str] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -139,7 +139,7 @@ class HunyuanVideoFramepackTransformer3DModel(
|
||||
text_embed_dim: int = 4096,
|
||||
pooled_projection_dim: int = 768,
|
||||
rope_theta: float = 256.0,
|
||||
rope_axes_dim: Tuple[int, ...] = (16, 56, 56),
|
||||
rope_axes_dim: Tuple[int] = (16, 56, 56),
|
||||
image_condition_type: Optional[str] = None,
|
||||
has_image_proj: int = False,
|
||||
image_proj_dim: int = 1152,
|
||||
|
||||
@@ -689,7 +689,7 @@ class HunyuanImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
|
||||
text_embed_dim: int = 3584,
|
||||
text_embed_2_dim: Optional[int] = None,
|
||||
rope_theta: float = 256.0,
|
||||
rope_axes_dim: Tuple[int, ...] = (64, 64),
|
||||
rope_axes_dim: Tuple[int] = (64, 64),
|
||||
use_meanflow: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -172,6 +172,7 @@ class SanaLinearAttnProcessor3_0:
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from diffusers.models.transformers.transformer_wan.WanRotaryPosEmbed
|
||||
class WanRotaryPosEmbed(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -389,10 +389,6 @@ class SkyReelsV2RotaryPosEmbed(nn.Module):
|
||||
t_dim = attention_head_dim - h_dim - w_dim
|
||||
freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
|
||||
|
||||
self.t_dim = t_dim
|
||||
self.h_dim = h_dim
|
||||
self.w_dim = w_dim
|
||||
|
||||
freqs_cos = []
|
||||
freqs_sin = []
|
||||
|
||||
@@ -416,7 +412,11 @@ class SkyReelsV2RotaryPosEmbed(nn.Module):
|
||||
p_t, p_h, p_w = self.patch_size
|
||||
ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
|
||||
|
||||
split_sizes = [self.t_dim, self.h_dim, self.w_dim]
|
||||
split_sizes = [
|
||||
self.attention_head_dim - 2 * (self.attention_head_dim // 3),
|
||||
self.attention_head_dim // 3,
|
||||
self.attention_head_dim // 3,
|
||||
]
|
||||
|
||||
freqs_cos = self.freqs_cos.split(split_sizes, dim=1)
|
||||
freqs_sin = self.freqs_sin.split(split_sizes, dim=1)
|
||||
@@ -570,7 +570,7 @@ class SkyReelsV2Transformer3DModel(
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: Tuple[int, ...] = (1, 2, 2),
|
||||
patch_size: Tuple[int] = (1, 2, 2),
|
||||
num_attention_heads: int = 16,
|
||||
attention_head_dim: int = 128,
|
||||
in_channels: int = 16,
|
||||
|
||||
@@ -362,11 +362,6 @@ class WanRotaryPosEmbed(nn.Module):
|
||||
|
||||
h_dim = w_dim = 2 * (attention_head_dim // 6)
|
||||
t_dim = attention_head_dim - h_dim - w_dim
|
||||
|
||||
self.t_dim = t_dim
|
||||
self.h_dim = h_dim
|
||||
self.w_dim = w_dim
|
||||
|
||||
freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
|
||||
|
||||
freqs_cos = []
|
||||
@@ -392,7 +387,11 @@ class WanRotaryPosEmbed(nn.Module):
|
||||
p_t, p_h, p_w = self.patch_size
|
||||
ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
|
||||
|
||||
split_sizes = [self.t_dim, self.h_dim, self.w_dim]
|
||||
split_sizes = [
|
||||
self.attention_head_dim - 2 * (self.attention_head_dim // 3),
|
||||
self.attention_head_dim // 3,
|
||||
self.attention_head_dim // 3,
|
||||
]
|
||||
|
||||
freqs_cos = self.freqs_cos.split(split_sizes, dim=1)
|
||||
freqs_sin = self.freqs_sin.split(split_sizes, dim=1)
|
||||
@@ -564,7 +563,7 @@ class WanTransformer3DModel(
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: Tuple[int, ...] = (1, 2, 2),
|
||||
patch_size: Tuple[int] = (1, 2, 2),
|
||||
num_attention_heads: int = 40,
|
||||
attention_head_dim: int = 128,
|
||||
in_channels: int = 16,
|
||||
|
||||
@@ -182,7 +182,7 @@ class WanVACETransformer3DModel(
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: Tuple[int, ...] = (1, 2, 2),
|
||||
patch_size: Tuple[int] = (1, 2, 2),
|
||||
num_attention_heads: int = 40,
|
||||
attention_head_dim: int = 128,
|
||||
in_channels: int = 16,
|
||||
|
||||
@@ -86,11 +86,11 @@ class UNet1DModel(ModelMixin, ConfigMixin):
|
||||
flip_sin_to_cos: bool = True,
|
||||
use_timestep_embedding: bool = False,
|
||||
freq_shift: float = 0.0,
|
||||
down_block_types: Tuple[str, ...] = ("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D"),
|
||||
up_block_types: Tuple[str, ...] = ("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip"),
|
||||
mid_block_type: str = "UNetMidBlock1D",
|
||||
down_block_types: Tuple[str] = ("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D"),
|
||||
up_block_types: Tuple[str] = ("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip"),
|
||||
mid_block_type: Tuple[str] = "UNetMidBlock1D",
|
||||
out_block_type: str = None,
|
||||
block_out_channels: Tuple[int, ...] = (32, 32, 64),
|
||||
block_out_channels: Tuple[int] = (32, 32, 64),
|
||||
act_fn: str = None,
|
||||
norm_num_groups: int = 8,
|
||||
layers_per_block: int = 1,
|
||||
|
||||
@@ -177,21 +177,16 @@ class UNet2DConditionModel(
|
||||
center_input_sample: bool = False,
|
||||
flip_sin_to_cos: bool = True,
|
||||
freq_shift: int = 0,
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
down_block_types: Tuple[str] = (
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"DownBlock2D",
|
||||
),
|
||||
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
|
||||
up_block_types: Tuple[str, ...] = (
|
||||
"UpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
),
|
||||
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
|
||||
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
||||
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
||||
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
||||
layers_per_block: Union[int, Tuple[int]] = 2,
|
||||
downsample_padding: int = 1,
|
||||
mid_block_scale_factor: float = 1,
|
||||
@@ -491,10 +486,10 @@ class UNet2DConditionModel(
|
||||
|
||||
def _check_config(
|
||||
self,
|
||||
down_block_types: Tuple[str, ...],
|
||||
up_block_types: Tuple[str, ...],
|
||||
down_block_types: Tuple[str],
|
||||
up_block_types: Tuple[str],
|
||||
only_cross_attention: Union[bool, Tuple[bool]],
|
||||
block_out_channels: Tuple[int, ...],
|
||||
block_out_channels: Tuple[int],
|
||||
layers_per_block: Union[int, Tuple[int]],
|
||||
cross_attention_dim: Union[int, Tuple[int]],
|
||||
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple[int]]],
|
||||
|
||||
@@ -54,7 +54,7 @@ class Kandinsky3UNet(ModelMixin, ConfigMixin):
|
||||
groups: int = 32,
|
||||
attention_head_dim: int = 64,
|
||||
layers_per_block: Union[int, Tuple[int]] = 3,
|
||||
block_out_channels: Tuple[int, ...] = (384, 768, 1536, 3072),
|
||||
block_out_channels: Tuple[int] = (384, 768, 1536, 3072),
|
||||
cross_attention_dim: Union[int, Tuple[int]] = 4096,
|
||||
encoder_hid_dim: int = 4096,
|
||||
):
|
||||
|
||||
@@ -73,25 +73,25 @@ class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionL
|
||||
sample_size: Optional[int] = None,
|
||||
in_channels: int = 8,
|
||||
out_channels: int = 4,
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
down_block_types: Tuple[str] = (
|
||||
"CrossAttnDownBlockSpatioTemporal",
|
||||
"CrossAttnDownBlockSpatioTemporal",
|
||||
"CrossAttnDownBlockSpatioTemporal",
|
||||
"DownBlockSpatioTemporal",
|
||||
),
|
||||
up_block_types: Tuple[str, ...] = (
|
||||
up_block_types: Tuple[str] = (
|
||||
"UpBlockSpatioTemporal",
|
||||
"CrossAttnUpBlockSpatioTemporal",
|
||||
"CrossAttnUpBlockSpatioTemporal",
|
||||
"CrossAttnUpBlockSpatioTemporal",
|
||||
),
|
||||
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
||||
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
||||
addition_time_embed_dim: int = 256,
|
||||
projection_class_embeddings_input_dim: int = 768,
|
||||
layers_per_block: Union[int, Tuple[int]] = 2,
|
||||
cross_attention_dim: Union[int, Tuple[int]] = 1024,
|
||||
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
|
||||
num_attention_heads: Union[int, Tuple[int, ...]] = (5, 10, 20, 20),
|
||||
num_attention_heads: Union[int, Tuple[int]] = (5, 10, 20, 20),
|
||||
num_frames: int = 25,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -145,10 +145,10 @@ class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
timestep_ratio_embedding_dim: int = 64,
|
||||
patch_size: int = 1,
|
||||
conditioning_dim: int = 2048,
|
||||
block_out_channels: Tuple[int, ...] = (2048, 2048),
|
||||
num_attention_heads: Tuple[int, ...] = (32, 32),
|
||||
down_num_layers_per_block: Tuple[int, ...] = (8, 24),
|
||||
up_num_layers_per_block: Tuple[int, ...] = (24, 8),
|
||||
block_out_channels: Tuple[int] = (2048, 2048),
|
||||
num_attention_heads: Tuple[int] = (32, 32),
|
||||
down_num_layers_per_block: Tuple[int] = (8, 24),
|
||||
up_num_layers_per_block: Tuple[int] = (24, 8),
|
||||
down_blocks_repeat_mappers: Optional[Tuple[int]] = (
|
||||
1,
|
||||
1,
|
||||
@@ -167,7 +167,7 @@ class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
kernel_size=3,
|
||||
dropout: Union[float, Tuple[float]] = (0.1, 0.1),
|
||||
self_attn: Union[bool, Tuple[bool]] = True,
|
||||
timestep_conditioning_type: Tuple[str, ...] = ("sca", "crp"),
|
||||
timestep_conditioning_type: Tuple[str] = ("sca", "crp"),
|
||||
switch_level: Optional[Tuple[bool]] = None,
|
||||
):
|
||||
"""
|
||||
|
||||
@@ -532,8 +532,8 @@ class FlaxEncoder(nn.Module):
|
||||
|
||||
in_channels: int = 3
|
||||
out_channels: int = 3
|
||||
down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",)
|
||||
block_out_channels: Tuple[int, ...] = (64,)
|
||||
down_block_types: Tuple[str] = ("DownEncoderBlock2D",)
|
||||
block_out_channels: Tuple[int] = (64,)
|
||||
layers_per_block: int = 2
|
||||
norm_num_groups: int = 32
|
||||
act_fn: str = "silu"
|
||||
@@ -650,8 +650,8 @@ class FlaxDecoder(nn.Module):
|
||||
|
||||
in_channels: int = 3
|
||||
out_channels: int = 3
|
||||
up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",)
|
||||
block_out_channels: Tuple[int, ...] = (64,)
|
||||
up_block_types: Tuple[str] = ("UpDecoderBlock2D",)
|
||||
block_out_channels: int = (64,)
|
||||
layers_per_block: int = 2
|
||||
norm_num_groups: int = 32
|
||||
act_fn: str = "silu"
|
||||
@@ -823,9 +823,9 @@ class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
|
||||
in_channels: int = 3
|
||||
out_channels: int = 3
|
||||
down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",)
|
||||
up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",)
|
||||
block_out_channels: Tuple[int, ...] = (64,)
|
||||
down_block_types: Tuple[str] = ("DownEncoderBlock2D",)
|
||||
up_block_types: Tuple[str] = ("UpDecoderBlock2D",)
|
||||
block_out_channels: Tuple[int] = (64,)
|
||||
layers_per_block: int = 1
|
||||
act_fn: str = "silu"
|
||||
latent_channels: int = 4
|
||||
|
||||
@@ -45,7 +45,7 @@ else:
|
||||
"InsertableDict",
|
||||
]
|
||||
_import_structure["stable_diffusion_xl"] = ["StableDiffusionXLAutoBlocks", "StableDiffusionXLModularPipeline"]
|
||||
_import_structure["wan"] = ["WanAutoBlocks", "Wan22AutoBlocks", "WanModularPipeline"]
|
||||
_import_structure["wan"] = ["WanAutoBlocks", "WanModularPipeline"]
|
||||
_import_structure["flux"] = [
|
||||
"FluxAutoBlocks",
|
||||
"FluxModularPipeline",
|
||||
@@ -90,7 +90,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
QwenImageModularPipeline,
|
||||
)
|
||||
from .stable_diffusion_xl import StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline
|
||||
from .wan import Wan22AutoBlocks, WanAutoBlocks, WanModularPipeline
|
||||
from .wan import WanAutoBlocks, WanModularPipeline
|
||||
else:
|
||||
import sys
|
||||
|
||||
|
||||
@@ -1441,8 +1441,6 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
|
||||
components_manager: Optional[ComponentsManager] = None,
|
||||
collection: Optional[str] = None,
|
||||
modular_config_dict: Optional[Dict[str, Any]] = None,
|
||||
config_dict: Optional[Dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@@ -1494,8 +1492,23 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
- The pipeline's config dict is also used to store the pipeline blocks's class name, which will be saved as
|
||||
`_blocks_class_name` in the config dict
|
||||
"""
|
||||
if blocks is None:
|
||||
blocks_class_name = self.default_blocks_name
|
||||
if blocks_class_name is not None:
|
||||
diffusers_module = importlib.import_module("diffusers")
|
||||
blocks_class = getattr(diffusers_module, blocks_class_name)
|
||||
blocks = blocks_class()
|
||||
else:
|
||||
logger.warning(f"`blocks` is `None`, no default blocks class found for {self.__class__.__name__}")
|
||||
|
||||
if modular_config_dict is None and config_dict is None and pretrained_model_name_or_path is not None:
|
||||
self.blocks = blocks
|
||||
self._components_manager = components_manager
|
||||
self._collection = collection
|
||||
self._component_specs = {spec.name: deepcopy(spec) for spec in self.blocks.expected_components}
|
||||
self._config_specs = {spec.name: deepcopy(spec) for spec in self.blocks.expected_configs}
|
||||
|
||||
# update component_specs and config_specs from modular_repo
|
||||
if pretrained_model_name_or_path is not None:
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
@@ -1511,59 +1524,52 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
"local_files_only": local_files_only,
|
||||
"revision": revision,
|
||||
}
|
||||
# try to load modular_model_index.json
|
||||
try:
|
||||
config_dict = self.load_config(pretrained_model_name_or_path, **load_config_kwargs)
|
||||
except EnvironmentError as e:
|
||||
logger.debug(f"modular_model_index.json not found: {e}")
|
||||
config_dict = None
|
||||
|
||||
modular_config_dict, config_dict = self._load_pipeline_config(
|
||||
pretrained_model_name_or_path, **load_config_kwargs
|
||||
)
|
||||
# update component_specs and config_specs based on modular_model_index.json
|
||||
if config_dict is not None:
|
||||
for name, value in config_dict.items():
|
||||
# all the components in modular_model_index.json are from_pretrained components
|
||||
if name in self._component_specs and isinstance(value, (tuple, list)) and len(value) == 3:
|
||||
library, class_name, component_spec_dict = value
|
||||
component_spec = self._dict_to_component_spec(name, component_spec_dict)
|
||||
component_spec.default_creation_method = "from_pretrained"
|
||||
self._component_specs[name] = component_spec
|
||||
|
||||
if blocks is None:
|
||||
if modular_config_dict is not None:
|
||||
blocks_class_name = modular_config_dict.get("_blocks_class_name")
|
||||
elif config_dict is not None:
|
||||
blocks_class_name = self.get_default_blocks_name(config_dict)
|
||||
elif name in self._config_specs:
|
||||
self._config_specs[name].default = value
|
||||
|
||||
# if modular_model_index.json is not found, try to load model_index.json
|
||||
else:
|
||||
blocks_class_name = None
|
||||
if blocks_class_name is not None:
|
||||
diffusers_module = importlib.import_module("diffusers")
|
||||
blocks_class = getattr(diffusers_module, blocks_class_name)
|
||||
blocks = blocks_class()
|
||||
else:
|
||||
logger.warning(f"`blocks` is `None`, no default blocks class found for {self.__class__.__name__}")
|
||||
logger.debug(" loading config from model_index.json")
|
||||
try:
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
self.blocks = blocks
|
||||
self._components_manager = components_manager
|
||||
self._collection = collection
|
||||
self._component_specs = {spec.name: deepcopy(spec) for spec in self.blocks.expected_components}
|
||||
self._config_specs = {spec.name: deepcopy(spec) for spec in self.blocks.expected_configs}
|
||||
config_dict = DiffusionPipeline.load_config(pretrained_model_name_or_path, **load_config_kwargs)
|
||||
except EnvironmentError as e:
|
||||
logger.debug(f" model_index.json not found in the repo: {e}")
|
||||
config_dict = None
|
||||
|
||||
# update component_specs and config_specs based on modular_model_index.json
|
||||
if modular_config_dict is not None:
|
||||
for name, value in modular_config_dict.items():
|
||||
# all the components in modular_model_index.json are from_pretrained components
|
||||
if name in self._component_specs and isinstance(value, (tuple, list)) and len(value) == 3:
|
||||
library, class_name, component_spec_dict = value
|
||||
component_spec = self._dict_to_component_spec(name, component_spec_dict)
|
||||
component_spec.default_creation_method = "from_pretrained"
|
||||
self._component_specs[name] = component_spec
|
||||
|
||||
elif name in self._config_specs:
|
||||
self._config_specs[name].default = value
|
||||
|
||||
# if `modular_config_dict` is None (i.e. `modular_model_index.json` is not found), update based on `config_dict` (i.e. `model_index.json`)
|
||||
elif config_dict is not None:
|
||||
for name, value in config_dict.items():
|
||||
if name in self._component_specs and isinstance(value, (tuple, list)) and len(value) == 2:
|
||||
library, class_name = value
|
||||
component_spec_dict = {
|
||||
"repo": pretrained_model_name_or_path,
|
||||
"subfolder": name,
|
||||
"type_hint": (library, class_name),
|
||||
}
|
||||
component_spec = self._dict_to_component_spec(name, component_spec_dict)
|
||||
component_spec.default_creation_method = "from_pretrained"
|
||||
self._component_specs[name] = component_spec
|
||||
elif name in self._config_specs:
|
||||
self._config_specs[name].default = value
|
||||
# update component_specs and config_specs based on model_index.json
|
||||
if config_dict is not None:
|
||||
for name, value in config_dict.items():
|
||||
if name in self._component_specs and isinstance(value, (tuple, list)) and len(value) == 2:
|
||||
library, class_name = value
|
||||
component_spec_dict = {
|
||||
"repo": pretrained_model_name_or_path,
|
||||
"subfolder": name,
|
||||
"type_hint": (library, class_name),
|
||||
}
|
||||
component_spec = self._dict_to_component_spec(name, component_spec_dict)
|
||||
component_spec.default_creation_method = "from_pretrained"
|
||||
self._component_specs[name] = component_spec
|
||||
elif name in self._config_specs:
|
||||
self._config_specs[name].default = value
|
||||
|
||||
if len(kwargs) > 0:
|
||||
logger.warning(f"Unexpected input '{kwargs.keys()}' provided. This input will be ignored.")
|
||||
@@ -1595,35 +1601,6 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
params[input_param.name] = input_param.default
|
||||
return params
|
||||
|
||||
def get_default_blocks_name(self, config_dict: Optional[Dict[str, Any]]) -> Optional[str]:
|
||||
return self.default_blocks_name
|
||||
|
||||
@classmethod
|
||||
def _load_pipeline_config(
|
||||
cls,
|
||||
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
|
||||
**load_config_kwargs,
|
||||
):
|
||||
try:
|
||||
# try to load modular_model_index.json
|
||||
modular_config_dict = cls.load_config(pretrained_model_name_or_path, **load_config_kwargs)
|
||||
return modular_config_dict, None
|
||||
|
||||
except EnvironmentError as e:
|
||||
logger.debug(f" modular_model_index.json not found in the repo: {e}")
|
||||
|
||||
try:
|
||||
logger.debug(" try to load model_index.json")
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
config_dict = DiffusionPipeline.load_config(pretrained_model_name_or_path, **load_config_kwargs)
|
||||
return None, config_dict
|
||||
|
||||
except EnvironmentError as e:
|
||||
logger.debug(f" model_index.json not found in the repo: {e}")
|
||||
|
||||
return None, None
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
def from_pretrained(
|
||||
@@ -1678,33 +1655,42 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
"revision": revision,
|
||||
}
|
||||
|
||||
modular_config_dict, config_dict = cls._load_pipeline_config(
|
||||
pretrained_model_name_or_path, **load_config_kwargs
|
||||
)
|
||||
try:
|
||||
# try to load modular_model_index.json
|
||||
config_dict = cls.load_config(pretrained_model_name_or_path, **load_config_kwargs)
|
||||
except EnvironmentError as e:
|
||||
logger.debug(f" modular_model_index.json not found in the repo: {e}")
|
||||
config_dict = None
|
||||
|
||||
if modular_config_dict is not None:
|
||||
pipeline_class = _get_pipeline_class(cls, config=modular_config_dict)
|
||||
elif config_dict is not None:
|
||||
from diffusers.pipelines.auto_pipeline import _get_model
|
||||
|
||||
logger.debug(" try to determine the modular pipeline class from model_index.json")
|
||||
standard_pipeline_class = _get_pipeline_class(cls, config=config_dict)
|
||||
model_name = _get_model(standard_pipeline_class.__name__)
|
||||
pipeline_class_name = MODULAR_PIPELINE_MAPPING.get(model_name, ModularPipeline.__name__)
|
||||
diffusers_module = importlib.import_module("diffusers")
|
||||
pipeline_class = getattr(diffusers_module, pipeline_class_name)
|
||||
if config_dict is not None:
|
||||
pipeline_class = _get_pipeline_class(cls, config=config_dict)
|
||||
else:
|
||||
# there is no config for modular pipeline, assuming that the pipeline block does not need any from_pretrained components
|
||||
pipeline_class = cls
|
||||
pretrained_model_name_or_path = None
|
||||
try:
|
||||
logger.debug(" try to load model_index.json")
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers.pipelines.auto_pipeline import _get_model
|
||||
|
||||
config_dict = DiffusionPipeline.load_config(pretrained_model_name_or_path, **load_config_kwargs)
|
||||
except EnvironmentError as e:
|
||||
logger.debug(f" model_index.json not found in the repo: {e}")
|
||||
|
||||
if config_dict is not None:
|
||||
logger.debug(" try to determine the modular pipeline class from model_index.json")
|
||||
standard_pipeline_class = _get_pipeline_class(cls, config=config_dict)
|
||||
model_name = _get_model(standard_pipeline_class.__name__)
|
||||
pipeline_class_name = MODULAR_PIPELINE_MAPPING.get(model_name, ModularPipeline.__name__)
|
||||
diffusers_module = importlib.import_module("diffusers")
|
||||
pipeline_class = getattr(diffusers_module, pipeline_class_name)
|
||||
else:
|
||||
# there is no config for modular pipeline, assuming that the pipeline block does not need any from_pretrained components
|
||||
pipeline_class = cls
|
||||
pretrained_model_name_or_path = None
|
||||
|
||||
pipeline = pipeline_class(
|
||||
blocks=blocks,
|
||||
pretrained_model_name_or_path=pretrained_model_name_or_path,
|
||||
components_manager=components_manager,
|
||||
collection=collection,
|
||||
modular_config_dict=modular_config_dict,
|
||||
config_dict=config_dict,
|
||||
**kwargs,
|
||||
)
|
||||
return pipeline
|
||||
@@ -2148,9 +2134,7 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
logger.warning(
|
||||
f"\nFailed to create component {name}:\n"
|
||||
f"- Component spec: {spec}\n"
|
||||
f"- load() called with kwargs: {component_load_kwargs}\n"
|
||||
"If this component is not required for your workflow you can safely ignore this message.\n\n"
|
||||
"Traceback:\n"
|
||||
f"- load() called with kwargs: {component_load_kwargs}\n\n"
|
||||
f"{traceback.format_exc()}"
|
||||
)
|
||||
|
||||
|
||||
@@ -21,14 +21,16 @@ except OptionalDependencyNotAvailable:
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["decoders"] = ["WanImageVaeDecoderStep"]
|
||||
_import_structure["encoders"] = ["WanTextEncoderStep"]
|
||||
_import_structure["modular_blocks"] = [
|
||||
"ALL_BLOCKS",
|
||||
"Wan22AutoBlocks",
|
||||
"AUTO_BLOCKS",
|
||||
"TEXT2VIDEO_BLOCKS",
|
||||
"WanAutoBeforeDenoiseStep",
|
||||
"WanAutoBlocks",
|
||||
"WanAutoImageEncoderStep",
|
||||
"WanAutoVaeImageEncoderStep",
|
||||
"WanAutoBlocks",
|
||||
"WanAutoDecodeStep",
|
||||
"WanAutoDenoiseStep",
|
||||
]
|
||||
_import_structure["modular_pipeline"] = ["WanModularPipeline"]
|
||||
|
||||
@@ -39,14 +41,15 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
|
||||
else:
|
||||
from .decoders import WanImageVaeDecoderStep
|
||||
from .encoders import WanTextEncoderStep
|
||||
from .modular_blocks import (
|
||||
ALL_BLOCKS,
|
||||
Wan22AutoBlocks,
|
||||
AUTO_BLOCKS,
|
||||
TEXT2VIDEO_BLOCKS,
|
||||
WanAutoBeforeDenoiseStep,
|
||||
WanAutoBlocks,
|
||||
WanAutoImageEncoderStep,
|
||||
WanAutoVaeImageEncoderStep,
|
||||
WanAutoDecodeStep,
|
||||
WanAutoDenoiseStep,
|
||||
)
|
||||
from .modular_pipeline import WanModularPipeline
|
||||
else:
|
||||
|
||||
@@ -13,11 +13,10 @@
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ...models import WanTransformer3DModel
|
||||
from ...schedulers import UniPCMultistepScheduler
|
||||
from ...utils import logging
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
@@ -35,97 +34,6 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
# configuration of guider is.
|
||||
|
||||
|
||||
def repeat_tensor_to_batch_size(
|
||||
input_name: str,
|
||||
input_tensor: torch.Tensor,
|
||||
batch_size: int,
|
||||
num_videos_per_prompt: int = 1,
|
||||
) -> torch.Tensor:
|
||||
"""Repeat tensor elements to match the final batch size.
|
||||
|
||||
This function expands a tensor's batch dimension to match the final batch size (batch_size * num_videos_per_prompt)
|
||||
by repeating each element along dimension 0.
|
||||
|
||||
The input tensor must have batch size 1 or batch_size. The function will:
|
||||
- If batch size is 1: repeat each element (batch_size * num_videos_per_prompt) times
|
||||
- If batch size equals batch_size: repeat each element num_videos_per_prompt times
|
||||
|
||||
Args:
|
||||
input_name (str): Name of the input tensor (used for error messages)
|
||||
input_tensor (torch.Tensor): The tensor to repeat. Must have batch size 1 or batch_size.
|
||||
batch_size (int): The base batch size (number of prompts)
|
||||
num_videos_per_prompt (int, optional): Number of videos to generate per prompt. Defaults to 1.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The repeated tensor with final batch size (batch_size * num_videos_per_prompt)
|
||||
|
||||
Raises:
|
||||
ValueError: If input_tensor is not a torch.Tensor or has invalid batch size
|
||||
|
||||
Examples:
|
||||
tensor = torch.tensor([[1, 2, 3]]) # shape: [1, 3] repeated = repeat_tensor_to_batch_size("image", tensor,
|
||||
batch_size=2, num_videos_per_prompt=2) repeated # tensor([[1, 2, 3], [1, 2, 3], [1, 2, 3], [1, 2, 3]]) - shape:
|
||||
[4, 3]
|
||||
|
||||
tensor = torch.tensor([[1, 2, 3], [4, 5, 6]]) # shape: [2, 3] repeated = repeat_tensor_to_batch_size("image",
|
||||
tensor, batch_size=2, num_videos_per_prompt=2) repeated # tensor([[1, 2, 3], [1, 2, 3], [4, 5, 6], [4, 5, 6]])
|
||||
- shape: [4, 3]
|
||||
"""
|
||||
# make sure input is a tensor
|
||||
if not isinstance(input_tensor, torch.Tensor):
|
||||
raise ValueError(f"`{input_name}` must be a tensor")
|
||||
|
||||
# make sure input tensor e.g. image_latents has batch size 1 or batch_size same as prompts
|
||||
if input_tensor.shape[0] == 1:
|
||||
repeat_by = batch_size * num_videos_per_prompt
|
||||
elif input_tensor.shape[0] == batch_size:
|
||||
repeat_by = num_videos_per_prompt
|
||||
else:
|
||||
raise ValueError(
|
||||
f"`{input_name}` must have have batch size 1 or {batch_size}, but got {input_tensor.shape[0]}"
|
||||
)
|
||||
|
||||
# expand the tensor to match the batch_size * num_videos_per_prompt
|
||||
input_tensor = input_tensor.repeat_interleave(repeat_by, dim=0)
|
||||
|
||||
return input_tensor
|
||||
|
||||
|
||||
def calculate_dimension_from_latents(
|
||||
latents: torch.Tensor, vae_scale_factor_temporal: int, vae_scale_factor_spatial: int
|
||||
) -> Tuple[int, int]:
|
||||
"""Calculate image dimensions from latent tensor dimensions.
|
||||
|
||||
This function converts latent temporal and spatial dimensions to image temporal and spatial dimensions by
|
||||
multiplying the latent num_frames/height/width by the VAE scale factor.
|
||||
|
||||
Args:
|
||||
latents (torch.Tensor): The latent tensor. Must have 4 or 5 dimensions.
|
||||
Expected shapes: [batch, channels, height, width] or [batch, channels, frames, height, width]
|
||||
vae_scale_factor_temporal (int): The scale factor used by the VAE to compress temporal dimension.
|
||||
Typically 4 for most VAEs (video is 4x larger than latents in temporal dimension)
|
||||
vae_scale_factor_spatial (int): The scale factor used by the VAE to compress spatial dimension.
|
||||
Typically 8 for most VAEs (image is 8x larger than latents in each dimension)
|
||||
|
||||
Returns:
|
||||
Tuple[int, int]: The calculated image dimensions as (height, width)
|
||||
|
||||
Raises:
|
||||
ValueError: If latents tensor doesn't have 4 or 5 dimensions
|
||||
|
||||
"""
|
||||
if latents.ndim != 5:
|
||||
raise ValueError(f"latents must have 5 dimensions, but got {latents.ndim}")
|
||||
|
||||
_, _, num_latent_frames, latent_height, latent_width = latents.shape
|
||||
|
||||
num_frames = (num_latent_frames - 1) * vae_scale_factor_temporal + 1
|
||||
height = latent_height * vae_scale_factor_spatial
|
||||
width = latent_width * vae_scale_factor_spatial
|
||||
|
||||
return num_frames, height, width
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
@@ -186,7 +94,7 @@ def retrieve_timesteps(
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class WanTextInputStep(ModularPipelineBlocks):
|
||||
class WanInputStep(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@property
|
||||
@@ -200,16 +108,15 @@ class WanTextInputStep(ModularPipelineBlocks):
|
||||
"have a final batch_size of batch_size * num_videos_per_prompt."
|
||||
)
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("transformer", WanTransformer3DModel),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("num_videos_per_prompt", default=1),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam(
|
||||
"prompt_embeds",
|
||||
required=True,
|
||||
@@ -234,7 +141,19 @@ class WanTextInputStep(ModularPipelineBlocks):
|
||||
OutputParam(
|
||||
"dtype",
|
||||
type_hint=torch.dtype,
|
||||
description="Data type of model tensor inputs (determined by `transformer.dtype`)",
|
||||
description="Data type of model tensor inputs (determined by `prompt_embeds`)",
|
||||
),
|
||||
OutputParam(
|
||||
"prompt_embeds",
|
||||
type_hint=torch.Tensor,
|
||||
kwargs_type="denoiser_input_fields", # already in intermedites state but declare here again for denoiser_input_fields
|
||||
description="text embeddings used to guide the image generation",
|
||||
),
|
||||
OutputParam(
|
||||
"negative_prompt_embeds",
|
||||
type_hint=torch.Tensor,
|
||||
kwargs_type="denoiser_input_fields", # already in intermedites state but declare here again for denoiser_input_fields
|
||||
description="negative text embeddings used to guide the image generation",
|
||||
),
|
||||
]
|
||||
|
||||
@@ -275,140 +194,6 @@ class WanTextInputStep(ModularPipelineBlocks):
|
||||
return components, state
|
||||
|
||||
|
||||
class WanAdditionalInputsStep(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_latent_inputs: List[str] = ["first_frame_latents"],
|
||||
additional_batch_inputs: List[str] = [],
|
||||
):
|
||||
"""Initialize a configurable step that standardizes the inputs for the denoising step. It:\n"
|
||||
|
||||
This step handles multiple common tasks to prepare inputs for the denoising step:
|
||||
1. For encoded image latents, use it update height/width if None, and expands batch size
|
||||
2. For additional_batch_inputs: Only expands batch dimensions to match final batch size
|
||||
|
||||
This is a dynamic block that allows you to configure which inputs to process.
|
||||
|
||||
Args:
|
||||
image_latent_inputs (List[str], optional): Names of image latent tensors to process.
|
||||
In additional to adjust batch size of these inputs, they will be used to determine height/width. Can be
|
||||
a single string or list of strings. Defaults to ["first_frame_latents"].
|
||||
additional_batch_inputs (List[str], optional):
|
||||
Names of additional conditional input tensors to expand batch size. These tensors will only have their
|
||||
batch dimensions adjusted to match the final batch size. Can be a single string or list of strings.
|
||||
Defaults to [].
|
||||
|
||||
Examples:
|
||||
# Configure to process first_frame_latents (default behavior) WanAdditionalInputsStep()
|
||||
|
||||
# Configure to process multiple image latent inputs
|
||||
WanAdditionalInputsStep(image_latent_inputs=["first_frame_latents", "last_frame_latents"])
|
||||
|
||||
# Configure to process image latents and additional batch inputs WanAdditionalInputsStep(
|
||||
image_latent_inputs=["first_frame_latents"], additional_batch_inputs=["image_embeds"]
|
||||
)
|
||||
"""
|
||||
if not isinstance(image_latent_inputs, list):
|
||||
image_latent_inputs = [image_latent_inputs]
|
||||
if not isinstance(additional_batch_inputs, list):
|
||||
additional_batch_inputs = [additional_batch_inputs]
|
||||
|
||||
self._image_latent_inputs = image_latent_inputs
|
||||
self._additional_batch_inputs = additional_batch_inputs
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
# Functionality section
|
||||
summary_section = (
|
||||
"Input processing step that:\n"
|
||||
" 1. For image latent inputs: Updates height/width if None, and expands batch size\n"
|
||||
" 2. For additional batch inputs: Expands batch dimensions to match final batch size"
|
||||
)
|
||||
|
||||
# Inputs info
|
||||
inputs_info = ""
|
||||
if self._image_latent_inputs or self._additional_batch_inputs:
|
||||
inputs_info = "\n\nConfigured inputs:"
|
||||
if self._image_latent_inputs:
|
||||
inputs_info += f"\n - Image latent inputs: {self._image_latent_inputs}"
|
||||
if self._additional_batch_inputs:
|
||||
inputs_info += f"\n - Additional batch inputs: {self._additional_batch_inputs}"
|
||||
|
||||
# Placement guidance
|
||||
placement_section = "\n\nThis block should be placed after the encoder steps and the text input step."
|
||||
|
||||
return summary_section + inputs_info + placement_section
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
inputs = [
|
||||
InputParam(name="num_videos_per_prompt", default=1),
|
||||
InputParam(name="batch_size", required=True),
|
||||
InputParam(name="height"),
|
||||
InputParam(name="width"),
|
||||
InputParam(name="num_frames"),
|
||||
]
|
||||
|
||||
# Add image latent inputs
|
||||
for image_latent_input_name in self._image_latent_inputs:
|
||||
inputs.append(InputParam(name=image_latent_input_name))
|
||||
|
||||
# Add additional batch inputs
|
||||
for input_name in self._additional_batch_inputs:
|
||||
inputs.append(InputParam(name=input_name))
|
||||
|
||||
return inputs
|
||||
|
||||
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
# Process image latent inputs (height/width calculation, patchify, and batch expansion)
|
||||
for image_latent_input_name in self._image_latent_inputs:
|
||||
image_latent_tensor = getattr(block_state, image_latent_input_name)
|
||||
if image_latent_tensor is None:
|
||||
continue
|
||||
|
||||
# 1. Calculate num_frames, height/width from latents
|
||||
num_frames, height, width = calculate_dimension_from_latents(
|
||||
image_latent_tensor, components.vae_scale_factor_temporal, components.vae_scale_factor_spatial
|
||||
)
|
||||
block_state.num_frames = block_state.num_frames or num_frames
|
||||
block_state.height = block_state.height or height
|
||||
block_state.width = block_state.width or width
|
||||
|
||||
# 3. Expand batch size
|
||||
image_latent_tensor = repeat_tensor_to_batch_size(
|
||||
input_name=image_latent_input_name,
|
||||
input_tensor=image_latent_tensor,
|
||||
num_videos_per_prompt=block_state.num_videos_per_prompt,
|
||||
batch_size=block_state.batch_size,
|
||||
)
|
||||
|
||||
setattr(block_state, image_latent_input_name, image_latent_tensor)
|
||||
|
||||
# Process additional batch inputs (only batch expansion)
|
||||
for input_name in self._additional_batch_inputs:
|
||||
input_tensor = getattr(block_state, input_name)
|
||||
if input_tensor is None:
|
||||
continue
|
||||
|
||||
# Only expand batch size
|
||||
input_tensor = repeat_tensor_to_batch_size(
|
||||
input_name=input_name,
|
||||
input_tensor=input_tensor,
|
||||
num_videos_per_prompt=block_state.num_videos_per_prompt,
|
||||
batch_size=block_state.batch_size,
|
||||
)
|
||||
|
||||
setattr(block_state, input_name, input_tensor)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class WanSetTimestepsStep(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@@ -430,15 +215,26 @@ class WanSetTimestepsStep(ModularPipelineBlocks):
|
||||
InputParam("sigmas"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"),
|
||||
OutputParam(
|
||||
"num_inference_steps",
|
||||
type_hint=int,
|
||||
description="The number of denoising steps to perform at inference time",
|
||||
),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
device = components._execution_device
|
||||
block_state.device = components._execution_device
|
||||
|
||||
block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps(
|
||||
components.scheduler,
|
||||
block_state.num_inference_steps,
|
||||
device,
|
||||
block_state.device,
|
||||
block_state.timesteps,
|
||||
block_state.sigmas,
|
||||
)
|
||||
@@ -450,6 +246,10 @@ class WanSetTimestepsStep(ModularPipelineBlocks):
|
||||
class WanPrepareLatentsStep(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return []
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Prepare latents step that prepares the latents for the text-to-video generation process"
|
||||
@@ -462,6 +262,11 @@ class WanPrepareLatentsStep(ModularPipelineBlocks):
|
||||
InputParam("num_frames", type_hint=int),
|
||||
InputParam("latents", type_hint=Optional[torch.Tensor]),
|
||||
InputParam("num_videos_per_prompt", type_hint=int, default=1),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("generator"),
|
||||
InputParam(
|
||||
"batch_size",
|
||||
@@ -532,106 +337,29 @@ class WanPrepareLatentsStep(ModularPipelineBlocks):
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
self.check_inputs(components, block_state)
|
||||
|
||||
device = components._execution_device
|
||||
dtype = torch.float32 # Wan latents should be torch.float32 for best quality
|
||||
|
||||
block_state.height = block_state.height or components.default_height
|
||||
block_state.width = block_state.width or components.default_width
|
||||
block_state.num_frames = block_state.num_frames or components.default_num_frames
|
||||
block_state.device = components._execution_device
|
||||
block_state.dtype = torch.float32 # Wan latents should be torch.float32 for best quality
|
||||
block_state.num_channels_latents = components.num_channels_latents
|
||||
|
||||
self.check_inputs(components, block_state)
|
||||
|
||||
block_state.latents = self.prepare_latents(
|
||||
components,
|
||||
batch_size=block_state.batch_size * block_state.num_videos_per_prompt,
|
||||
num_channels_latents=components.num_channels_latents,
|
||||
height=block_state.height,
|
||||
width=block_state.width,
|
||||
num_frames=block_state.num_frames,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
generator=block_state.generator,
|
||||
latents=block_state.latents,
|
||||
block_state.batch_size * block_state.num_videos_per_prompt,
|
||||
block_state.num_channels_latents,
|
||||
block_state.height,
|
||||
block_state.width,
|
||||
block_state.num_frames,
|
||||
block_state.dtype,
|
||||
block_state.device,
|
||||
block_state.generator,
|
||||
block_state.latents,
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
|
||||
|
||||
class WanPrepareFirstFrameLatentsStep(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "step that prepares the masked first frame latents and add it to the latent condition"
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("first_frame_latents", type_hint=Optional[torch.Tensor]),
|
||||
InputParam("num_frames", type_hint=int),
|
||||
]
|
||||
|
||||
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
batch_size, _, _, latent_height, latent_width = block_state.first_frame_latents.shape
|
||||
|
||||
mask_lat_size = torch.ones(batch_size, 1, block_state.num_frames, latent_height, latent_width)
|
||||
mask_lat_size[:, :, list(range(1, block_state.num_frames))] = 0
|
||||
|
||||
first_frame_mask = mask_lat_size[:, :, 0:1]
|
||||
first_frame_mask = torch.repeat_interleave(
|
||||
first_frame_mask, dim=2, repeats=components.vae_scale_factor_temporal
|
||||
)
|
||||
mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2)
|
||||
mask_lat_size = mask_lat_size.view(
|
||||
batch_size, -1, components.vae_scale_factor_temporal, latent_height, latent_width
|
||||
)
|
||||
mask_lat_size = mask_lat_size.transpose(1, 2)
|
||||
mask_lat_size = mask_lat_size.to(block_state.first_frame_latents.device)
|
||||
block_state.first_frame_latents = torch.concat([mask_lat_size, block_state.first_frame_latents], dim=1)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class WanPrepareFirstLastFrameLatentsStep(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "step that prepares the masked latents with first and last frames and add it to the latent condition"
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("first_last_frame_latents", type_hint=Optional[torch.Tensor]),
|
||||
InputParam("num_frames", type_hint=int),
|
||||
]
|
||||
|
||||
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
batch_size, _, _, latent_height, latent_width = block_state.first_last_frame_latents.shape
|
||||
|
||||
mask_lat_size = torch.ones(batch_size, 1, block_state.num_frames, latent_height, latent_width)
|
||||
mask_lat_size[:, :, list(range(1, block_state.num_frames - 1))] = 0
|
||||
|
||||
first_frame_mask = mask_lat_size[:, :, 0:1]
|
||||
first_frame_mask = torch.repeat_interleave(
|
||||
first_frame_mask, dim=2, repeats=components.vae_scale_factor_temporal
|
||||
)
|
||||
mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2)
|
||||
mask_lat_size = mask_lat_size.view(
|
||||
batch_size, -1, components.vae_scale_factor_temporal, latent_height, latent_width
|
||||
)
|
||||
mask_lat_size = mask_lat_size.transpose(1, 2)
|
||||
mask_lat_size = mask_lat_size.to(block_state.first_last_frame_latents.device)
|
||||
block_state.first_last_frame_latents = torch.concat(
|
||||
[mask_lat_size, block_state.first_last_frame_latents], dim=1
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
@@ -29,7 +29,7 @@ from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class WanImageVaeDecoderStep(ModularPipelineBlocks):
|
||||
class WanDecodeStep(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@property
|
||||
@@ -50,6 +50,12 @@ class WanImageVaeDecoderStep(ModularPipelineBlocks):
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return [
|
||||
InputParam("output_type", default="pil"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam(
|
||||
"latents",
|
||||
@@ -74,20 +80,25 @@ class WanImageVaeDecoderStep(ModularPipelineBlocks):
|
||||
block_state = self.get_block_state(state)
|
||||
vae_dtype = components.vae.dtype
|
||||
|
||||
latents = block_state.latents
|
||||
latents_mean = (
|
||||
torch.tensor(components.vae.config.latents_mean)
|
||||
.view(1, components.vae.config.z_dim, 1, 1, 1)
|
||||
.to(latents.device, latents.dtype)
|
||||
)
|
||||
latents_std = 1.0 / torch.tensor(components.vae.config.latents_std).view(
|
||||
1, components.vae.config.z_dim, 1, 1, 1
|
||||
).to(latents.device, latents.dtype)
|
||||
latents = latents / latents_std + latents_mean
|
||||
latents = latents.to(vae_dtype)
|
||||
block_state.videos = components.vae.decode(latents, return_dict=False)[0]
|
||||
if not block_state.output_type == "latent":
|
||||
latents = block_state.latents
|
||||
latents_mean = (
|
||||
torch.tensor(components.vae.config.latents_mean)
|
||||
.view(1, components.vae.config.z_dim, 1, 1, 1)
|
||||
.to(latents.device, latents.dtype)
|
||||
)
|
||||
latents_std = 1.0 / torch.tensor(components.vae.config.latents_std).view(
|
||||
1, components.vae.config.z_dim, 1, 1, 1
|
||||
).to(latents.device, latents.dtype)
|
||||
latents = latents / latents_std + latents_mean
|
||||
latents = latents.to(vae_dtype)
|
||||
block_state.videos = components.vae.decode(latents, return_dict=False)[0]
|
||||
else:
|
||||
block_state.videos = block_state.latents
|
||||
|
||||
block_state.videos = components.video_processor.postprocess_video(block_state.videos, output_type="np")
|
||||
block_state.videos = components.video_processor.postprocess_video(
|
||||
block_state.videos, output_type=block_state.output_type
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Dict, List, Tuple
|
||||
from typing import Any, List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@@ -27,156 +27,16 @@ from ..modular_pipeline import (
|
||||
ModularPipelineBlocks,
|
||||
PipelineState,
|
||||
)
|
||||
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
from .modular_pipeline import WanModularPipeline
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class WanLoopBeforeDenoiser(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"step within the denoising loop that prepares the latent input for the denoiser. "
|
||||
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
|
||||
"object (e.g. `WanDenoiseLoopWrapper`)"
|
||||
)
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
|
||||
),
|
||||
InputParam(
|
||||
"dtype",
|
||||
required=True,
|
||||
type_hint=torch.dtype,
|
||||
description="The dtype of the model inputs. Can be generated in input step.",
|
||||
),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
|
||||
block_state.latent_model_input = block_state.latents.to(block_state.dtype)
|
||||
return components, block_state
|
||||
|
||||
|
||||
class WanImage2VideoLoopBeforeDenoiser(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"step within the denoising loop that prepares the latent input for the denoiser. "
|
||||
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
|
||||
"object (e.g. `WanDenoiseLoopWrapper`)"
|
||||
)
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
|
||||
),
|
||||
InputParam(
|
||||
"first_frame_latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The first frame latents to use for the denoising process. Can be generated in prepare_first_frame_latents step.",
|
||||
),
|
||||
InputParam(
|
||||
"dtype",
|
||||
required=True,
|
||||
type_hint=torch.dtype,
|
||||
description="The dtype of the model inputs. Can be generated in input step.",
|
||||
),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
|
||||
block_state.latent_model_input = torch.cat([block_state.latents, block_state.first_frame_latents], dim=1).to(
|
||||
block_state.dtype
|
||||
)
|
||||
return components, block_state
|
||||
|
||||
|
||||
class WanFLF2VLoopBeforeDenoiser(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"step within the denoising loop that prepares the latent input for the denoiser. "
|
||||
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
|
||||
"object (e.g. `WanDenoiseLoopWrapper`)"
|
||||
)
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
|
||||
),
|
||||
InputParam(
|
||||
"first_last_frame_latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The first and last frame latents to use for the denoising process. Can be generated in prepare_first_last_frame_latents step.",
|
||||
),
|
||||
InputParam(
|
||||
"dtype",
|
||||
required=True,
|
||||
type_hint=torch.dtype,
|
||||
description="The dtype of the model inputs. Can be generated in input step.",
|
||||
),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
|
||||
block_state.latent_model_input = torch.cat(
|
||||
[block_state.latents, block_state.first_last_frame_latents], dim=1
|
||||
).to(block_state.dtype)
|
||||
return components, block_state
|
||||
|
||||
|
||||
class WanLoopDenoiser(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
guider_input_fields: Dict[str, Any] = {"encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds")},
|
||||
):
|
||||
"""Initialize a denoiser block that calls the denoiser model. This block is used in Wan2.1.
|
||||
|
||||
Args:
|
||||
guider_input_fields: A dictionary that maps each argument expected by the denoiser model
|
||||
(for example, "encoder_hidden_states") to data stored on 'block_state'. The value can be either:
|
||||
|
||||
- A tuple of strings. For instance, {"encoder_hidden_states": ("prompt_embeds",
|
||||
"negative_prompt_embeds")} tells the guider to read `block_state.prompt_embeds` and
|
||||
`block_state.negative_prompt_embeds` and pass them as the conditional and unconditional batches of
|
||||
'encoder_hidden_states'.
|
||||
- A string. For example, {"encoder_hidden_image": "image_embeds"} makes the guider forward
|
||||
`block_state.image_embeds` for both conditional and unconditional batches.
|
||||
"""
|
||||
if not isinstance(guider_input_fields, dict):
|
||||
raise ValueError(f"guider_input_fields must be a dictionary but is {type(guider_input_fields)}")
|
||||
self._guider_input_fields = guider_input_fields
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
@@ -199,30 +59,49 @@ class WanLoopDenoiser(ModularPipelineBlocks):
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
inputs = [
|
||||
return [
|
||||
InputParam("attention_kwargs"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
|
||||
),
|
||||
InputParam(
|
||||
"num_inference_steps",
|
||||
required=True,
|
||||
type_hint=int,
|
||||
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
|
||||
),
|
||||
InputParam(
|
||||
kwargs_type="denoiser_input_fields",
|
||||
description=(
|
||||
"All conditional model inputs that need to be prepared with guider. "
|
||||
"It should contain prompt_embeds/negative_prompt_embeds. "
|
||||
"Please add `kwargs_type=denoiser_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state"
|
||||
),
|
||||
),
|
||||
]
|
||||
guider_input_names = []
|
||||
for value in self._guider_input_fields.values():
|
||||
if isinstance(value, tuple):
|
||||
guider_input_names.extend(value)
|
||||
else:
|
||||
guider_input_names.append(value)
|
||||
|
||||
for name in guider_input_names:
|
||||
inputs.append(InputParam(name=name, required=True, type_hint=torch.Tensor))
|
||||
return inputs
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor
|
||||
) -> PipelineState:
|
||||
# Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds)
|
||||
# to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds)
|
||||
guider_inputs = {
|
||||
"prompt_embeds": (
|
||||
getattr(block_state, "prompt_embeds", None),
|
||||
getattr(block_state, "negative_prompt_embeds", None),
|
||||
),
|
||||
}
|
||||
transformer_dtype = components.transformer.dtype
|
||||
|
||||
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
|
||||
|
||||
# The guider splits model inputs into separate batches for conditional/unconditional predictions.
|
||||
@@ -233,26 +112,22 @@ class WanLoopDenoiser(ModularPipelineBlocks):
|
||||
# {"encoder_hidden_states": negative_prompt_embeds, "__guidance_identifier__": "pred_uncond"}, # unconditional batch
|
||||
# ]
|
||||
# Other guidance methods may return 1 batch (no guidance) or 3+ batches (e.g., PAG, APG).
|
||||
guider_state = components.guider.prepare_inputs_from_block_state(block_state, self._guider_input_fields)
|
||||
guider_state = components.guider.prepare_inputs(guider_inputs)
|
||||
|
||||
# run the denoiser for each guidance batch
|
||||
for guider_state_batch in guider_state:
|
||||
components.guider.prepare_models(components.transformer)
|
||||
cond_kwargs = guider_state_batch.as_dict()
|
||||
cond_kwargs = {
|
||||
k: v.to(block_state.dtype) if isinstance(v, torch.Tensor) else v
|
||||
for k, v in cond_kwargs.items()
|
||||
if k in self._guider_input_fields.keys()
|
||||
}
|
||||
cond_kwargs = {input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()}
|
||||
prompt_embeds = cond_kwargs.pop("prompt_embeds")
|
||||
|
||||
# Predict the noise residual
|
||||
# store the noise_pred in guider_state_batch so that we can apply guidance across all batches
|
||||
guider_state_batch.noise_pred = components.transformer(
|
||||
hidden_states=block_state.latent_model_input.to(block_state.dtype),
|
||||
timestep=t.expand(block_state.latent_model_input.shape[0]).to(block_state.dtype),
|
||||
hidden_states=block_state.latents.to(transformer_dtype),
|
||||
timestep=t.flatten(),
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
attention_kwargs=block_state.attention_kwargs,
|
||||
return_dict=False,
|
||||
**cond_kwargs,
|
||||
)[0]
|
||||
components.guider.cleanup_models(components.transformer)
|
||||
|
||||
@@ -262,141 +137,6 @@ class WanLoopDenoiser(ModularPipelineBlocks):
|
||||
return components, block_state
|
||||
|
||||
|
||||
class Wan22LoopDenoiser(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
guider_input_fields: Dict[str, Any] = {"encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds")},
|
||||
):
|
||||
"""Initialize a denoiser block that calls the denoiser model. This block is used in Wan2.2.
|
||||
|
||||
Args:
|
||||
guider_input_fields: A dictionary that maps each argument expected by the denoiser model
|
||||
(for example, "encoder_hidden_states") to data stored on `block_state`. The value can be either:
|
||||
|
||||
- A tuple of strings. For instance, `{"encoder_hidden_states": ("prompt_embeds",
|
||||
"negative_prompt_embeds")}` tells the guider to read `block_state.prompt_embeds` and
|
||||
`block_state.negative_prompt_embeds` and pass them as the conditional and unconditional batches of
|
||||
`encoder_hidden_states`.
|
||||
- A string. For example, `{"encoder_hidden_image": "image_embeds"}` makes the guider forward
|
||||
`block_state.image_embeds` for both conditional and unconditional batches.
|
||||
"""
|
||||
if not isinstance(guider_input_fields, dict):
|
||||
raise ValueError(f"guider_input_fields must be a dictionary but is {type(guider_input_fields)}")
|
||||
self._guider_input_fields = guider_input_fields
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec(
|
||||
"guider",
|
||||
ClassifierFreeGuidance,
|
||||
config=FrozenDict({"guidance_scale": 4.0}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
ComponentSpec(
|
||||
"guider_2",
|
||||
ClassifierFreeGuidance,
|
||||
config=FrozenDict({"guidance_scale": 3.0}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
ComponentSpec("transformer", WanTransformer3DModel),
|
||||
ComponentSpec("transformer_2", WanTransformer3DModel),
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Step within the denoising loop that denoise the latents with guidance. "
|
||||
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
|
||||
"object (e.g. `WanDenoiseLoopWrapper`)"
|
||||
)
|
||||
|
||||
@property
|
||||
def expected_configs(self) -> List[ConfigSpec]:
|
||||
return [
|
||||
ConfigSpec(
|
||||
name="boundary_ratio",
|
||||
default=0.875,
|
||||
description="The boundary ratio to divide the denoising loop into high noise and low noise stages.",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
inputs = [
|
||||
InputParam("attention_kwargs"),
|
||||
InputParam(
|
||||
"num_inference_steps",
|
||||
required=True,
|
||||
type_hint=int,
|
||||
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
|
||||
),
|
||||
]
|
||||
guider_input_names = []
|
||||
for value in self._guider_input_fields.values():
|
||||
if isinstance(value, tuple):
|
||||
guider_input_names.extend(value)
|
||||
else:
|
||||
guider_input_names.append(value)
|
||||
|
||||
for name in guider_input_names:
|
||||
inputs.append(InputParam(name=name, required=True, type_hint=torch.Tensor))
|
||||
return inputs
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor
|
||||
) -> PipelineState:
|
||||
boundary_timestep = components.config.boundary_ratio * components.num_train_timesteps
|
||||
if t >= boundary_timestep:
|
||||
block_state.current_model = components.transformer
|
||||
block_state.guider = components.guider
|
||||
else:
|
||||
block_state.current_model = components.transformer_2
|
||||
block_state.guider = components.guider_2
|
||||
|
||||
block_state.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
|
||||
|
||||
# The guider splits model inputs into separate batches for conditional/unconditional predictions.
|
||||
# For CFG with guider_inputs = {"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds)}:
|
||||
# you will get a guider_state with two batches:
|
||||
# guider_state = [
|
||||
# {"encoder_hidden_states": prompt_embeds, "__guidance_identifier__": "pred_cond"}, # conditional batch
|
||||
# {"encoder_hidden_states": negative_prompt_embeds, "__guidance_identifier__": "pred_uncond"}, # unconditional batch
|
||||
# ]
|
||||
# Other guidance methods may return 1 batch (no guidance) or 3+ batches (e.g., PAG, APG).
|
||||
guider_state = block_state.guider.prepare_inputs_from_block_state(block_state, self._guider_input_fields)
|
||||
|
||||
# run the denoiser for each guidance batch
|
||||
for guider_state_batch in guider_state:
|
||||
block_state.guider.prepare_models(block_state.current_model)
|
||||
cond_kwargs = guider_state_batch.as_dict()
|
||||
cond_kwargs = {
|
||||
k: v.to(block_state.dtype) if isinstance(v, torch.Tensor) else v
|
||||
for k, v in cond_kwargs.items()
|
||||
if k in self._guider_input_fields.keys()
|
||||
}
|
||||
|
||||
# Predict the noise residual
|
||||
# store the noise_pred in guider_state_batch so that we can apply guidance across all batches
|
||||
guider_state_batch.noise_pred = block_state.current_model(
|
||||
hidden_states=block_state.latent_model_input.to(block_state.dtype),
|
||||
timestep=t.expand(block_state.latent_model_input.shape[0]).to(block_state.dtype),
|
||||
attention_kwargs=block_state.attention_kwargs,
|
||||
return_dict=False,
|
||||
**cond_kwargs,
|
||||
)[0]
|
||||
block_state.guider.cleanup_models(block_state.current_model)
|
||||
|
||||
# Perform guidance
|
||||
block_state.noise_pred = block_state.guider(guider_state)[0]
|
||||
|
||||
return components, block_state
|
||||
|
||||
|
||||
class WanLoopAfterDenoiser(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@@ -414,6 +154,20 @@ class WanLoopAfterDenoiser(ModularPipelineBlocks):
|
||||
"object (e.g. `WanDenoiseLoopWrapper`)"
|
||||
)
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return []
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam("generator"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
|
||||
# Perform scheduler step using the predicted output
|
||||
@@ -444,11 +198,18 @@ class WanDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
|
||||
@property
|
||||
def loop_expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec(
|
||||
"guider",
|
||||
ClassifierFreeGuidance,
|
||||
config=FrozenDict({"guidance_scale": 5.0}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
ComponentSpec("scheduler", UniPCMultistepScheduler),
|
||||
ComponentSpec("transformer", WanTransformer3DModel),
|
||||
]
|
||||
|
||||
@property
|
||||
def loop_inputs(self) -> List[InputParam]:
|
||||
def loop_intermediate_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
"timesteps",
|
||||
@@ -487,12 +248,7 @@ class WanDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
|
||||
|
||||
class WanDenoiseStep(WanDenoiseLoopWrapper):
|
||||
block_classes = [
|
||||
WanLoopBeforeDenoiser,
|
||||
WanLoopDenoiser(
|
||||
guider_input_fields={
|
||||
"encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"),
|
||||
}
|
||||
),
|
||||
WanLoopDenoiser,
|
||||
WanLoopAfterDenoiser,
|
||||
]
|
||||
block_names = ["before_denoiser", "denoiser", "after_denoiser"]
|
||||
@@ -503,110 +259,7 @@ class WanDenoiseStep(WanDenoiseLoopWrapper):
|
||||
"Denoise step that iteratively denoise the latents. \n"
|
||||
"Its loop logic is defined in `WanDenoiseLoopWrapper.__call__` method \n"
|
||||
"At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n"
|
||||
" - `WanLoopBeforeDenoiser`\n"
|
||||
" - `WanLoopDenoiser`\n"
|
||||
" - `WanLoopAfterDenoiser`\n"
|
||||
"This block supports text-to-video tasks for wan2.1."
|
||||
)
|
||||
|
||||
|
||||
class Wan22DenoiseStep(WanDenoiseLoopWrapper):
|
||||
block_classes = [
|
||||
WanLoopBeforeDenoiser,
|
||||
Wan22LoopDenoiser(
|
||||
guider_input_fields={
|
||||
"encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"),
|
||||
}
|
||||
),
|
||||
WanLoopAfterDenoiser,
|
||||
]
|
||||
block_names = ["before_denoiser", "denoiser", "after_denoiser"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Denoise step that iteratively denoise the latents. \n"
|
||||
"Its loop logic is defined in `WanDenoiseLoopWrapper.__call__` method \n"
|
||||
"At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n"
|
||||
" - `WanLoopBeforeDenoiser`\n"
|
||||
" - `Wan22LoopDenoiser`\n"
|
||||
" - `WanLoopAfterDenoiser`\n"
|
||||
"This block supports text-to-video tasks for Wan2.2."
|
||||
)
|
||||
|
||||
|
||||
class WanImage2VideoDenoiseStep(WanDenoiseLoopWrapper):
|
||||
block_classes = [
|
||||
WanImage2VideoLoopBeforeDenoiser,
|
||||
WanLoopDenoiser(
|
||||
guider_input_fields={
|
||||
"encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"),
|
||||
"encoder_hidden_states_image": "image_embeds",
|
||||
}
|
||||
),
|
||||
WanLoopAfterDenoiser,
|
||||
]
|
||||
block_names = ["before_denoiser", "denoiser", "after_denoiser"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Denoise step that iteratively denoise the latents. \n"
|
||||
"Its loop logic is defined in `WanDenoiseLoopWrapper.__call__` method \n"
|
||||
"At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n"
|
||||
" - `WanImage2VideoLoopBeforeDenoiser`\n"
|
||||
" - `WanLoopDenoiser`\n"
|
||||
" - `WanLoopAfterDenoiser`\n"
|
||||
"This block supports image-to-video tasks for wan2.1."
|
||||
)
|
||||
|
||||
|
||||
class Wan22Image2VideoDenoiseStep(WanDenoiseLoopWrapper):
|
||||
block_classes = [
|
||||
WanImage2VideoLoopBeforeDenoiser,
|
||||
Wan22LoopDenoiser(
|
||||
guider_input_fields={
|
||||
"encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"),
|
||||
}
|
||||
),
|
||||
WanLoopAfterDenoiser,
|
||||
]
|
||||
block_names = ["before_denoiser", "denoiser", "after_denoiser"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Denoise step that iteratively denoise the latents. \n"
|
||||
"Its loop logic is defined in `WanDenoiseLoopWrapper.__call__` method \n"
|
||||
"At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n"
|
||||
" - `WanImage2VideoLoopBeforeDenoiser`\n"
|
||||
" - `WanLoopDenoiser`\n"
|
||||
" - `WanLoopAfterDenoiser`\n"
|
||||
"This block supports image-to-video tasks for Wan2.2."
|
||||
)
|
||||
|
||||
|
||||
class WanFLF2VDenoiseStep(WanDenoiseLoopWrapper):
|
||||
block_classes = [
|
||||
WanFLF2VLoopBeforeDenoiser,
|
||||
WanLoopDenoiser(
|
||||
guider_input_fields={
|
||||
"encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"),
|
||||
"encoder_hidden_states_image": "image_embeds",
|
||||
}
|
||||
),
|
||||
WanLoopAfterDenoiser,
|
||||
]
|
||||
block_names = ["before_denoiser", "denoiser", "after_denoiser"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Denoise step that iteratively denoise the latents. \n"
|
||||
"Its loop logic is defined in `WanDenoiseLoopWrapper.__call__` method \n"
|
||||
"At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n"
|
||||
" - `WanFLF2VLoopBeforeDenoiser`\n"
|
||||
" - `WanLoopDenoiser`\n"
|
||||
" - `WanLoopAfterDenoiser`\n"
|
||||
"This block supports FLF2V tasks for wan2.1."
|
||||
"This block supports both text2vid tasks."
|
||||
)
|
||||
|
||||
@@ -15,29 +15,21 @@
|
||||
import html
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import regex as re
|
||||
import torch
|
||||
from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModel, UMT5EncoderModel
|
||||
from transformers import AutoTokenizer, UMT5EncoderModel
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...guiders import ClassifierFreeGuidance
|
||||
from ...image_processor import PipelineImageInput
|
||||
from ...models import AutoencoderKLWan
|
||||
from ...utils import is_ftfy_available, is_torchvision_available, logging
|
||||
from ...video_processor import VideoProcessor
|
||||
from ...utils import is_ftfy_available, logging
|
||||
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
|
||||
from .modular_pipeline import WanModularPipeline
|
||||
|
||||
|
||||
if is_ftfy_available():
|
||||
import ftfy
|
||||
|
||||
if is_torchvision_available():
|
||||
from torchvision import transforms
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@@ -59,103 +51,6 @@ def prompt_clean(text):
|
||||
return text
|
||||
|
||||
|
||||
def get_t5_prompt_embeds(
|
||||
text_encoder: UMT5EncoderModel,
|
||||
tokenizer: AutoTokenizer,
|
||||
prompt: Union[str, List[str]],
|
||||
max_sequence_length: int,
|
||||
device: torch.device,
|
||||
):
|
||||
dtype = text_encoder.dtype
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
prompt = [prompt_clean(u) for u in prompt]
|
||||
|
||||
text_inputs = tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
add_special_tokens=True,
|
||||
return_attention_mask=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
|
||||
seq_lens = mask.gt(0).sum(dim=1).long()
|
||||
prompt_embeds = text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
|
||||
prompt_embeds = torch.stack(
|
||||
[torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
|
||||
)
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
|
||||
def encode_image(
|
||||
image: PipelineImageInput,
|
||||
image_processor: CLIPImageProcessor,
|
||||
image_encoder: CLIPVisionModel,
|
||||
device: Optional[torch.device] = None,
|
||||
):
|
||||
image = image_processor(images=image, return_tensors="pt").to(device)
|
||||
image_embeds = image_encoder(**image, output_hidden_states=True)
|
||||
return image_embeds.hidden_states[-2]
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return encoder_output.latents
|
||||
else:
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
def encode_vae_image(
|
||||
video_tensor: torch.Tensor,
|
||||
vae: AutoencoderKLWan,
|
||||
generator: torch.Generator,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
latent_channels: int = 16,
|
||||
):
|
||||
if not isinstance(video_tensor, torch.Tensor):
|
||||
raise ValueError(f"Expected video_tensor to be a tensor, got {type(video_tensor)}.")
|
||||
|
||||
if isinstance(generator, list) and len(generator) != video_tensor.shape[0]:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but it is not same as number of images {video_tensor.shape[0]}."
|
||||
)
|
||||
|
||||
video_tensor = video_tensor.to(device=device, dtype=dtype)
|
||||
|
||||
if isinstance(generator, list):
|
||||
video_latents = [
|
||||
retrieve_latents(vae.encode(video_tensor[i : i + 1]), generator=generator[i], sample_mode="argmax")
|
||||
for i in range(video_tensor.shape[0])
|
||||
]
|
||||
video_latents = torch.cat(video_latents, dim=0)
|
||||
else:
|
||||
video_latents = retrieve_latents(vae.encode(video_tensor), sample_mode="argmax")
|
||||
|
||||
latents_mean = (
|
||||
torch.tensor(vae.config.latents_mean)
|
||||
.view(1, latent_channels, 1, 1, 1)
|
||||
.to(video_latents.device, video_latents.dtype)
|
||||
)
|
||||
latents_std = 1.0 / torch.tensor(vae.config.latents_std).view(1, latent_channels, 1, 1, 1).to(
|
||||
video_latents.device, video_latents.dtype
|
||||
)
|
||||
video_latents = (video_latents - latents_mean) * latents_std
|
||||
|
||||
return video_latents
|
||||
|
||||
|
||||
class WanTextEncoderStep(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@@ -176,12 +71,16 @@ class WanTextEncoderStep(ModularPipelineBlocks):
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def expected_configs(self) -> List[ConfigSpec]:
|
||||
return []
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("prompt"),
|
||||
InputParam("negative_prompt"),
|
||||
InputParam("max_sequence_length", default=512),
|
||||
InputParam("attention_kwargs"),
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -208,13 +107,47 @@ class WanTextEncoderStep(ModularPipelineBlocks):
|
||||
):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}")
|
||||
|
||||
@staticmethod
|
||||
def _get_t5_prompt_embeds(
|
||||
components,
|
||||
prompt: Union[str, List[str]],
|
||||
max_sequence_length: int,
|
||||
device: torch.device,
|
||||
):
|
||||
dtype = components.text_encoder.dtype
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
prompt = [prompt_clean(u) for u in prompt]
|
||||
|
||||
text_inputs = components.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
add_special_tokens=True,
|
||||
return_attention_mask=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
|
||||
seq_lens = mask.gt(0).sum(dim=1).long()
|
||||
prompt_embeds = components.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
|
||||
prompt_embeds = torch.stack(
|
||||
[torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
|
||||
)
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
@staticmethod
|
||||
def encode_prompt(
|
||||
components,
|
||||
prompt: str,
|
||||
device: Optional[torch.device] = None,
|
||||
num_videos_per_prompt: int = 1,
|
||||
prepare_unconditional_embeds: bool = True,
|
||||
negative_prompt: Optional[str] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
max_sequence_length: int = 512,
|
||||
):
|
||||
r"""
|
||||
@@ -225,29 +158,32 @@ class WanTextEncoderStep(ModularPipelineBlocks):
|
||||
prompt to be encoded
|
||||
device: (`torch.device`):
|
||||
torch device
|
||||
num_videos_per_prompt (`int`):
|
||||
number of videos that should be generated per prompt
|
||||
prepare_unconditional_embeds (`bool`):
|
||||
whether to use prepare unconditional embeddings or not
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
max_sequence_length (`int`, defaults to `512`):
|
||||
The maximum number of text tokens to be used for the generation process.
|
||||
"""
|
||||
device = device or components._execution_device
|
||||
if not isinstance(prompt, list):
|
||||
prompt = [prompt]
|
||||
batch_size = len(prompt)
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt) if prompt is not None else prompt_embeds.shape[0]
|
||||
|
||||
prompt_embeds = get_t5_prompt_embeds(
|
||||
text_encoder=components.text_encoder,
|
||||
tokenizer=components.tokenizer,
|
||||
prompt=prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
)
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds = WanTextEncoderStep._get_t5_prompt_embeds(components, prompt, max_sequence_length, device)
|
||||
|
||||
if prepare_unconditional_embeds:
|
||||
if prepare_unconditional_embeds and negative_prompt_embeds is None:
|
||||
negative_prompt = negative_prompt or ""
|
||||
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
|
||||
@@ -263,14 +199,18 @@ class WanTextEncoderStep(ModularPipelineBlocks):
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
|
||||
negative_prompt_embeds = get_t5_prompt_embeds(
|
||||
text_encoder=components.text_encoder,
|
||||
tokenizer=components.tokenizer,
|
||||
prompt=negative_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
negative_prompt_embeds = WanTextEncoderStep._get_t5_prompt_embeds(
|
||||
components, negative_prompt, max_sequence_length, device
|
||||
)
|
||||
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1)
|
||||
|
||||
if prepare_unconditional_embeds:
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -279,6 +219,7 @@ class WanTextEncoderStep(ModularPipelineBlocks):
|
||||
block_state = self.get_block_state(state)
|
||||
self.check_inputs(block_state)
|
||||
|
||||
block_state.prepare_unconditional_embeds = components.guider.num_conditions > 1
|
||||
block_state.device = components._execution_device
|
||||
|
||||
# Encode input prompt
|
||||
@@ -286,382 +227,16 @@ class WanTextEncoderStep(ModularPipelineBlocks):
|
||||
block_state.prompt_embeds,
|
||||
block_state.negative_prompt_embeds,
|
||||
) = self.encode_prompt(
|
||||
components=components,
|
||||
prompt=block_state.prompt,
|
||||
device=block_state.device,
|
||||
prepare_unconditional_embeds=components.requires_unconditional_embeds,
|
||||
negative_prompt=block_state.negative_prompt,
|
||||
max_sequence_length=block_state.max_sequence_length,
|
||||
components,
|
||||
block_state.prompt,
|
||||
block_state.device,
|
||||
1,
|
||||
block_state.prepare_unconditional_embeds,
|
||||
block_state.negative_prompt,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
)
|
||||
|
||||
# Add outputs
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class WanImageResizeStep(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Image Resize step that resize the image to the target area (height * width) while maintaining the aspect ratio."
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("image", type_hint=PIL.Image.Image, required=True),
|
||||
InputParam("height", type_hint=int, default=480),
|
||||
InputParam("width", type_hint=int, default=832),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam("resized_image", type_hint=PIL.Image.Image),
|
||||
]
|
||||
|
||||
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
max_area = block_state.height * block_state.width
|
||||
|
||||
image = block_state.image
|
||||
aspect_ratio = image.height / image.width
|
||||
mod_value = components.vae_scale_factor_spatial * components.patch_size_spatial
|
||||
block_state.height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
|
||||
block_state.width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
|
||||
block_state.resized_image = image.resize((block_state.width, block_state.height))
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class WanImageCropResizeStep(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Image Resize step that resize the last_image to the same size of first frame image with center crop."
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
"resized_image", type_hint=PIL.Image.Image, required=True, description="The resized first frame image"
|
||||
),
|
||||
InputParam("last_image", type_hint=PIL.Image.Image, required=True, description="The last frameimage"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam("resized_last_image", type_hint=PIL.Image.Image),
|
||||
]
|
||||
|
||||
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
height = block_state.resized_image.height
|
||||
width = block_state.resized_image.width
|
||||
image = block_state.last_image
|
||||
|
||||
# Calculate resize ratio to match first frame dimensions
|
||||
resize_ratio = max(width / image.width, height / image.height)
|
||||
|
||||
# Resize the image
|
||||
width = round(image.width * resize_ratio)
|
||||
height = round(image.height * resize_ratio)
|
||||
size = [width, height]
|
||||
resized_image = transforms.functional.center_crop(image, size)
|
||||
block_state.resized_last_image = resized_image
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class WanImageEncoderStep(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Image Encoder step that generate image_embeds based on first frame image to guide the video generation"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("image_processor", CLIPImageProcessor),
|
||||
ComponentSpec("image_encoder", CLIPVisionModel),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("resized_image", type_hint=PIL.Image.Image, required=True),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam("image_embeds", type_hint=torch.Tensor, description="The image embeddings"),
|
||||
]
|
||||
|
||||
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
device = components._execution_device
|
||||
|
||||
image = block_state.resized_image
|
||||
|
||||
image_embeds = encode_image(
|
||||
image_processor=components.image_processor,
|
||||
image_encoder=components.image_encoder,
|
||||
image=image,
|
||||
device=device,
|
||||
)
|
||||
block_state.image_embeds = image_embeds
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class WanFirstLastFrameImageEncoderStep(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Image Encoder step that generate image_embeds based on first and last frame images to guide the video generation"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("image_processor", CLIPImageProcessor),
|
||||
ComponentSpec("image_encoder", CLIPVisionModel),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("resized_image", type_hint=PIL.Image.Image, required=True),
|
||||
InputParam("resized_last_image", type_hint=PIL.Image.Image, required=True),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam("image_embeds", type_hint=torch.Tensor, description="The image embeddings"),
|
||||
]
|
||||
|
||||
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
device = components._execution_device
|
||||
|
||||
first_frame_image = block_state.resized_image
|
||||
last_frame_image = block_state.resized_last_image
|
||||
|
||||
image_embeds = encode_image(
|
||||
image_processor=components.image_processor,
|
||||
image_encoder=components.image_encoder,
|
||||
image=[first_frame_image, last_frame_image],
|
||||
device=device,
|
||||
)
|
||||
block_state.image_embeds = image_embeds
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class WanVaeImageEncoderStep(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Vae Image Encoder step that generate condition_latents based on first frame image to guide the video generation"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("vae", AutoencoderKLWan),
|
||||
ComponentSpec(
|
||||
"video_processor",
|
||||
VideoProcessor,
|
||||
config=FrozenDict({"vae_scale_factor": 8}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("resized_image", type_hint=PIL.Image.Image, required=True),
|
||||
InputParam("height"),
|
||||
InputParam("width"),
|
||||
InputParam("num_frames"),
|
||||
InputParam("generator"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
"first_frame_latents",
|
||||
type_hint=torch.Tensor,
|
||||
description="video latent representation with the first frame image condition",
|
||||
),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def check_inputs(components, block_state):
|
||||
if (block_state.height is not None and block_state.height % components.vae_scale_factor_spatial != 0) or (
|
||||
block_state.width is not None and block_state.width % components.vae_scale_factor_spatial != 0
|
||||
):
|
||||
raise ValueError(
|
||||
f"`height` and `width` have to be divisible by {components.vae_scale_factor_spatial} but are {block_state.height} and {block_state.width}."
|
||||
)
|
||||
if block_state.num_frames is not None and (
|
||||
block_state.num_frames < 1 or (block_state.num_frames - 1) % components.vae_scale_factor_temporal != 0
|
||||
):
|
||||
raise ValueError(
|
||||
f"`num_frames` has to be greater than 0, and (num_frames - 1) must be divisible by {components.vae_scale_factor_temporal}, but got {block_state.num_frames}."
|
||||
)
|
||||
|
||||
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
self.check_inputs(components, block_state)
|
||||
|
||||
image = block_state.resized_image
|
||||
|
||||
device = components._execution_device
|
||||
dtype = torch.float32
|
||||
|
||||
height = block_state.height or components.default_height
|
||||
width = block_state.width or components.default_width
|
||||
num_frames = block_state.num_frames or components.default_num_frames
|
||||
|
||||
image_tensor = components.video_processor.preprocess(image, height=height, width=width).to(
|
||||
device=device, dtype=dtype
|
||||
)
|
||||
|
||||
if image_tensor.dim() == 4:
|
||||
image_tensor = image_tensor.unsqueeze(2)
|
||||
|
||||
video_tensor = torch.cat(
|
||||
[
|
||||
image_tensor,
|
||||
image_tensor.new_zeros(image_tensor.shape[0], image_tensor.shape[1], num_frames - 1, height, width),
|
||||
],
|
||||
dim=2,
|
||||
).to(device=device, dtype=dtype)
|
||||
|
||||
block_state.first_frame_latents = encode_vae_image(
|
||||
video_tensor=video_tensor,
|
||||
vae=components.vae,
|
||||
generator=block_state.generator,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
latent_channels=components.num_channels_latents,
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class WanFirstLastFrameVaeImageEncoderStep(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Vae Image Encoder step that generate condition_latents based on first and last frame images to guide the video generation"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("vae", AutoencoderKLWan),
|
||||
ComponentSpec(
|
||||
"video_processor",
|
||||
VideoProcessor,
|
||||
config=FrozenDict({"vae_scale_factor": 8}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("resized_image", type_hint=PIL.Image.Image, required=True),
|
||||
InputParam("resized_last_image", type_hint=PIL.Image.Image, required=True),
|
||||
InputParam("height"),
|
||||
InputParam("width"),
|
||||
InputParam("num_frames"),
|
||||
InputParam("generator"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
"first_last_frame_latents",
|
||||
type_hint=torch.Tensor,
|
||||
description="video latent representation with the first and last frame images condition",
|
||||
),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def check_inputs(components, block_state):
|
||||
if (block_state.height is not None and block_state.height % components.vae_scale_factor_spatial != 0) or (
|
||||
block_state.width is not None and block_state.width % components.vae_scale_factor_spatial != 0
|
||||
):
|
||||
raise ValueError(
|
||||
f"`height` and `width` have to be divisible by {components.vae_scale_factor_spatial} but are {block_state.height} and {block_state.width}."
|
||||
)
|
||||
if block_state.num_frames is not None and (
|
||||
block_state.num_frames < 1 or (block_state.num_frames - 1) % components.vae_scale_factor_temporal != 0
|
||||
):
|
||||
raise ValueError(
|
||||
f"`num_frames` has to be greater than 0, and (num_frames - 1) must be divisible by {components.vae_scale_factor_temporal}, but got {block_state.num_frames}."
|
||||
)
|
||||
|
||||
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
self.check_inputs(components, block_state)
|
||||
|
||||
first_frame_image = block_state.resized_image
|
||||
last_frame_image = block_state.resized_last_image
|
||||
|
||||
device = components._execution_device
|
||||
dtype = torch.float32
|
||||
|
||||
height = block_state.height or components.default_height
|
||||
width = block_state.width or components.default_width
|
||||
num_frames = block_state.num_frames or components.default_num_frames
|
||||
|
||||
first_image_tensor = components.video_processor.preprocess(first_frame_image, height=height, width=width).to(
|
||||
device=device, dtype=dtype
|
||||
)
|
||||
first_image_tensor = first_image_tensor.unsqueeze(2)
|
||||
|
||||
last_image_tensor = components.video_processor.preprocess(last_frame_image, height=height, width=width).to(
|
||||
device=device, dtype=dtype
|
||||
)
|
||||
|
||||
last_image_tensor = last_image_tensor.unsqueeze(2)
|
||||
|
||||
video_tensor = torch.cat(
|
||||
[
|
||||
first_image_tensor,
|
||||
first_image_tensor.new_zeros(
|
||||
first_image_tensor.shape[0], first_image_tensor.shape[1], num_frames - 2, height, width
|
||||
),
|
||||
last_image_tensor,
|
||||
],
|
||||
dim=2,
|
||||
).to(device=device, dtype=dtype)
|
||||
|
||||
block_state.first_last_frame_latents = encode_vae_image(
|
||||
video_tensor=video_tensor,
|
||||
vae=components.vae,
|
||||
generator=block_state.generator,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
latent_channels=components.num_channels_latents,
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
@@ -16,244 +16,96 @@ from ...utils import logging
|
||||
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
|
||||
from ..modular_pipeline_utils import InsertableDict
|
||||
from .before_denoise import (
|
||||
WanAdditionalInputsStep,
|
||||
WanPrepareFirstFrameLatentsStep,
|
||||
WanPrepareFirstLastFrameLatentsStep,
|
||||
WanInputStep,
|
||||
WanPrepareLatentsStep,
|
||||
WanSetTimestepsStep,
|
||||
WanTextInputStep,
|
||||
)
|
||||
from .decoders import WanImageVaeDecoderStep
|
||||
from .denoise import (
|
||||
Wan22DenoiseStep,
|
||||
Wan22Image2VideoDenoiseStep,
|
||||
WanDenoiseStep,
|
||||
WanFLF2VDenoiseStep,
|
||||
WanImage2VideoDenoiseStep,
|
||||
)
|
||||
from .encoders import (
|
||||
WanFirstLastFrameImageEncoderStep,
|
||||
WanFirstLastFrameVaeImageEncoderStep,
|
||||
WanImageCropResizeStep,
|
||||
WanImageEncoderStep,
|
||||
WanImageResizeStep,
|
||||
WanTextEncoderStep,
|
||||
WanVaeImageEncoderStep,
|
||||
)
|
||||
from .decoders import WanDecodeStep
|
||||
from .denoise import WanDenoiseStep
|
||||
from .encoders import WanTextEncoderStep
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
# wan2.1
|
||||
# wan2.1: text2vid
|
||||
class WanCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
# before_denoise: text2vid
|
||||
class WanBeforeDenoiseStep(SequentialPipelineBlocks):
|
||||
block_classes = [
|
||||
WanTextInputStep,
|
||||
WanInputStep,
|
||||
WanSetTimestepsStep,
|
||||
WanPrepareLatentsStep,
|
||||
WanDenoiseStep,
|
||||
]
|
||||
block_names = ["input", "set_timesteps", "prepare_latents", "denoise"]
|
||||
block_names = ["input", "set_timesteps", "prepare_latents"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"denoise block that takes encoded conditions and runs the denoising process.\n"
|
||||
"Before denoise step that prepare the inputs for the denoise step.\n"
|
||||
+ "This is a sequential pipeline blocks:\n"
|
||||
+ " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n"
|
||||
+ " - `WanInputStep` is used to adjust the batch size of the model inputs\n"
|
||||
+ " - `WanSetTimestepsStep` is used to set the timesteps\n"
|
||||
+ " - `WanPrepareLatentsStep` is used to prepare the latents\n"
|
||||
+ " - `WanDenoiseStep` is used to denoise the latents\n"
|
||||
)
|
||||
|
||||
|
||||
# wan2.1: image2video
|
||||
## image encoder
|
||||
class WanImage2VideoImageEncoderStep(SequentialPipelineBlocks):
|
||||
model_name = "wan"
|
||||
block_classes = [WanImageResizeStep, WanImageEncoderStep]
|
||||
block_names = ["image_resize", "image_encoder"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Image2Video Image Encoder step that resize the image and encode the image to generate the image embeddings"
|
||||
|
||||
|
||||
## vae encoder
|
||||
class WanImage2VideoVaeImageEncoderStep(SequentialPipelineBlocks):
|
||||
model_name = "wan"
|
||||
block_classes = [WanImageResizeStep, WanVaeImageEncoderStep]
|
||||
block_names = ["image_resize", "vae_image_encoder"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Image2Video Vae Image Encoder step that resize the image and encode the first frame image to its latent representation"
|
||||
|
||||
|
||||
## denoise
|
||||
class WanImage2VideoCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
# before_denoise: all task (text2vid,)
|
||||
class WanAutoBeforeDenoiseStep(AutoPipelineBlocks):
|
||||
block_classes = [
|
||||
WanTextInputStep,
|
||||
WanAdditionalInputsStep(image_latent_inputs=["first_frame_latents"]),
|
||||
WanSetTimestepsStep,
|
||||
WanPrepareLatentsStep,
|
||||
WanPrepareFirstFrameLatentsStep,
|
||||
WanImage2VideoDenoiseStep,
|
||||
]
|
||||
block_names = [
|
||||
"input",
|
||||
"additional_inputs",
|
||||
"set_timesteps",
|
||||
"prepare_latents",
|
||||
"prepare_first_frame_latents",
|
||||
"denoise",
|
||||
WanBeforeDenoiseStep,
|
||||
]
|
||||
block_names = ["text2vid"]
|
||||
block_trigger_inputs = [None]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"denoise block that takes encoded text and image latent conditions and runs the denoising process.\n"
|
||||
+ "This is a sequential pipeline blocks:\n"
|
||||
+ " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n"
|
||||
+ " - `WanAdditionalInputsStep` is used to adjust the batch size of the latent conditions\n"
|
||||
+ " - `WanSetTimestepsStep` is used to set the timesteps\n"
|
||||
+ " - `WanPrepareLatentsStep` is used to prepare the latents\n"
|
||||
+ " - `WanPrepareFirstFrameLatentsStep` is used to prepare the first frame latent conditions\n"
|
||||
+ " - `WanImage2VideoDenoiseStep` is used to denoise the latents\n"
|
||||
"Before denoise step that prepare the inputs for the denoise step.\n"
|
||||
+ "This is an auto pipeline block that works for text2vid.\n"
|
||||
+ " - `WanBeforeDenoiseStep` (text2vid) is used.\n"
|
||||
)
|
||||
|
||||
|
||||
# wan2.1: FLF2v
|
||||
|
||||
|
||||
## image encoder
|
||||
class WanFLF2VImageEncoderStep(SequentialPipelineBlocks):
|
||||
model_name = "wan"
|
||||
block_classes = [WanImageResizeStep, WanImageCropResizeStep, WanFirstLastFrameImageEncoderStep]
|
||||
block_names = ["image_resize", "last_image_resize", "image_encoder"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "FLF2V Image Encoder step that resize and encode and encode the first and last frame images to generate the image embeddings"
|
||||
|
||||
|
||||
## vae encoder
|
||||
class WanFLF2VVaeImageEncoderStep(SequentialPipelineBlocks):
|
||||
model_name = "wan"
|
||||
block_classes = [WanImageResizeStep, WanImageCropResizeStep, WanFirstLastFrameVaeImageEncoderStep]
|
||||
block_names = ["image_resize", "last_image_resize", "vae_image_encoder"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "FLF2V Vae Image Encoder step that resize and encode and encode the first and last frame images to generate the latent conditions"
|
||||
|
||||
|
||||
## denoise
|
||||
class WanFLF2VCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
block_classes = [
|
||||
WanTextInputStep,
|
||||
WanAdditionalInputsStep(image_latent_inputs=["first_last_frame_latents"]),
|
||||
WanSetTimestepsStep,
|
||||
WanPrepareLatentsStep,
|
||||
WanPrepareFirstLastFrameLatentsStep,
|
||||
WanFLF2VDenoiseStep,
|
||||
]
|
||||
block_names = [
|
||||
"input",
|
||||
"additional_inputs",
|
||||
"set_timesteps",
|
||||
"prepare_latents",
|
||||
"prepare_first_last_frame_latents",
|
||||
"denoise",
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"denoise block that takes encoded text and image latent conditions and runs the denoising process.\n"
|
||||
+ "This is a sequential pipeline blocks:\n"
|
||||
+ " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n"
|
||||
+ " - `WanAdditionalInputsStep` is used to adjust the batch size of the latent conditions\n"
|
||||
+ " - `WanSetTimestepsStep` is used to set the timesteps\n"
|
||||
+ " - `WanPrepareLatentsStep` is used to prepare the latents\n"
|
||||
+ " - `WanPrepareFirstLastFrameLatentsStep` is used to prepare the latent conditions\n"
|
||||
+ " - `WanImage2VideoDenoiseStep` is used to denoise the latents\n"
|
||||
)
|
||||
|
||||
|
||||
# wan2.1: auto blocks
|
||||
## image encoder
|
||||
class WanAutoImageEncoderStep(AutoPipelineBlocks):
|
||||
block_classes = [WanFLF2VImageEncoderStep, WanImage2VideoImageEncoderStep]
|
||||
block_names = ["flf2v_image_encoder", "image2video_image_encoder"]
|
||||
block_trigger_inputs = ["last_image", "image"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Image Encoder step that encode the image to generate the image embeddings"
|
||||
+ "This is an auto pipeline block that works for image2video tasks."
|
||||
+ " - `WanFLF2VImageEncoderStep` (flf2v) is used when `last_image` is provided."
|
||||
+ " - `WanImage2VideoImageEncoderStep` (image2video) is used when `image` is provided."
|
||||
+ " - if `last_image` or `image` is not provided, step will be skipped."
|
||||
)
|
||||
|
||||
|
||||
## vae encoder
|
||||
class WanAutoVaeImageEncoderStep(AutoPipelineBlocks):
|
||||
block_classes = [WanFLF2VVaeImageEncoderStep, WanImage2VideoVaeImageEncoderStep]
|
||||
block_names = ["flf2v_vae_image_encoder", "image2video_vae_image_encoder"]
|
||||
block_trigger_inputs = ["last_image", "image"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Vae Image Encoder step that encode the image to generate the image latents"
|
||||
+ "This is an auto pipeline block that works for image2video tasks."
|
||||
+ " - `WanFLF2VVaeImageEncoderStep` (flf2v) is used when `last_image` is provided."
|
||||
+ " - `WanImage2VideoVaeImageEncoderStep` (image2video) is used when `image` is provided."
|
||||
+ " - if `last_image` or `image` is not provided, step will be skipped."
|
||||
)
|
||||
|
||||
|
||||
## denoise
|
||||
# denoise: text2vid
|
||||
class WanAutoDenoiseStep(AutoPipelineBlocks):
|
||||
block_classes = [
|
||||
WanFLF2VCoreDenoiseStep,
|
||||
WanImage2VideoCoreDenoiseStep,
|
||||
WanCoreDenoiseStep,
|
||||
WanDenoiseStep,
|
||||
]
|
||||
block_names = ["flf2v", "image2video", "text2video"]
|
||||
block_trigger_inputs = ["first_last_frame_latents", "first_frame_latents", None]
|
||||
block_names = ["denoise"]
|
||||
block_trigger_inputs = [None]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Denoise step that iteratively denoise the latents. "
|
||||
"This is a auto pipeline block that works for text2video and image2video tasks."
|
||||
" - `WanCoreDenoiseStep` (text2video) for text2vid tasks."
|
||||
" - `WanCoreImage2VideoCoreDenoiseStep` (image2video) for image2video tasks."
|
||||
+ " - if `first_frame_latents` is provided, `WanCoreImage2VideoDenoiseStep` will be used.\n"
|
||||
+ " - if `first_frame_latents` is not provided, `WanCoreDenoiseStep` will be used.\n"
|
||||
"This is a auto pipeline block that works for text2vid tasks.."
|
||||
" - `WanDenoiseStep` (denoise) for text2vid tasks."
|
||||
)
|
||||
|
||||
|
||||
# auto pipeline blocks
|
||||
# decode: all task (text2img, img2img, inpainting)
|
||||
class WanAutoDecodeStep(AutoPipelineBlocks):
|
||||
block_classes = [WanDecodeStep]
|
||||
block_names = ["non-inpaint"]
|
||||
block_trigger_inputs = [None]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Decode step that decode the denoised latents into videos outputs.\n - `WanDecodeStep`"
|
||||
|
||||
|
||||
# text2vid
|
||||
class WanAutoBlocks(SequentialPipelineBlocks):
|
||||
block_classes = [
|
||||
WanTextEncoderStep,
|
||||
WanAutoImageEncoderStep,
|
||||
WanAutoVaeImageEncoderStep,
|
||||
WanAutoBeforeDenoiseStep,
|
||||
WanAutoDenoiseStep,
|
||||
WanImageVaeDecoderStep,
|
||||
WanAutoDecodeStep,
|
||||
]
|
||||
block_names = [
|
||||
"text_encoder",
|
||||
"image_encoder",
|
||||
"vae_image_encoder",
|
||||
"before_denoise",
|
||||
"denoise",
|
||||
"decode",
|
||||
"decoder",
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -264,211 +116,29 @@ class WanAutoBlocks(SequentialPipelineBlocks):
|
||||
)
|
||||
|
||||
|
||||
# wan22
|
||||
# wan2.2: text2vid
|
||||
|
||||
|
||||
## denoise
|
||||
class Wan22CoreDenoiseStep(SequentialPipelineBlocks):
|
||||
block_classes = [
|
||||
WanTextInputStep,
|
||||
WanSetTimestepsStep,
|
||||
WanPrepareLatentsStep,
|
||||
Wan22DenoiseStep,
|
||||
]
|
||||
block_names = ["input", "set_timesteps", "prepare_latents", "denoise"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"denoise block that takes encoded conditions and runs the denoising process.\n"
|
||||
+ "This is a sequential pipeline blocks:\n"
|
||||
+ " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n"
|
||||
+ " - `WanSetTimestepsStep` is used to set the timesteps\n"
|
||||
+ " - `WanPrepareLatentsStep` is used to prepare the latents\n"
|
||||
+ " - `Wan22DenoiseStep` is used to denoise the latents in wan2.2\n"
|
||||
)
|
||||
|
||||
|
||||
# wan2.2: image2video
|
||||
## denoise
|
||||
class Wan22Image2VideoCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
block_classes = [
|
||||
WanTextInputStep,
|
||||
WanAdditionalInputsStep(image_latent_inputs=["first_frame_latents"]),
|
||||
WanSetTimestepsStep,
|
||||
WanPrepareLatentsStep,
|
||||
WanPrepareFirstFrameLatentsStep,
|
||||
Wan22Image2VideoDenoiseStep,
|
||||
]
|
||||
block_names = [
|
||||
"input",
|
||||
"additional_inputs",
|
||||
"set_timesteps",
|
||||
"prepare_latents",
|
||||
"prepare_first_frame_latents",
|
||||
"denoise",
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"denoise block that takes encoded text and image latent conditions and runs the denoising process.\n"
|
||||
+ "This is a sequential pipeline blocks:\n"
|
||||
+ " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n"
|
||||
+ " - `WanAdditionalInputsStep` is used to adjust the batch size of the latent conditions\n"
|
||||
+ " - `WanSetTimestepsStep` is used to set the timesteps\n"
|
||||
+ " - `WanPrepareLatentsStep` is used to prepare the latents\n"
|
||||
+ " - `WanPrepareFirstFrameLatentsStep` is used to prepare the first frame latent conditions\n"
|
||||
+ " - `Wan22Image2VideoDenoiseStep` is used to denoise the latents in wan2.2\n"
|
||||
)
|
||||
|
||||
|
||||
class Wan22AutoDenoiseStep(AutoPipelineBlocks):
|
||||
block_classes = [
|
||||
Wan22Image2VideoCoreDenoiseStep,
|
||||
Wan22CoreDenoiseStep,
|
||||
]
|
||||
block_names = ["image2video", "text2video"]
|
||||
block_trigger_inputs = ["first_frame_latents", None]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Denoise step that iteratively denoise the latents. "
|
||||
"This is a auto pipeline block that works for text2video and image2video tasks."
|
||||
" - `Wan22Image2VideoCoreDenoiseStep` (image2video) for image2video tasks."
|
||||
" - `Wan22CoreDenoiseStep` (text2video) for text2vid tasks."
|
||||
+ " - if `first_frame_latents` is provided, `Wan22Image2VideoCoreDenoiseStep` will be used.\n"
|
||||
+ " - if `first_frame_latents` is not provided, `Wan22CoreDenoiseStep` will be used.\n"
|
||||
)
|
||||
|
||||
|
||||
class Wan22AutoBlocks(SequentialPipelineBlocks):
|
||||
block_classes = [
|
||||
WanTextEncoderStep,
|
||||
WanAutoVaeImageEncoderStep,
|
||||
Wan22AutoDenoiseStep,
|
||||
WanImageVaeDecoderStep,
|
||||
]
|
||||
block_names = [
|
||||
"text_encoder",
|
||||
"vae_image_encoder",
|
||||
"denoise",
|
||||
"decode",
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Auto Modular pipeline for text-to-video using Wan2.2.\n"
|
||||
+ "- for text-to-video generation, all you need to provide is `prompt`"
|
||||
)
|
||||
|
||||
|
||||
# presets for wan2.1 and wan2.2
|
||||
# YiYi Notes: should we move these to doc?
|
||||
# wan2.1
|
||||
TEXT2VIDEO_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", WanTextEncoderStep),
|
||||
("input", WanTextInputStep),
|
||||
("input", WanInputStep),
|
||||
("set_timesteps", WanSetTimestepsStep),
|
||||
("prepare_latents", WanPrepareLatentsStep),
|
||||
("denoise", WanDenoiseStep),
|
||||
("decode", WanImageVaeDecoderStep),
|
||||
("decode", WanDecodeStep),
|
||||
]
|
||||
)
|
||||
|
||||
IMAGE2VIDEO_BLOCKS = InsertableDict(
|
||||
[
|
||||
("image_resize", WanImageResizeStep),
|
||||
("image_encoder", WanImage2VideoImageEncoderStep),
|
||||
("vae_image_encoder", WanImage2VideoVaeImageEncoderStep),
|
||||
("input", WanTextInputStep),
|
||||
("additional_inputs", WanAdditionalInputsStep(image_latent_inputs=["first_frame_latents"])),
|
||||
("set_timesteps", WanSetTimestepsStep),
|
||||
("prepare_latents", WanPrepareLatentsStep),
|
||||
("prepare_first_frame_latents", WanPrepareFirstFrameLatentsStep),
|
||||
("denoise", WanImage2VideoDenoiseStep),
|
||||
("decode", WanImageVaeDecoderStep),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
FLF2V_BLOCKS = InsertableDict(
|
||||
[
|
||||
("image_resize", WanImageResizeStep),
|
||||
("last_image_resize", WanImageCropResizeStep),
|
||||
("image_encoder", WanFLF2VImageEncoderStep),
|
||||
("vae_image_encoder", WanFLF2VVaeImageEncoderStep),
|
||||
("input", WanTextInputStep),
|
||||
("additional_inputs", WanAdditionalInputsStep(image_latent_inputs=["first_last_frame_latents"])),
|
||||
("set_timesteps", WanSetTimestepsStep),
|
||||
("prepare_latents", WanPrepareLatentsStep),
|
||||
("prepare_first_last_frame_latents", WanPrepareFirstLastFrameLatentsStep),
|
||||
("denoise", WanFLF2VDenoiseStep),
|
||||
("decode", WanImageVaeDecoderStep),
|
||||
]
|
||||
)
|
||||
|
||||
AUTO_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", WanTextEncoderStep),
|
||||
("image_encoder", WanAutoImageEncoderStep),
|
||||
("vae_image_encoder", WanAutoVaeImageEncoderStep),
|
||||
("before_denoise", WanAutoBeforeDenoiseStep),
|
||||
("denoise", WanAutoDenoiseStep),
|
||||
("decode", WanImageVaeDecoderStep),
|
||||
("decode", WanAutoDecodeStep),
|
||||
]
|
||||
)
|
||||
|
||||
# wan2.2 presets
|
||||
|
||||
TEXT2VIDEO_BLOCKS_WAN22 = InsertableDict(
|
||||
[
|
||||
("text_encoder", WanTextEncoderStep),
|
||||
("input", WanTextInputStep),
|
||||
("set_timesteps", WanSetTimestepsStep),
|
||||
("prepare_latents", WanPrepareLatentsStep),
|
||||
("denoise", Wan22DenoiseStep),
|
||||
("decode", WanImageVaeDecoderStep),
|
||||
]
|
||||
)
|
||||
|
||||
IMAGE2VIDEO_BLOCKS_WAN22 = InsertableDict(
|
||||
[
|
||||
("image_resize", WanImageResizeStep),
|
||||
("vae_image_encoder", WanImage2VideoVaeImageEncoderStep),
|
||||
("input", WanTextInputStep),
|
||||
("set_timesteps", WanSetTimestepsStep),
|
||||
("prepare_latents", WanPrepareLatentsStep),
|
||||
("denoise", Wan22DenoiseStep),
|
||||
("decode", WanImageVaeDecoderStep),
|
||||
]
|
||||
)
|
||||
|
||||
AUTO_BLOCKS_WAN22 = InsertableDict(
|
||||
[
|
||||
("text_encoder", WanTextEncoderStep),
|
||||
("vae_image_encoder", WanAutoVaeImageEncoderStep),
|
||||
("denoise", Wan22AutoDenoiseStep),
|
||||
("decode", WanImageVaeDecoderStep),
|
||||
]
|
||||
)
|
||||
|
||||
# presets all blocks (wan and wan22)
|
||||
|
||||
|
||||
ALL_BLOCKS = {
|
||||
"wan2.1": {
|
||||
"text2video": TEXT2VIDEO_BLOCKS,
|
||||
"image2video": IMAGE2VIDEO_BLOCKS,
|
||||
"flf2v": FLF2V_BLOCKS,
|
||||
"auto": AUTO_BLOCKS,
|
||||
},
|
||||
"wan2.2": {
|
||||
"text2video": TEXT2VIDEO_BLOCKS_WAN22,
|
||||
"image2video": IMAGE2VIDEO_BLOCKS_WAN22,
|
||||
"auto": AUTO_BLOCKS_WAN22,
|
||||
},
|
||||
"text2video": TEXT2VIDEO_BLOCKS,
|
||||
"auto": AUTO_BLOCKS,
|
||||
}
|
||||
|
||||
@@ -13,8 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from ...loaders import WanLoraLoaderMixin
|
||||
from ...pipelines.pipeline_utils import StableDiffusionMixin
|
||||
from ...utils import logging
|
||||
@@ -37,13 +35,6 @@ class WanModularPipeline(
|
||||
|
||||
default_blocks_name = "WanAutoBlocks"
|
||||
|
||||
# override the default_blocks_name in base class, which is just return self.default_blocks_name
|
||||
def get_default_blocks_name(self, config_dict: Optional[Dict[str, Any]]) -> Optional[str]:
|
||||
if config_dict is not None and "boundary_ratio" in config_dict and config_dict["boundary_ratio"] is not None:
|
||||
return "Wan22AutoBlocks"
|
||||
else:
|
||||
return "WanAutoBlocks"
|
||||
|
||||
@property
|
||||
def default_height(self):
|
||||
return self.default_sample_height * self.vae_scale_factor_spatial
|
||||
@@ -68,13 +59,6 @@ class WanModularPipeline(
|
||||
def default_sample_num_frames(self):
|
||||
return 21
|
||||
|
||||
@property
|
||||
def patch_size_spatial(self):
|
||||
patch_size_spatial = 2
|
||||
if hasattr(self, "transformer") and self.transformer is not None:
|
||||
patch_size_spatial = self.transformer.config.patch_size[1]
|
||||
return patch_size_spatial
|
||||
|
||||
@property
|
||||
def vae_scale_factor_spatial(self):
|
||||
vae_scale_factor = 8
|
||||
@@ -102,19 +86,3 @@ class WanModularPipeline(
|
||||
if hasattr(self, "vae") and self.vae is not None:
|
||||
num_channels_latents = self.vae.config.z_dim
|
||||
return num_channels_latents
|
||||
|
||||
@property
|
||||
def requires_unconditional_embeds(self):
|
||||
requires_unconditional_embeds = False
|
||||
|
||||
if hasattr(self, "guider") and self.guider is not None:
|
||||
requires_unconditional_embeds = self.guider._enabled and self.guider.num_conditions > 1
|
||||
|
||||
return requires_unconditional_embeds
|
||||
|
||||
@property
|
||||
def num_train_timesteps(self):
|
||||
num_train_timesteps = 1000
|
||||
if hasattr(self, "scheduler") and self.scheduler is not None:
|
||||
num_train_timesteps = self.scheduler.config.num_train_timesteps
|
||||
return num_train_timesteps
|
||||
|
||||
@@ -245,21 +245,16 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad
|
||||
out_channels: int = 4,
|
||||
flip_sin_to_cos: bool = True,
|
||||
freq_shift: int = 0,
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
down_block_types: Tuple[str] = (
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"DownBlock2D",
|
||||
),
|
||||
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
|
||||
up_block_types: Tuple[str, ...] = (
|
||||
"UpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
),
|
||||
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
|
||||
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
||||
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
||||
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
||||
layers_per_block: Union[int, Tuple[int]] = 2,
|
||||
downsample_padding: int = 1,
|
||||
mid_block_scale_factor: float = 1,
|
||||
|
||||
@@ -117,7 +117,6 @@ from .stable_diffusion_xl import (
|
||||
StableDiffusionXLInpaintPipeline,
|
||||
StableDiffusionXLPipeline,
|
||||
)
|
||||
from .wan import WanImageToVideoPipeline, WanPipeline, WanVideoToVideoPipeline
|
||||
from .wuerstchen import WuerstchenCombinedPipeline, WuerstchenDecoderPipeline
|
||||
|
||||
|
||||
@@ -215,24 +214,6 @@ AUTO_INPAINT_PIPELINES_MAPPING = OrderedDict(
|
||||
]
|
||||
)
|
||||
|
||||
AUTO_TEXT2VIDEO_PIPELINES_MAPPING = OrderedDict(
|
||||
[
|
||||
("wan", WanPipeline),
|
||||
]
|
||||
)
|
||||
|
||||
AUTO_IMAGE2VIDEO_PIPELINES_MAPPING = OrderedDict(
|
||||
[
|
||||
("wan", WanImageToVideoPipeline),
|
||||
]
|
||||
)
|
||||
|
||||
AUTO_VIDEO2VIDEO_PIPELINES_MAPPING = OrderedDict(
|
||||
[
|
||||
("wan", WanVideoToVideoPipeline),
|
||||
]
|
||||
)
|
||||
|
||||
_AUTO_TEXT2IMAGE_DECODER_PIPELINES_MAPPING = OrderedDict(
|
||||
[
|
||||
("kandinsky", KandinskyPipeline),
|
||||
@@ -266,9 +247,6 @@ SUPPORTED_TASKS_MAPPINGS = [
|
||||
AUTO_TEXT2IMAGE_PIPELINES_MAPPING,
|
||||
AUTO_IMAGE2IMAGE_PIPELINES_MAPPING,
|
||||
AUTO_INPAINT_PIPELINES_MAPPING,
|
||||
AUTO_TEXT2VIDEO_PIPELINES_MAPPING,
|
||||
AUTO_IMAGE2VIDEO_PIPELINES_MAPPING,
|
||||
AUTO_VIDEO2VIDEO_PIPELINES_MAPPING,
|
||||
_AUTO_TEXT2IMAGE_DECODER_PIPELINES_MAPPING,
|
||||
_AUTO_IMAGE2IMAGE_DECODER_PIPELINES_MAPPING,
|
||||
_AUTO_INPAINT_DECODER_PIPELINES_MAPPING,
|
||||
|
||||
@@ -374,21 +374,21 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
center_input_sample: bool = False,
|
||||
flip_sin_to_cos: bool = True,
|
||||
freq_shift: int = 0,
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
down_block_types: Tuple[str] = (
|
||||
"CrossAttnDownBlockFlat",
|
||||
"CrossAttnDownBlockFlat",
|
||||
"CrossAttnDownBlockFlat",
|
||||
"DownBlockFlat",
|
||||
),
|
||||
mid_block_type: Optional[str] = "UNetMidBlockFlatCrossAttn",
|
||||
up_block_types: Tuple[str, ...] = (
|
||||
up_block_types: Tuple[str] = (
|
||||
"UpBlockFlat",
|
||||
"CrossAttnUpBlockFlat",
|
||||
"CrossAttnUpBlockFlat",
|
||||
"CrossAttnUpBlockFlat",
|
||||
),
|
||||
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
||||
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
||||
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
||||
layers_per_block: Union[int, Tuple[int]] = 2,
|
||||
downsample_padding: int = 1,
|
||||
mid_block_scale_factor: float = 1,
|
||||
|
||||
@@ -742,7 +742,7 @@ class ShapEParamsProjModel(ModelMixin, ConfigMixin):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
param_names: Tuple[str, ...] = (
|
||||
param_names: Tuple[str] = (
|
||||
"nerstf.mlp.0.weight",
|
||||
"nerstf.mlp.1.weight",
|
||||
"nerstf.mlp.2.weight",
|
||||
@@ -786,13 +786,13 @@ class ShapERenderer(ModelMixin, ConfigMixin):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
param_names: Tuple[str, ...] = (
|
||||
param_names: Tuple[str] = (
|
||||
"nerstf.mlp.0.weight",
|
||||
"nerstf.mlp.1.weight",
|
||||
"nerstf.mlp.2.weight",
|
||||
"nerstf.mlp.3.weight",
|
||||
),
|
||||
param_shapes: Tuple[Tuple[int, int], ...] = (
|
||||
param_shapes: Tuple[Tuple[int]] = (
|
||||
(256, 93),
|
||||
(256, 256),
|
||||
(256, 256),
|
||||
@@ -804,7 +804,7 @@ class ShapERenderer(ModelMixin, ConfigMixin):
|
||||
n_hidden_layers: int = 6,
|
||||
act_fn: str = "swish",
|
||||
insert_direction_at: int = 4,
|
||||
background: Tuple[float, ...] = (
|
||||
background: Tuple[float] = (
|
||||
255.0,
|
||||
255.0,
|
||||
255.0,
|
||||
|
||||
@@ -182,21 +182,6 @@ class StableDiffusionXLModularPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class Wan22AutoBlocks(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class WanAutoBlocks(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
@@ -32,20 +32,6 @@ warnings.simplefilter(action="ignore", category=FutureWarning)
|
||||
|
||||
def pytest_configure(config):
|
||||
config.addinivalue_line("markers", "big_accelerator: marks tests as requiring big accelerator resources")
|
||||
config.addinivalue_line("markers", "lora: marks tests for LoRA/PEFT functionality")
|
||||
config.addinivalue_line("markers", "ip_adapter: marks tests for IP Adapter functionality")
|
||||
config.addinivalue_line("markers", "training: marks tests for training functionality")
|
||||
config.addinivalue_line("markers", "attention: marks tests for attention processor functionality")
|
||||
config.addinivalue_line("markers", "memory: marks tests for memory optimization functionality")
|
||||
config.addinivalue_line("markers", "cpu_offload: marks tests for CPU offloading functionality")
|
||||
config.addinivalue_line("markers", "group_offload: marks tests for group offloading functionality")
|
||||
config.addinivalue_line("markers", "compile: marks tests for torch.compile functionality")
|
||||
config.addinivalue_line("markers", "single_file: marks tests for single file checkpoint loading")
|
||||
config.addinivalue_line("markers", "bitsandbytes: marks tests for BitsAndBytes quantization functionality")
|
||||
config.addinivalue_line("markers", "quanto: marks tests for Quanto quantization functionality")
|
||||
config.addinivalue_line("markers", "torchao: marks tests for TorchAO quantization functionality")
|
||||
config.addinivalue_line("markers", "gguf: marks tests for GGUF quantization functionality")
|
||||
config.addinivalue_line("markers", "modelopt: marks tests for NVIDIA ModelOpt quantization functionality")
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
|
||||
@@ -317,9 +317,9 @@ class ModelUtilsTest(unittest.TestCase):
|
||||
repo_id, subfolder="transformer", cache_dir=tmpdir, local_files_only=True
|
||||
)
|
||||
|
||||
assert all(
|
||||
torch.equal(p1, p2) for p1, p2 in zip(model.parameters(), local_model.parameters())
|
||||
), "Model parameters don't match!"
|
||||
assert all(torch.equal(p1, p2) for p1, p2 in zip(model.parameters(), local_model.parameters())), (
|
||||
"Model parameters don't match!"
|
||||
)
|
||||
|
||||
# Remove a shard file
|
||||
cached_shard_file = try_to_load_from_cache(
|
||||
@@ -335,9 +335,9 @@ class ModelUtilsTest(unittest.TestCase):
|
||||
|
||||
# Verify error mentions the missing shard
|
||||
error_msg = str(context.exception)
|
||||
assert (
|
||||
cached_shard_file in error_msg or "required according to the checkpoint index" in error_msg
|
||||
), f"Expected error about missing shard, got: {error_msg}"
|
||||
assert cached_shard_file in error_msg or "required according to the checkpoint index" in error_msg, (
|
||||
f"Expected error about missing shard, got: {error_msg}"
|
||||
)
|
||||
|
||||
@unittest.skip("Flaky behaviour on CI. Re-enable after migrating to new runners")
|
||||
@unittest.skipIf(torch_device == "mps", reason="Test not supported for MPS.")
|
||||
@@ -354,9 +354,9 @@ class ModelUtilsTest(unittest.TestCase):
|
||||
)
|
||||
|
||||
download_requests = [r.method for r in m.request_history]
|
||||
assert (
|
||||
download_requests.count("HEAD") == 3
|
||||
), "3 HEAD requests one for config, one for model, and one for shard index file."
|
||||
assert download_requests.count("HEAD") == 3, (
|
||||
"3 HEAD requests one for config, one for model, and one for shard index file."
|
||||
)
|
||||
assert download_requests.count("GET") == 2, "2 GET requests one for config, one for model"
|
||||
|
||||
with requests_mock.mock(real_http=True) as m:
|
||||
@@ -368,9 +368,9 @@ class ModelUtilsTest(unittest.TestCase):
|
||||
)
|
||||
|
||||
cache_requests = [r.method for r in m.request_history]
|
||||
assert (
|
||||
"HEAD" == cache_requests[0] and len(cache_requests) == 2
|
||||
), "We should call only `model_info` to check for commit hash and knowing if shard index is present."
|
||||
assert "HEAD" == cache_requests[0] and len(cache_requests) == 2, (
|
||||
"We should call only `model_info` to check for commit hash and knowing if shard index is present."
|
||||
)
|
||||
|
||||
def test_weight_overwrite(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdirname, self.assertRaises(ValueError) as error_context:
|
||||
|
||||
@@ -1,37 +0,0 @@
|
||||
from .attention import AttentionTesterMixin
|
||||
from .common import ModelTesterMixin
|
||||
from .compile import TorchCompileTesterMixin
|
||||
from .ip_adapter import IPAdapterTesterMixin
|
||||
from .lora import LoraTesterMixin
|
||||
from .memory import CPUOffloadTesterMixin, GroupOffloadTesterMixin, LayerwiseCastingTesterMixin, MemoryTesterMixin
|
||||
from .quantization import (
|
||||
BitsAndBytesTesterMixin,
|
||||
GGUFTesterMixin,
|
||||
ModelOptTesterMixin,
|
||||
QuantizationTesterMixin,
|
||||
QuantoTesterMixin,
|
||||
TorchAoTesterMixin,
|
||||
)
|
||||
from .single_file import SingleFileTesterMixin
|
||||
from .training import TrainingTesterMixin
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AttentionTesterMixin",
|
||||
"BitsAndBytesTesterMixin",
|
||||
"CPUOffloadTesterMixin",
|
||||
"GGUFTesterMixin",
|
||||
"GroupOffloadTesterMixin",
|
||||
"IPAdapterTesterMixin",
|
||||
"LayerwiseCastingTesterMixin",
|
||||
"LoraTesterMixin",
|
||||
"MemoryTesterMixin",
|
||||
"ModelOptTesterMixin",
|
||||
"ModelTesterMixin",
|
||||
"QuantizationTesterMixin",
|
||||
"QuantoTesterMixin",
|
||||
"SingleFileTesterMixin",
|
||||
"TorchAoTesterMixin",
|
||||
"TorchCompileTesterMixin",
|
||||
"TrainingTesterMixin",
|
||||
]
|
||||
@@ -1,180 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from diffusers.models.attention import AttentionModuleMixin
|
||||
from diffusers.models.attention_processor import (
|
||||
AttnProcessor,
|
||||
)
|
||||
|
||||
from ...testing_utils import is_attention, require_accelerator, torch_device
|
||||
|
||||
|
||||
@is_attention
|
||||
@require_accelerator
|
||||
class AttentionTesterMixin:
|
||||
"""
|
||||
Mixin class for testing attention processor and module functionality on models.
|
||||
|
||||
Tests functionality from AttentionModuleMixin including:
|
||||
- Attention processor management (set/get)
|
||||
- QKV projection fusion/unfusion
|
||||
- Attention backends (XFormers, NPU, etc.)
|
||||
|
||||
Expected class attributes to be set by subclasses:
|
||||
- model_class: The model class to test
|
||||
- base_precision: Tolerance for floating point comparisons (default: 1e-3)
|
||||
- uses_custom_attn_processor: Whether model uses custom attention processors (default: False)
|
||||
|
||||
Expected methods to be implemented by subclasses:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
Pytest mark: attention
|
||||
Use `pytest -m "not attention"` to skip these tests
|
||||
"""
|
||||
|
||||
base_precision = 1e-3
|
||||
|
||||
def test_fuse_unfuse_qkv_projections(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
if not hasattr(model, "fuse_qkv_projections"):
|
||||
pytest.skip("Model does not support QKV projection fusion.")
|
||||
|
||||
# Get output before fusion
|
||||
with torch.no_grad():
|
||||
output_before_fusion = model(**inputs_dict)
|
||||
if isinstance(output_before_fusion, dict):
|
||||
output_before_fusion = output_before_fusion.to_tuple()[0]
|
||||
|
||||
# Fuse projections
|
||||
model.fuse_qkv_projections()
|
||||
|
||||
# Verify fusion occurred by checking for fused attributes
|
||||
has_fused_projections = False
|
||||
for module in model.modules():
|
||||
if isinstance(module, AttentionModuleMixin):
|
||||
if hasattr(module, "to_qkv") or hasattr(module, "to_kv"):
|
||||
has_fused_projections = True
|
||||
assert module.fused_projections, "fused_projections flag should be True"
|
||||
break
|
||||
|
||||
if has_fused_projections:
|
||||
# Get output after fusion
|
||||
with torch.no_grad():
|
||||
output_after_fusion = model(**inputs_dict)
|
||||
if isinstance(output_after_fusion, dict):
|
||||
output_after_fusion = output_after_fusion.to_tuple()[0]
|
||||
|
||||
# Verify outputs match
|
||||
assert torch.allclose(
|
||||
output_before_fusion, output_after_fusion, atol=self.base_precision
|
||||
), "Output should not change after fusing projections"
|
||||
|
||||
# Unfuse projections
|
||||
model.unfuse_qkv_projections()
|
||||
|
||||
# Verify unfusion occurred
|
||||
for module in model.modules():
|
||||
if isinstance(module, AttentionModuleMixin):
|
||||
assert not hasattr(module, "to_qkv"), "to_qkv should be removed after unfusing"
|
||||
assert not hasattr(module, "to_kv"), "to_kv should be removed after unfusing"
|
||||
assert not module.fused_projections, "fused_projections flag should be False"
|
||||
|
||||
# Get output after unfusion
|
||||
with torch.no_grad():
|
||||
output_after_unfusion = model(**inputs_dict)
|
||||
if isinstance(output_after_unfusion, dict):
|
||||
output_after_unfusion = output_after_unfusion.to_tuple()[0]
|
||||
|
||||
# Verify outputs still match
|
||||
assert torch.allclose(
|
||||
output_before_fusion, output_after_unfusion, atol=self.base_precision
|
||||
), "Output should match original after unfusing projections"
|
||||
|
||||
def test_get_set_processor(self):
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
# Check if model has attention processors
|
||||
if not hasattr(model, "attn_processors"):
|
||||
pytest.skip("Model does not have attention processors.")
|
||||
|
||||
# Test getting processors
|
||||
processors = model.attn_processors
|
||||
assert isinstance(processors, dict), "attn_processors should return a dict"
|
||||
assert len(processors) > 0, "Model should have at least one attention processor"
|
||||
|
||||
# Test that all processors can be retrieved via get_processor
|
||||
for module in model.modules():
|
||||
if isinstance(module, AttentionModuleMixin):
|
||||
processor = module.get_processor()
|
||||
assert processor is not None, "get_processor should return a processor"
|
||||
|
||||
# Test setting a new processor
|
||||
new_processor = AttnProcessor()
|
||||
module.set_processor(new_processor)
|
||||
retrieved_processor = module.get_processor()
|
||||
assert retrieved_processor is new_processor, "Retrieved processor should be the same as the one set"
|
||||
|
||||
def test_attention_processor_dict(self):
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
if not hasattr(model, "set_attn_processor"):
|
||||
pytest.skip("Model does not support setting attention processors.")
|
||||
|
||||
# Get current processors
|
||||
current_processors = model.attn_processors
|
||||
|
||||
# Create a dict of new processors
|
||||
new_processors = {key: AttnProcessor() for key in current_processors.keys()}
|
||||
|
||||
# Set processors using dict
|
||||
model.set_attn_processor(new_processors)
|
||||
|
||||
# Verify all processors were set
|
||||
updated_processors = model.attn_processors
|
||||
for key in current_processors.keys():
|
||||
assert type(updated_processors[key]) == AttnProcessor, f"Processor {key} should be AttnProcessor"
|
||||
|
||||
def test_attention_processor_count_mismatch_raises_error(self):
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
if not hasattr(model, "set_attn_processor"):
|
||||
pytest.skip("Model does not support setting attention processors.")
|
||||
|
||||
# Get current processors
|
||||
current_processors = model.attn_processors
|
||||
|
||||
# Create a dict with wrong number of processors
|
||||
wrong_processors = {list(current_processors.keys())[0]: AttnProcessor()}
|
||||
|
||||
# Verify error is raised
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
model.set_attn_processor(wrong_processors)
|
||||
|
||||
assert "number of processors" in str(exc_info.value).lower(), "Error should mention processor count mismatch"
|
||||
@@ -1,514 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from accelerate.utils.modeling import _get_proper_dtype, compute_module_sizes, dtype_byte_size
|
||||
|
||||
from diffusers.utils import SAFE_WEIGHTS_INDEX_NAME, _add_variant
|
||||
from diffusers.utils.testing_utils import require_accelerator, require_torch_multi_accelerator
|
||||
|
||||
from ...testing_utils import torch_device
|
||||
|
||||
|
||||
def compute_module_persistent_sizes(
|
||||
model: nn.Module,
|
||||
dtype: Optional[Union[str, torch.device]] = None,
|
||||
special_dtypes: Optional[Dict[str, Union[str, torch.device]]] = None,
|
||||
):
|
||||
"""
|
||||
Compute the size of each submodule of a given model (parameters + persistent buffers).
|
||||
"""
|
||||
if dtype is not None:
|
||||
dtype = _get_proper_dtype(dtype)
|
||||
dtype_size = dtype_byte_size(dtype)
|
||||
if special_dtypes is not None:
|
||||
special_dtypes = {key: _get_proper_dtype(dtyp) for key, dtyp in special_dtypes.items()}
|
||||
special_dtypes_size = {key: dtype_byte_size(dtyp) for key, dtyp in special_dtypes.items()}
|
||||
module_sizes = defaultdict(int)
|
||||
|
||||
module_list = []
|
||||
|
||||
module_list = named_persistent_module_tensors(model, recurse=True)
|
||||
|
||||
for name, tensor in module_list:
|
||||
if special_dtypes is not None and name in special_dtypes:
|
||||
size = tensor.numel() * special_dtypes_size[name]
|
||||
elif dtype is None:
|
||||
size = tensor.numel() * dtype_byte_size(tensor.dtype)
|
||||
elif str(tensor.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
|
||||
# According to the code in set_module_tensor_to_device, these types won't be converted
|
||||
# so use their original size here
|
||||
size = tensor.numel() * dtype_byte_size(tensor.dtype)
|
||||
else:
|
||||
size = tensor.numel() * min(dtype_size, dtype_byte_size(tensor.dtype))
|
||||
name_parts = name.split(".")
|
||||
for idx in range(len(name_parts) + 1):
|
||||
module_sizes[".".join(name_parts[:idx])] += size
|
||||
|
||||
return module_sizes
|
||||
|
||||
|
||||
def calculate_expected_num_shards(index_map_path):
|
||||
"""
|
||||
Calculate expected number of shards from index file.
|
||||
|
||||
Args:
|
||||
index_map_path: Path to the sharded checkpoint index file
|
||||
|
||||
Returns:
|
||||
int: Expected number of shards
|
||||
"""
|
||||
with open(index_map_path) as f:
|
||||
weight_map_dict = json.load(f)["weight_map"]
|
||||
first_key = list(weight_map_dict.keys())[0]
|
||||
weight_loc = weight_map_dict[first_key] # e.g., diffusion_pytorch_model-00001-of-00002.safetensors
|
||||
expected_num_shards = int(weight_loc.split("-")[-1].split(".")[0])
|
||||
return expected_num_shards
|
||||
|
||||
|
||||
def check_device_map_is_respected(model, device_map):
|
||||
for param_name, param in model.named_parameters():
|
||||
# Find device in device_map
|
||||
while len(param_name) > 0 and param_name not in device_map:
|
||||
param_name = ".".join(param_name.split(".")[:-1])
|
||||
if param_name not in device_map:
|
||||
raise ValueError("device map is incomplete, it does not contain any device for `param_name`.")
|
||||
|
||||
param_device = device_map[param_name]
|
||||
if param_device in ["cpu", "disk"]:
|
||||
assert param.device == torch.device("meta"), f"Expected device 'meta' for {param_name}, got {param.device}"
|
||||
else:
|
||||
assert param.device == torch.device(
|
||||
param_device
|
||||
), f"Expected device {param_device} for {param_name}, got {param.device}"
|
||||
|
||||
|
||||
class ModelTesterMixin:
|
||||
"""
|
||||
Base mixin class for model testing with common test methods.
|
||||
|
||||
Expected class attributes to be set by subclasses:
|
||||
- model_class: The model class to test
|
||||
- main_input_name: Name of the main input tensor (e.g., "sample", "hidden_states")
|
||||
- base_precision: Default tolerance for floating point comparisons (default: 1e-3)
|
||||
|
||||
Expected methods to be implemented by subclasses:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
"""
|
||||
|
||||
model_class = None
|
||||
base_precision = 1e-3
|
||||
model_split_percents = [0.5, 0.7]
|
||||
|
||||
def get_init_dict(self):
|
||||
raise NotImplementedError("get_init_dict must be implemented by subclasses. ")
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
raise NotImplementedError(
|
||||
"get_dummy_inputs must be implemented by subclasses. " "It should return inputs_dict."
|
||||
)
|
||||
|
||||
def test_from_save_pretrained(self, expected_max_diff=5e-5):
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
new_model = self.model_class.from_pretrained(tmpdirname)
|
||||
new_model.to(torch_device)
|
||||
|
||||
# check if all parameters shape are the same
|
||||
for param_name in model.state_dict().keys():
|
||||
param_1 = model.state_dict()[param_name]
|
||||
param_2 = new_model.state_dict()[param_name]
|
||||
assert (
|
||||
param_1.shape == param_2.shape
|
||||
), f"Parameter shape mismatch for {param_name}. Original: {param_1.shape}, loaded: {param_2.shape}"
|
||||
|
||||
with torch.no_grad():
|
||||
image = model(**self.get_dummy_inputs())
|
||||
|
||||
if isinstance(image, dict):
|
||||
image = image.to_tuple()[0]
|
||||
|
||||
new_image = new_model(**self.get_dummy_inputs())
|
||||
|
||||
if isinstance(new_image, dict):
|
||||
new_image = new_image.to_tuple()[0]
|
||||
|
||||
max_diff = (image - new_image).abs().max().item()
|
||||
assert (
|
||||
max_diff <= expected_max_diff
|
||||
), f"Models give different forward passes. Max diff: {max_diff}, expected: {expected_max_diff}"
|
||||
|
||||
def test_from_save_pretrained_variant(self, expected_max_diff=5e-5):
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname, variant="fp16")
|
||||
new_model = self.model_class.from_pretrained(tmpdirname, variant="fp16")
|
||||
|
||||
# non-variant cannot be loaded
|
||||
with pytest.raises(OSError) as exc_info:
|
||||
self.model_class.from_pretrained(tmpdirname)
|
||||
|
||||
# make sure that error message states what keys are missing
|
||||
assert "Error no file named diffusion_pytorch_model.bin found in directory" in str(exc_info.value)
|
||||
|
||||
new_model.to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
image = model(**self.get_dummy_inputs())
|
||||
if isinstance(image, dict):
|
||||
image = image.to_tuple()[0]
|
||||
|
||||
new_image = new_model(**self.get_dummy_inputs())
|
||||
|
||||
if isinstance(new_image, dict):
|
||||
new_image = new_image.to_tuple()[0]
|
||||
|
||||
max_diff = (image - new_image).abs().max().item()
|
||||
assert (
|
||||
max_diff <= expected_max_diff
|
||||
), f"Models give different forward passes. Max diff: {max_diff}, expected: {expected_max_diff}"
|
||||
|
||||
def test_from_save_pretrained_dtype(self):
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
|
||||
if torch_device == "mps" and dtype == torch.bfloat16:
|
||||
continue
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.to(dtype)
|
||||
model.save_pretrained(tmpdirname)
|
||||
new_model = self.model_class.from_pretrained(tmpdirname, low_cpu_mem_usage=True, torch_dtype=dtype)
|
||||
assert new_model.dtype == dtype
|
||||
if (
|
||||
hasattr(self.model_class, "_keep_in_fp32_modules")
|
||||
and self.model_class._keep_in_fp32_modules is None
|
||||
):
|
||||
# When loading without accelerate dtype == torch.float32 if _keep_in_fp32_modules is not None
|
||||
new_model = self.model_class.from_pretrained(
|
||||
tmpdirname, low_cpu_mem_usage=False, torch_dtype=dtype
|
||||
)
|
||||
assert new_model.dtype == dtype
|
||||
|
||||
def test_determinism(self, expected_max_diff=1e-5):
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
first = model(**self.get_dummy_inputs())
|
||||
if isinstance(first, dict):
|
||||
first = first.to_tuple()[0]
|
||||
|
||||
second = model(**self.get_dummy_inputs())
|
||||
if isinstance(second, dict):
|
||||
second = second.to_tuple()[0]
|
||||
|
||||
# Remove NaN values and compute max difference
|
||||
first_flat = first.flatten()
|
||||
second_flat = second.flatten()
|
||||
|
||||
# Filter out NaN values
|
||||
mask = ~(torch.isnan(first_flat) | torch.isnan(second_flat))
|
||||
first_filtered = first_flat[mask]
|
||||
second_filtered = second_flat[mask]
|
||||
|
||||
max_diff = torch.abs(first_filtered - second_filtered).max().item()
|
||||
assert (
|
||||
max_diff <= expected_max_diff
|
||||
), f"Model outputs are not deterministic. Max diff: {max_diff}, expected: {expected_max_diff}"
|
||||
|
||||
def test_output(self, expected_output_shape=None):
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_dict)
|
||||
|
||||
if isinstance(output, dict):
|
||||
output = output.to_tuple()[0]
|
||||
|
||||
assert output is not None, "Model output is None"
|
||||
assert (
|
||||
output.shape == expected_output_shape
|
||||
), f"Output shape does not match expected. Expected {expected_output_shape}, got {output.shape}"
|
||||
|
||||
def test_outputs_equivalence(self):
|
||||
def set_nan_tensor_to_zero(t):
|
||||
# Temporary fallback until `aten::_index_put_impl_` is implemented in mps
|
||||
# Track progress in https://github.com/pytorch/pytorch/issues/77764
|
||||
device = t.device
|
||||
if device.type == "mps":
|
||||
t = t.to("cpu")
|
||||
t[t != t] = 0
|
||||
return t.to(device)
|
||||
|
||||
def recursive_check(tuple_object, dict_object):
|
||||
if isinstance(tuple_object, (List, Tuple)):
|
||||
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()):
|
||||
recursive_check(tuple_iterable_value, dict_iterable_value)
|
||||
elif isinstance(tuple_object, Dict):
|
||||
for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()):
|
||||
recursive_check(tuple_iterable_value, dict_iterable_value)
|
||||
elif tuple_object is None:
|
||||
return
|
||||
else:
|
||||
assert torch.allclose(
|
||||
set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5
|
||||
), (
|
||||
"Tuple and dict output are not equal. Difference:"
|
||||
f" {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`:"
|
||||
f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has"
|
||||
f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}."
|
||||
)
|
||||
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
outputs_dict = model(**self.get_dummy_inputs())
|
||||
outputs_tuple = model(**self.get_dummy_inputs(), return_dict=False)
|
||||
|
||||
recursive_check(outputs_tuple, outputs_dict)
|
||||
|
||||
def test_model_config_to_json_string(self):
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
|
||||
json_string = model.config.to_json_string()
|
||||
assert isinstance(json_string, str), "Config to_json_string should return a string"
|
||||
assert len(json_string) > 0, "JSON string should not be empty"
|
||||
|
||||
@require_accelerator
|
||||
@pytest.mark.skipif(torch_device not in ["cuda", "xpu"])
|
||||
def test_from_save_pretrained_float16_bfloat16(self):
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
model.to(torch_device)
|
||||
fp32_modules = model._keep_in_fp32_modules
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
for torch_dtype in [torch.bfloat16, torch.float16]:
|
||||
model.to(torch_dtype).save_pretrained(tmp_dir)
|
||||
model_loaded = self.model_class.from_pretrained(tmp_dir, torch_dtype=torch_dtype).to(torch_device)
|
||||
|
||||
for name, param in model_loaded.named_parameters():
|
||||
if any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in fp32_modules):
|
||||
assert param.data.dtype == torch.float32
|
||||
else:
|
||||
assert param.data.dtype == torch_dtype
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(**get_dummy_inputs())
|
||||
output_loaded = model_loaded(**get_dummy_inputs())
|
||||
|
||||
assert torch.allclose(
|
||||
output, output_loaded, atol=1e-4
|
||||
), f"Loaded model output differs for {torch_dtype}"
|
||||
|
||||
@require_accelerator
|
||||
def test_sharded_checkpoints(self):
|
||||
torch.manual_seed(0)
|
||||
config = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**config).eval()
|
||||
model = model.to(torch_device)
|
||||
|
||||
base_output = model(**inputs_dict)
|
||||
|
||||
model_size = compute_module_persistent_sizes(model)[""]
|
||||
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB")
|
||||
assert os.path.exists(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)), "Index file should exist"
|
||||
|
||||
# Check if the right number of shards exists
|
||||
expected_num_shards = calculate_expected_num_shards(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME))
|
||||
actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(".safetensors")])
|
||||
assert (
|
||||
actual_num_shards == expected_num_shards
|
||||
), f"Expected {expected_num_shards} shards, got {actual_num_shards}"
|
||||
|
||||
new_model = self.model_class.from_pretrained(tmp_dir).eval()
|
||||
new_model = new_model.to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
inputs_dict_new = self.get_dummy_inputs()
|
||||
new_output = new_model(**inputs_dict_new)
|
||||
|
||||
assert torch.allclose(
|
||||
base_output[0], new_output[0], atol=1e-5
|
||||
), "Output should match after sharded save/load"
|
||||
|
||||
@require_accelerator
|
||||
def test_sharded_checkpoints_with_variant(self):
|
||||
torch.manual_seed(0)
|
||||
config = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**config).eval()
|
||||
model = model.to(torch_device)
|
||||
|
||||
base_output = model(**inputs_dict)
|
||||
|
||||
model_size = compute_module_persistent_sizes(model)[""]
|
||||
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small
|
||||
variant = "fp16"
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB", variant=variant)
|
||||
|
||||
index_filename = _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)
|
||||
assert os.path.exists(
|
||||
os.path.join(tmp_dir, index_filename)
|
||||
), f"Variant index file {index_filename} should exist"
|
||||
|
||||
# Check if the right number of shards exists
|
||||
expected_num_shards = calculate_expected_num_shards(os.path.join(tmp_dir, index_filename))
|
||||
actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(".safetensors")])
|
||||
assert (
|
||||
actual_num_shards == expected_num_shards
|
||||
), f"Expected {expected_num_shards} shards, got {actual_num_shards}"
|
||||
|
||||
new_model = self.model_class.from_pretrained(tmp_dir, variant=variant).eval()
|
||||
new_model = new_model.to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
inputs_dict_new = self.get_dummy_inputs()
|
||||
new_output = new_model(**inputs_dict_new)
|
||||
|
||||
assert torch.allclose(
|
||||
base_output[0], new_output[0], atol=1e-5
|
||||
), "Output should match after variant sharded save/load"
|
||||
|
||||
@require_accelerator
|
||||
def test_sharded_checkpoints_with_parallel_loading(self):
|
||||
import time
|
||||
|
||||
from diffusers.utils import constants
|
||||
|
||||
torch.manual_seed(0)
|
||||
config = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**config).eval()
|
||||
model = model.to(torch_device)
|
||||
|
||||
base_output = model(**inputs_dict)
|
||||
|
||||
model_size = compute_module_persistent_sizes(model)[""]
|
||||
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small
|
||||
|
||||
# Save original values to restore after test
|
||||
original_parallel_loading = constants.HF_ENABLE_PARALLEL_LOADING
|
||||
original_parallel_workers = getattr(constants, "HF_PARALLEL_WORKERS", None)
|
||||
|
||||
try:
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB")
|
||||
assert os.path.exists(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)), "Index file should exist"
|
||||
|
||||
# Check if the right number of shards exists
|
||||
expected_num_shards = calculate_expected_num_shards(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME))
|
||||
actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(".safetensors")])
|
||||
assert (
|
||||
actual_num_shards == expected_num_shards
|
||||
), f"Expected {expected_num_shards} shards, got {actual_num_shards}"
|
||||
|
||||
# Load without parallel loading
|
||||
constants.HF_ENABLE_PARALLEL_LOADING = False
|
||||
start_time = time.time()
|
||||
model_sequential = self.model_class.from_pretrained(tmp_dir).eval()
|
||||
sequential_load_time = time.time() - start_time
|
||||
model_sequential = model_sequential.to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
# Load with parallel loading
|
||||
constants.HF_ENABLE_PARALLEL_LOADING = True
|
||||
constants.DEFAULT_HF_PARALLEL_LOADING_WORKERS = 2
|
||||
|
||||
start_time = time.time()
|
||||
model_parallel = self.model_class.from_pretrained(tmp_dir).eval()
|
||||
parallel_load_time = time.time() - start_time
|
||||
model_parallel = model_parallel.to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
inputs_dict_parallel = self.get_dummy_inputs()
|
||||
output_parallel = model_parallel(**inputs_dict_parallel)
|
||||
|
||||
assert torch.allclose(
|
||||
base_output[0], output_parallel[0], atol=1e-5
|
||||
), "Output should match with parallel loading"
|
||||
|
||||
# Verify parallel loading is faster or at least not significantly slower
|
||||
# For small test models, the difference might be negligible or even slightly slower due to overhead
|
||||
# so we just check that parallel loading completed successfully and outputs match
|
||||
assert (
|
||||
parallel_load_time < sequential_load_time
|
||||
), f"Parallel loading took {parallel_load_time:.4f}s, sequential took {sequential_load_time:.4f}s"
|
||||
finally:
|
||||
# Restore original values
|
||||
constants.HF_ENABLE_PARALLEL_LOADING = original_parallel_loading
|
||||
if original_parallel_workers is not None:
|
||||
constants.HF_PARALLEL_WORKERS = original_parallel_workers
|
||||
|
||||
@require_torch_multi_accelerator
|
||||
def test_model_parallelism(self):
|
||||
if self.model_class._no_split_modules is None:
|
||||
pytest.skip("Test not supported for this model as `_no_split_modules` is not set.")
|
||||
|
||||
config = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**config).eval()
|
||||
|
||||
model = model.to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
base_output = model(**inputs_dict)
|
||||
|
||||
model_size = compute_module_sizes(model)[""]
|
||||
max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.cpu().save_pretrained(tmp_dir)
|
||||
|
||||
for max_size in max_gpu_sizes:
|
||||
max_memory = {0: max_size, 1: model_size * 2, "cpu": model_size * 2}
|
||||
new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
|
||||
# Making sure part of the model will be on GPU 0 and GPU 1
|
||||
assert set(new_model.hf_device_map.values()) == {0, 1}, "Model should be split across GPUs"
|
||||
|
||||
check_device_map_is_respected(new_model, new_model.hf_device_map)
|
||||
|
||||
torch.manual_seed(0)
|
||||
new_output = new_model(**inputs_dict)
|
||||
|
||||
assert torch.allclose(
|
||||
base_output[0], new_output[0], atol=1e-5
|
||||
), "Output should match with model parallelism"
|
||||
@@ -1,162 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from ...testing_utils import (
|
||||
backend_empty_cache,
|
||||
is_torch_compile,
|
||||
require_accelerator,
|
||||
require_torch_version_greater,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
@is_torch_compile
|
||||
@require_accelerator
|
||||
@require_torch_version_greater("2.7.1")
|
||||
class TorchCompileTesterMixin:
|
||||
"""
|
||||
Mixin class for testing torch.compile functionality on models.
|
||||
|
||||
Expected class attributes to be set by subclasses:
|
||||
- model_class: The model class to test
|
||||
- different_shapes_for_compilation: Optional list of (height, width) tuples for dynamic shape testing
|
||||
|
||||
Expected methods to be implemented by subclasses:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
Pytest mark: compile
|
||||
Use `pytest -m "not compile"` to skip these tests
|
||||
"""
|
||||
|
||||
different_shapes_for_compilation = None
|
||||
|
||||
def setup_method(self):
|
||||
torch.compiler.reset()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def teardown_method(self):
|
||||
torch.compiler.reset()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_torch_compile_recompilation_and_graph_break(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.eval()
|
||||
model = torch.compile(model, fullgraph=True)
|
||||
|
||||
with (
|
||||
torch._inductor.utils.fresh_inductor_cache(),
|
||||
torch._dynamo.config.patch(error_on_recompile=True),
|
||||
torch.no_grad(),
|
||||
):
|
||||
_ = model(**inputs_dict)
|
||||
_ = model(**inputs_dict)
|
||||
|
||||
def test_torch_compile_repeated_blocks(self):
|
||||
if self.model_class._repeated_blocks is None:
|
||||
pytest.skip("Skipping test as the model class doesn't have `_repeated_blocks` set.")
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.eval()
|
||||
model.compile_repeated_blocks(fullgraph=True)
|
||||
|
||||
recompile_limit = 1
|
||||
if self.model_class.__name__ == "UNet2DConditionModel":
|
||||
recompile_limit = 2
|
||||
|
||||
with (
|
||||
torch._inductor.utils.fresh_inductor_cache(),
|
||||
torch._dynamo.config.patch(recompile_limit=recompile_limit),
|
||||
torch.no_grad(),
|
||||
):
|
||||
_ = model(**inputs_dict)
|
||||
_ = model(**inputs_dict)
|
||||
|
||||
def test_compile_with_group_offloading(self):
|
||||
if not self.model_class._supports_group_offloading:
|
||||
pytest.skip("Model does not support group offloading.")
|
||||
|
||||
torch._dynamo.config.cache_size_limit = 10000
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict)
|
||||
model.eval()
|
||||
|
||||
group_offload_kwargs = {
|
||||
"onload_device": torch_device,
|
||||
"offload_device": "cpu",
|
||||
"offload_type": "block_level",
|
||||
"num_blocks_per_group": 1,
|
||||
"use_stream": True,
|
||||
"non_blocking": True,
|
||||
}
|
||||
model.enable_group_offload(**group_offload_kwargs)
|
||||
model.compile()
|
||||
|
||||
with torch.no_grad():
|
||||
_ = model(**inputs_dict)
|
||||
_ = model(**inputs_dict)
|
||||
|
||||
def test_compile_on_different_shapes(self):
|
||||
if self.different_shapes_for_compilation is None:
|
||||
pytest.skip(f"Skipping as `different_shapes_for_compilation` is not set for {self.__class__.__name__}.")
|
||||
torch.fx.experimental._config.use_duck_shape = False
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.eval()
|
||||
model = torch.compile(model, fullgraph=True, dynamic=True)
|
||||
|
||||
for height, width in self.different_shapes_for_compilation:
|
||||
with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad():
|
||||
inputs_dict = self.get_dummy_inputs(height=height, width=width)
|
||||
_ = model(**inputs_dict)
|
||||
|
||||
def test_compile_works_with_aot(self):
|
||||
from torch._inductor.package import load_package
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
exported_model = torch.export.export(model, args=(), kwargs=inputs_dict)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
package_path = os.path.join(tmpdir, f"{self.model_class.__name__}.pt2")
|
||||
_ = torch._inductor.aoti_compile_and_package(exported_model, package_path=package_path)
|
||||
assert os.path.exists(package_path), f"Package file not created at {package_path}"
|
||||
loaded_binary = load_package(package_path, run_single_threaded=True)
|
||||
|
||||
model.forward = loaded_binary
|
||||
|
||||
with torch.no_grad():
|
||||
_ = model(**inputs_dict)
|
||||
_ = model(**inputs_dict)
|
||||
@@ -1,109 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import tempfile
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from huggingface_hub.utils import is_jinja_available
|
||||
|
||||
from ...others.test_utils import TOKEN, USER, is_staging_test
|
||||
|
||||
|
||||
@is_staging_test
|
||||
class ModelPushToHubTesterMixin:
|
||||
"""
|
||||
Mixin class for testing push_to_hub functionality on models.
|
||||
|
||||
Expected class attributes to be set by subclasses:
|
||||
- model_class: The model class to test
|
||||
|
||||
Expected methods to be implemented by subclasses:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
"""
|
||||
|
||||
identifier = uuid.uuid4()
|
||||
repo_id = f"test-model-{identifier}"
|
||||
org_repo_id = f"valid_org/{repo_id}-org"
|
||||
|
||||
def test_push_to_hub(self):
|
||||
"""Test pushing model to hub and loading it back."""
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict)
|
||||
model.push_to_hub(self.repo_id, token=TOKEN)
|
||||
|
||||
new_model = self.model_class.from_pretrained(f"{USER}/{self.repo_id}")
|
||||
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||
assert torch.equal(p1, p2), "Parameters don't match after push_to_hub and from_pretrained"
|
||||
|
||||
# Reset repo
|
||||
delete_repo(token=TOKEN, repo_id=self.repo_id)
|
||||
|
||||
# Push to hub via save_pretrained
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir, repo_id=self.repo_id, push_to_hub=True, token=TOKEN)
|
||||
|
||||
new_model = self.model_class.from_pretrained(f"{USER}/{self.repo_id}")
|
||||
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||
assert torch.equal(
|
||||
p1, p2
|
||||
), "Parameters don't match after save_pretrained with push_to_hub and from_pretrained"
|
||||
|
||||
# Reset repo
|
||||
delete_repo(self.repo_id, token=TOKEN)
|
||||
|
||||
def test_push_to_hub_in_organization(self):
|
||||
"""Test pushing model to hub in organization namespace."""
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict)
|
||||
model.push_to_hub(self.org_repo_id, token=TOKEN)
|
||||
|
||||
new_model = self.model_class.from_pretrained(self.org_repo_id)
|
||||
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||
assert torch.equal(p1, p2), "Parameters don't match after push_to_hub to org and from_pretrained"
|
||||
|
||||
# Reset repo
|
||||
delete_repo(token=TOKEN, repo_id=self.org_repo_id)
|
||||
|
||||
# Push to hub via save_pretrained
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir, push_to_hub=True, token=TOKEN, repo_id=self.org_repo_id)
|
||||
|
||||
new_model = self.model_class.from_pretrained(self.org_repo_id)
|
||||
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||
assert torch.equal(
|
||||
p1, p2
|
||||
), "Parameters don't match after save_pretrained with push_to_hub to org and from_pretrained"
|
||||
|
||||
# Reset repo
|
||||
delete_repo(self.org_repo_id, token=TOKEN)
|
||||
|
||||
def test_push_to_hub_library_name(self):
|
||||
"""Test that library_name in model card is set to 'diffusers'."""
|
||||
if not is_jinja_available():
|
||||
pytest.skip("Model card tests cannot be performed without Jinja installed.")
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict)
|
||||
model.push_to_hub(self.repo_id, token=TOKEN)
|
||||
|
||||
model_card = ModelCard.load(f"{USER}/{self.repo_id}", token=TOKEN).data
|
||||
assert (
|
||||
model_card.library_name == "diffusers"
|
||||
), f"Expected library_name 'diffusers', got {model_card.library_name}"
|
||||
|
||||
# Reset repo
|
||||
delete_repo(self.repo_id, token=TOKEN)
|
||||
@@ -1,205 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers.models.attention_processor import IPAdapterAttnProcessor
|
||||
|
||||
from ...testing_utils import is_ip_adapter, torch_device
|
||||
|
||||
|
||||
def create_ip_adapter_state_dict(model):
|
||||
"""
|
||||
Create a dummy IP Adapter state dict for testing.
|
||||
|
||||
Args:
|
||||
model: The model to create IP adapter weights for
|
||||
|
||||
Returns:
|
||||
dict: IP adapter state dict with to_k_ip and to_v_ip weights
|
||||
"""
|
||||
ip_state_dict = {}
|
||||
key_id = 1
|
||||
|
||||
for name in model.attn_processors.keys():
|
||||
# Skip self-attention processors
|
||||
cross_attention_dim = getattr(model.config, "cross_attention_dim", None)
|
||||
if cross_attention_dim is None:
|
||||
continue
|
||||
|
||||
# Get hidden size based on model architecture
|
||||
hidden_size = getattr(model.config, "hidden_size", cross_attention_dim)
|
||||
|
||||
# Create IP adapter processor to get state dict structure
|
||||
sd = IPAdapterAttnProcessor(
|
||||
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0
|
||||
).state_dict()
|
||||
|
||||
ip_state_dict.update(
|
||||
{
|
||||
f"{key_id}.to_k_ip.weight": sd["to_k_ip.0.weight"],
|
||||
f"{key_id}.to_v_ip.weight": sd["to_v_ip.0.weight"],
|
||||
}
|
||||
)
|
||||
key_id += 2
|
||||
|
||||
return {"ip_adapter": ip_state_dict}
|
||||
|
||||
|
||||
def check_if_ip_adapter_correctly_set(model) -> bool:
|
||||
"""
|
||||
Check if IP Adapter processors are correctly set in the model.
|
||||
|
||||
Args:
|
||||
model: The model to check
|
||||
|
||||
Returns:
|
||||
bool: True if IP Adapter is correctly set, False otherwise
|
||||
"""
|
||||
for module in model.attn_processors.values():
|
||||
if isinstance(module, IPAdapterAttnProcessor):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@is_ip_adapter
|
||||
class IPAdapterTesterMixin:
|
||||
"""
|
||||
Mixin class for testing IP Adapter functionality on models.
|
||||
|
||||
Expected class attributes to be set by subclasses:
|
||||
- model_class: The model class to test
|
||||
|
||||
Expected methods to be implemented by subclasses:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
Pytest mark: ip_adapter
|
||||
Use `pytest -m "not ip_adapter"` to skip these tests
|
||||
"""
|
||||
|
||||
def create_ip_adapter_state_dict(self, model):
|
||||
raise NotImplementedError("child class must implement method to create IPAdapter State Dict")
|
||||
|
||||
def test_load_ip_adapter(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
output_no_adapter = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
# Create dummy IP adapter state dict
|
||||
ip_adapter_state_dict = self.create_ip_adapter_state_dict(model)
|
||||
|
||||
# Load IP adapter
|
||||
model._load_ip_adapter_weights([ip_adapter_state_dict])
|
||||
assert check_if_ip_adapter_correctly_set(model), "IP Adapter processors not set correctly"
|
||||
|
||||
torch.manual_seed(0)
|
||||
# Create dummy image embeds for IP adapter
|
||||
cross_attention_dim = getattr(model.config, "cross_attention_dim", 32)
|
||||
image_embeds = torch.randn(1, 1, cross_attention_dim).to(torch_device)
|
||||
inputs_dict_with_adapter = inputs_dict.copy()
|
||||
inputs_dict_with_adapter["image_embeds"] = image_embeds
|
||||
|
||||
outputs_with_adapter = model(**inputs_dict_with_adapter, return_dict=False)[0]
|
||||
|
||||
assert not torch.allclose(
|
||||
output_no_adapter, outputs_with_adapter, atol=1e-4, rtol=1e-4
|
||||
), "Output should differ with IP Adapter enabled"
|
||||
|
||||
def test_ip_adapter_scale(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
# Create and load dummy IP adapter state dict
|
||||
ip_adapter_state_dict = create_ip_adapter_state_dict(model)
|
||||
model._load_ip_adapter_weights([ip_adapter_state_dict])
|
||||
|
||||
# Test scale = 0.0 (no effect)
|
||||
model.set_ip_adapter_scale(0.0)
|
||||
torch.manual_seed(0)
|
||||
output_scale_zero = model(**inputs_dict_with_adapter, return_dict=False)[0]
|
||||
|
||||
# Test scale = 1.0 (full effect)
|
||||
model.set_ip_adapter_scale(1.0)
|
||||
torch.manual_seed(0)
|
||||
output_scale_one = model(**inputs_dict_with_adapter, return_dict=False)[0]
|
||||
|
||||
# Outputs should differ with different scales
|
||||
assert not torch.allclose(
|
||||
output_scale_zero, output_scale_one, atol=1e-4, rtol=1e-4
|
||||
), "Output should differ with different IP Adapter scales"
|
||||
|
||||
def test_unload_ip_adapter(self):
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
# Save original processors
|
||||
original_processors = {k: type(v).__name__ for k, v in model.attn_processors.items()}
|
||||
|
||||
# Create and load IP adapter
|
||||
ip_adapter_state_dict = create_ip_adapter_state_dict(model)
|
||||
model._load_ip_adapter_weights([ip_adapter_state_dict])
|
||||
assert check_if_ip_adapter_correctly_set(model), "IP Adapter should be set"
|
||||
|
||||
# Unload IP adapter
|
||||
model.unload_ip_adapter()
|
||||
assert not check_if_ip_adapter_correctly_set(model), "IP Adapter should be unloaded"
|
||||
|
||||
# Verify processors are restored
|
||||
current_processors = {k: type(v).__name__ for k, v in model.attn_processors.items()}
|
||||
assert original_processors == current_processors, "Processors should be restored after unload"
|
||||
|
||||
def test_ip_adapter_save_load(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
# Create and load IP adapter
|
||||
ip_adapter_state_dict = self.create_ip_adapter_state_dict()
|
||||
model._load_ip_adapter_weights([ip_adapter_state_dict])
|
||||
|
||||
torch.manual_seed(0)
|
||||
output_before_save = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Save the IP adapter weights
|
||||
save_path = os.path.join(tmpdir, "ip_adapter.safetensors")
|
||||
import safetensors.torch
|
||||
|
||||
safetensors.torch.save_file(ip_adapter_state_dict["ip_adapter"], save_path)
|
||||
|
||||
# Unload and reload
|
||||
model.unload_ip_adapter()
|
||||
assert not check_if_ip_adapter_correctly_set(model), "IP Adapter should be unloaded"
|
||||
|
||||
# Reload from saved file
|
||||
loaded_state_dict = {"ip_adapter": safetensors.torch.load_file(save_path)}
|
||||
model._load_ip_adapter_weights([loaded_state_dict])
|
||||
assert check_if_ip_adapter_correctly_set(model), "IP Adapter should be loaded"
|
||||
|
||||
torch.manual_seed(0)
|
||||
output_after_load = model(**inputs_dict_with_adapter, return_dict=False)[0]
|
||||
|
||||
# Outputs should match before and after save/load
|
||||
assert torch.allclose(
|
||||
output_before_save, output_after_load, atol=1e-4, rtol=1e-4
|
||||
), "Output should match before and after save/load"
|
||||
@@ -1,220 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
import safetensors.torch
|
||||
import torch
|
||||
|
||||
from diffusers.utils.testing_utils import check_if_dicts_are_equal
|
||||
|
||||
from ...testing_utils import is_lora, require_peft_backend, torch_device
|
||||
|
||||
|
||||
def check_if_lora_correctly_set(model) -> bool:
|
||||
"""
|
||||
Check if LoRA layers are correctly set in the model.
|
||||
|
||||
Args:
|
||||
model: The model to check
|
||||
|
||||
Returns:
|
||||
bool: True if LoRA is correctly set, False otherwise
|
||||
"""
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
for module in model.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@is_lora
|
||||
@require_peft_backend
|
||||
class LoraTesterMixin:
|
||||
"""
|
||||
Mixin class for testing LoRA/PEFT functionality on models.
|
||||
|
||||
Expected class attributes to be set by subclasses:
|
||||
- model_class: The model class to test
|
||||
|
||||
Expected methods to be implemented by subclasses:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
Pytest mark: lora
|
||||
Use `pytest -m "not lora"` to skip these tests
|
||||
"""
|
||||
|
||||
def setup_method(self):
|
||||
from diffusers.loaders.peft import PeftAdapterMixin
|
||||
|
||||
if not issubclass(self.model_class, PeftAdapterMixin):
|
||||
pytest.skip(f"PEFT is not supported for this model ({self.model_class.__name__}).")
|
||||
|
||||
def test_save_load_lora_adapter(self, rank=4, lora_alpha=4, use_dora=False):
|
||||
from peft import LoraConfig
|
||||
from peft.utils import get_peft_model_state_dict
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
output_no_lora = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
denoiser_lora_config = LoraConfig(
|
||||
r=rank,
|
||||
lora_alpha=lora_alpha,
|
||||
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
|
||||
init_lora_weights=False,
|
||||
use_dora=use_dora,
|
||||
)
|
||||
model.add_adapter(denoiser_lora_config)
|
||||
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly"
|
||||
|
||||
torch.manual_seed(0)
|
||||
outputs_with_lora = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
assert not torch.allclose(
|
||||
output_no_lora, outputs_with_lora, atol=1e-4, rtol=1e-4
|
||||
), "Output should differ with LoRA enabled"
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
model.save_lora_adapter(tmpdir)
|
||||
assert os.path.isfile(
|
||||
os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
|
||||
), "LoRA weights file not created"
|
||||
|
||||
state_dict_loaded = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
|
||||
|
||||
model.unload_lora()
|
||||
assert not check_if_lora_correctly_set(model), "LoRA should be unloaded"
|
||||
|
||||
model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True)
|
||||
state_dict_retrieved = get_peft_model_state_dict(model, adapter_name="default_0")
|
||||
|
||||
for k in state_dict_loaded:
|
||||
loaded_v = state_dict_loaded[k]
|
||||
retrieved_v = state_dict_retrieved[k].to(loaded_v.device)
|
||||
assert torch.allclose(loaded_v, retrieved_v), f"Mismatch in LoRA weight {k}"
|
||||
|
||||
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly after reload"
|
||||
|
||||
torch.manual_seed(0)
|
||||
outputs_with_lora_2 = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
assert not torch.allclose(
|
||||
output_no_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4
|
||||
), "Output should differ with LoRA enabled"
|
||||
assert torch.allclose(
|
||||
outputs_with_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4
|
||||
), "Outputs should match before and after save/load"
|
||||
|
||||
def test_lora_wrong_adapter_name_raises_error(self):
|
||||
from peft import LoraConfig
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
denoiser_lora_config = LoraConfig(
|
||||
r=4,
|
||||
lora_alpha=4,
|
||||
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
|
||||
init_lora_weights=False,
|
||||
use_dora=False,
|
||||
)
|
||||
model.add_adapter(denoiser_lora_config)
|
||||
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly"
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
wrong_name = "foo"
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
model.save_lora_adapter(tmpdir, adapter_name=wrong_name)
|
||||
|
||||
assert f"Adapter name {wrong_name} not found in the model." in str(exc_info.value)
|
||||
|
||||
def test_lora_adapter_metadata_is_loaded_correctly(self, rank=4, lora_alpha=4, use_dora=False):
|
||||
from peft import LoraConfig
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
denoiser_lora_config = LoraConfig(
|
||||
r=rank,
|
||||
lora_alpha=lora_alpha,
|
||||
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
|
||||
init_lora_weights=False,
|
||||
use_dora=use_dora,
|
||||
)
|
||||
model.add_adapter(denoiser_lora_config)
|
||||
metadata = model.peft_config["default"].to_dict()
|
||||
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly"
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
model.save_lora_adapter(tmpdir)
|
||||
model_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
|
||||
assert os.path.isfile(model_file), "LoRA weights file not created"
|
||||
|
||||
model.unload_lora()
|
||||
assert not check_if_lora_correctly_set(model), "LoRA should be unloaded"
|
||||
|
||||
model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True)
|
||||
parsed_metadata = model.peft_config["default_0"].to_dict()
|
||||
check_if_dicts_are_equal(metadata, parsed_metadata)
|
||||
|
||||
def test_lora_adapter_wrong_metadata_raises_error(self):
|
||||
from peft import LoraConfig
|
||||
|
||||
from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
denoiser_lora_config = LoraConfig(
|
||||
r=4,
|
||||
lora_alpha=4,
|
||||
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
|
||||
init_lora_weights=False,
|
||||
use_dora=False,
|
||||
)
|
||||
model.add_adapter(denoiser_lora_config)
|
||||
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly"
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
model.save_lora_adapter(tmpdir)
|
||||
model_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
|
||||
assert os.path.isfile(model_file), "LoRA weights file not created"
|
||||
|
||||
# Perturb the metadata in the state dict
|
||||
loaded_state_dict = safetensors.torch.load_file(model_file)
|
||||
metadata = {"format": "pt"}
|
||||
lora_adapter_metadata = denoiser_lora_config.to_dict()
|
||||
lora_adapter_metadata.update({"foo": 1, "bar": 2})
|
||||
for key, value in lora_adapter_metadata.items():
|
||||
if isinstance(value, set):
|
||||
lora_adapter_metadata[key] = list(value)
|
||||
metadata[LORA_ADAPTER_METADATA_KEY] = json.dumps(lora_adapter_metadata, indent=2, sort_keys=True)
|
||||
safetensors.torch.save_file(loaded_state_dict, model_file, metadata=metadata)
|
||||
|
||||
model.unload_lora()
|
||||
assert not check_if_lora_correctly_set(model), "LoRA should be unloaded"
|
||||
|
||||
with pytest.raises(TypeError) as exc_info:
|
||||
model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True)
|
||||
assert "`LoraConfig` class could not be instantiated" in str(exc_info.value)
|
||||
@@ -1,443 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import glob
|
||||
import inspect
|
||||
import tempfile
|
||||
from functools import wraps
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from accelerate.utils.modeling import compute_module_sizes
|
||||
|
||||
from diffusers.utils.testing_utils import _check_safetensors_serialization
|
||||
from diffusers.utils.torch_utils import get_torch_cuda_device_capability
|
||||
|
||||
from ...testing_utils import (
|
||||
backend_empty_cache,
|
||||
backend_max_memory_allocated,
|
||||
backend_reset_peak_memory_stats,
|
||||
backend_synchronize,
|
||||
is_cpu_offload,
|
||||
is_group_offload,
|
||||
is_memory,
|
||||
require_accelerator,
|
||||
torch_device,
|
||||
)
|
||||
from .common import check_device_map_is_respected
|
||||
|
||||
|
||||
def cast_maybe_tensor_dtype(inputs_dict, from_dtype, to_dtype):
|
||||
"""Helper to cast tensor inputs from one dtype to another."""
|
||||
for key, value in inputs_dict.items():
|
||||
if isinstance(value, torch.Tensor) and value.dtype == from_dtype:
|
||||
inputs_dict[key] = value.to(to_dtype)
|
||||
return inputs_dict
|
||||
|
||||
|
||||
def require_offload_support(func):
|
||||
"""
|
||||
Decorator to skip tests if model doesn't support offloading (requires _no_split_modules).
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
if self.model_class._no_split_modules is None:
|
||||
pytest.skip("Test not supported for this model as `_no_split_modules` is not set.")
|
||||
return func(self, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def require_group_offload_support(func):
|
||||
"""
|
||||
Decorator to skip tests if model doesn't support group offloading.
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
if not self.model_class._supports_group_offloading:
|
||||
pytest.skip("Model does not support group offloading.")
|
||||
return func(self, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@is_cpu_offload
|
||||
class CPUOffloadTesterMixin:
|
||||
"""
|
||||
Mixin class for testing CPU offloading functionality.
|
||||
|
||||
Expected class attributes to be set by subclasses:
|
||||
- model_class: The model class to test
|
||||
- model_split_percents: List of percentages for splitting model across devices
|
||||
|
||||
Expected methods to be implemented by subclasses:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
Pytest mark: cpu_offload
|
||||
Use `pytest -m "not cpu_offload"` to skip these tests
|
||||
"""
|
||||
|
||||
model_split_percents = [0.5, 0.7]
|
||||
|
||||
@require_offload_support
|
||||
def test_cpu_offload(self):
|
||||
config = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**config).eval()
|
||||
|
||||
model = model.to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
base_output = model(**inputs_dict)
|
||||
|
||||
model_size = compute_module_sizes(model)[""]
|
||||
# We test several splits of sizes to make sure it works
|
||||
max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents]
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.cpu().save_pretrained(tmp_dir)
|
||||
|
||||
for max_size in max_gpu_sizes:
|
||||
max_memory = {0: max_size, "cpu": model_size * 2}
|
||||
new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
|
||||
# Making sure part of the model will actually end up offloaded
|
||||
assert set(new_model.hf_device_map.values()) == {0, "cpu"}, "Model should be split between GPU and CPU"
|
||||
|
||||
check_device_map_is_respected(new_model, new_model.hf_device_map)
|
||||
torch.manual_seed(0)
|
||||
new_output = new_model(**inputs_dict)
|
||||
|
||||
assert torch.allclose(
|
||||
base_output[0], new_output[0], atol=1e-5
|
||||
), "Output should match with CPU offloading"
|
||||
|
||||
@require_offload_support
|
||||
def test_disk_offload_without_safetensors(self):
|
||||
config = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**config).eval()
|
||||
|
||||
model = model.to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
base_output = model(**inputs_dict)
|
||||
|
||||
model_size = compute_module_sizes(model)[""]
|
||||
max_size = int(self.model_split_percents[0] * model_size)
|
||||
# Force disk offload by setting very small CPU memory
|
||||
max_memory = {0: max_size, "cpu": int(0.1 * max_size)}
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.cpu().save_pretrained(tmp_dir, safe_serialization=False)
|
||||
# This errors out because it's missing an offload folder
|
||||
with pytest.raises(ValueError):
|
||||
new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
|
||||
|
||||
new_model = self.model_class.from_pretrained(
|
||||
tmp_dir, device_map="auto", max_memory=max_memory, offload_folder=tmp_dir
|
||||
)
|
||||
|
||||
check_device_map_is_respected(new_model, new_model.hf_device_map)
|
||||
torch.manual_seed(0)
|
||||
new_output = new_model(**inputs_dict)
|
||||
|
||||
assert torch.allclose(base_output[0], new_output[0], atol=1e-5), "Output should match with disk offloading"
|
||||
|
||||
@require_offload_support
|
||||
def test_disk_offload_with_safetensors(self):
|
||||
config = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**config).eval()
|
||||
|
||||
model = model.to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
base_output = model(**inputs_dict)
|
||||
|
||||
model_size = compute_module_sizes(model)[""]
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.cpu().save_pretrained(tmp_dir)
|
||||
|
||||
max_size = int(self.model_split_percents[0] * model_size)
|
||||
max_memory = {0: max_size, "cpu": max_size}
|
||||
new_model = self.model_class.from_pretrained(
|
||||
tmp_dir, device_map="auto", offload_folder=tmp_dir, max_memory=max_memory
|
||||
)
|
||||
|
||||
check_device_map_is_respected(new_model, new_model.hf_device_map)
|
||||
torch.manual_seed(0)
|
||||
new_output = new_model(**inputs_dict)
|
||||
|
||||
assert torch.allclose(
|
||||
base_output[0], new_output[0], atol=1e-5
|
||||
), "Output should match with disk offloading (safetensors)"
|
||||
|
||||
|
||||
@is_group_offload
|
||||
class GroupOffloadTesterMixin:
|
||||
"""
|
||||
Mixin class for testing group offloading functionality.
|
||||
|
||||
Expected class attributes to be set by subclasses:
|
||||
- model_class: The model class to test
|
||||
|
||||
Expected methods to be implemented by subclasses:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
Pytest mark: group_offload
|
||||
Use `pytest -m "not group_offload"` to skip these tests
|
||||
"""
|
||||
|
||||
@require_group_offload_support
|
||||
def test_group_offloading(self, record_stream=False):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
torch.manual_seed(0)
|
||||
|
||||
@torch.no_grad()
|
||||
def run_forward(model):
|
||||
assert all(
|
||||
module._diffusers_hook.get_hook("group_offloading") is not None
|
||||
for module in model.modules()
|
||||
if hasattr(module, "_diffusers_hook")
|
||||
), "Group offloading hook should be set"
|
||||
model.eval()
|
||||
return model(**inputs_dict)[0]
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
|
||||
model.to(torch_device)
|
||||
output_without_group_offloading = run_forward(model)
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict)
|
||||
model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1)
|
||||
output_with_group_offloading1 = run_forward(model)
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict)
|
||||
model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, non_blocking=True)
|
||||
output_with_group_offloading2 = run_forward(model)
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict)
|
||||
model.enable_group_offload(torch_device, offload_type="leaf_level")
|
||||
output_with_group_offloading3 = run_forward(model)
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict)
|
||||
model.enable_group_offload(
|
||||
torch_device, offload_type="leaf_level", use_stream=True, record_stream=record_stream
|
||||
)
|
||||
output_with_group_offloading4 = run_forward(model)
|
||||
|
||||
assert torch.allclose(
|
||||
output_without_group_offloading, output_with_group_offloading1, atol=1e-5
|
||||
), "Output should match with block-level offloading"
|
||||
assert torch.allclose(
|
||||
output_without_group_offloading, output_with_group_offloading2, atol=1e-5
|
||||
), "Output should match with non-blocking block-level offloading"
|
||||
assert torch.allclose(
|
||||
output_without_group_offloading, output_with_group_offloading3, atol=1e-5
|
||||
), "Output should match with leaf-level offloading"
|
||||
assert torch.allclose(
|
||||
output_without_group_offloading, output_with_group_offloading4, atol=1e-5
|
||||
), "Output should match with leaf-level offloading with stream"
|
||||
|
||||
@require_group_offload_support
|
||||
@torch.no_grad()
|
||||
def test_group_offloading_with_layerwise_casting(self, record_stream=False, offload_type="block_level"):
|
||||
torch.manual_seed(0)
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict)
|
||||
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
_ = model(**inputs_dict)[0]
|
||||
|
||||
torch.manual_seed(0)
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
storage_dtype, compute_dtype = torch.float16, torch.float32
|
||||
inputs_dict = cast_maybe_tensor_dtype(inputs_dict, torch.float32, compute_dtype)
|
||||
model = self.model_class(**init_dict)
|
||||
model.eval()
|
||||
additional_kwargs = {} if offload_type == "leaf_level" else {"num_blocks_per_group": 1}
|
||||
model.enable_group_offload(
|
||||
torch_device, offload_type=offload_type, use_stream=True, record_stream=record_stream, **additional_kwargs
|
||||
)
|
||||
model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
|
||||
_ = model(**inputs_dict)[0]
|
||||
|
||||
@require_group_offload_support
|
||||
@torch.no_grad()
|
||||
@torch.inference_mode()
|
||||
def test_group_offloading_with_disk(self, offload_type="block_level", record_stream=False, atol=1e-5):
|
||||
def _has_generator_arg(model):
|
||||
sig = inspect.signature(model.forward)
|
||||
params = sig.parameters
|
||||
return "generator" in params
|
||||
|
||||
def _run_forward(model, inputs_dict):
|
||||
accepts_generator = _has_generator_arg(model)
|
||||
if accepts_generator:
|
||||
inputs_dict["generator"] = torch.manual_seed(0)
|
||||
torch.manual_seed(0)
|
||||
return model(**inputs_dict)[0]
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict)
|
||||
|
||||
model.eval()
|
||||
model.to(torch_device)
|
||||
output_without_group_offloading = _run_forward(model, inputs_dict)
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict)
|
||||
model.eval()
|
||||
|
||||
num_blocks_per_group = None if offload_type == "leaf_level" else 1
|
||||
additional_kwargs = {} if offload_type == "leaf_level" else {"num_blocks_per_group": num_blocks_per_group}
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
model.enable_group_offload(
|
||||
torch_device,
|
||||
offload_type=offload_type,
|
||||
offload_to_disk_path=tmpdir,
|
||||
use_stream=True,
|
||||
record_stream=record_stream,
|
||||
**additional_kwargs,
|
||||
)
|
||||
has_safetensors = glob.glob(f"{tmpdir}/*.safetensors")
|
||||
assert has_safetensors, "No safetensors found in the directory."
|
||||
|
||||
# For "leaf-level", there is a prefetching hook which makes this check a bit non-deterministic
|
||||
# in nature. So, skip it.
|
||||
if offload_type != "leaf_level":
|
||||
is_correct, extra_files, missing_files = _check_safetensors_serialization(
|
||||
module=model,
|
||||
offload_to_disk_path=tmpdir,
|
||||
offload_type=offload_type,
|
||||
num_blocks_per_group=num_blocks_per_group,
|
||||
)
|
||||
if not is_correct:
|
||||
if extra_files:
|
||||
raise ValueError(f"Found extra files: {', '.join(extra_files)}")
|
||||
elif missing_files:
|
||||
raise ValueError(f"Following files are missing: {', '.join(missing_files)}")
|
||||
|
||||
output_with_group_offloading = _run_forward(model, inputs_dict)
|
||||
assert torch.allclose(
|
||||
output_without_group_offloading, output_with_group_offloading, atol=atol
|
||||
), "Output should match with disk-based group offloading"
|
||||
|
||||
|
||||
class LayerwiseCastingTesterMixin:
|
||||
"""
|
||||
Mixin class for testing layerwise dtype casting for memory optimization.
|
||||
|
||||
Expected class attributes to be set by subclasses:
|
||||
- model_class: The model class to test
|
||||
|
||||
Expected methods to be implemented by subclasses:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
"""
|
||||
|
||||
@torch.no_grad()
|
||||
def test_layerwise_casting_memory(self):
|
||||
MB_TOLERANCE = 0.2
|
||||
LEAST_COMPUTE_CAPABILITY = 8.0
|
||||
|
||||
def reset_memory_stats():
|
||||
gc.collect()
|
||||
backend_synchronize(torch_device)
|
||||
backend_empty_cache(torch_device)
|
||||
backend_reset_peak_memory_stats(torch_device)
|
||||
|
||||
def get_memory_usage(storage_dtype, compute_dtype):
|
||||
torch.manual_seed(0)
|
||||
config = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
inputs_dict = cast_maybe_tensor_dtype(inputs_dict, torch.float32, compute_dtype)
|
||||
model = self.model_class(**config).eval()
|
||||
model = model.to(torch_device, dtype=compute_dtype)
|
||||
model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
|
||||
|
||||
reset_memory_stats()
|
||||
model(**inputs_dict)
|
||||
model_memory_footprint = model.get_memory_footprint()
|
||||
peak_inference_memory_allocated_mb = backend_max_memory_allocated(torch_device) / 1024**2
|
||||
|
||||
return model_memory_footprint, peak_inference_memory_allocated_mb
|
||||
|
||||
fp32_memory_footprint, fp32_max_memory = get_memory_usage(torch.float32, torch.float32)
|
||||
fp8_e4m3_fp32_memory_footprint, fp8_e4m3_fp32_max_memory = get_memory_usage(torch.float8_e4m3fn, torch.float32)
|
||||
fp8_e4m3_bf16_memory_footprint, fp8_e4m3_bf16_max_memory = get_memory_usage(
|
||||
torch.float8_e4m3fn, torch.bfloat16
|
||||
)
|
||||
|
||||
compute_capability = get_torch_cuda_device_capability() if torch_device == "cuda" else None
|
||||
assert (
|
||||
fp8_e4m3_bf16_memory_footprint < fp8_e4m3_fp32_memory_footprint < fp32_memory_footprint
|
||||
), "Memory footprint should decrease with lower precision storage"
|
||||
|
||||
# NOTE: the following assertion would fail on our CI (running Tesla T4) due to bf16 using more memory than fp32.
|
||||
# On other devices, such as DGX (Ampere) and Audace (Ada), the test passes. So, we conditionally check it.
|
||||
if compute_capability and compute_capability >= LEAST_COMPUTE_CAPABILITY:
|
||||
assert (
|
||||
fp8_e4m3_bf16_max_memory < fp8_e4m3_fp32_max_memory
|
||||
), "Peak memory should be lower with bf16 compute on newer GPUs"
|
||||
|
||||
# On this dummy test case with a small model, sometimes fp8_e4m3_fp32 max memory usage is higher than fp32 by a few
|
||||
# bytes. This only happens for some models, so we allow a small tolerance.
|
||||
# For any real model being tested, the order would be fp8_e4m3_bf16 < fp8_e4m3_fp32 < fp32.
|
||||
assert (
|
||||
fp8_e4m3_fp32_max_memory < fp32_max_memory
|
||||
or abs(fp8_e4m3_fp32_max_memory - fp32_max_memory) < MB_TOLERANCE
|
||||
), "Peak memory should be lower or within tolerance with fp8 storage"
|
||||
|
||||
|
||||
@is_memory
|
||||
@require_accelerator
|
||||
class MemoryTesterMixin(CPUOffloadTesterMixin, GroupOffloadTesterMixin, LayerwiseCastingTesterMixin):
|
||||
"""
|
||||
Combined mixin class for all memory optimization tests including CPU/disk offloading,
|
||||
group offloading, and layerwise dtype casting.
|
||||
|
||||
This mixin inherits from:
|
||||
- CPUOffloadTesterMixin: CPU and disk offloading tests
|
||||
- GroupOffloadTesterMixin: Group offloading tests (block-level and leaf-level)
|
||||
- LayerwiseCastingTesterMixin: Layerwise dtype casting tests
|
||||
|
||||
Expected class attributes to be set by subclasses:
|
||||
- model_class: The model class to test
|
||||
- model_split_percents: List of percentages for splitting model across devices (default: [0.5, 0.7])
|
||||
|
||||
Expected methods to be implemented by subclasses:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
Pytest mark: memory
|
||||
Use `pytest -m "not memory"` to skip these tests
|
||||
"""
|
||||
|
||||
pass
|
||||
@@ -1,833 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from diffusers import BitsAndBytesConfig, GGUFQuantizationConfig, NVIDIAModelOptConfig, QuantoConfig, TorchAoConfig
|
||||
from diffusers.utils.import_utils import (
|
||||
is_bitsandbytes_available,
|
||||
is_gguf_available,
|
||||
is_nvidia_modelopt_available,
|
||||
is_optimum_quanto_available,
|
||||
)
|
||||
|
||||
from ...testing_utils import (
|
||||
backend_empty_cache,
|
||||
is_bitsandbytes,
|
||||
is_gguf,
|
||||
is_modelopt,
|
||||
is_quanto,
|
||||
is_torchao,
|
||||
nightly,
|
||||
require_accelerate,
|
||||
require_accelerator,
|
||||
require_bitsandbytes_version_greater,
|
||||
require_gguf_version_greater_or_equal,
|
||||
require_quanto,
|
||||
require_torchao_version_greater_or_equal,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
if is_nvidia_modelopt_available():
|
||||
import modelopt.torch.quantization as mtq
|
||||
|
||||
if is_bitsandbytes_available():
|
||||
import bitsandbytes as bnb
|
||||
|
||||
if is_optimum_quanto_available():
|
||||
from optimum.quanto import QLinear
|
||||
|
||||
if is_gguf_available():
|
||||
pass
|
||||
|
||||
if is_torchao_available():
|
||||
|
||||
if is_torchao_version(">=", "0.9.0"):
|
||||
pass
|
||||
|
||||
|
||||
@require_accelerator
|
||||
class QuantizationTesterMixin:
|
||||
"""
|
||||
Base mixin class providing common test implementations for quantization testing.
|
||||
|
||||
Backend-specific mixins should:
|
||||
1. Implement _create_quantized_model(config_kwargs)
|
||||
2. Implement _verify_if_layer_quantized(name, module, config_kwargs)
|
||||
3. Define their config dict (e.g., BNB_CONFIGS, QUANTO_WEIGHT_TYPES, etc.)
|
||||
4. Use @pytest.mark.parametrize to create tests that call the common test methods below
|
||||
|
||||
Expected class attributes:
|
||||
- model_class: The model class to test
|
||||
- pretrained_model_name_or_path: Hub repository ID for the pretrained model
|
||||
- pretrained_model_kwargs: (Optional) Dict of kwargs to pass to from_pretrained (e.g., {"subfolder": "transformer"})
|
||||
|
||||
Expected methods in test classes:
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
"""
|
||||
|
||||
def setup_method(self):
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def teardown_method(self):
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def _create_quantized_model(self, config_kwargs, **extra_kwargs):
|
||||
"""
|
||||
Create a quantized model with the given config kwargs.
|
||||
|
||||
Args:
|
||||
config_kwargs: Quantization config parameters
|
||||
**extra_kwargs: Additional kwargs to pass to from_pretrained (e.g., device_map, offload_folder)
|
||||
"""
|
||||
raise NotImplementedError("Subclass must implement _create_quantized_model")
|
||||
|
||||
def _verify_if_layer_quantized(self, name, module, config_kwargs):
|
||||
raise NotImplementedError("Subclass must implement _verify_if_layer_quantized")
|
||||
|
||||
def _is_module_quantized(self, module):
|
||||
"""
|
||||
Check if a module is quantized. Returns True if quantized, False otherwise.
|
||||
Default implementation tries _verify_if_layer_quantized and catches exceptions.
|
||||
Subclasses can override for more efficient checking.
|
||||
"""
|
||||
try:
|
||||
self._verify_if_layer_quantized("", module, {})
|
||||
return True
|
||||
except (AssertionError, AttributeError):
|
||||
return False
|
||||
|
||||
def _load_unquantized_model(self):
|
||||
kwargs = getattr(self, "pretrained_model_kwargs", {})
|
||||
return self.model_class.from_pretrained(self.pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
def _test_quantization_num_parameters(self, config_kwargs):
|
||||
model = self._load_unquantized_model()
|
||||
num_params = model.num_parameters()
|
||||
|
||||
model_quantized = self._create_quantized_model(config_kwargs)
|
||||
num_params_quantized = model_quantized.num_parameters()
|
||||
|
||||
assert (
|
||||
num_params == num_params_quantized
|
||||
), f"Parameter count mismatch: unquantized={num_params}, quantized={num_params_quantized}"
|
||||
|
||||
def _test_quantization_memory_footprint(self, config_kwargs, expected_memory_reduction=1.2):
|
||||
model = self._load_unquantized_model()
|
||||
mem = model.get_memory_footprint()
|
||||
|
||||
model_quantized = self._create_quantized_model(config_kwargs)
|
||||
mem_quantized = model_quantized.get_memory_footprint()
|
||||
|
||||
ratio = mem / mem_quantized
|
||||
assert (
|
||||
ratio >= expected_memory_reduction
|
||||
), f"Memory ratio {ratio:.2f} is less than expected ({expected_memory_reduction}x). unquantized={mem}, quantized={mem_quantized}"
|
||||
|
||||
def _test_quantization_inference(self, config_kwargs):
|
||||
model_quantized = self._create_quantized_model(config_kwargs)
|
||||
|
||||
with torch.no_grad():
|
||||
inputs = self.get_dummy_inputs()
|
||||
output = model_quantized(**inputs)
|
||||
|
||||
if isinstance(output, tuple):
|
||||
output = output[0]
|
||||
assert output is not None, "Model output is None"
|
||||
assert not torch.isnan(output).any(), "Model output contains NaN"
|
||||
|
||||
def _test_quantization_dtype_assignment(self, config_kwargs):
|
||||
model = self._create_quantized_model(config_kwargs)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
model.to(torch.float16)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
device_0 = f"{torch_device}:0"
|
||||
model.to(device=device_0, dtype=torch.float16)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
model.float()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
model.half()
|
||||
|
||||
model.to(torch_device)
|
||||
|
||||
def _test_quantization_lora_inference(self, config_kwargs):
|
||||
try:
|
||||
from peft import LoraConfig
|
||||
except ImportError:
|
||||
pytest.skip("peft is not available")
|
||||
|
||||
from diffusers.loaders.peft import PeftAdapterMixin
|
||||
|
||||
if not issubclass(self.model_class, PeftAdapterMixin):
|
||||
pytest.skip(f"PEFT is not supported for this model ({self.model_class.__name__})")
|
||||
|
||||
model = self._create_quantized_model(config_kwargs)
|
||||
|
||||
lora_config = LoraConfig(
|
||||
r=4,
|
||||
lora_alpha=4,
|
||||
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
|
||||
init_lora_weights=False,
|
||||
)
|
||||
model.add_adapter(lora_config)
|
||||
|
||||
with torch.no_grad():
|
||||
inputs = self.get_dummy_inputs()
|
||||
output = model(**inputs)
|
||||
|
||||
if isinstance(output, tuple):
|
||||
output = output[0]
|
||||
assert output is not None, "Model output is None with LoRA"
|
||||
assert not torch.isnan(output).any(), "Model output contains NaN with LoRA"
|
||||
|
||||
def _test_quantization_serialization(self, config_kwargs):
|
||||
model = self._create_quantized_model(config_kwargs)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
model.save_pretrained(tmpdir, safe_serialization=True)
|
||||
|
||||
model_loaded = self.model_class.from_pretrained(tmpdir)
|
||||
|
||||
with torch.no_grad():
|
||||
inputs = self.get_dummy_inputs()
|
||||
output = model_loaded(**inputs)
|
||||
if isinstance(output, tuple):
|
||||
output = output[0]
|
||||
assert not torch.isnan(output).any(), "Loaded model output contains NaN"
|
||||
|
||||
def _test_quantized_layers(self, config_kwargs):
|
||||
model_fp = self._load_unquantized_model()
|
||||
num_linear_layers = sum(1 for module in model_fp.modules() if isinstance(module, torch.nn.Linear))
|
||||
|
||||
model_quantized = self._create_quantized_model(config_kwargs)
|
||||
|
||||
num_fp32_modules = 0
|
||||
if hasattr(model_quantized, "_keep_in_fp32_modules") and model_quantized._keep_in_fp32_modules:
|
||||
for name, module in model_quantized.named_modules():
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
if any(fp32_name in name for fp32_name in model_quantized._keep_in_fp32_modules):
|
||||
num_fp32_modules += 1
|
||||
|
||||
expected_quantized_layers = num_linear_layers - num_fp32_modules
|
||||
|
||||
num_quantized_layers = 0
|
||||
for name, module in model_quantized.named_modules():
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
if hasattr(model_quantized, "_keep_in_fp32_modules") and model_quantized._keep_in_fp32_modules:
|
||||
if any(fp32_name in name for fp32_name in model_quantized._keep_in_fp32_modules):
|
||||
continue
|
||||
self._verify_if_layer_quantized(name, module, config_kwargs)
|
||||
num_quantized_layers += 1
|
||||
|
||||
assert (
|
||||
num_quantized_layers > 0
|
||||
), f"No quantized layers found in model (expected {expected_quantized_layers} linear layers, {num_fp32_modules} kept in FP32)"
|
||||
assert (
|
||||
num_quantized_layers == expected_quantized_layers
|
||||
), f"Quantized layer count mismatch: expected {expected_quantized_layers}, got {num_quantized_layers} (total linear layers: {num_linear_layers}, FP32 modules: {num_fp32_modules})"
|
||||
|
||||
def _test_quantization_modules_to_not_convert(self, config_kwargs, modules_to_not_convert):
|
||||
"""
|
||||
Test that modules specified in modules_to_not_convert are not quantized.
|
||||
|
||||
Args:
|
||||
config_kwargs: Base quantization config kwargs
|
||||
modules_to_not_convert: List of module names to exclude from quantization
|
||||
"""
|
||||
# Create config with modules_to_not_convert
|
||||
config_kwargs_with_exclusion = config_kwargs.copy()
|
||||
config_kwargs_with_exclusion["modules_to_not_convert"] = modules_to_not_convert
|
||||
|
||||
model_with_exclusion = self._create_quantized_model(config_kwargs_with_exclusion)
|
||||
|
||||
# Find a module that should NOT be quantized
|
||||
found_excluded = False
|
||||
for name, module in model_with_exclusion.named_modules():
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
# Check if this module is in the exclusion list
|
||||
if any(excluded in name for excluded in modules_to_not_convert):
|
||||
found_excluded = True
|
||||
# This module should NOT be quantized
|
||||
assert not self._is_module_quantized(
|
||||
module
|
||||
), f"Module {name} should not be quantized but was found to be quantized"
|
||||
|
||||
assert found_excluded, f"No linear layers found in excluded modules: {modules_to_not_convert}"
|
||||
|
||||
# Find a module that SHOULD be quantized (not in exclusion list)
|
||||
found_quantized = False
|
||||
for name, module in model_with_exclusion.named_modules():
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
# Check if this module is NOT in the exclusion list
|
||||
if not any(excluded in name for excluded in modules_to_not_convert):
|
||||
if self._is_module_quantized(module):
|
||||
found_quantized = True
|
||||
break
|
||||
|
||||
assert found_quantized, "No quantized layers found outside of excluded modules"
|
||||
|
||||
# Compare memory footprint with fully quantized model
|
||||
model_fully_quantized = self._create_quantized_model(config_kwargs)
|
||||
|
||||
mem_with_exclusion = model_with_exclusion.get_memory_footprint()
|
||||
mem_fully_quantized = model_fully_quantized.get_memory_footprint()
|
||||
|
||||
assert (
|
||||
mem_with_exclusion > mem_fully_quantized
|
||||
), f"Model with exclusions should be larger. With exclusion: {mem_with_exclusion}, fully quantized: {mem_fully_quantized}"
|
||||
|
||||
def _test_quantization_device_map(self, config_kwargs):
|
||||
"""
|
||||
Test that quantized models work correctly with device_map="auto".
|
||||
|
||||
Args:
|
||||
config_kwargs: Base quantization config kwargs
|
||||
"""
|
||||
model = self._create_quantized_model(config_kwargs, device_map="auto")
|
||||
|
||||
# Verify device map is set
|
||||
assert hasattr(model, "hf_device_map"), "Model should have hf_device_map attribute"
|
||||
assert model.hf_device_map is not None, "hf_device_map should not be None"
|
||||
|
||||
# Verify inference works
|
||||
with torch.no_grad():
|
||||
inputs = self.get_dummy_inputs()
|
||||
output = model(**inputs)
|
||||
if isinstance(output, tuple):
|
||||
output = output[0]
|
||||
assert output is not None, "Model output is None"
|
||||
assert not torch.isnan(output).any(), "Model output contains NaN"
|
||||
|
||||
|
||||
@is_bitsandbytes
|
||||
@nightly
|
||||
@require_accelerator
|
||||
@require_bitsandbytes_version_greater("0.43.2")
|
||||
@require_accelerate
|
||||
class BitsAndBytesTesterMixin(QuantizationTesterMixin):
|
||||
"""
|
||||
Mixin class for testing BitsAndBytes quantization on models.
|
||||
|
||||
Expected class attributes:
|
||||
- model_class: The model class to test
|
||||
- pretrained_model_name_or_path: Hub repository ID for the pretrained model
|
||||
- pretrained_model_kwargs: (Optional) Dict of kwargs to pass to from_pretrained (e.g., {"subfolder": "transformer"})
|
||||
|
||||
Expected methods to be implemented by subclasses:
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
Optional class attributes:
|
||||
- BNB_CONFIGS: Dict of config name -> BitsAndBytesConfig kwargs to test
|
||||
|
||||
Pytest mark: bitsandbytes
|
||||
Use `pytest -m "not bitsandbytes"` to skip these tests
|
||||
"""
|
||||
|
||||
# Standard BnB configs tested for all models
|
||||
# Subclasses can override to add or modify configs
|
||||
BNB_CONFIGS = {
|
||||
"4bit_nf4": {
|
||||
"load_in_4bit": True,
|
||||
"bnb_4bit_quant_type": "nf4",
|
||||
"bnb_4bit_compute_dtype": torch.float16,
|
||||
},
|
||||
"4bit_fp4": {
|
||||
"load_in_4bit": True,
|
||||
"bnb_4bit_quant_type": "fp4",
|
||||
"bnb_4bit_compute_dtype": torch.float16,
|
||||
},
|
||||
"8bit": {
|
||||
"load_in_8bit": True,
|
||||
},
|
||||
}
|
||||
|
||||
BNB_EXPECTED_MEMORY_REDUCTIONS = {
|
||||
"4bit_nf4": 3.0,
|
||||
"4bit_fp4": 3.0,
|
||||
"8bit": 1.5,
|
||||
}
|
||||
|
||||
def _create_quantized_model(self, config_kwargs, **extra_kwargs):
|
||||
config = BitsAndBytesConfig(**config_kwargs)
|
||||
kwargs = getattr(self, "pretrained_model_kwargs", {}).copy()
|
||||
kwargs["quantization_config"] = config
|
||||
kwargs.update(extra_kwargs)
|
||||
return self.model_class.from_pretrained(self.pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
def _verify_if_layer_quantized(self, name, module, config_kwargs):
|
||||
expected_weight_class = bnb.nn.Params4bit if config_kwargs.get("load_in_4bit") else bnb.nn.Int8Params
|
||||
assert (
|
||||
module.weight.__class__ == expected_weight_class
|
||||
), f"Layer {name} has weight type {module.weight.__class__}, expected {expected_weight_class}"
|
||||
|
||||
@pytest.mark.parametrize("config_name", list(BNB_CONFIGS.keys()))
|
||||
def test_bnb_quantization_num_parameters(self, config_name):
|
||||
self._test_quantization_num_parameters(self.BNB_CONFIGS[config_name])
|
||||
|
||||
@pytest.mark.parametrize("config_name", list(BNB_CONFIGS.keys()))
|
||||
def test_bnb_quantization_memory_footprint(self, config_name):
|
||||
expected = self.BNB_EXPECTED_MEMORY_REDUCTIONS.get(config_name, 1.2)
|
||||
self._test_quantization_memory_footprint(self.BNB_CONFIGS[config_name], expected_memory_reduction=expected)
|
||||
|
||||
@pytest.mark.parametrize("config_name", list(BNB_CONFIGS.keys()))
|
||||
def test_bnb_quantization_inference(self, config_name):
|
||||
self._test_quantization_inference(self.BNB_CONFIGS[config_name])
|
||||
|
||||
@pytest.mark.parametrize("config_name", ["4bit_nf4"])
|
||||
def test_bnb_quantization_dtype_assignment(self, config_name):
|
||||
self._test_quantization_dtype_assignment(self.BNB_CONFIGS[config_name])
|
||||
|
||||
@pytest.mark.parametrize("config_name", ["4bit_nf4"])
|
||||
def test_bnb_quantization_lora_inference(self, config_name):
|
||||
self._test_quantization_lora_inference(self.BNB_CONFIGS[config_name])
|
||||
|
||||
@pytest.mark.parametrize("config_name", ["4bit_nf4"])
|
||||
def test_bnb_quantization_serialization(self, config_name):
|
||||
self._test_quantization_serialization(self.BNB_CONFIGS[config_name])
|
||||
|
||||
@pytest.mark.parametrize("config_name", list(BNB_CONFIGS.keys()))
|
||||
def test_bnb_quantized_layers(self, config_name):
|
||||
self._test_quantized_layers(self.BNB_CONFIGS[config_name])
|
||||
|
||||
@pytest.mark.parametrize("config_name", list(BNB_CONFIGS.keys()))
|
||||
def test_bnb_quantization_config_serialization(self, config_name):
|
||||
model = self._create_quantized_model(self.BNB_CONFIGS[config_name])
|
||||
|
||||
assert "quantization_config" in model.config, "Missing quantization_config"
|
||||
_ = model.config["quantization_config"].to_dict()
|
||||
_ = model.config["quantization_config"].to_diff_dict()
|
||||
_ = model.config["quantization_config"].to_json_string()
|
||||
|
||||
def test_bnb_original_dtype(self):
|
||||
config_name = list(self.BNB_CONFIGS.keys())[0]
|
||||
config_kwargs = self.BNB_CONFIGS[config_name]
|
||||
|
||||
model = self._create_quantized_model(config_kwargs)
|
||||
|
||||
assert "_pre_quantization_dtype" in model.config, "Missing _pre_quantization_dtype"
|
||||
assert model.config["_pre_quantization_dtype"] in [
|
||||
torch.float16,
|
||||
torch.float32,
|
||||
torch.bfloat16,
|
||||
], f"Unexpected dtype: {model.config['_pre_quantization_dtype']}"
|
||||
|
||||
def test_bnb_keep_modules_in_fp32(self):
|
||||
if not hasattr(self.model_class, "_keep_in_fp32_modules"):
|
||||
pytest.skip(f"{self.model_class.__name__} does not have _keep_in_fp32_modules")
|
||||
|
||||
config_kwargs = self.BNB_CONFIGS["4bit_nf4"]
|
||||
|
||||
original_fp32_modules = getattr(self.model_class, "_keep_in_fp32_modules", None)
|
||||
self.model_class._keep_in_fp32_modules = ["proj_out"]
|
||||
|
||||
try:
|
||||
model = self._create_quantized_model(config_kwargs)
|
||||
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
if any(fp32_name in name for fp32_name in model._keep_in_fp32_modules):
|
||||
assert (
|
||||
module.weight.dtype == torch.float32
|
||||
), f"Module {name} should be FP32 but is {module.weight.dtype}"
|
||||
else:
|
||||
assert (
|
||||
module.weight.dtype == torch.uint8
|
||||
), f"Module {name} should be uint8 but is {module.weight.dtype}"
|
||||
|
||||
with torch.no_grad():
|
||||
inputs = self.get_dummy_inputs()
|
||||
_ = model(**inputs)
|
||||
finally:
|
||||
if original_fp32_modules is not None:
|
||||
self.model_class._keep_in_fp32_modules = original_fp32_modules
|
||||
|
||||
def test_bnb_modules_to_not_convert(self):
|
||||
"""Test that modules_to_not_convert parameter works correctly."""
|
||||
modules_to_exclude = getattr(self, "modules_to_not_convert_for_test", None)
|
||||
if modules_to_exclude is None:
|
||||
pytest.skip("modules_to_not_convert_for_test not defined for this model")
|
||||
|
||||
self._test_quantization_modules_to_not_convert(self.BNB_CONFIGS["4bit_nf4"], modules_to_exclude)
|
||||
|
||||
def test_bnb_device_map(self):
|
||||
"""Test that device_map='auto' works correctly with quantization."""
|
||||
self._test_quantization_device_map(self.BNB_CONFIGS["4bit_nf4"])
|
||||
|
||||
|
||||
@is_quanto
|
||||
@nightly
|
||||
@require_quanto
|
||||
@require_accelerate
|
||||
@require_accelerator
|
||||
class QuantoTesterMixin(QuantizationTesterMixin):
|
||||
"""
|
||||
Mixin class for testing Quanto quantization on models.
|
||||
|
||||
Expected class attributes:
|
||||
- model_class: The model class to test
|
||||
- pretrained_model_name_or_path: Hub repository ID for the pretrained model
|
||||
- pretrained_model_kwargs: (Optional) Dict of kwargs to pass to from_pretrained (e.g., {"subfolder": "transformer"})
|
||||
|
||||
Expected methods to be implemented by subclasses:
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
Optional class attributes:
|
||||
- QUANTO_WEIGHT_TYPES: Dict of weight_type_name -> qtype
|
||||
|
||||
Pytest mark: quanto
|
||||
Use `pytest -m "not quanto"` to skip these tests
|
||||
"""
|
||||
|
||||
QUANTO_WEIGHT_TYPES = {
|
||||
"float8": {"weights_dtype": "float8"},
|
||||
"int8": {"weights_dtype": "int8"},
|
||||
"int4": {"weights_dtype": "int4"},
|
||||
"int2": {"weights_dtype": "int2"},
|
||||
}
|
||||
|
||||
QUANTO_EXPECTED_MEMORY_REDUCTIONS = {
|
||||
"float8": 1.5,
|
||||
"int8": 1.5,
|
||||
"int4": 3.0,
|
||||
"int2": 7.0,
|
||||
}
|
||||
|
||||
def _create_quantized_model(self, config_kwargs, **extra_kwargs):
|
||||
config = QuantoConfig(**config_kwargs)
|
||||
kwargs = getattr(self, "pretrained_model_kwargs", {}).copy()
|
||||
kwargs["quantization_config"] = config
|
||||
kwargs.update(extra_kwargs)
|
||||
return self.model_class.from_pretrained(self.pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
def _verify_if_layer_quantized(self, name, module, config_kwargs):
|
||||
assert isinstance(module, QLinear), f"Layer {name} is not QLinear, got {type(module)}"
|
||||
|
||||
@pytest.mark.parametrize("weight_type_name", list(QUANTO_WEIGHT_TYPES.keys()))
|
||||
def test_quanto_quantization_num_parameters(self, weight_type_name):
|
||||
self._test_quantization_num_parameters(self.QUANTO_WEIGHT_TYPES[weight_type_name])
|
||||
|
||||
@pytest.mark.parametrize("weight_type_name", list(QUANTO_WEIGHT_TYPES.keys()))
|
||||
def test_quanto_quantization_memory_footprint(self, weight_type_name):
|
||||
expected = self.QUANTO_EXPECTED_MEMORY_REDUCTIONS.get(weight_type_name, 1.2)
|
||||
self._test_quantization_memory_footprint(
|
||||
self.QUANTO_WEIGHT_TYPES[weight_type_name], expected_memory_reduction=expected
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("weight_type_name", list(QUANTO_WEIGHT_TYPES.keys()))
|
||||
def test_quanto_quantization_inference(self, weight_type_name):
|
||||
self._test_quantization_inference(self.QUANTO_WEIGHT_TYPES[weight_type_name])
|
||||
|
||||
@pytest.mark.parametrize("weight_type_name", ["int8"])
|
||||
def test_quanto_quantized_layers(self, weight_type_name):
|
||||
self._test_quantized_layers(self.QUANTO_WEIGHT_TYPES[weight_type_name])
|
||||
|
||||
@pytest.mark.parametrize("weight_type_name", ["int8"])
|
||||
def test_quanto_quantization_lora_inference(self, weight_type_name):
|
||||
self._test_quantization_lora_inference(self.QUANTO_WEIGHT_TYPES[weight_type_name])
|
||||
|
||||
@pytest.mark.parametrize("weight_type_name", ["int8"])
|
||||
def test_quanto_quantization_serialization(self, weight_type_name):
|
||||
self._test_quantization_serialization(self.QUANTO_WEIGHT_TYPES[weight_type_name])
|
||||
|
||||
def test_quanto_modules_to_not_convert(self):
|
||||
"""Test that modules_to_not_convert parameter works correctly."""
|
||||
modules_to_exclude = getattr(self, "modules_to_not_convert_for_test", None)
|
||||
if modules_to_exclude is None:
|
||||
pytest.skip("modules_to_not_convert_for_test not defined for this model")
|
||||
|
||||
self._test_quantization_modules_to_not_convert(self.QUANTO_WEIGHT_TYPES["int8"], modules_to_exclude)
|
||||
|
||||
def test_quanto_device_map(self):
|
||||
"""Test that device_map='auto' works correctly with quantization."""
|
||||
self._test_quantization_device_map(self.QUANTO_WEIGHT_TYPES["int8"])
|
||||
|
||||
|
||||
@is_torchao
|
||||
@require_accelerator
|
||||
@require_torchao_version_greater_or_equal("0.7.0")
|
||||
class TorchAoTesterMixin(QuantizationTesterMixin):
|
||||
"""
|
||||
Mixin class for testing TorchAO quantization on models.
|
||||
|
||||
Expected class attributes:
|
||||
- model_class: The model class to test
|
||||
- pretrained_model_name_or_path: Hub repository ID for the pretrained model
|
||||
- pretrained_model_kwargs: (Optional) Dict of kwargs to pass to from_pretrained (e.g., {"subfolder": "transformer"})
|
||||
|
||||
Expected methods to be implemented by subclasses:
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
Optional class attributes:
|
||||
- TORCHAO_QUANT_TYPES: Dict of quantization type strings to test
|
||||
|
||||
Pytest mark: torchao
|
||||
Use `pytest -m "not torchao"` to skip these tests
|
||||
"""
|
||||
|
||||
TORCHAO_QUANT_TYPES = {
|
||||
"int4wo": {"quant_type": "int4_weight_only"},
|
||||
"int8wo": {"quant_type": "int8_weight_only"},
|
||||
"int8dq": {"quant_type": "int8_dynamic_activation_int8_weight"},
|
||||
}
|
||||
|
||||
TORCHAO_EXPECTED_MEMORY_REDUCTIONS = {
|
||||
"int4wo": 3.0,
|
||||
"int8wo": 1.5,
|
||||
"int8dq": 1.5,
|
||||
}
|
||||
|
||||
def _create_quantized_model(self, config_kwargs, **extra_kwargs):
|
||||
config = TorchAoConfig(**config_kwargs)
|
||||
kwargs = getattr(self, "pretrained_model_kwargs", {}).copy()
|
||||
kwargs["quantization_config"] = config
|
||||
kwargs.update(extra_kwargs)
|
||||
return self.model_class.from_pretrained(self.pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
def _verify_if_layer_quantized(self, name, module, config_kwargs):
|
||||
assert isinstance(module, torch.nn.Linear), f"Layer {name} is not Linear, got {type(module)}"
|
||||
|
||||
@pytest.mark.parametrize("quant_type", list(TORCHAO_QUANT_TYPES.keys()))
|
||||
def test_torchao_quantization_num_parameters(self, quant_type):
|
||||
self._test_quantization_num_parameters(self.TORCHAO_QUANT_TYPES[quant_type])
|
||||
|
||||
@pytest.mark.parametrize("quant_type", list(TORCHAO_QUANT_TYPES.keys()))
|
||||
def test_torchao_quantization_memory_footprint(self, quant_type):
|
||||
expected = self.TORCHAO_EXPECTED_MEMORY_REDUCTIONS.get(quant_type, 1.2)
|
||||
self._test_quantization_memory_footprint(
|
||||
self.TORCHAO_QUANT_TYPES[quant_type], expected_memory_reduction=expected
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("quant_type", list(TORCHAO_QUANT_TYPES.keys()))
|
||||
def test_torchao_quantization_inference(self, quant_type):
|
||||
self._test_quantization_inference(self.TORCHAO_QUANT_TYPES[quant_type])
|
||||
|
||||
@pytest.mark.parametrize("quant_type", ["int8wo"])
|
||||
def test_torchao_quantized_layers(self, quant_type):
|
||||
self._test_quantized_layers(self.TORCHAO_QUANT_TYPES[quant_type])
|
||||
|
||||
@pytest.mark.parametrize("quant_type", ["int8wo"])
|
||||
def test_torchao_quantization_lora_inference(self, quant_type):
|
||||
self._test_quantization_lora_inference(self.TORCHAO_QUANT_TYPES[quant_type])
|
||||
|
||||
@pytest.mark.parametrize("quant_type", ["int8wo"])
|
||||
def test_torchao_quantization_serialization(self, quant_type):
|
||||
self._test_quantization_serialization(self.TORCHAO_QUANT_TYPES[quant_type])
|
||||
|
||||
def test_torchao_modules_to_not_convert(self):
|
||||
"""Test that modules_to_not_convert parameter works correctly."""
|
||||
# Get a module name that exists in the model - this needs to be set by test classes
|
||||
# For now, use a generic pattern that should work with transformer models
|
||||
modules_to_exclude = getattr(self, "modules_to_not_convert_for_test", None)
|
||||
if modules_to_exclude is None:
|
||||
pytest.skip("modules_to_not_convert_for_test not defined for this model")
|
||||
|
||||
self._test_quantization_modules_to_not_convert(
|
||||
self.TORCHAO_QUANT_TYPES["int8wo"], modules_to_exclude
|
||||
)
|
||||
|
||||
def test_torchao_device_map(self):
|
||||
"""Test that device_map='auto' works correctly with quantization."""
|
||||
self._test_quantization_device_map(self.TORCHAO_QUANT_TYPES["int8wo"])
|
||||
|
||||
|
||||
@is_gguf
|
||||
@nightly
|
||||
@require_accelerate
|
||||
@require_accelerator
|
||||
@require_gguf_version_greater_or_equal("0.10.0")
|
||||
class GGUFTesterMixin(QuantizationTesterMixin):
|
||||
"""
|
||||
Mixin class for testing GGUF quantization on models.
|
||||
|
||||
Expected class attributes:
|
||||
- model_class: The model class to test
|
||||
- gguf_filename: URL or path to the GGUF file
|
||||
|
||||
Expected methods to be implemented by subclasses:
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
Pytest mark: gguf
|
||||
Use `pytest -m "not gguf"` to skip these tests
|
||||
"""
|
||||
|
||||
gguf_filename = None
|
||||
|
||||
def _create_quantized_model(self, config_kwargs=None, **extra_kwargs):
|
||||
if config_kwargs is None:
|
||||
config_kwargs = {"compute_dtype": torch.bfloat16}
|
||||
|
||||
config = GGUFQuantizationConfig(**config_kwargs)
|
||||
kwargs = {
|
||||
"quantization_config": config,
|
||||
"torch_dtype": config_kwargs.get("compute_dtype", torch.bfloat16),
|
||||
}
|
||||
kwargs.update(extra_kwargs)
|
||||
return self.model_class.from_single_file(self.gguf_filename, **kwargs)
|
||||
|
||||
def _verify_if_layer_quantized(self, name, module, config_kwargs=None):
|
||||
from diffusers.quantizers.gguf.utils import GGUFParameter
|
||||
|
||||
assert isinstance(module.weight, GGUFParameter), f"{name} weight is not GGUFParameter"
|
||||
assert hasattr(module.weight, "quant_type"), f"{name} weight missing quant_type"
|
||||
assert module.weight.dtype == torch.uint8, f"{name} weight dtype should be uint8"
|
||||
|
||||
def test_gguf_quantization_inference(self):
|
||||
self._test_quantization_inference({"compute_dtype": torch.bfloat16})
|
||||
|
||||
def test_gguf_keep_modules_in_fp32(self):
|
||||
if not hasattr(self.model_class, "_keep_in_fp32_modules"):
|
||||
pytest.skip(f"{self.model_class.__name__} does not have _keep_in_fp32_modules")
|
||||
|
||||
_keep_in_fp32_modules = self.model_class._keep_in_fp32_modules
|
||||
self.model_class._keep_in_fp32_modules = ["proj_out"]
|
||||
|
||||
try:
|
||||
model = self._create_quantized_model()
|
||||
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
if any(fp32_name in name for fp32_name in model._keep_in_fp32_modules):
|
||||
assert module.weight.dtype == torch.float32, f"Module {name} should be FP32"
|
||||
finally:
|
||||
self.model_class._keep_in_fp32_modules = _keep_in_fp32_modules
|
||||
|
||||
def test_gguf_quantization_dtype_assignment(self):
|
||||
self._test_quantization_dtype_assignment({"compute_dtype": torch.bfloat16})
|
||||
|
||||
def test_gguf_quantization_lora_inference(self):
|
||||
self._test_quantization_lora_inference({"compute_dtype": torch.bfloat16})
|
||||
|
||||
def test_gguf_dequantize_model(self):
|
||||
from diffusers.quantizers.gguf.utils import GGUFLinear, GGUFParameter
|
||||
|
||||
model = self._create_quantized_model()
|
||||
model.dequantize()
|
||||
|
||||
def _check_for_gguf_linear(model):
|
||||
has_children = list(model.children())
|
||||
if not has_children:
|
||||
return
|
||||
|
||||
for name, module in model.named_children():
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
assert not isinstance(module, GGUFLinear), f"{name} is still GGUFLinear"
|
||||
assert not isinstance(module.weight, GGUFParameter), f"{name} weight is still GGUFParameter"
|
||||
|
||||
for name, module in model.named_children():
|
||||
_check_for_gguf_linear(module)
|
||||
|
||||
def test_gguf_quantized_layers(self):
|
||||
self._test_quantized_layers({"compute_dtype": torch.bfloat16})
|
||||
|
||||
|
||||
@is_modelopt
|
||||
@nightly
|
||||
@require_accelerator
|
||||
@require_accelerate
|
||||
@require_modelopt_version_greater_or_equal("0.33.1")
|
||||
class ModelOptTesterMixin(QuantizationTesterMixin):
|
||||
"""
|
||||
Mixin class for testing NVIDIA ModelOpt quantization on models.
|
||||
|
||||
Expected class attributes:
|
||||
- model_class: The model class to test
|
||||
- pretrained_model_name_or_path: Hub repository ID for the pretrained model
|
||||
- pretrained_model_kwargs: (Optional) Dict of kwargs to pass to from_pretrained (e.g., {"subfolder": "transformer"})
|
||||
|
||||
Expected methods to be implemented by subclasses:
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
Optional class attributes:
|
||||
- MODELOPT_CONFIGS: Dict of config name -> NVIDIAModelOptConfig kwargs to test
|
||||
|
||||
Pytest mark: modelopt
|
||||
Use `pytest -m "not modelopt"` to skip these tests
|
||||
"""
|
||||
|
||||
MODELOPT_CONFIGS = {
|
||||
"fp8": {"quant_type": "FP8"},
|
||||
"int8": {"quant_type": "INT8"},
|
||||
"int4": {"quant_type": "INT4"},
|
||||
}
|
||||
|
||||
MODELOPT_EXPECTED_MEMORY_REDUCTIONS = {
|
||||
"fp8": 1.5,
|
||||
"int8": 1.5,
|
||||
"int4": 3.0,
|
||||
}
|
||||
|
||||
def _create_quantized_model(self, config_kwargs, **extra_kwargs):
|
||||
config = NVIDIAModelOptConfig(**config_kwargs)
|
||||
kwargs = getattr(self, "pretrained_model_kwargs", {}).copy()
|
||||
kwargs["quantization_config"] = config
|
||||
kwargs.update(extra_kwargs)
|
||||
return self.model_class.from_pretrained(self.pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
def _verify_if_layer_quantized(self, name, module, config_kwargs):
|
||||
assert mtq.utils.is_quantized(module), f"Layer {name} does not have weight_quantizer attribute (not quantized)"
|
||||
|
||||
@pytest.mark.parametrize("config_name", ["fp8"])
|
||||
def test_modelopt_quantization_num_parameters(self, config_name):
|
||||
self._test_quantization_num_parameters(self.MODELOPT_CONFIGS[config_name])
|
||||
|
||||
@pytest.mark.parametrize("config_name", list(MODELOPT_CONFIGS.keys()))
|
||||
def test_modelopt_quantization_memory_footprint(self, config_name):
|
||||
expected = self.MODELOPT_EXPECTED_MEMORY_REDUCTIONS.get(config_name, 1.2)
|
||||
self._test_quantization_memory_footprint(
|
||||
self.MODELOPT_CONFIGS[config_name], expected_memory_reduction=expected
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("config_name", list(MODELOPT_CONFIGS.keys()))
|
||||
def test_modelopt_quantization_inference(self, config_name):
|
||||
self._test_quantization_inference(self.MODELOPT_CONFIGS[config_name])
|
||||
|
||||
@pytest.mark.parametrize("config_name", ["fp8"])
|
||||
def test_modelopt_quantization_dtype_assignment(self, config_name):
|
||||
self._test_quantization_dtype_assignment(self.MODELOPT_CONFIGS[config_name])
|
||||
|
||||
@pytest.mark.parametrize("config_name", ["fp8"])
|
||||
def test_modelopt_quantization_lora_inference(self, config_name):
|
||||
self._test_quantization_lora_inference(self.MODELOPT_CONFIGS[config_name])
|
||||
|
||||
@pytest.mark.parametrize("config_name", ["fp8"])
|
||||
def test_modelopt_quantization_serialization(self, config_name):
|
||||
self._test_quantization_serialization(self.MODELOPT_CONFIGS[config_name])
|
||||
|
||||
@pytest.mark.parametrize("config_name", ["fp8"])
|
||||
def test_modelopt_quantized_layers(self, config_name):
|
||||
self._test_quantized_layers(self.MODELOPT_CONFIGS[config_name])
|
||||
|
||||
def test_modelopt_modules_to_not_convert(self):
|
||||
"""Test that modules_to_not_convert parameter works correctly."""
|
||||
modules_to_exclude = getattr(self, "modules_to_not_convert_for_test", None)
|
||||
if modules_to_exclude is None:
|
||||
pytest.skip("modules_to_not_convert_for_test not defined for this model")
|
||||
|
||||
self._test_quantization_modules_to_not_convert(self.MODELOPT_CONFIGS["fp8"], modules_to_exclude)
|
||||
|
||||
def test_modelopt_device_map(self):
|
||||
"""Test that device_map='auto' works correctly with quantization."""
|
||||
self._test_quantization_device_map(self.MODELOPT_CONFIGS["fp8"])
|
||||
@@ -1,247 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import tempfile
|
||||
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download, snapshot_download
|
||||
|
||||
from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name
|
||||
|
||||
from ...testing_utils import (
|
||||
backend_empty_cache,
|
||||
is_single_file,
|
||||
nightly,
|
||||
require_torch_accelerator,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
def download_single_file_checkpoint(pretrained_model_name_or_path, filename, tmpdir):
|
||||
"""Download a single file checkpoint from the Hub to a temporary directory."""
|
||||
path = hf_hub_download(pretrained_model_name_or_path, filename=filename, local_dir=tmpdir)
|
||||
return path
|
||||
|
||||
|
||||
def download_diffusers_config(pretrained_model_name_or_path, tmpdir):
|
||||
"""Download diffusers config files (excluding weights) from a repository."""
|
||||
path = snapshot_download(
|
||||
pretrained_model_name_or_path,
|
||||
ignore_patterns=[
|
||||
"**/*.ckpt",
|
||||
"*.ckpt",
|
||||
"**/*.bin",
|
||||
"*.bin",
|
||||
"**/*.pt",
|
||||
"*.pt",
|
||||
"**/*.safetensors",
|
||||
"*.safetensors",
|
||||
],
|
||||
allow_patterns=["**/*.json", "*.json", "*.txt", "**/*.txt"],
|
||||
local_dir=tmpdir,
|
||||
)
|
||||
return path
|
||||
|
||||
|
||||
@nightly
|
||||
@require_torch_accelerator
|
||||
@is_single_file
|
||||
class SingleFileTesterMixin:
|
||||
"""
|
||||
Mixin class for testing single file loading for models.
|
||||
|
||||
Expected class attributes:
|
||||
- model_class: The model class to test
|
||||
- pretrained_model_name_or_path: Hub repository ID for the pretrained model
|
||||
- ckpt_path: Path or Hub path to the single file checkpoint
|
||||
- subfolder: (Optional) Subfolder within the repo
|
||||
- torch_dtype: (Optional) torch dtype to use for testing
|
||||
|
||||
Pytest mark: single_file
|
||||
Use `pytest -m "not single_file"` to skip these tests
|
||||
"""
|
||||
|
||||
pretrained_model_name_or_path = None
|
||||
ckpt_path = None
|
||||
|
||||
def setup_method(self):
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def teardown_method(self):
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_single_file_model_config(self):
|
||||
pretrained_kwargs = {}
|
||||
single_file_kwargs = {}
|
||||
|
||||
pretrained_kwargs["device"] = torch_device
|
||||
single_file_kwargs["device"] = torch_device
|
||||
|
||||
if hasattr(self, "subfolder") and self.subfolder:
|
||||
pretrained_kwargs["subfolder"] = self.subfolder
|
||||
|
||||
if hasattr(self, "torch_dtype") and self.torch_dtype:
|
||||
pretrained_kwargs["torch_dtype"] = self.torch_dtype
|
||||
single_file_kwargs["torch_dtype"] = self.torch_dtype
|
||||
|
||||
model = self.model_class.from_pretrained(self.pretrained_model_name_or_path, **pretrained_kwargs)
|
||||
model_single_file = self.model_class.from_single_file(self.ckpt_path, **single_file_kwargs)
|
||||
|
||||
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
|
||||
for param_name, param_value in model_single_file.config.items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
continue
|
||||
assert model.config[param_name] == param_value, (
|
||||
f"{param_name} differs between pretrained loading and single file loading: "
|
||||
f"pretrained={model.config[param_name]}, single_file={param_value}"
|
||||
)
|
||||
|
||||
def test_single_file_model_parameters(self):
|
||||
pretrained_kwargs = {}
|
||||
single_file_kwargs = {}
|
||||
|
||||
pretrained_kwargs["device"] = torch_device
|
||||
single_file_kwargs["device"] = torch_device
|
||||
|
||||
if hasattr(self, "subfolder") and self.subfolder:
|
||||
pretrained_kwargs["subfolder"] = self.subfolder
|
||||
|
||||
if hasattr(self, "torch_dtype") and self.torch_dtype:
|
||||
pretrained_kwargs["torch_dtype"] = self.torch_dtype
|
||||
single_file_kwargs["torch_dtype"] = self.torch_dtype
|
||||
|
||||
model = self.model_class.from_pretrained(self.pretrained_model_name_or_path, **pretrained_kwargs)
|
||||
model_single_file = self.model_class.from_single_file(self.ckpt_path, **single_file_kwargs)
|
||||
|
||||
state_dict = model.state_dict()
|
||||
state_dict_single_file = model_single_file.state_dict()
|
||||
|
||||
assert set(state_dict.keys()) == set(state_dict_single_file.keys()), (
|
||||
"Model parameters keys differ between pretrained and single file loading. "
|
||||
f"Missing in single file: {set(state_dict.keys()) - set(state_dict_single_file.keys())}. "
|
||||
f"Extra in single file: {set(state_dict_single_file.keys()) - set(state_dict.keys())}"
|
||||
)
|
||||
|
||||
for key in state_dict.keys():
|
||||
param = state_dict[key]
|
||||
param_single_file = state_dict_single_file[key]
|
||||
|
||||
assert param.shape == param_single_file.shape, (
|
||||
f"Parameter shape mismatch for {key}: "
|
||||
f"pretrained {param.shape} vs single file {param_single_file.shape}"
|
||||
)
|
||||
|
||||
assert torch.allclose(param, param_single_file, rtol=1e-5, atol=1e-5), (
|
||||
f"Parameter values differ for {key}: "
|
||||
f"max difference {torch.max(torch.abs(param - param_single_file)).item()}"
|
||||
)
|
||||
|
||||
def test_single_file_loading_local_files_only(self):
|
||||
single_file_kwargs = {}
|
||||
|
||||
if hasattr(self, "torch_dtype") and self.torch_dtype:
|
||||
single_file_kwargs["torch_dtype"] = self.torch_dtype
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
pretrained_model_name_or_path, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path)
|
||||
local_ckpt_path = download_single_file_checkpoint(pretrained_model_name_or_path, weight_name, tmpdir)
|
||||
|
||||
model_single_file = self.model_class.from_single_file(
|
||||
local_ckpt_path, local_files_only=True, **single_file_kwargs
|
||||
)
|
||||
|
||||
assert model_single_file is not None, "Failed to load model with local_files_only=True"
|
||||
|
||||
def test_single_file_loading_with_diffusers_config(self):
|
||||
single_file_kwargs = {}
|
||||
|
||||
if hasattr(self, "torch_dtype") and self.torch_dtype:
|
||||
single_file_kwargs["torch_dtype"] = self.torch_dtype
|
||||
|
||||
# Load with config parameter
|
||||
model_single_file = self.model_class.from_single_file(
|
||||
self.ckpt_path, config=self.pretrained_model_name_or_path, **single_file_kwargs
|
||||
)
|
||||
|
||||
# Load pretrained for comparison
|
||||
pretrained_kwargs = {}
|
||||
if hasattr(self, "subfolder") and self.subfolder:
|
||||
pretrained_kwargs["subfolder"] = self.subfolder
|
||||
if hasattr(self, "torch_dtype") and self.torch_dtype:
|
||||
pretrained_kwargs["torch_dtype"] = self.torch_dtype
|
||||
|
||||
model = self.model_class.from_pretrained(self.pretrained_model_name_or_path, **pretrained_kwargs)
|
||||
|
||||
# Compare configs
|
||||
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
|
||||
for param_name, param_value in model_single_file.config.items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
continue
|
||||
assert (
|
||||
model.config[param_name] == param_value
|
||||
), f"{param_name} differs: pretrained={model.config[param_name]}, single_file={param_value}"
|
||||
|
||||
def test_single_file_loading_with_diffusers_config_local_files_only(self):
|
||||
single_file_kwargs = {}
|
||||
|
||||
if hasattr(self, "torch_dtype") and self.torch_dtype:
|
||||
single_file_kwargs["torch_dtype"] = self.torch_dtype
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
pretrained_model_name_or_path, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path)
|
||||
local_ckpt_path = download_single_file_checkpoint(pretrained_model_name_or_path, weight_name, tmpdir)
|
||||
local_diffusers_config = download_diffusers_config(self.pretrained_model_name_or_path, tmpdir)
|
||||
|
||||
model_single_file = self.model_class.from_single_file(
|
||||
local_ckpt_path, config=local_diffusers_config, local_files_only=True, **single_file_kwargs
|
||||
)
|
||||
|
||||
assert model_single_file is not None, "Failed to load model with config and local_files_only=True"
|
||||
|
||||
def test_single_file_loading_dtype(self):
|
||||
for dtype in [torch.float32, torch.float16]:
|
||||
if torch_device == "mps" and dtype == torch.bfloat16:
|
||||
continue
|
||||
|
||||
model_single_file = self.model_class.from_single_file(self.ckpt_path, torch_dtype=dtype)
|
||||
|
||||
assert model_single_file.dtype == dtype, f"Expected dtype {dtype}, got {model_single_file.dtype}"
|
||||
|
||||
# Cleanup
|
||||
del model_single_file
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_checkpoint_variant_loading(self):
|
||||
if not hasattr(self, "alternate_ckpt_paths") or not self.alternate_ckpt_paths:
|
||||
return
|
||||
|
||||
for ckpt_path in self.alternate_ckpt_paths:
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
single_file_kwargs = {}
|
||||
if hasattr(self, "torch_dtype") and self.torch_dtype:
|
||||
single_file_kwargs["torch_dtype"] = self.torch_dtype
|
||||
|
||||
model = self.model_class.from_single_file(ckpt_path, **single_file_kwargs)
|
||||
|
||||
assert model is not None, f"Failed to load checkpoint from {ckpt_path}"
|
||||
|
||||
del model
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
@@ -1,224 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from diffusers.training_utils import EMAModel
|
||||
|
||||
from ...testing_utils import is_training, require_torch_accelerator_with_training, torch_all_close, torch_device
|
||||
|
||||
|
||||
@is_training
|
||||
@require_torch_accelerator_with_training
|
||||
class TrainingTesterMixin:
|
||||
"""
|
||||
Mixin class for testing training functionality on models.
|
||||
|
||||
Expected class attributes to be set by subclasses:
|
||||
- model_class: The model class to test
|
||||
|
||||
Expected methods to be implemented by subclasses:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
Expected properties to be implemented by subclasses:
|
||||
- output_shape: Tuple defining the expected output shape
|
||||
|
||||
Pytest mark: training
|
||||
Use `pytest -m "not training"` to skip these tests
|
||||
"""
|
||||
|
||||
def test_training(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.train()
|
||||
output = model(**inputs_dict)
|
||||
|
||||
if isinstance(output, dict):
|
||||
output = output.to_tuple()[0]
|
||||
|
||||
noise = torch.randn((output.shape[0],) + self.output_shape).to(torch_device)
|
||||
loss = torch.nn.functional.mse_loss(output, noise)
|
||||
loss.backward()
|
||||
|
||||
def test_training_with_ema(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.train()
|
||||
ema_model = EMAModel(model.parameters())
|
||||
|
||||
output = model(**inputs_dict)
|
||||
|
||||
if isinstance(output, dict):
|
||||
output = output.to_tuple()[0]
|
||||
|
||||
noise = torch.randn((output.shape[0],) + self.output_shape).to(torch_device)
|
||||
loss = torch.nn.functional.mse_loss(output, noise)
|
||||
loss.backward()
|
||||
ema_model.step(model.parameters())
|
||||
|
||||
def test_gradient_checkpointing(self):
|
||||
if not self.model_class._supports_gradient_checkpointing:
|
||||
pytest.skip("Gradient checkpointing is not supported.")
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
|
||||
# at init model should have gradient checkpointing disabled
|
||||
model = self.model_class(**init_dict)
|
||||
assert not model.is_gradient_checkpointing, "Gradient checkpointing should be disabled at init"
|
||||
|
||||
# check enable works
|
||||
model.enable_gradient_checkpointing()
|
||||
assert model.is_gradient_checkpointing, "Gradient checkpointing should be enabled"
|
||||
|
||||
# check disable works
|
||||
model.disable_gradient_checkpointing()
|
||||
assert not model.is_gradient_checkpointing, "Gradient checkpointing should be disabled"
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self, expected_set=None):
|
||||
if not self.model_class._supports_gradient_checkpointing:
|
||||
pytest.skip("Gradient checkpointing is not supported.")
|
||||
|
||||
if expected_set is None:
|
||||
pytest.skip("expected_set must be provided to verify gradient checkpointing is applied.")
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
|
||||
model_class_copy = copy.copy(self.model_class)
|
||||
model = model_class_copy(**init_dict)
|
||||
model.enable_gradient_checkpointing()
|
||||
|
||||
modules_with_gc_enabled = {}
|
||||
for submodule in model.modules():
|
||||
if hasattr(submodule, "gradient_checkpointing"):
|
||||
assert submodule.gradient_checkpointing, f"{submodule.__class__.__name__} should have GC enabled"
|
||||
modules_with_gc_enabled[submodule.__class__.__name__] = True
|
||||
|
||||
assert set(modules_with_gc_enabled.keys()) == expected_set, (
|
||||
f"Modules with GC enabled {set(modules_with_gc_enabled.keys())} "
|
||||
f"do not match expected set {expected_set}"
|
||||
)
|
||||
assert all(modules_with_gc_enabled.values()), "All modules should have GC enabled"
|
||||
|
||||
def test_gradient_checkpointing_equivalence(self, loss_tolerance=1e-5, param_grad_tol=5e-5, skip=None):
|
||||
if not self.model_class._supports_gradient_checkpointing:
|
||||
pytest.skip("Gradient checkpointing is not supported.")
|
||||
|
||||
if skip is None:
|
||||
skip = set()
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
inputs_dict_copy = copy.deepcopy(inputs_dict)
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
assert not model.is_gradient_checkpointing and model.training
|
||||
|
||||
out = model(**inputs_dict)
|
||||
if isinstance(out, dict):
|
||||
out = out.sample if hasattr(out, "sample") else out.to_tuple()[0]
|
||||
|
||||
# run the backwards pass on the model
|
||||
model.zero_grad()
|
||||
|
||||
labels = torch.randn_like(out)
|
||||
loss = (out - labels).mean()
|
||||
loss.backward()
|
||||
|
||||
# re-instantiate the model now enabling gradient checkpointing
|
||||
torch.manual_seed(0)
|
||||
model_2 = self.model_class(**init_dict)
|
||||
# clone model
|
||||
model_2.load_state_dict(model.state_dict())
|
||||
model_2.to(torch_device)
|
||||
model_2.enable_gradient_checkpointing()
|
||||
|
||||
assert model_2.is_gradient_checkpointing and model_2.training
|
||||
|
||||
out_2 = model_2(**inputs_dict_copy)
|
||||
if isinstance(out_2, dict):
|
||||
out_2 = out_2.sample if hasattr(out_2, "sample") else out_2.to_tuple()[0]
|
||||
|
||||
# run the backwards pass on the model
|
||||
model_2.zero_grad()
|
||||
loss_2 = (out_2 - labels).mean()
|
||||
loss_2.backward()
|
||||
|
||||
# compare the output and parameters gradients
|
||||
assert (
|
||||
loss - loss_2
|
||||
).abs() < loss_tolerance, f"Loss difference {(loss - loss_2).abs()} exceeds tolerance {loss_tolerance}"
|
||||
|
||||
named_params = dict(model.named_parameters())
|
||||
named_params_2 = dict(model_2.named_parameters())
|
||||
|
||||
for name, param in named_params.items():
|
||||
if "post_quant_conv" in name:
|
||||
continue
|
||||
if name in skip:
|
||||
continue
|
||||
if param.grad is None:
|
||||
continue
|
||||
|
||||
assert torch_all_close(
|
||||
param.grad.data, named_params_2[name].grad.data, atol=param_grad_tol
|
||||
), f"Gradient mismatch for {name}"
|
||||
|
||||
def test_mixed_precision_training(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.train()
|
||||
|
||||
# Test with float16
|
||||
if torch.device(torch_device).type != "cpu":
|
||||
with torch.amp.autocast(device_type=torch.device(torch_device).type, dtype=torch.float16):
|
||||
output = model(**inputs_dict)
|
||||
|
||||
if isinstance(output, dict):
|
||||
output = output.to_tuple()[0]
|
||||
|
||||
noise = torch.randn((output.shape[0],) + self.output_shape).to(torch_device)
|
||||
loss = torch.nn.functional.mse_loss(output, noise)
|
||||
|
||||
loss.backward()
|
||||
|
||||
# Test with bfloat16
|
||||
if torch.device(torch_device).type != "cpu":
|
||||
model.zero_grad()
|
||||
with torch.amp.autocast(device_type=torch.device(torch_device).type, dtype=torch.bfloat16):
|
||||
output = model(**inputs_dict)
|
||||
|
||||
if isinstance(output, dict):
|
||||
output = output.to_tuple()[0]
|
||||
|
||||
noise = torch.randn((output.shape[0],) + self.output_shape).to(torch_device)
|
||||
loss = torch.nn.functional.mse_loss(output, noise)
|
||||
|
||||
loss.backward()
|
||||
@@ -1,316 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import FluxTransformer2DModel
|
||||
from diffusers.models.embeddings import ImageProjection
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..test_modeling_common import LoraHotSwappingForModelTesterMixin
|
||||
from ..testing_utils import (
|
||||
AttentionTesterMixin,
|
||||
BitsAndBytesTesterMixin,
|
||||
GGUFTesterMixin,
|
||||
IPAdapterTesterMixin,
|
||||
LoraTesterMixin,
|
||||
MemoryTesterMixin,
|
||||
ModelOptTesterMixin,
|
||||
ModelTesterMixin,
|
||||
QuantoTesterMixin,
|
||||
SingleFileTesterMixin,
|
||||
TorchAoTesterMixin,
|
||||
TorchCompileTesterMixin,
|
||||
TrainingTesterMixin,
|
||||
)
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class FluxTransformerTesterConfig:
|
||||
model_class = FluxTransformer2DModel
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-pipe"
|
||||
pretrained_model_kwargs = {"subfolder": "transformer"}
|
||||
|
||||
def get_init_dict(self):
|
||||
"""Return Flux model initialization arguments."""
|
||||
return {
|
||||
"patch_size": 1,
|
||||
"in_channels": 4,
|
||||
"num_layers": 1,
|
||||
"num_single_layers": 1,
|
||||
"attention_head_dim": 16,
|
||||
"num_attention_heads": 2,
|
||||
"joint_attention_dim": 32,
|
||||
"pooled_projection_dim": 32,
|
||||
"axes_dims_rope": [4, 4, 8],
|
||||
}
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
batch_size = 1
|
||||
height = width = 4
|
||||
num_latent_channels = 4
|
||||
num_image_channels = 3
|
||||
sequence_length = 24
|
||||
embedding_dim = 8
|
||||
|
||||
return {
|
||||
"hidden_states": randn_tensor((batch_size, height * width, num_latent_channels)),
|
||||
"encoder_hidden_states": randn_tensor((batch_size, sequence_length, embedding_dim)),
|
||||
"pooled_projections": randn_tensor((batch_size, embedding_dim)),
|
||||
"img_ids": randn_tensor((height * width, num_image_channels)),
|
||||
"txt_ids": randn_tensor((sequence_length, num_image_channels)),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size),
|
||||
}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (16, 4)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (16, 4)
|
||||
|
||||
|
||||
class TestFluxTransformer(FluxTransformerTesterConfig, ModelTesterMixin):
|
||||
def test_deprecated_inputs_img_txt_ids_3d(self):
|
||||
"""Test that deprecated 3D img_ids and txt_ids still work."""
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
output_1 = model(**inputs_dict).to_tuple()[0]
|
||||
|
||||
# update inputs_dict with txt_ids and img_ids as 3d tensors (deprecated)
|
||||
text_ids_3d = inputs_dict["txt_ids"].unsqueeze(0)
|
||||
image_ids_3d = inputs_dict["img_ids"].unsqueeze(0)
|
||||
|
||||
assert text_ids_3d.ndim == 3, "text_ids_3d should be a 3d tensor"
|
||||
assert image_ids_3d.ndim == 3, "img_ids_3d should be a 3d tensor"
|
||||
|
||||
inputs_dict["txt_ids"] = text_ids_3d
|
||||
inputs_dict["img_ids"] = image_ids_3d
|
||||
|
||||
with torch.no_grad():
|
||||
output_2 = model(**inputs_dict).to_tuple()[0]
|
||||
|
||||
assert output_1.shape == output_2.shape
|
||||
assert torch.allclose(output_1, output_2, atol=1e-5), (
|
||||
"output with deprecated inputs (img_ids and txt_ids as 3d torch tensors) "
|
||||
"are not equal as them as 2d inputs"
|
||||
)
|
||||
|
||||
|
||||
class TestFluxTransformerMemory(FluxTransformerTesterConfig, MemoryTesterMixin):
|
||||
"""Memory optimization tests for Flux Transformer."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TestFluxTransformerTraining(FluxTransformerTesterConfig, TrainingTesterMixin):
|
||||
"""Training tests for Flux Transformer."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TestFluxTransformerAttention(FluxTransformerTesterConfig, AttentionTesterMixin):
|
||||
"""Attention processor tests for Flux Transformer."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TestFluxTransformerIPAdapter(FluxTransformerTesterConfig, IPAdapterTesterMixin):
|
||||
"""IP Adapter tests for Flux Transformer."""
|
||||
|
||||
def create_ip_adapter_state_dict(self, model):
|
||||
from diffusers.models.transformers.transformer_flux import FluxIPAdapterAttnProcessor
|
||||
|
||||
ip_cross_attn_state_dict = {}
|
||||
key_id = 0
|
||||
|
||||
for name in model.attn_processors.keys():
|
||||
if name.startswith("single_transformer_blocks"):
|
||||
continue
|
||||
|
||||
joint_attention_dim = model.config["joint_attention_dim"]
|
||||
hidden_size = model.config["num_attention_heads"] * model.config["attention_head_dim"]
|
||||
sd = FluxIPAdapterAttnProcessor(
|
||||
hidden_size=hidden_size, cross_attention_dim=joint_attention_dim, scale=1.0
|
||||
).state_dict()
|
||||
ip_cross_attn_state_dict.update(
|
||||
{
|
||||
f"{key_id}.to_k_ip.weight": sd["to_k_ip.0.weight"],
|
||||
f"{key_id}.to_v_ip.weight": sd["to_v_ip.0.weight"],
|
||||
f"{key_id}.to_k_ip.bias": sd["to_k_ip.0.bias"],
|
||||
f"{key_id}.to_v_ip.bias": sd["to_v_ip.0.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
key_id += 1
|
||||
|
||||
image_projection = ImageProjection(
|
||||
cross_attention_dim=model.config["joint_attention_dim"],
|
||||
image_embed_dim=(
|
||||
model.config["pooled_projection_dim"] if "pooled_projection_dim" in model.config.keys() else 768
|
||||
),
|
||||
num_image_text_embeds=4,
|
||||
)
|
||||
|
||||
ip_image_projection_state_dict = {}
|
||||
sd = image_projection.state_dict()
|
||||
ip_image_projection_state_dict.update(
|
||||
{
|
||||
"proj.weight": sd["image_embeds.weight"],
|
||||
"proj.bias": sd["image_embeds.bias"],
|
||||
"norm.weight": sd["norm.weight"],
|
||||
"norm.bias": sd["norm.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
del sd
|
||||
ip_state_dict = {}
|
||||
ip_state_dict.update({"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict})
|
||||
return ip_state_dict
|
||||
|
||||
|
||||
class TestFluxTransformerLoRA(FluxTransformerTesterConfig, LoraTesterMixin):
|
||||
"""LoRA adapter tests for Flux Transformer."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TestFluxTransformerLoRAHotSwap(FluxTransformerTesterConfig, LoraHotSwappingForModelTesterMixin):
|
||||
"""LoRA hot-swapping tests for Flux Transformer."""
|
||||
|
||||
different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
|
||||
|
||||
def get_dummy_inputs(self, height=4, width=4):
|
||||
"""Override to support dynamic height/width for LoRA hotswap tests."""
|
||||
batch_size = 1
|
||||
num_latent_channels = 4
|
||||
num_image_channels = 3
|
||||
sequence_length = 24
|
||||
embedding_dim = 8
|
||||
|
||||
return {
|
||||
"hidden_states": randn_tensor((batch_size, height * width, num_latent_channels)),
|
||||
"encoder_hidden_states": randn_tensor((batch_size, sequence_length, embedding_dim)),
|
||||
"pooled_projections": randn_tensor((batch_size, embedding_dim)),
|
||||
"img_ids": randn_tensor((height * width, num_image_channels)),
|
||||
"txt_ids": randn_tensor((sequence_length, num_image_channels)),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size),
|
||||
}
|
||||
|
||||
|
||||
class TestFluxTransformerCompile(FluxTransformerTesterConfig, TorchCompileTesterMixin):
|
||||
different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
|
||||
|
||||
def get_dummy_inputs(self, height=4, width=4):
|
||||
"""Override to support dynamic height/width for compilation tests."""
|
||||
batch_size = 1
|
||||
num_latent_channels = 4
|
||||
num_image_channels = 3
|
||||
sequence_length = 24
|
||||
embedding_dim = 8
|
||||
|
||||
return {
|
||||
"hidden_states": randn_tensor((batch_size, height * width, num_latent_channels)),
|
||||
"encoder_hidden_states": randn_tensor((batch_size, sequence_length, embedding_dim)),
|
||||
"pooled_projections": randn_tensor((batch_size, embedding_dim)),
|
||||
"img_ids": randn_tensor((height * width, num_image_channels)),
|
||||
"txt_ids": randn_tensor((sequence_length, num_image_channels)),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size),
|
||||
}
|
||||
|
||||
|
||||
class TestFluxSingleFile(FluxTransformerTesterConfig, SingleFileTesterMixin):
|
||||
ckpt_path = "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/flux1-dev.safetensors"
|
||||
alternate_keys_ckpt_paths = ["https://huggingface.co/Comfy-Org/flux1-dev/blob/main/flux1-dev-fp8.safetensors"]
|
||||
pretrained_model_name_or_path = "black-forest-labs/FLUX.1-dev"
|
||||
subfolder = "transformer"
|
||||
pass
|
||||
|
||||
|
||||
class TestFluxTransformerBitsAndBytes(FluxTransformerTesterConfig, BitsAndBytesTesterMixin):
|
||||
def get_dummy_inputs(self):
|
||||
return {
|
||||
"hidden_states": randn_tensor((1, 4096, 64)),
|
||||
"encoder_hidden_states": randn_tensor((1, 512, 4096)),
|
||||
"pooled_projections": randn_tensor((1, 768)),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device),
|
||||
"img_ids": randn_tensor((4096, 3)),
|
||||
"txt_ids": randn_tensor((512, 3)),
|
||||
"guidance": torch.tensor([3.5]).to(torch_device),
|
||||
}
|
||||
|
||||
|
||||
class TestFluxTransformerQuanto(FluxTransformerTesterConfig, QuantoTesterMixin):
|
||||
def get_dummy_inputs(self):
|
||||
return {
|
||||
"hidden_states": randn_tensor((1, 4096, 64)),
|
||||
"encoder_hidden_states": randn_tensor((1, 512, 4096)),
|
||||
"pooled_projections": randn_tensor((1, 768)),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device),
|
||||
"img_ids": randn_tensor((4096, 3)),
|
||||
"txt_ids": randn_tensor((512, 3)),
|
||||
"guidance": torch.tensor([3.5]).to(torch_device),
|
||||
}
|
||||
|
||||
|
||||
class TestFluxTransformerTorchAo(FluxTransformerTesterConfig, TorchAoTesterMixin):
|
||||
def get_dummy_inputs(self):
|
||||
return {
|
||||
"hidden_states": randn_tensor((1, 4096, 64)),
|
||||
"encoder_hidden_states": randn_tensor((1, 512, 4096)),
|
||||
"pooled_projections": randn_tensor((1, 768)),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device),
|
||||
"img_ids": randn_tensor((4096, 3)),
|
||||
"txt_ids": randn_tensor((512, 3)),
|
||||
"guidance": torch.tensor([3.5]).to(torch_device),
|
||||
}
|
||||
|
||||
|
||||
class TestFluxTransformerGGUF(FluxTransformerTesterConfig, GGUFTesterMixin):
|
||||
gguf_filename = "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q8_0.gguf"
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
return {
|
||||
"hidden_states": randn_tensor((1, 4096, 64)),
|
||||
"encoder_hidden_states": randn_tensor((1, 512, 4096)),
|
||||
"pooled_projections": randn_tensor((1, 768)),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device),
|
||||
"img_ids": randn_tensor((4096, 3)),
|
||||
"txt_ids": randn_tensor((512, 3)),
|
||||
"guidance": torch.tensor([3.5]).to(torch_device),
|
||||
}
|
||||
|
||||
|
||||
class TestFluxTransformerModelOpt(FluxTransformerTesterConfig, ModelOptTesterMixin):
|
||||
def get_dummy_inputs(self):
|
||||
return {
|
||||
"hidden_states": randn_tensor((1, 4096, 64)),
|
||||
"encoder_hidden_states": randn_tensor((1, 512, 4096)),
|
||||
"pooled_projections": randn_tensor((1, 768)),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device),
|
||||
"img_ids": randn_tensor((4096, 3)),
|
||||
"txt_ids": randn_tensor((512, 3)),
|
||||
"guidance": torch.tensor([3.5]).to(torch_device),
|
||||
}
|
||||
@@ -98,9 +98,9 @@ class GGUFCudaKernelsTests(unittest.TestCase):
|
||||
output_native = linear.forward_native(x)
|
||||
output_cuda = linear.forward_cuda(x)
|
||||
|
||||
assert torch.allclose(
|
||||
output_native, output_cuda, 1e-2
|
||||
), f"GGUF CUDA Kernel Output is different from Native Output for {quant_type}"
|
||||
assert torch.allclose(output_native, output_cuda, 1e-2), (
|
||||
f"GGUF CUDA Kernel Output is different from Native Output for {quant_type}"
|
||||
)
|
||||
|
||||
|
||||
@nightly
|
||||
|
||||
+81
-197
@@ -13,6 +13,7 @@ import struct
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
import unittest
|
||||
import urllib.parse
|
||||
from collections import UserDict
|
||||
from contextlib import contextmanager
|
||||
@@ -23,7 +24,6 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Tupl
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
import PIL.ImageOps
|
||||
import pytest
|
||||
import requests
|
||||
from numpy.linalg import norm
|
||||
from packaging import version
|
||||
@@ -241,6 +241,7 @@ def parse_flag_from_env(key, default=False):
|
||||
|
||||
_run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False)
|
||||
_run_nightly_tests = parse_flag_from_env("RUN_NIGHTLY", default=False)
|
||||
_run_compile_tests = parse_flag_from_env("RUN_COMPILE", default=False)
|
||||
|
||||
|
||||
def floats_tensor(shape, scale=1.0, rng=None, name=None):
|
||||
@@ -266,7 +267,7 @@ def slow(test_case):
|
||||
Slow tests are skipped by default. Set the RUN_SLOW environment variable to a truthy value to run them.
|
||||
|
||||
"""
|
||||
return pytest.mark.skipif(not _run_slow_tests, reason="test is slow")(test_case)
|
||||
return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case)
|
||||
|
||||
|
||||
def nightly(test_case):
|
||||
@@ -276,149 +277,33 @@ def nightly(test_case):
|
||||
Slow tests are skipped by default. Set the RUN_NIGHTLY environment variable to a truthy value to run them.
|
||||
|
||||
"""
|
||||
return pytest.mark.skipif(not _run_nightly_tests, reason="test is nightly")(test_case)
|
||||
return unittest.skipUnless(_run_nightly_tests, "test is nightly")(test_case)
|
||||
|
||||
|
||||
def is_torch_compile(test_case):
|
||||
"""
|
||||
Decorator marking a test as a torch.compile test. These tests can be filtered using:
|
||||
pytest -m "not compile" to skip
|
||||
pytest -m compile to run only these tests
|
||||
"""
|
||||
return pytest.mark.compile(test_case)
|
||||
Decorator marking a test that runs compile tests in the diffusers CI.
|
||||
|
||||
Compile tests are skipped by default. Set the RUN_COMPILE environment variable to a truthy value to run them.
|
||||
|
||||
def is_single_file(test_case):
|
||||
"""
|
||||
Decorator marking a test as a single file loading test. These tests can be filtered using:
|
||||
pytest -m "not single_file" to skip
|
||||
pytest -m single_file to run only these tests
|
||||
"""
|
||||
return pytest.mark.single_file(test_case)
|
||||
|
||||
|
||||
def is_lora(test_case):
|
||||
"""
|
||||
Decorator marking a test as a LoRA test. These tests can be filtered using:
|
||||
pytest -m "not lora" to skip
|
||||
pytest -m lora to run only these tests
|
||||
"""
|
||||
return pytest.mark.lora(test_case)
|
||||
|
||||
|
||||
def is_ip_adapter(test_case):
|
||||
"""
|
||||
Decorator marking a test as an IP Adapter test. These tests can be filtered using:
|
||||
pytest -m "not ip_adapter" to skip
|
||||
pytest -m ip_adapter to run only these tests
|
||||
"""
|
||||
return pytest.mark.ip_adapter(test_case)
|
||||
|
||||
|
||||
def is_training(test_case):
|
||||
"""
|
||||
Decorator marking a test as a training test. These tests can be filtered using:
|
||||
pytest -m "not training" to skip
|
||||
pytest -m training to run only these tests
|
||||
"""
|
||||
return pytest.mark.training(test_case)
|
||||
|
||||
|
||||
def is_attention(test_case):
|
||||
"""
|
||||
Decorator marking a test as an attention test. These tests can be filtered using:
|
||||
pytest -m "not attention" to skip
|
||||
pytest -m attention to run only these tests
|
||||
"""
|
||||
return pytest.mark.attention(test_case)
|
||||
|
||||
|
||||
def is_memory(test_case):
|
||||
"""
|
||||
Decorator marking a test as a memory optimization test. These tests can be filtered using:
|
||||
pytest -m "not memory" to skip
|
||||
pytest -m memory to run only these tests
|
||||
"""
|
||||
return pytest.mark.memory(test_case)
|
||||
|
||||
|
||||
def is_cpu_offload(test_case):
|
||||
"""
|
||||
Decorator marking a test as a CPU offload test. These tests can be filtered using:
|
||||
pytest -m "not cpu_offload" to skip
|
||||
pytest -m cpu_offload to run only these tests
|
||||
"""
|
||||
return pytest.mark.cpu_offload(test_case)
|
||||
|
||||
|
||||
def is_group_offload(test_case):
|
||||
"""
|
||||
Decorator marking a test as a group offload test. These tests can be filtered using:
|
||||
pytest -m "not group_offload" to skip
|
||||
pytest -m group_offload to run only these tests
|
||||
"""
|
||||
return pytest.mark.group_offload(test_case)
|
||||
|
||||
|
||||
def is_bitsandbytes(test_case):
|
||||
"""
|
||||
Decorator marking a test as a BitsAndBytes quantization test. These tests can be filtered using:
|
||||
pytest -m "not bitsandbytes" to skip
|
||||
pytest -m bitsandbytes to run only these tests
|
||||
"""
|
||||
return pytest.mark.bitsandbytes(test_case)
|
||||
|
||||
|
||||
def is_quanto(test_case):
|
||||
"""
|
||||
Decorator marking a test as a Quanto quantization test. These tests can be filtered using:
|
||||
pytest -m "not quanto" to skip
|
||||
pytest -m quanto to run only these tests
|
||||
"""
|
||||
return pytest.mark.quanto(test_case)
|
||||
|
||||
|
||||
def is_torchao(test_case):
|
||||
"""
|
||||
Decorator marking a test as a TorchAO quantization test. These tests can be filtered using:
|
||||
pytest -m "not torchao" to skip
|
||||
pytest -m torchao to run only these tests
|
||||
"""
|
||||
return pytest.mark.torchao(test_case)
|
||||
|
||||
|
||||
def is_gguf(test_case):
|
||||
"""
|
||||
Decorator marking a test as a GGUF quantization test. These tests can be filtered using:
|
||||
pytest -m "not gguf" to skip
|
||||
pytest -m gguf to run only these tests
|
||||
"""
|
||||
return pytest.mark.gguf(test_case)
|
||||
|
||||
|
||||
def is_modelopt(test_case):
|
||||
"""
|
||||
Decorator marking a test as a NVIDIA ModelOpt quantization test. These tests can be filtered using:
|
||||
pytest -m "not modelopt" to skip
|
||||
pytest -m modelopt to run only these tests
|
||||
"""
|
||||
return pytest.mark.modelopt(test_case)
|
||||
return unittest.skipUnless(_run_compile_tests, "test is torch compile")(test_case)
|
||||
|
||||
|
||||
def require_torch(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires PyTorch. These tests are skipped when PyTorch isn't installed.
|
||||
"""
|
||||
return pytest.mark.skipif(not is_torch_available(), reason="test requires PyTorch")(test_case)
|
||||
return unittest.skipUnless(is_torch_available(), "test requires PyTorch")(test_case)
|
||||
|
||||
|
||||
def require_torch_2(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires PyTorch 2. These tests are skipped when it isn't installed.
|
||||
"""
|
||||
return pytest.mark.skipif(
|
||||
not (is_torch_available() and is_torch_version(">=", "2.0.0")), reason="test requires PyTorch 2"
|
||||
)(test_case)
|
||||
return unittest.skipUnless(is_torch_available() and is_torch_version(">=", "2.0.0"), "test requires PyTorch 2")(
|
||||
test_case
|
||||
)
|
||||
|
||||
|
||||
def require_torch_version_greater_equal(torch_version):
|
||||
@@ -426,9 +311,8 @@ def require_torch_version_greater_equal(torch_version):
|
||||
|
||||
def decorator(test_case):
|
||||
correct_torch_version = is_torch_available() and is_torch_version(">=", torch_version)
|
||||
return pytest.mark.skipif(
|
||||
not correct_torch_version,
|
||||
reason=f"test requires torch with the version greater than or equal to {torch_version}",
|
||||
return unittest.skipUnless(
|
||||
correct_torch_version, f"test requires torch with the version greater than or equal to {torch_version}"
|
||||
)(test_case)
|
||||
|
||||
return decorator
|
||||
@@ -439,8 +323,8 @@ def require_torch_version_greater(torch_version):
|
||||
|
||||
def decorator(test_case):
|
||||
correct_torch_version = is_torch_available() and is_torch_version(">", torch_version)
|
||||
return pytest.mark.skipif(
|
||||
not correct_torch_version, reason=f"test requires torch with the version greater than {torch_version}"
|
||||
return unittest.skipUnless(
|
||||
correct_torch_version, f"test requires torch with the version greater than {torch_version}"
|
||||
)(test_case)
|
||||
|
||||
return decorator
|
||||
@@ -448,18 +332,19 @@ def require_torch_version_greater(torch_version):
|
||||
|
||||
def require_torch_gpu(test_case):
|
||||
"""Decorator marking a test that requires CUDA and PyTorch."""
|
||||
return pytest.mark.skipif(torch_device != "cuda", reason="test requires PyTorch+CUDA")(test_case)
|
||||
return unittest.skipUnless(is_torch_available() and torch_device == "cuda", "test requires PyTorch+CUDA")(
|
||||
test_case
|
||||
)
|
||||
|
||||
|
||||
def require_torch_cuda_compatibility(expected_compute_capability):
|
||||
def decorator(test_case):
|
||||
if torch.cuda.is_available():
|
||||
current_compute_capability = get_torch_cuda_device_capability()
|
||||
return pytest.mark.skipif(
|
||||
float(current_compute_capability) != float(expected_compute_capability),
|
||||
reason="Test not supported for this compute capability.",
|
||||
)(test_case)
|
||||
return test_case
|
||||
return unittest.skipUnless(
|
||||
float(current_compute_capability) == float(expected_compute_capability),
|
||||
"Test not supported for this compute capability.",
|
||||
)
|
||||
|
||||
return decorator
|
||||
|
||||
@@ -467,7 +352,9 @@ def require_torch_cuda_compatibility(expected_compute_capability):
|
||||
# These decorators are for accelerator-specific behaviours that are not GPU-specific
|
||||
def require_torch_accelerator(test_case):
|
||||
"""Decorator marking a test that requires an accelerator backend and PyTorch."""
|
||||
return pytest.mark.skipif(torch_device == "cpu", reason="test requires accelerator+PyTorch")(test_case)
|
||||
return unittest.skipUnless(is_torch_available() and torch_device != "cpu", "test requires accelerator+PyTorch")(
|
||||
test_case
|
||||
)
|
||||
|
||||
|
||||
def require_torch_multi_gpu(test_case):
|
||||
@@ -477,11 +364,11 @@ def require_torch_multi_gpu(test_case):
|
||||
-k "multi_gpu"
|
||||
"""
|
||||
if not is_torch_available():
|
||||
return pytest.mark.skip(reason="test requires PyTorch")(test_case)
|
||||
return unittest.skip("test requires PyTorch")(test_case)
|
||||
|
||||
import torch
|
||||
|
||||
return pytest.mark.skipif(torch.cuda.device_count() <= 1, reason="test requires multiple GPUs")(test_case)
|
||||
return unittest.skipUnless(torch.cuda.device_count() > 1, "test requires multiple GPUs")(test_case)
|
||||
|
||||
|
||||
def require_torch_multi_accelerator(test_case):
|
||||
@@ -490,28 +377,27 @@ def require_torch_multi_accelerator(test_case):
|
||||
without multiple hardware accelerators.
|
||||
"""
|
||||
if not is_torch_available():
|
||||
return pytest.mark.skip(reason="test requires PyTorch")(test_case)
|
||||
return unittest.skip("test requires PyTorch")(test_case)
|
||||
|
||||
import torch
|
||||
|
||||
return pytest.mark.skipif(
|
||||
not (torch.cuda.device_count() > 1 or torch.xpu.device_count() > 1),
|
||||
reason="test requires multiple hardware accelerators",
|
||||
return unittest.skipUnless(
|
||||
torch.cuda.device_count() > 1 or torch.xpu.device_count() > 1, "test requires multiple hardware accelerators"
|
||||
)(test_case)
|
||||
|
||||
|
||||
def require_torch_accelerator_with_fp16(test_case):
|
||||
"""Decorator marking a test that requires an accelerator with support for the FP16 data type."""
|
||||
return pytest.mark.skipif(
|
||||
not _is_torch_fp16_available(torch_device), reason="test requires accelerator with fp16 support"
|
||||
)(test_case)
|
||||
return unittest.skipUnless(_is_torch_fp16_available(torch_device), "test requires accelerator with fp16 support")(
|
||||
test_case
|
||||
)
|
||||
|
||||
|
||||
def require_torch_accelerator_with_fp64(test_case):
|
||||
"""Decorator marking a test that requires an accelerator with support for the FP64 data type."""
|
||||
return pytest.mark.skipif(
|
||||
not _is_torch_fp64_available(torch_device), reason="test requires accelerator with fp64 support"
|
||||
)(test_case)
|
||||
return unittest.skipUnless(_is_torch_fp64_available(torch_device), "test requires accelerator with fp64 support")(
|
||||
test_case
|
||||
)
|
||||
|
||||
|
||||
def require_big_gpu_with_torch_cuda(test_case):
|
||||
@@ -520,17 +406,17 @@ def require_big_gpu_with_torch_cuda(test_case):
|
||||
etc.
|
||||
"""
|
||||
if not is_torch_available():
|
||||
return pytest.mark.skip(reason="test requires PyTorch")(test_case)
|
||||
return unittest.skip("test requires PyTorch")(test_case)
|
||||
|
||||
import torch
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
return pytest.mark.skip(reason="test requires PyTorch CUDA")(test_case)
|
||||
return unittest.skip("test requires PyTorch CUDA")(test_case)
|
||||
|
||||
device_properties = torch.cuda.get_device_properties(0)
|
||||
total_memory = device_properties.total_memory / (1024**3)
|
||||
return pytest.mark.skipif(
|
||||
total_memory < BIG_GPU_MEMORY, reason=f"test requires a GPU with at least {BIG_GPU_MEMORY} GB memory"
|
||||
return unittest.skipUnless(
|
||||
total_memory >= BIG_GPU_MEMORY, f"test requires a GPU with at least {BIG_GPU_MEMORY} GB memory"
|
||||
)(test_case)
|
||||
|
||||
|
||||
@@ -544,12 +430,12 @@ def require_big_accelerator(test_case):
|
||||
test_case = pytest.mark.big_accelerator(test_case)
|
||||
|
||||
if not is_torch_available():
|
||||
return pytest.mark.skip(reason="test requires PyTorch")(test_case)
|
||||
return unittest.skip("test requires PyTorch")(test_case)
|
||||
|
||||
import torch
|
||||
|
||||
if not (torch.cuda.is_available() or torch.xpu.is_available()):
|
||||
return pytest.mark.skip(reason="test requires PyTorch CUDA")(test_case)
|
||||
return unittest.skip("test requires PyTorch CUDA")(test_case)
|
||||
|
||||
if torch.xpu.is_available():
|
||||
device_properties = torch.xpu.get_device_properties(0)
|
||||
@@ -557,30 +443,30 @@ def require_big_accelerator(test_case):
|
||||
device_properties = torch.cuda.get_device_properties(0)
|
||||
|
||||
total_memory = device_properties.total_memory / (1024**3)
|
||||
return pytest.mark.skipif(
|
||||
total_memory < BIG_GPU_MEMORY,
|
||||
reason=f"test requires a hardware accelerator with at least {BIG_GPU_MEMORY} GB memory",
|
||||
return unittest.skipUnless(
|
||||
total_memory >= BIG_GPU_MEMORY,
|
||||
f"test requires a hardware accelerator with at least {BIG_GPU_MEMORY} GB memory",
|
||||
)(test_case)
|
||||
|
||||
|
||||
def require_torch_accelerator_with_training(test_case):
|
||||
"""Decorator marking a test that requires an accelerator with support for training."""
|
||||
return pytest.mark.skipif(
|
||||
not (is_torch_available() and backend_supports_training(torch_device)),
|
||||
reason="test requires accelerator with training support",
|
||||
return unittest.skipUnless(
|
||||
is_torch_available() and backend_supports_training(torch_device),
|
||||
"test requires accelerator with training support",
|
||||
)(test_case)
|
||||
|
||||
|
||||
def skip_mps(test_case):
|
||||
"""Decorator marking a test to skip if torch_device is 'mps'"""
|
||||
return pytest.mark.skipif(torch_device == "mps", reason="test requires non 'mps' device")(test_case)
|
||||
return unittest.skipUnless(torch_device != "mps", "test requires non 'mps' device")(test_case)
|
||||
|
||||
|
||||
def require_flax(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires JAX & Flax. These tests are skipped when one / both are not installed
|
||||
"""
|
||||
return pytest.mark.skipif(not is_flax_available(), reason="test requires JAX & Flax")(test_case)
|
||||
return unittest.skipUnless(is_flax_available(), "test requires JAX & Flax")(test_case)
|
||||
|
||||
|
||||
def require_compel(test_case):
|
||||
@@ -588,21 +474,21 @@ def require_compel(test_case):
|
||||
Decorator marking a test that requires compel: https://github.com/damian0815/compel. These tests are skipped when
|
||||
the library is not installed.
|
||||
"""
|
||||
return pytest.mark.skipif(not is_compel_available(), reason="test requires compel")(test_case)
|
||||
return unittest.skipUnless(is_compel_available(), "test requires compel")(test_case)
|
||||
|
||||
|
||||
def require_onnxruntime(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires onnxruntime. These tests are skipped when onnxruntime isn't installed.
|
||||
"""
|
||||
return pytest.mark.skipif(not is_onnx_available(), reason="test requires onnxruntime")(test_case)
|
||||
return unittest.skipUnless(is_onnx_available(), "test requires onnxruntime")(test_case)
|
||||
|
||||
|
||||
def require_note_seq(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires note_seq. These tests are skipped when note_seq isn't installed.
|
||||
"""
|
||||
return pytest.mark.skipif(not is_note_seq_available(), reason="test requires note_seq")(test_case)
|
||||
return unittest.skipUnless(is_note_seq_available(), "test requires note_seq")(test_case)
|
||||
|
||||
|
||||
def require_accelerator(test_case):
|
||||
@@ -610,14 +496,14 @@ def require_accelerator(test_case):
|
||||
Decorator marking a test that requires a hardware accelerator backend. These tests are skipped when there are no
|
||||
hardware accelerator available.
|
||||
"""
|
||||
return pytest.mark.skipif(torch_device == "cpu", reason="test requires a hardware accelerator")(test_case)
|
||||
return unittest.skipUnless(torch_device != "cpu", "test requires a hardware accelerator")(test_case)
|
||||
|
||||
|
||||
def require_torchsde(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires torchsde. These tests are skipped when torchsde isn't installed.
|
||||
"""
|
||||
return pytest.mark.skipif(not is_torchsde_available(), reason="test requires torchsde")(test_case)
|
||||
return unittest.skipUnless(is_torchsde_available(), "test requires torchsde")(test_case)
|
||||
|
||||
|
||||
def require_peft_backend(test_case):
|
||||
@@ -625,35 +511,35 @@ def require_peft_backend(test_case):
|
||||
Decorator marking a test that requires PEFT backend, this would require some specific versions of PEFT and
|
||||
transformers.
|
||||
"""
|
||||
return pytest.mark.skipif(not USE_PEFT_BACKEND, reason="test requires PEFT backend")(test_case)
|
||||
return unittest.skipUnless(USE_PEFT_BACKEND, "test requires PEFT backend")(test_case)
|
||||
|
||||
|
||||
def require_timm(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires timm. These tests are skipped when timm isn't installed.
|
||||
"""
|
||||
return pytest.mark.skipif(not is_timm_available(), reason="test requires timm")(test_case)
|
||||
return unittest.skipUnless(is_timm_available(), "test requires timm")(test_case)
|
||||
|
||||
|
||||
def require_bitsandbytes(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires bitsandbytes. These tests are skipped when bitsandbytes isn't installed.
|
||||
"""
|
||||
return pytest.mark.skipif(not is_bitsandbytes_available(), reason="test requires bitsandbytes")(test_case)
|
||||
return unittest.skipUnless(is_bitsandbytes_available(), "test requires bitsandbytes")(test_case)
|
||||
|
||||
|
||||
def require_quanto(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires quanto. These tests are skipped when quanto isn't installed.
|
||||
"""
|
||||
return pytest.mark.skipif(not is_optimum_quanto_available(), reason="test requires quanto")(test_case)
|
||||
return unittest.skipUnless(is_optimum_quanto_available(), "test requires quanto")(test_case)
|
||||
|
||||
|
||||
def require_accelerate(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires accelerate. These tests are skipped when accelerate isn't installed.
|
||||
"""
|
||||
return pytest.mark.skipif(not is_accelerate_available(), reason="test requires accelerate")(test_case)
|
||||
return unittest.skipUnless(is_accelerate_available(), "test requires accelerate")(test_case)
|
||||
|
||||
|
||||
def require_peft_version_greater(peft_version):
|
||||
@@ -666,8 +552,8 @@ def require_peft_version_greater(peft_version):
|
||||
correct_peft_version = is_peft_available() and version.parse(
|
||||
version.parse(importlib.metadata.version("peft")).base_version
|
||||
) > version.parse(peft_version)
|
||||
return pytest.mark.skipif(
|
||||
not correct_peft_version, reason=f"test requires PEFT backend with the version greater than {peft_version}"
|
||||
return unittest.skipUnless(
|
||||
correct_peft_version, f"test requires PEFT backend with the version greater than {peft_version}"
|
||||
)(test_case)
|
||||
|
||||
return decorator
|
||||
@@ -683,9 +569,9 @@ def require_transformers_version_greater(transformers_version):
|
||||
correct_transformers_version = is_transformers_available() and version.parse(
|
||||
version.parse(importlib.metadata.version("transformers")).base_version
|
||||
) > version.parse(transformers_version)
|
||||
return pytest.mark.skipif(
|
||||
not correct_transformers_version,
|
||||
reason=f"test requires transformers with the version greater than {transformers_version}",
|
||||
return unittest.skipUnless(
|
||||
correct_transformers_version,
|
||||
f"test requires transformers with the version greater than {transformers_version}",
|
||||
)(test_case)
|
||||
|
||||
return decorator
|
||||
@@ -696,9 +582,8 @@ def require_accelerate_version_greater(accelerate_version):
|
||||
correct_accelerate_version = is_accelerate_available() and version.parse(
|
||||
version.parse(importlib.metadata.version("accelerate")).base_version
|
||||
) > version.parse(accelerate_version)
|
||||
return pytest.mark.skipif(
|
||||
not correct_accelerate_version,
|
||||
reason=f"Test requires accelerate with the version greater than {accelerate_version}.",
|
||||
return unittest.skipUnless(
|
||||
correct_accelerate_version, f"Test requires accelerate with the version greater than {accelerate_version}."
|
||||
)(test_case)
|
||||
|
||||
return decorator
|
||||
@@ -709,8 +594,8 @@ def require_bitsandbytes_version_greater(bnb_version):
|
||||
correct_bnb_version = is_bitsandbytes_available() and version.parse(
|
||||
version.parse(importlib.metadata.version("bitsandbytes")).base_version
|
||||
) > version.parse(bnb_version)
|
||||
return pytest.mark.skipif(
|
||||
not correct_bnb_version, reason=f"Test requires bitsandbytes with the version greater than {bnb_version}."
|
||||
return unittest.skipUnless(
|
||||
correct_bnb_version, f"Test requires bitsandbytes with the version greater than {bnb_version}."
|
||||
)(test_case)
|
||||
|
||||
return decorator
|
||||
@@ -721,9 +606,8 @@ def require_hf_hub_version_greater(hf_hub_version):
|
||||
correct_hf_hub_version = version.parse(
|
||||
version.parse(importlib.metadata.version("huggingface_hub")).base_version
|
||||
) > version.parse(hf_hub_version)
|
||||
return pytest.mark.skipif(
|
||||
not correct_hf_hub_version,
|
||||
reason=f"Test requires huggingface_hub with the version greater than {hf_hub_version}.",
|
||||
return unittest.skipUnless(
|
||||
correct_hf_hub_version, f"Test requires huggingface_hub with the version greater than {hf_hub_version}."
|
||||
)(test_case)
|
||||
|
||||
return decorator
|
||||
@@ -734,8 +618,8 @@ def require_gguf_version_greater_or_equal(gguf_version):
|
||||
correct_gguf_version = is_gguf_available() and version.parse(
|
||||
version.parse(importlib.metadata.version("gguf")).base_version
|
||||
) >= version.parse(gguf_version)
|
||||
return pytest.mark.skipif(
|
||||
not correct_gguf_version, reason=f"Test requires gguf with the version greater than {gguf_version}."
|
||||
return unittest.skipUnless(
|
||||
correct_gguf_version, f"Test requires gguf with the version greater than {gguf_version}."
|
||||
)(test_case)
|
||||
|
||||
return decorator
|
||||
@@ -746,8 +630,8 @@ def require_torchao_version_greater_or_equal(torchao_version):
|
||||
correct_torchao_version = is_torchao_available() and version.parse(
|
||||
version.parse(importlib.metadata.version("torchao")).base_version
|
||||
) >= version.parse(torchao_version)
|
||||
return pytest.mark.skipif(
|
||||
not correct_torchao_version, reason=f"Test requires torchao with version greater than {torchao_version}."
|
||||
return unittest.skipUnless(
|
||||
correct_torchao_version, f"Test requires torchao with version greater than {torchao_version}."
|
||||
)(test_case)
|
||||
|
||||
return decorator
|
||||
@@ -758,8 +642,8 @@ def require_kernels_version_greater_or_equal(kernels_version):
|
||||
correct_kernels_version = is_kernels_available() and version.parse(
|
||||
version.parse(importlib.metadata.version("kernels")).base_version
|
||||
) >= version.parse(kernels_version)
|
||||
return pytest.mark.skipif(
|
||||
not correct_kernels_version, reason=f"Test requires kernels with version greater than {kernels_version}."
|
||||
return unittest.skipUnless(
|
||||
correct_kernels_version, f"Test requires kernels with version greater than {kernels_version}."
|
||||
)(test_case)
|
||||
|
||||
return decorator
|
||||
@@ -769,7 +653,7 @@ def deprecate_after_peft_backend(test_case):
|
||||
"""
|
||||
Decorator marking a test that will be skipped after PEFT backend
|
||||
"""
|
||||
return pytest.mark.skipif(USE_PEFT_BACKEND, reason="test skipped in favor of PEFT backend")(test_case)
|
||||
return unittest.skipUnless(not USE_PEFT_BACKEND, "test skipped in favor of PEFT backend")(test_case)
|
||||
|
||||
|
||||
def get_python_version():
|
||||
@@ -1180,8 +1064,8 @@ def run_test_in_subprocess(test_case, target_func, inputs=None, timeout=None):
|
||||
To run a test in a subprocess. In particular, this can avoid (GPU) memory issue.
|
||||
|
||||
Args:
|
||||
test_case:
|
||||
The test case object that will run `target_func`.
|
||||
test_case (`unittest.TestCase`):
|
||||
The test that will run `target_func`.
|
||||
target_func (`Callable`):
|
||||
The function implementing the actual testing logic.
|
||||
inputs (`dict`, *optional*, defaults to `None`):
|
||||
@@ -1199,7 +1083,7 @@ def run_test_in_subprocess(test_case, target_func, inputs=None, timeout=None):
|
||||
input_queue = ctx.Queue(1)
|
||||
output_queue = ctx.JoinableQueue(1)
|
||||
|
||||
# We can't send test case objects to the child, otherwise we get issues regarding pickle.
|
||||
# We can't send `unittest.TestCase` to the child, otherwise we get issues regarding pickle.
|
||||
input_queue.put(inputs, timeout=timeout)
|
||||
|
||||
process = ctx.Process(target=target_func, args=(input_queue, output_queue, timeout))
|
||||
|
||||
Reference in New Issue
Block a user