Compare commits
31 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| e68c936f42 | |||
| 051c8a1c0f | |||
| d54622c267 | |||
| df8dd77817 | |||
| 9f3c0fdcd8 | |||
| dccc206e35 | |||
| 6f2ded53a1 | |||
| 6d2a80c14b | |||
| 219a8ab031 | |||
| 3a00e23f5a | |||
| 19fe63170c | |||
| 41381b1bb1 | |||
| bcada5bfaf | |||
| 4490e4cc44 | |||
| 27c1ac49b4 | |||
| 585c32b304 | |||
| ca5afaebca | |||
| 6c066f0e13 | |||
| fbb25a05be | |||
| fbc4c998ed | |||
| 56d2986d5d | |||
| a33ef355f6 | |||
| 85b7478fe9 | |||
| d1e6ffffad | |||
| 61c6eae207 | |||
| a076cd8e16 | |||
| 2b72beefe7 | |||
| 11bf2cf1d1 | |||
| 19921e9362 | |||
| 5aa4f1dc55 | |||
| 922e273e6b |
@@ -22,7 +22,7 @@ jobs:
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.8"
|
||||
python-version: "3.10"
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip install -e .
|
||||
|
||||
@@ -35,7 +35,7 @@ jobs:
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.8"
|
||||
python-version: "3.10"
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip install --upgrade pip
|
||||
@@ -55,7 +55,7 @@ jobs:
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.8"
|
||||
python-version: "3.10"
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip install --upgrade pip
|
||||
|
||||
@@ -36,7 +36,7 @@ jobs:
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.8"
|
||||
python-version: "3.10"
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip install --upgrade pip
|
||||
@@ -56,7 +56,7 @@ jobs:
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.8"
|
||||
python-version: "3.10"
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip install --upgrade pip
|
||||
|
||||
@@ -22,7 +22,7 @@ jobs:
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.8"
|
||||
python-version: "3.10"
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip install -e .
|
||||
|
||||
@@ -47,7 +47,7 @@ jobs:
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.8"
|
||||
python-version: "3.10"
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
|
||||
@@ -122,7 +122,7 @@ _deps = [
|
||||
"pytest",
|
||||
"pytest-timeout",
|
||||
"pytest-xdist",
|
||||
"python>=3.8.0",
|
||||
"python>=3.9.0",
|
||||
"ruff==0.9.10",
|
||||
"safetensors>=0.3.1",
|
||||
"sentencepiece>=0.1.91,!=0.1.92",
|
||||
@@ -287,7 +287,7 @@ setup(
|
||||
packages=find_packages("src"),
|
||||
package_data={"diffusers": ["py.typed"]},
|
||||
include_package_data=True,
|
||||
python_requires=">=3.8.0",
|
||||
python_requires=">=3.10.0",
|
||||
install_requires=list(install_requires),
|
||||
extras_require=extras,
|
||||
entry_points={"console_scripts": ["diffusers-cli=diffusers.commands.diffusers_cli:main"]},
|
||||
|
||||
+12
-12
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
|
||||
from .configuration_utils import ConfigMixin, register_to_config
|
||||
from .utils import CONFIG_NAME
|
||||
@@ -33,13 +33,13 @@ class PipelineCallback(ConfigMixin):
|
||||
raise ValueError("cutoff_step_ratio must be a float between 0.0 and 1.0.")
|
||||
|
||||
@property
|
||||
def tensor_inputs(self) -> List[str]:
|
||||
def tensor_inputs(self) -> list[str]:
|
||||
raise NotImplementedError(f"You need to set the attribute `tensor_inputs` for {self.__class__}")
|
||||
|
||||
def callback_fn(self, pipeline, step_index, timesteps, callback_kwargs) -> Dict[str, Any]:
|
||||
def callback_fn(self, pipeline, step_index, timesteps, callback_kwargs) -> dict[str, Any]:
|
||||
raise NotImplementedError(f"You need to implement the method `callback_fn` for {self.__class__}")
|
||||
|
||||
def __call__(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
|
||||
def __call__(self, pipeline, step_index, timestep, callback_kwargs) -> dict[str, Any]:
|
||||
return self.callback_fn(pipeline, step_index, timestep, callback_kwargs)
|
||||
|
||||
|
||||
@@ -49,14 +49,14 @@ class MultiPipelineCallbacks:
|
||||
provides a unified interface for calling all of them.
|
||||
"""
|
||||
|
||||
def __init__(self, callbacks: List[PipelineCallback]):
|
||||
def __init__(self, callbacks: list[PipelineCallback]):
|
||||
self.callbacks = callbacks
|
||||
|
||||
@property
|
||||
def tensor_inputs(self) -> List[str]:
|
||||
def tensor_inputs(self) -> list[str]:
|
||||
return [input for callback in self.callbacks for input in callback.tensor_inputs]
|
||||
|
||||
def __call__(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
|
||||
def __call__(self, pipeline, step_index, timestep, callback_kwargs) -> dict[str, Any]:
|
||||
"""
|
||||
Calls all the callbacks in order with the given arguments and returns the final callback_kwargs.
|
||||
"""
|
||||
@@ -76,7 +76,7 @@ class SDCFGCutoffCallback(PipelineCallback):
|
||||
|
||||
tensor_inputs = ["prompt_embeds"]
|
||||
|
||||
def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
|
||||
def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> dict[str, Any]:
|
||||
cutoff_step_ratio = self.config.cutoff_step_ratio
|
||||
cutoff_step_index = self.config.cutoff_step_index
|
||||
|
||||
@@ -109,7 +109,7 @@ class SDXLCFGCutoffCallback(PipelineCallback):
|
||||
"add_time_ids",
|
||||
]
|
||||
|
||||
def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
|
||||
def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> dict[str, Any]:
|
||||
cutoff_step_ratio = self.config.cutoff_step_ratio
|
||||
cutoff_step_index = self.config.cutoff_step_index
|
||||
|
||||
@@ -152,7 +152,7 @@ class SDXLControlnetCFGCutoffCallback(PipelineCallback):
|
||||
"image",
|
||||
]
|
||||
|
||||
def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
|
||||
def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> dict[str, Any]:
|
||||
cutoff_step_ratio = self.config.cutoff_step_ratio
|
||||
cutoff_step_index = self.config.cutoff_step_index
|
||||
|
||||
@@ -195,7 +195,7 @@ class IPAdapterScaleCutoffCallback(PipelineCallback):
|
||||
|
||||
tensor_inputs = []
|
||||
|
||||
def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
|
||||
def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> dict[str, Any]:
|
||||
cutoff_step_ratio = self.config.cutoff_step_ratio
|
||||
cutoff_step_index = self.config.cutoff_step_index
|
||||
|
||||
@@ -219,7 +219,7 @@ class SD3CFGCutoffCallback(PipelineCallback):
|
||||
|
||||
tensor_inputs = ["prompt_embeds", "pooled_prompt_embeds"]
|
||||
|
||||
def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
|
||||
def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> dict[str, Any]:
|
||||
cutoff_step_ratio = self.config.cutoff_step_ratio
|
||||
cutoff_step_index = self.config.cutoff_step_index
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ import os
|
||||
import re
|
||||
from collections import OrderedDict
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
from huggingface_hub import DDUFEntry, create_repo, hf_hub_download
|
||||
@@ -94,10 +94,10 @@ class ConfigMixin:
|
||||
Class attributes:
|
||||
- **config_name** (`str`) -- A filename under which the config should stored when calling
|
||||
[`~ConfigMixin.save_config`] (should be overridden by parent class).
|
||||
- **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be
|
||||
- **ignore_for_config** (`list[str]`) -- A list of attributes that should not be saved in the config (should be
|
||||
overridden by subclass).
|
||||
- **has_compatibles** (`bool`) -- Whether the class has compatible classes (should be overridden by subclass).
|
||||
- **_deprecated_kwargs** (`List[str]`) -- Keyword arguments that are deprecated. Note that the `init` function
|
||||
- **_deprecated_kwargs** (`list[str]`) -- Keyword arguments that are deprecated. Note that the `init` function
|
||||
should only have a `kwargs` argument if at least one argument is deprecated (should be overridden by
|
||||
subclass).
|
||||
"""
|
||||
@@ -143,7 +143,7 @@ class ConfigMixin:
|
||||
|
||||
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
|
||||
|
||||
def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
|
||||
def save_config(self, save_directory: str | os.PathLike, push_to_hub: bool = False, **kwargs):
|
||||
"""
|
||||
Save a configuration object to the directory specified in `save_directory` so that it can be reloaded using the
|
||||
[`~ConfigMixin.from_config`] class method.
|
||||
@@ -155,7 +155,7 @@ class ConfigMixin:
|
||||
Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
|
||||
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
|
||||
namespace).
|
||||
kwargs (`Dict[str, Any]`, *optional*):
|
||||
kwargs (`dict[str, Any]`, *optional*):
|
||||
Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
||||
"""
|
||||
if os.path.isfile(save_directory):
|
||||
@@ -189,13 +189,13 @@ class ConfigMixin:
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs
|
||||
) -> Union[Self, Tuple[Self, Dict[str, Any]]]:
|
||||
cls, config: FrozenDict | dict[str, Any] = None, return_unused_kwargs=False, **kwargs
|
||||
) -> Self | tuple[Self, dict[str, Any]]:
|
||||
r"""
|
||||
Instantiate a Python class from a config dictionary.
|
||||
|
||||
Parameters:
|
||||
config (`Dict[str, Any]`):
|
||||
config (`dict[str, Any]`):
|
||||
A config dictionary from which the Python class is instantiated. Make sure to only load configuration
|
||||
files of compatible classes.
|
||||
return_unused_kwargs (`bool`, *optional*, defaults to `False`):
|
||||
@@ -292,11 +292,11 @@ class ConfigMixin:
|
||||
@validate_hf_hub_args
|
||||
def load_config(
|
||||
cls,
|
||||
pretrained_model_name_or_path: Union[str, os.PathLike],
|
||||
pretrained_model_name_or_path: str | os.PathLike,
|
||||
return_unused_kwargs=False,
|
||||
return_commit_hash=False,
|
||||
**kwargs,
|
||||
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||
r"""
|
||||
Load a model or scheduler configuration.
|
||||
|
||||
@@ -315,7 +315,7 @@ class ConfigMixin:
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
proxies (`dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
output_loading_info(`bool`, *optional*, defaults to `False`):
|
||||
@@ -352,7 +352,7 @@ class ConfigMixin:
|
||||
_ = kwargs.pop("mirror", None)
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
user_agent = kwargs.pop("user_agent", {})
|
||||
dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None)
|
||||
dduf_entries: Optional[dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None)
|
||||
|
||||
user_agent = {**user_agent, "file_type": "config"}
|
||||
user_agent = http_user_agent(user_agent)
|
||||
@@ -563,9 +563,7 @@ class ConfigMixin:
|
||||
return init_dict, unused_kwargs, hidden_config_dict
|
||||
|
||||
@classmethod
|
||||
def _dict_from_json_file(
|
||||
cls, json_file: Union[str, os.PathLike], dduf_entries: Optional[Dict[str, DDUFEntry]] = None
|
||||
):
|
||||
def _dict_from_json_file(cls, json_file: str | os.PathLike, dduf_entries: Optional[dict[str, DDUFEntry]] = None):
|
||||
if dduf_entries:
|
||||
text = dduf_entries[json_file].read_text()
|
||||
else:
|
||||
@@ -577,12 +575,12 @@ class ConfigMixin:
|
||||
return f"{self.__class__.__name__} {self.to_json_string()}"
|
||||
|
||||
@property
|
||||
def config(self) -> Dict[str, Any]:
|
||||
def config(self) -> dict[str, Any]:
|
||||
"""
|
||||
Returns the config of the class as a frozen dictionary
|
||||
|
||||
Returns:
|
||||
`Dict[str, Any]`: Config of the class.
|
||||
`dict[str, Any]`: Config of the class.
|
||||
"""
|
||||
return self._internal_dict
|
||||
|
||||
@@ -625,7 +623,7 @@ class ConfigMixin:
|
||||
|
||||
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
|
||||
|
||||
def to_json_file(self, json_file_path: Union[str, os.PathLike]):
|
||||
def to_json_file(self, json_file_path: str | os.PathLike):
|
||||
"""
|
||||
Save the configuration instance's parameters to a JSON file.
|
||||
|
||||
@@ -637,7 +635,7 @@ class ConfigMixin:
|
||||
writer.write(self.to_json_string())
|
||||
|
||||
@classmethod
|
||||
def _get_config_file_from_dduf(cls, pretrained_model_name_or_path: str, dduf_entries: Dict[str, DDUFEntry]):
|
||||
def _get_config_file_from_dduf(cls, pretrained_model_name_or_path: str, dduf_entries: dict[str, DDUFEntry]):
|
||||
# paths inside a DDUF file must always be "/"
|
||||
config_file = (
|
||||
cls.config_name
|
||||
@@ -756,7 +754,7 @@ class LegacyConfigMixin(ConfigMixin):
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs):
|
||||
def from_config(cls, config: FrozenDict | dict[str, Any] = None, return_unused_kwargs=False, **kwargs):
|
||||
# To prevent dependency import problem.
|
||||
from .models.model_loading_utils import _fetch_remapped_cls_from_config
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ deps = {
|
||||
"pytest": "pytest",
|
||||
"pytest-timeout": "pytest-timeout",
|
||||
"pytest-xdist": "pytest-xdist",
|
||||
"python": "python>=3.8.0",
|
||||
"python": "python>=3.9.0",
|
||||
"ruff": "ruff==0.9.10",
|
||||
"safetensors": "safetensors>=0.3.1",
|
||||
"sentencepiece": "sentencepiece>=0.1.91,!=0.1.92",
|
||||
|
||||
@@ -12,8 +12,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -77,7 +79,7 @@ class AdaptiveProjectedGuidance(BaseGuidance):
|
||||
self.use_original_formulation = use_original_formulation
|
||||
self.momentum_buffer = None
|
||||
|
||||
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
|
||||
def prepare_inputs(self, data: dict[str, tuple[torch.Tensor, torch.Tensor]]) -> list["BlockState"]:
|
||||
if self._step == 0:
|
||||
if self.adaptive_projected_guidance_momentum is not None:
|
||||
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
|
||||
|
||||
@@ -12,8 +12,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -36,10 +38,10 @@ class AutoGuidance(BaseGuidance):
|
||||
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
|
||||
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
|
||||
deterioration of image quality.
|
||||
auto_guidance_layers (`int` or `List[int]`, *optional*):
|
||||
auto_guidance_layers (`int` or `list[int]`, *optional*):
|
||||
The layer indices to apply skip layer guidance to. Can be a single integer or a list of integers. If not
|
||||
provided, `skip_layer_config` must be provided.
|
||||
auto_guidance_config (`LayerSkipConfig` or `List[LayerSkipConfig]`, *optional*):
|
||||
auto_guidance_config (`LayerSkipConfig` or `list[LayerSkipConfig]`, *optional*):
|
||||
The configuration for the skip layer guidance. Can be a single `LayerSkipConfig` or a list of
|
||||
`LayerSkipConfig`. If not provided, `skip_layer_guidance_layers` must be provided.
|
||||
dropout (`float`, *optional*):
|
||||
@@ -65,8 +67,8 @@ class AutoGuidance(BaseGuidance):
|
||||
def __init__(
|
||||
self,
|
||||
guidance_scale: float = 7.5,
|
||||
auto_guidance_layers: Optional[Union[int, List[int]]] = None,
|
||||
auto_guidance_config: Union[LayerSkipConfig, List[LayerSkipConfig], Dict[str, Any]] = None,
|
||||
auto_guidance_layers: Optional[int | list[int]] = None,
|
||||
auto_guidance_config: LayerSkipConfig | list[LayerSkipConfig] | dict[str, Any] = None,
|
||||
dropout: Optional[float] = None,
|
||||
guidance_rescale: float = 0.0,
|
||||
use_original_formulation: bool = False,
|
||||
@@ -133,7 +135,7 @@ class AutoGuidance(BaseGuidance):
|
||||
registry = HookRegistry.check_if_exists_or_initialize(denoiser)
|
||||
registry.remove_hook(name, recurse=True)
|
||||
|
||||
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
|
||||
def prepare_inputs(self, data: dict[str, tuple[torch.Tensor, torch.Tensor]]) -> list["BlockState"]:
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
|
||||
|
||||
@@ -12,8 +12,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -91,7 +93,7 @@ class ClassifierFreeGuidance(BaseGuidance):
|
||||
self.guidance_rescale = guidance_rescale
|
||||
self.use_original_formulation = use_original_formulation
|
||||
|
||||
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
|
||||
def prepare_inputs(self, data: dict[str, tuple[torch.Tensor, torch.Tensor]]) -> list["BlockState"]:
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
|
||||
|
||||
@@ -12,8 +12,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -77,7 +79,7 @@ class ClassifierFreeZeroStarGuidance(BaseGuidance):
|
||||
self.guidance_rescale = guidance_rescale
|
||||
self.use_original_formulation = use_original_formulation
|
||||
|
||||
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
|
||||
def prepare_inputs(self, data: dict[str, tuple[torch.Tensor, torch.Tensor]]) -> list["BlockState"]:
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
|
||||
|
||||
@@ -12,8 +12,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -37,7 +39,7 @@ else:
|
||||
build_laplacian_pyramid_func = None
|
||||
|
||||
|
||||
def project(v0: torch.Tensor, v1: torch.Tensor, upcast_to_double: bool = True) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
def project(v0: torch.Tensor, v1: torch.Tensor, upcast_to_double: bool = True) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Project vector v0 onto vector v1, returning the parallel and orthogonal components of v0. Implementation from paper
|
||||
(Algorithm 2).
|
||||
@@ -58,7 +60,7 @@ def project(v0: torch.Tensor, v1: torch.Tensor, upcast_to_double: bool = True) -
|
||||
return v0_parallel, v0_orthogonal
|
||||
|
||||
|
||||
def build_image_from_pyramid(pyramid: List[torch.Tensor]) -> torch.Tensor:
|
||||
def build_image_from_pyramid(pyramid: list[torch.Tensor]) -> torch.Tensor:
|
||||
"""
|
||||
Recovers the data space latents from the Laplacian pyramid frequency space. Implementation from the paper
|
||||
(Algorithm 2).
|
||||
@@ -99,19 +101,19 @@ class FrequencyDecoupledGuidance(BaseGuidance):
|
||||
paper. By default, we use the diffusers-native implementation that has been in the codebase for a long time.
|
||||
|
||||
Args:
|
||||
guidance_scales (`List[float]`, defaults to `[10.0, 5.0]`):
|
||||
guidance_scales (`list[float]`, defaults to `[10.0, 5.0]`):
|
||||
The scale parameter for frequency-decoupled guidance for each frequency component, listed from highest
|
||||
frequency level to lowest. Higher values result in stronger conditioning on the text prompt, while lower
|
||||
values allow for more freedom in generation. Higher values may lead to saturation and deterioration of
|
||||
image quality. The FDG authors recommend using higher guidance scales for higher frequency components and
|
||||
lower guidance scales for lower frequency components (so `guidance_scales` should typically be sorted in
|
||||
descending order).
|
||||
guidance_rescale (`float` or `List[float]`, defaults to `0.0`):
|
||||
guidance_rescale (`float` or `list[float]`, defaults to `0.0`):
|
||||
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
|
||||
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://huggingface.co/papers/2305.08891). If a list is supplied, it should be the same length as
|
||||
`guidance_scales`.
|
||||
parallel_weights (`float` or `List[float]`, *optional*):
|
||||
parallel_weights (`float` or `list[float]`, *optional*):
|
||||
Optional weights for the parallel component of each frequency component of the projected CFG shift. If not
|
||||
set, the weights will default to `1.0` for all components, which corresponds to using the normal CFG shift
|
||||
(that is, equal weights for the parallel and orthogonal components). If set, a value in `[0, 1]` is
|
||||
@@ -120,10 +122,10 @@ class FrequencyDecoupledGuidance(BaseGuidance):
|
||||
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
|
||||
we use the diffusers-native implementation that has been in the codebase for a long time. See
|
||||
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
|
||||
start (`float` or `List[float]`, defaults to `0.0`):
|
||||
start (`float` or `list[float]`, defaults to `0.0`):
|
||||
The fraction of the total number of denoising steps after which guidance starts. If a list is supplied, it
|
||||
should be the same length as `guidance_scales`.
|
||||
stop (`float` or `List[float]`, defaults to `1.0`):
|
||||
stop (`float` or `list[float]`, defaults to `1.0`):
|
||||
The fraction of the total number of denoising steps after which guidance stops. If a list is supplied, it
|
||||
should be the same length as `guidance_scales`.
|
||||
guidance_rescale_space (`str`, defaults to `"data"`):
|
||||
@@ -141,12 +143,12 @@ class FrequencyDecoupledGuidance(BaseGuidance):
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
guidance_scales: Union[List[float], Tuple[float]] = [10.0, 5.0],
|
||||
guidance_rescale: Union[float, List[float], Tuple[float]] = 0.0,
|
||||
parallel_weights: Optional[Union[float, List[float], Tuple[float]]] = None,
|
||||
guidance_scales: list[float] | tuple[float] = [10.0, 5.0],
|
||||
guidance_rescale: float | list[float] | tuple[float] = 0.0,
|
||||
parallel_weights: Optional[float | list[float] | tuple[float]] = None,
|
||||
use_original_formulation: bool = False,
|
||||
start: Union[float, List[float], Tuple[float]] = 0.0,
|
||||
stop: Union[float, List[float], Tuple[float]] = 1.0,
|
||||
start: float | list[float] | tuple[float] = 0.0,
|
||||
stop: float | list[float] | tuple[float] = 1.0,
|
||||
guidance_rescale_space: str = "data",
|
||||
upcast_to_double: bool = True,
|
||||
enabled: bool = True,
|
||||
@@ -218,7 +220,7 @@ class FrequencyDecoupledGuidance(BaseGuidance):
|
||||
f"({len(self.guidance_scales)})"
|
||||
)
|
||||
|
||||
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
|
||||
def prepare_inputs(self, data: dict[str, tuple[torch.Tensor, torch.Tensor]]) -> list["BlockState"]:
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
|
||||
|
||||
@@ -12,8 +12,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import torch
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
@@ -51,8 +53,8 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
|
||||
self._num_inference_steps: int = None
|
||||
self._timestep: torch.LongTensor = None
|
||||
self._count_prepared = 0
|
||||
self._input_fields: Dict[str, Union[str, Tuple[str, str]]] = None
|
||||
self._enabled = enabled
|
||||
self._input_fields: dict[str, str | tuple[str, str]] = None
|
||||
self._enabled = True
|
||||
|
||||
if not (0.0 <= start < 1.0):
|
||||
raise ValueError(f"Expected `start` to be between 0.0 and 1.0, but got {start}.")
|
||||
@@ -101,7 +103,7 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
|
||||
self._timestep = timestep
|
||||
self._count_prepared = 0
|
||||
|
||||
def get_state(self) -> Dict[str, Any]:
|
||||
def get_state(self) -> dict[str, Any]:
|
||||
"""
|
||||
Returns the current state of the guidance technique as a dictionary. The state variables will be included in
|
||||
the __repr__ method. Returns:
|
||||
@@ -163,10 +165,10 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
|
||||
"""
|
||||
pass
|
||||
|
||||
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
|
||||
def prepare_inputs(self, data: "BlockState") -> list["BlockState"]:
|
||||
raise NotImplementedError("BaseGuidance::prepare_inputs must be implemented in subclasses.")
|
||||
|
||||
def __call__(self, data: List["BlockState"]) -> Any:
|
||||
def __call__(self, data: list["BlockState"]) -> Any:
|
||||
if not all(hasattr(d, "noise_pred") for d in data):
|
||||
raise ValueError("Expected all data to have `noise_pred` attribute.")
|
||||
if len(data) != self.num_conditions:
|
||||
@@ -194,7 +196,7 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
|
||||
@classmethod
|
||||
def _prepare_batch(
|
||||
cls,
|
||||
data: Dict[str, Tuple[torch.Tensor, torch.Tensor]],
|
||||
data: dict[str, tuple[torch.Tensor, torch.Tensor]],
|
||||
tuple_index: int,
|
||||
identifier: str,
|
||||
) -> "BlockState":
|
||||
@@ -203,7 +205,7 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
|
||||
`BaseGuidance` class. It prepares the batch based on the provided tuple index.
|
||||
|
||||
Args:
|
||||
input_fields (`Dict[str, Union[str, Tuple[str, str]]]`):
|
||||
input_fields (`dict[str, Union[str, tuple[str, str]]]`):
|
||||
A dictionary where the keys are the names of the fields that will be used to store the data once it is
|
||||
prepared with `prepare_inputs`. The values can be either a string or a tuple of length 2, which is used
|
||||
to look up the required data provided for preparation. If a string is provided, it will be used as the
|
||||
@@ -238,7 +240,7 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
|
||||
@validate_hf_hub_args
|
||||
def from_pretrained(
|
||||
cls,
|
||||
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
|
||||
pretrained_model_name_or_path: Optional[str | os.PathLike] = None,
|
||||
subfolder: Optional[str] = None,
|
||||
return_unused_kwargs=False,
|
||||
**kwargs,
|
||||
@@ -265,7 +267,7 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
proxies (`dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
output_loading_info(`bool`, *optional*, defaults to `False`):
|
||||
@@ -295,7 +297,7 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
|
||||
)
|
||||
return cls.from_config(config, return_unused_kwargs=return_unused_kwargs, **kwargs)
|
||||
|
||||
def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
|
||||
def save_pretrained(self, save_directory: str | os.PathLike, push_to_hub: bool = False, **kwargs):
|
||||
"""
|
||||
Save a guider configuration object to a directory so that it can be reloaded using the
|
||||
[`~BaseGuidance.from_pretrained`] class method.
|
||||
@@ -307,7 +309,7 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
|
||||
Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
|
||||
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
|
||||
namespace).
|
||||
kwargs (`Dict[str, Any]`, *optional*):
|
||||
kwargs (`dict[str, Any]`, *optional*):
|
||||
Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
||||
"""
|
||||
self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
|
||||
|
||||
@@ -12,8 +12,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -58,10 +60,10 @@ class PerturbedAttentionGuidance(BaseGuidance):
|
||||
The fraction of the total number of denoising steps after which perturbed attention guidance starts.
|
||||
perturbed_guidance_stop (`float`, defaults to `0.2`):
|
||||
The fraction of the total number of denoising steps after which perturbed attention guidance stops.
|
||||
perturbed_guidance_layers (`int` or `List[int]`, *optional*):
|
||||
perturbed_guidance_layers (`int` or `list[int]`, *optional*):
|
||||
The layer indices to apply perturbed attention guidance to. Can be a single integer or a list of integers.
|
||||
If not provided, `perturbed_guidance_config` must be provided.
|
||||
perturbed_guidance_config (`LayerSkipConfig` or `List[LayerSkipConfig]`, *optional*):
|
||||
perturbed_guidance_config (`LayerSkipConfig` or `list[LayerSkipConfig]`, *optional*):
|
||||
The configuration for the perturbed attention guidance. Can be a single `LayerSkipConfig` or a list of
|
||||
`LayerSkipConfig`. If not provided, `perturbed_guidance_layers` must be provided.
|
||||
guidance_rescale (`float`, defaults to `0.0`):
|
||||
@@ -92,8 +94,8 @@ class PerturbedAttentionGuidance(BaseGuidance):
|
||||
perturbed_guidance_scale: float = 2.8,
|
||||
perturbed_guidance_start: float = 0.01,
|
||||
perturbed_guidance_stop: float = 0.2,
|
||||
perturbed_guidance_layers: Optional[Union[int, List[int]]] = None,
|
||||
perturbed_guidance_config: Union[LayerSkipConfig, List[LayerSkipConfig], Dict[str, Any]] = None,
|
||||
perturbed_guidance_layers: Optional[int | list[int]] = None,
|
||||
perturbed_guidance_config: LayerSkipConfig | list[LayerSkipConfig] | dict[str, Any] = None,
|
||||
guidance_rescale: float = 0.0,
|
||||
use_original_formulation: bool = False,
|
||||
start: float = 0.0,
|
||||
@@ -169,7 +171,7 @@ class PerturbedAttentionGuidance(BaseGuidance):
|
||||
registry.remove_hook(hook_name, recurse=True)
|
||||
|
||||
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.prepare_inputs
|
||||
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
|
||||
def prepare_inputs(self, data: dict[str, tuple[torch.Tensor, torch.Tensor]]) -> list["BlockState"]:
|
||||
if self.num_conditions == 1:
|
||||
tuple_indices = [0]
|
||||
input_predictions = ["pred_cond"]
|
||||
|
||||
@@ -12,8 +12,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -64,11 +66,11 @@ class SkipLayerGuidance(BaseGuidance):
|
||||
The fraction of the total number of denoising steps after which skip layer guidance starts.
|
||||
skip_layer_guidance_stop (`float`, defaults to `0.2`):
|
||||
The fraction of the total number of denoising steps after which skip layer guidance stops.
|
||||
skip_layer_guidance_layers (`int` or `List[int]`, *optional*):
|
||||
skip_layer_guidance_layers (`int` or `list[int]`, *optional*):
|
||||
The layer indices to apply skip layer guidance to. Can be a single integer or a list of integers. If not
|
||||
provided, `skip_layer_config` must be provided. The recommended values are `[7, 8, 9]` for Stable Diffusion
|
||||
3.5 Medium.
|
||||
skip_layer_config (`LayerSkipConfig` or `List[LayerSkipConfig]`, *optional*):
|
||||
skip_layer_config (`LayerSkipConfig` or `list[LayerSkipConfig]`, *optional*):
|
||||
The configuration for the skip layer guidance. Can be a single `LayerSkipConfig` or a list of
|
||||
`LayerSkipConfig`. If not provided, `skip_layer_guidance_layers` must be provided.
|
||||
guidance_rescale (`float`, defaults to `0.0`):
|
||||
@@ -94,8 +96,8 @@ class SkipLayerGuidance(BaseGuidance):
|
||||
skip_layer_guidance_scale: float = 2.8,
|
||||
skip_layer_guidance_start: float = 0.01,
|
||||
skip_layer_guidance_stop: float = 0.2,
|
||||
skip_layer_guidance_layers: Optional[Union[int, List[int]]] = None,
|
||||
skip_layer_config: Union[LayerSkipConfig, List[LayerSkipConfig], Dict[str, Any]] = None,
|
||||
skip_layer_guidance_layers: Optional[int | list[int]] = None,
|
||||
skip_layer_config: LayerSkipConfig | list[LayerSkipConfig] | dict[str, Any] = None,
|
||||
guidance_rescale: float = 0.0,
|
||||
use_original_formulation: bool = False,
|
||||
start: float = 0.0,
|
||||
@@ -165,7 +167,7 @@ class SkipLayerGuidance(BaseGuidance):
|
||||
for hook_name in self._skip_layer_hook_names:
|
||||
registry.remove_hook(hook_name, recurse=True)
|
||||
|
||||
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
|
||||
def prepare_inputs(self, data: dict[str, tuple[torch.Tensor, torch.Tensor]]) -> list["BlockState"]:
|
||||
if self.num_conditions == 1:
|
||||
tuple_indices = [0]
|
||||
input_predictions = ["pred_cond"]
|
||||
|
||||
@@ -12,8 +12,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -54,11 +56,11 @@ class SmoothedEnergyGuidance(BaseGuidance):
|
||||
The fraction of the total number of denoising steps after which smoothed energy guidance starts.
|
||||
seg_guidance_stop (`float`, defaults to `1.0`):
|
||||
The fraction of the total number of denoising steps after which smoothed energy guidance stops.
|
||||
seg_guidance_layers (`int` or `List[int]`, *optional*):
|
||||
seg_guidance_layers (`int` or `list[int]`, *optional*):
|
||||
The layer indices to apply smoothed energy guidance to. Can be a single integer or a list of integers. If
|
||||
not provided, `seg_guidance_config` must be provided. The recommended values are `[7, 8, 9]` for Stable
|
||||
Diffusion 3.5 Medium.
|
||||
seg_guidance_config (`SmoothedEnergyGuidanceConfig` or `List[SmoothedEnergyGuidanceConfig]`, *optional*):
|
||||
seg_guidance_config (`SmoothedEnergyGuidanceConfig` or `list[SmoothedEnergyGuidanceConfig]`, *optional*):
|
||||
The configuration for the smoothed energy layer guidance. Can be a single `SmoothedEnergyGuidanceConfig` or
|
||||
a list of `SmoothedEnergyGuidanceConfig`. If not provided, `seg_guidance_layers` must be provided.
|
||||
guidance_rescale (`float`, defaults to `0.0`):
|
||||
@@ -86,8 +88,8 @@ class SmoothedEnergyGuidance(BaseGuidance):
|
||||
seg_blur_threshold_inf: float = 9999.0,
|
||||
seg_guidance_start: float = 0.0,
|
||||
seg_guidance_stop: float = 1.0,
|
||||
seg_guidance_layers: Optional[Union[int, List[int]]] = None,
|
||||
seg_guidance_config: Union[SmoothedEnergyGuidanceConfig, List[SmoothedEnergyGuidanceConfig]] = None,
|
||||
seg_guidance_layers: Optional[int | list[int]] = None,
|
||||
seg_guidance_config: SmoothedEnergyGuidanceConfig | list[SmoothedEnergyGuidanceConfig] = None,
|
||||
guidance_rescale: float = 0.0,
|
||||
use_original_formulation: bool = False,
|
||||
start: float = 0.0,
|
||||
@@ -154,7 +156,7 @@ class SmoothedEnergyGuidance(BaseGuidance):
|
||||
for hook_name in self._seg_layer_hook_names:
|
||||
registry.remove_hook(hook_name, recurse=True)
|
||||
|
||||
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
|
||||
def prepare_inputs(self, data: dict[str, tuple[torch.Tensor, torch.Tensor]]) -> list["BlockState"]:
|
||||
if self.num_conditions == 1:
|
||||
tuple_indices = [0]
|
||||
input_predictions = ["pred_cond"]
|
||||
|
||||
@@ -12,8 +12,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -66,7 +68,7 @@ class TangentialClassifierFreeGuidance(BaseGuidance):
|
||||
self.guidance_rescale = guidance_rescale
|
||||
self.use_original_formulation = use_original_formulation
|
||||
|
||||
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
|
||||
def prepare_inputs(self, data: dict[str, tuple[torch.Tensor, torch.Tensor]]) -> list["BlockState"]:
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
|
||||
import inspect
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Dict, Type
|
||||
from typing import Any, Callable, Type
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -28,7 +28,7 @@ class TransformerBlockMetadata:
|
||||
return_encoder_hidden_states_index: int = None
|
||||
|
||||
_cls: Type = None
|
||||
_cached_parameter_indices: Dict[str, int] = None
|
||||
_cached_parameter_indices: dict[str, int] = None
|
||||
|
||||
def _get_parameter_from_args_kwargs(self, identifier: str, args=(), kwargs=None):
|
||||
kwargs = kwargs or {}
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
|
||||
import inspect
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Type, Union
|
||||
from typing import Type
|
||||
|
||||
import torch
|
||||
|
||||
@@ -42,7 +42,7 @@ _CONTEXT_PARALLEL_OUTPUT_HOOK_TEMPLATE = "cp_output---{}"
|
||||
# TODO(aryan): consolidate with ._helpers.TransformerBlockMetadata
|
||||
@dataclass
|
||||
class ModuleForwardMetadata:
|
||||
cached_parameter_indices: Dict[str, int] = None
|
||||
cached_parameter_indices: dict[str, int] = None
|
||||
_cls: Type = None
|
||||
|
||||
def _get_parameter_from_args_kwargs(self, identifier: str, args=(), kwargs=None):
|
||||
@@ -78,7 +78,7 @@ class ModuleForwardMetadata:
|
||||
def apply_context_parallel(
|
||||
module: torch.nn.Module,
|
||||
parallel_config: ContextParallelConfig,
|
||||
plan: Dict[str, ContextParallelModelPlan],
|
||||
plan: dict[str, ContextParallelModelPlan],
|
||||
) -> None:
|
||||
"""Apply context parallel on a model."""
|
||||
logger.debug(f"Applying context parallel with CP mesh: {parallel_config._mesh} and plan: {plan}")
|
||||
@@ -107,7 +107,7 @@ def apply_context_parallel(
|
||||
registry.register_hook(hook, hook_name)
|
||||
|
||||
|
||||
def remove_context_parallel(module: torch.nn.Module, plan: Dict[str, ContextParallelModelPlan]) -> None:
|
||||
def remove_context_parallel(module: torch.nn.Module, plan: dict[str, ContextParallelModelPlan]) -> None:
|
||||
for module_id, cp_model_plan in plan.items():
|
||||
submodule = _get_submodule_by_name(module, module_id)
|
||||
if not isinstance(submodule, list):
|
||||
@@ -272,13 +272,13 @@ class EquipartitionSharder:
|
||||
return tensor
|
||||
|
||||
|
||||
def _get_submodule_by_name(model: torch.nn.Module, name: str) -> Union[torch.nn.Module, List[torch.nn.Module]]:
|
||||
def _get_submodule_by_name(model: torch.nn.Module, name: str) -> torch.nn.Module | list[torch.nn.Module]:
|
||||
if name.count("*") > 1:
|
||||
raise ValueError("Wildcard '*' can only be used once in the name")
|
||||
return _find_submodule_by_name(model, name)
|
||||
|
||||
|
||||
def _find_submodule_by_name(model: torch.nn.Module, name: str) -> Union[torch.nn.Module, List[torch.nn.Module]]:
|
||||
def _find_submodule_by_name(model: torch.nn.Module, name: str) -> torch.nn.Module | list[torch.nn.Module]:
|
||||
if name == "":
|
||||
return model
|
||||
first_atom, remaining_name = name.split(".", 1) if "." in name else (name, "")
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, List, Optional, Tuple
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -60,7 +60,7 @@ class FasterCacheConfig:
|
||||
Calculate the attention states every `N` iterations. If this is set to `N`, the attention computation will
|
||||
be skipped `N - 1` times (i.e., cached attention states will be reused) before computing the new attention
|
||||
states again.
|
||||
spatial_attention_timestep_skip_range (`Tuple[float, float]`, defaults to `(-1, 681)`):
|
||||
spatial_attention_timestep_skip_range (`tuple[float, float]`, defaults to `(-1, 681)`):
|
||||
The timestep range within which the spatial attention computation can be skipped without a significant loss
|
||||
in quality. This is to be determined by the user based on the underlying model. The first value in the
|
||||
tuple is the lower bound and the second value is the upper bound. Typically, diffusion timesteps for
|
||||
@@ -68,17 +68,17 @@ class FasterCacheConfig:
|
||||
timestep 0). For the default values, this would mean that the spatial attention computation skipping will
|
||||
be applicable only after denoising timestep 681 is reached, and continue until the end of the denoising
|
||||
process.
|
||||
temporal_attention_timestep_skip_range (`Tuple[float, float]`, *optional*, defaults to `None`):
|
||||
temporal_attention_timestep_skip_range (`tuple[float, float]`, *optional*, defaults to `None`):
|
||||
The timestep range within which the temporal attention computation can be skipped without a significant
|
||||
loss in quality. This is to be determined by the user based on the underlying model. The first value in the
|
||||
tuple is the lower bound and the second value is the upper bound. Typically, diffusion timesteps for
|
||||
denoising are in the reversed range of 0 to 1000 (i.e. denoising starts at timestep 1000 and ends at
|
||||
timestep 0).
|
||||
low_frequency_weight_update_timestep_range (`Tuple[int, int]`, defaults to `(99, 901)`):
|
||||
low_frequency_weight_update_timestep_range (`tuple[int, int]`, defaults to `(99, 901)`):
|
||||
The timestep range within which the low frequency weight scaling update is applied. The first value in the
|
||||
tuple is the lower bound and the second value is the upper bound of the timestep range. The callback
|
||||
function for the update is called only within this range.
|
||||
high_frequency_weight_update_timestep_range (`Tuple[int, int]`, defaults to `(-1, 301)`):
|
||||
high_frequency_weight_update_timestep_range (`tuple[int, int]`, defaults to `(-1, 301)`):
|
||||
The timestep range within which the high frequency weight scaling update is applied. The first value in the
|
||||
tuple is the lower bound and the second value is the upper bound of the timestep range. The callback
|
||||
function for the update is called only within this range.
|
||||
@@ -92,15 +92,15 @@ class FasterCacheConfig:
|
||||
Process the unconditional branch every `N` iterations. If this is set to `N`, the unconditional branch
|
||||
computation will be skipped `N - 1` times (i.e., cached unconditional branch states will be reused) before
|
||||
computing the new unconditional branch states again.
|
||||
unconditional_batch_timestep_skip_range (`Tuple[float, float]`, defaults to `(-1, 641)`):
|
||||
unconditional_batch_timestep_skip_range (`tuple[float, float]`, defaults to `(-1, 641)`):
|
||||
The timestep range within which the unconditional branch computation can be skipped without a significant
|
||||
loss in quality. This is to be determined by the user based on the underlying model. The first value in the
|
||||
tuple is the lower bound and the second value is the upper bound.
|
||||
spatial_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("blocks.*attn1", "transformer_blocks.*attn1", "single_transformer_blocks.*attn1")`):
|
||||
spatial_attention_block_identifiers (`tuple[str, ...]`, defaults to `("blocks.*attn1", "transformer_blocks.*attn1", "single_transformer_blocks.*attn1")`):
|
||||
The identifiers to match the spatial attention blocks in the model. If the name of the block contains any
|
||||
of these identifiers, FasterCache will be applied to that block. This can either be the full layer names,
|
||||
partial layer names, or regex patterns. Matching will always be done using a regex match.
|
||||
temporal_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("temporal_transformer_blocks.*attn1",)`):
|
||||
temporal_attention_block_identifiers (`tuple[str, ...]`, defaults to `("temporal_transformer_blocks.*attn1",)`):
|
||||
The identifiers to match the temporal attention blocks in the model. If the name of the block contains any
|
||||
of these identifiers, FasterCache will be applied to that block. This can either be the full layer names,
|
||||
partial layer names, or regex patterns. Matching will always be done using a regex match.
|
||||
@@ -123,7 +123,7 @@ class FasterCacheConfig:
|
||||
is_guidance_distilled (`bool`, defaults to `False`):
|
||||
Whether the model is guidance distilled or not. If the model is guidance distilled, FasterCache will not be
|
||||
applied at the denoiser-level to skip the unconditional branch computation (as there is none).
|
||||
_unconditional_conditional_input_kwargs_identifiers (`List[str]`, defaults to `("hidden_states", "encoder_hidden_states", "timestep", "attention_mask", "encoder_attention_mask")`):
|
||||
_unconditional_conditional_input_kwargs_identifiers (`list[str]`, defaults to `("hidden_states", "encoder_hidden_states", "timestep", "attention_mask", "encoder_attention_mask")`):
|
||||
The identifiers to match the input kwargs that contain the batchwise-concatenated unconditional and
|
||||
conditional inputs. If the name of the input kwargs contains any of these identifiers, FasterCache will
|
||||
split the inputs into unconditional and conditional branches. This must be a list of exact input kwargs
|
||||
@@ -135,12 +135,12 @@ class FasterCacheConfig:
|
||||
spatial_attention_block_skip_range: int = 2
|
||||
temporal_attention_block_skip_range: Optional[int] = None
|
||||
|
||||
spatial_attention_timestep_skip_range: Tuple[int, int] = (-1, 681)
|
||||
temporal_attention_timestep_skip_range: Tuple[int, int] = (-1, 681)
|
||||
spatial_attention_timestep_skip_range: tuple[int, int] = (-1, 681)
|
||||
temporal_attention_timestep_skip_range: tuple[int, int] = (-1, 681)
|
||||
|
||||
# Indicator functions for low/high frequency as mentioned in Equation 11 of the paper
|
||||
low_frequency_weight_update_timestep_range: Tuple[int, int] = (99, 901)
|
||||
high_frequency_weight_update_timestep_range: Tuple[int, int] = (-1, 301)
|
||||
low_frequency_weight_update_timestep_range: tuple[int, int] = (99, 901)
|
||||
high_frequency_weight_update_timestep_range: tuple[int, int] = (-1, 301)
|
||||
|
||||
# ⍺1 and ⍺2 as mentioned in Equation 11 of the paper
|
||||
alpha_low_frequency: float = 1.1
|
||||
@@ -148,10 +148,10 @@ class FasterCacheConfig:
|
||||
|
||||
# n as described in CFG-Cache explanation in the paper - dependent on the model
|
||||
unconditional_batch_skip_range: int = 5
|
||||
unconditional_batch_timestep_skip_range: Tuple[int, int] = (-1, 641)
|
||||
unconditional_batch_timestep_skip_range: tuple[int, int] = (-1, 641)
|
||||
|
||||
spatial_attention_block_identifiers: Tuple[str, ...] = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS
|
||||
temporal_attention_block_identifiers: Tuple[str, ...] = _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS
|
||||
spatial_attention_block_identifiers: tuple[str, ...] = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS
|
||||
temporal_attention_block_identifiers: tuple[str, ...] = _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS
|
||||
|
||||
attention_weight_callback: Callable[[torch.nn.Module], float] = None
|
||||
low_frequency_weight_callback: Callable[[torch.nn.Module], float] = None
|
||||
@@ -162,7 +162,7 @@ class FasterCacheConfig:
|
||||
|
||||
current_timestep_callback: Callable[[], int] = None
|
||||
|
||||
_unconditional_conditional_input_kwargs_identifiers: List[str] = _UNCOND_COND_INPUT_KWARGS_IDENTIFIERS
|
||||
_unconditional_conditional_input_kwargs_identifiers: list[str] = _UNCOND_COND_INPUT_KWARGS_IDENTIFIERS
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
@@ -209,7 +209,7 @@ class FasterCacheBlockState:
|
||||
def __init__(self) -> None:
|
||||
self.iteration: int = 0
|
||||
self.batch_size: int = None
|
||||
self.cache: Tuple[torch.Tensor, torch.Tensor] = None
|
||||
self.cache: tuple[torch.Tensor, torch.Tensor] = None
|
||||
|
||||
def reset(self):
|
||||
self.iteration = 0
|
||||
@@ -223,10 +223,10 @@ class FasterCacheDenoiserHook(ModelHook):
|
||||
def __init__(
|
||||
self,
|
||||
unconditional_batch_skip_range: int,
|
||||
unconditional_batch_timestep_skip_range: Tuple[int, int],
|
||||
unconditional_batch_timestep_skip_range: tuple[int, int],
|
||||
tensor_format: str,
|
||||
is_guidance_distilled: bool,
|
||||
uncond_cond_input_kwargs_identifiers: List[str],
|
||||
uncond_cond_input_kwargs_identifiers: list[str],
|
||||
current_timestep_callback: Callable[[], int],
|
||||
low_frequency_weight_callback: Callable[[torch.nn.Module], torch.Tensor],
|
||||
high_frequency_weight_callback: Callable[[torch.nn.Module], torch.Tensor],
|
||||
@@ -252,7 +252,7 @@ class FasterCacheDenoiserHook(ModelHook):
|
||||
return module
|
||||
|
||||
@staticmethod
|
||||
def _get_cond_input(input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
def _get_cond_input(input: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# Note: this method assumes that the input tensor is batchwise-concatenated with unconditional inputs
|
||||
# followed by conditional inputs.
|
||||
_, cond = input.chunk(2, dim=0)
|
||||
@@ -371,7 +371,7 @@ class FasterCacheBlockHook(ModelHook):
|
||||
def __init__(
|
||||
self,
|
||||
block_skip_range: int,
|
||||
timestep_skip_range: Tuple[int, int],
|
||||
timestep_skip_range: tuple[int, int],
|
||||
is_guidance_distilled: bool,
|
||||
weight_callback: Callable[[torch.nn.Module], float],
|
||||
current_timestep_callback: Callable[[], int],
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -53,9 +52,9 @@ class FBCSharedBlockState(BaseState):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.head_block_output: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None
|
||||
self.head_block_output: torch.Tensor | tuple[torch.Tensor, ...] = None
|
||||
self.head_block_residual: torch.Tensor = None
|
||||
self.tail_block_residuals: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None
|
||||
self.tail_block_residuals: torch.Tensor | tuple[torch.Tensor, ...] = None
|
||||
self.should_compute: bool = True
|
||||
|
||||
def reset(self):
|
||||
|
||||
@@ -17,7 +17,7 @@ import os
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Optional, Set, Tuple, Union
|
||||
from typing import Optional, Set
|
||||
|
||||
import safetensors.torch
|
||||
import torch
|
||||
@@ -58,21 +58,21 @@ class GroupOffloadingConfig:
|
||||
low_cpu_mem_usage: bool
|
||||
num_blocks_per_group: Optional[int] = None
|
||||
offload_to_disk_path: Optional[str] = None
|
||||
stream: Optional[Union[torch.cuda.Stream, torch.Stream]] = None
|
||||
stream: Optional[torch.cuda.Stream | torch.Stream] = None
|
||||
|
||||
|
||||
class ModuleGroup:
|
||||
def __init__(
|
||||
self,
|
||||
modules: List[torch.nn.Module],
|
||||
modules: list[torch.nn.Module],
|
||||
offload_device: torch.device,
|
||||
onload_device: torch.device,
|
||||
offload_leader: torch.nn.Module,
|
||||
onload_leader: Optional[torch.nn.Module] = None,
|
||||
parameters: Optional[List[torch.nn.Parameter]] = None,
|
||||
buffers: Optional[List[torch.Tensor]] = None,
|
||||
parameters: Optional[list[torch.nn.Parameter]] = None,
|
||||
buffers: Optional[list[torch.Tensor]] = None,
|
||||
non_blocking: bool = False,
|
||||
stream: Union[torch.cuda.Stream, torch.Stream, None] = None,
|
||||
stream: torch.cuda.Stream | torch.Stream | None = None,
|
||||
record_stream: Optional[bool] = False,
|
||||
low_cpu_mem_usage: bool = False,
|
||||
onload_self: bool = True,
|
||||
@@ -340,7 +340,7 @@ class LazyPrefetchGroupOffloadingHook(ModelHook):
|
||||
_is_stateful = False
|
||||
|
||||
def __init__(self):
|
||||
self.execution_order: List[Tuple[str, torch.nn.Module]] = []
|
||||
self.execution_order: list[tuple[str, torch.nn.Module]] = []
|
||||
self._layer_execution_tracker_module_names = set()
|
||||
|
||||
def initialize_hook(self, module):
|
||||
@@ -444,9 +444,9 @@ class LayerExecutionTrackerHook(ModelHook):
|
||||
|
||||
def apply_group_offloading(
|
||||
module: torch.nn.Module,
|
||||
onload_device: Union[str, torch.device],
|
||||
offload_device: Union[str, torch.device] = torch.device("cpu"),
|
||||
offload_type: Union[str, GroupOffloadingType] = "block_level",
|
||||
onload_device: str | torch.device,
|
||||
offload_device: str | torch.device = torch.device("cpu"),
|
||||
offload_type: str | GroupOffloadingType = "block_level",
|
||||
num_blocks_per_group: Optional[int] = None,
|
||||
non_blocking: bool = False,
|
||||
use_stream: bool = False,
|
||||
@@ -787,7 +787,7 @@ def _apply_lazy_group_offloading_hook(
|
||||
|
||||
def _gather_parameters_with_no_group_offloading_parent(
|
||||
module: torch.nn.Module, modules_with_group_offloading: Set[str]
|
||||
) -> List[torch.nn.Parameter]:
|
||||
) -> list[torch.nn.Parameter]:
|
||||
parameters = []
|
||||
for name, parameter in module.named_parameters():
|
||||
has_parent_with_group_offloading = False
|
||||
@@ -805,7 +805,7 @@ def _gather_parameters_with_no_group_offloading_parent(
|
||||
|
||||
def _gather_buffers_with_no_group_offloading_parent(
|
||||
module: torch.nn.Module, modules_with_group_offloading: Set[str]
|
||||
) -> List[torch.Tensor]:
|
||||
) -> list[torch.Tensor]:
|
||||
buffers = []
|
||||
for name, buffer in module.named_buffers():
|
||||
has_parent_with_group_offloading = False
|
||||
@@ -821,7 +821,7 @@ def _gather_buffers_with_no_group_offloading_parent(
|
||||
return buffers
|
||||
|
||||
|
||||
def _find_parent_module_in_module_dict(name: str, module_dict: Dict[str, torch.nn.Module]) -> str:
|
||||
def _find_parent_module_in_module_dict(name: str, module_dict: dict[str, torch.nn.Module]) -> str:
|
||||
atoms = name.split(".")
|
||||
while len(atoms) > 0:
|
||||
parent_name = ".".join(atoms)
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import functools
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -86,19 +86,19 @@ class ModelHook:
|
||||
"""
|
||||
return module
|
||||
|
||||
def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[Any], Dict[str, Any]]:
|
||||
def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> tuple[tuple[Any], dict[str, Any]]:
|
||||
r"""
|
||||
Hook that is executed just before the forward method of the model.
|
||||
|
||||
Args:
|
||||
module (`torch.nn.Module`):
|
||||
The module whose forward pass will be executed just after this event.
|
||||
args (`Tuple[Any]`):
|
||||
args (`tuple[Any]`):
|
||||
The positional arguments passed to the module.
|
||||
kwargs (`Dict[Str, Any]`):
|
||||
kwargs (`dict[Str, Any]`):
|
||||
The keyword arguments passed to the module.
|
||||
Returns:
|
||||
`Tuple[Tuple[Any], Dict[Str, Any]]`:
|
||||
`tuple[tuple[Any], dict[Str, Any]]`:
|
||||
A tuple with the treated `args` and `kwargs`.
|
||||
"""
|
||||
return args, kwargs
|
||||
@@ -168,7 +168,7 @@ class HookRegistry:
|
||||
def __init__(self, module_ref: torch.nn.Module) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.hooks: Dict[str, ModelHook] = {}
|
||||
self.hooks: dict[str, ModelHook] = {}
|
||||
|
||||
self._module_ref = module_ref
|
||||
self._hook_order = []
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
|
||||
import math
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Callable, List, Optional
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -43,7 +43,7 @@ class LayerSkipConfig:
|
||||
Configuration for skipping internal transformer blocks when executing a transformer model.
|
||||
|
||||
Args:
|
||||
indices (`List[int]`):
|
||||
indices (`list[int]`):
|
||||
The indices of the layer to skip. This is typically the first layer in the transformer block.
|
||||
fqn (`str`, defaults to `"auto"`):
|
||||
The fully qualified name identifying the stack of transformer blocks. Typically, this is
|
||||
@@ -63,7 +63,7 @@ class LayerSkipConfig:
|
||||
skipped layers are fully retained, which is equivalent to not skipping any layers.
|
||||
"""
|
||||
|
||||
indices: List[int]
|
||||
indices: list[int]
|
||||
fqn: str = "auto"
|
||||
skip_attention: bool = True
|
||||
skip_attention_scores: bool = False
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
from typing import Optional, Tuple, Type, Union
|
||||
from typing import Optional, Type
|
||||
|
||||
import torch
|
||||
|
||||
@@ -102,8 +102,8 @@ def apply_layerwise_casting(
|
||||
module: torch.nn.Module,
|
||||
storage_dtype: torch.dtype,
|
||||
compute_dtype: torch.dtype,
|
||||
skip_modules_pattern: Union[str, Tuple[str, ...]] = "auto",
|
||||
skip_modules_classes: Optional[Tuple[Type[torch.nn.Module], ...]] = None,
|
||||
skip_modules_pattern: str | tuple[str, ...] = "auto",
|
||||
skip_modules_classes: Optional[tuple[Type[torch.nn.Module], ...]] = None,
|
||||
non_blocking: bool = False,
|
||||
) -> None:
|
||||
r"""
|
||||
@@ -137,12 +137,12 @@ def apply_layerwise_casting(
|
||||
The dtype to cast the module to before/after the forward pass for storage.
|
||||
compute_dtype (`torch.dtype`):
|
||||
The dtype to cast the module to during the forward pass for computation.
|
||||
skip_modules_pattern (`Tuple[str, ...]`, defaults to `"auto"`):
|
||||
skip_modules_pattern (`tuple[str, ...]`, defaults to `"auto"`):
|
||||
A list of patterns to match the names of the modules to skip during the layerwise casting process. If set
|
||||
to `"auto"`, the default patterns are used. If set to `None`, no modules are skipped. If set to `None`
|
||||
alongside `skip_modules_classes` being `None`, the layerwise casting is applied directly to the module
|
||||
instead of its internal submodules.
|
||||
skip_modules_classes (`Tuple[Type[torch.nn.Module], ...]`, defaults to `None`):
|
||||
skip_modules_classes (`tuple[Type[torch.nn.Module], ...]`, defaults to `None`):
|
||||
A list of module classes to skip during the layerwise casting process.
|
||||
non_blocking (`bool`, defaults to `False`):
|
||||
If `True`, the weight casting operations are non-blocking.
|
||||
@@ -169,8 +169,8 @@ def _apply_layerwise_casting(
|
||||
module: torch.nn.Module,
|
||||
storage_dtype: torch.dtype,
|
||||
compute_dtype: torch.dtype,
|
||||
skip_modules_pattern: Optional[Tuple[str, ...]] = None,
|
||||
skip_modules_classes: Optional[Tuple[Type[torch.nn.Module], ...]] = None,
|
||||
skip_modules_pattern: Optional[tuple[str, ...]] = None,
|
||||
skip_modules_classes: Optional[tuple[Type[torch.nn.Module], ...]] = None,
|
||||
non_blocking: bool = False,
|
||||
_prefix: str = "",
|
||||
) -> None:
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -54,20 +54,20 @@ class PyramidAttentionBroadcastConfig:
|
||||
The number of times a specific cross-attention broadcast is skipped before computing the attention states
|
||||
to re-use. If this is set to the value `N`, the attention computation will be skipped `N - 1` times (i.e.,
|
||||
old attention states will be reused) before computing the new attention states again.
|
||||
spatial_attention_timestep_skip_range (`Tuple[int, int]`, defaults to `(100, 800)`):
|
||||
spatial_attention_timestep_skip_range (`tuple[int, int]`, defaults to `(100, 800)`):
|
||||
The range of timesteps to skip in the spatial attention layer. The attention computations will be
|
||||
conditionally skipped if the current timestep is within the specified range.
|
||||
temporal_attention_timestep_skip_range (`Tuple[int, int]`, defaults to `(100, 800)`):
|
||||
temporal_attention_timestep_skip_range (`tuple[int, int]`, defaults to `(100, 800)`):
|
||||
The range of timesteps to skip in the temporal attention layer. The attention computations will be
|
||||
conditionally skipped if the current timestep is within the specified range.
|
||||
cross_attention_timestep_skip_range (`Tuple[int, int]`, defaults to `(100, 800)`):
|
||||
cross_attention_timestep_skip_range (`tuple[int, int]`, defaults to `(100, 800)`):
|
||||
The range of timesteps to skip in the cross-attention layer. The attention computations will be
|
||||
conditionally skipped if the current timestep is within the specified range.
|
||||
spatial_attention_block_identifiers (`Tuple[str, ...]`):
|
||||
spatial_attention_block_identifiers (`tuple[str, ...]`):
|
||||
The identifiers to match against the layer names to determine if the layer is a spatial attention layer.
|
||||
temporal_attention_block_identifiers (`Tuple[str, ...]`):
|
||||
temporal_attention_block_identifiers (`tuple[str, ...]`):
|
||||
The identifiers to match against the layer names to determine if the layer is a temporal attention layer.
|
||||
cross_attention_block_identifiers (`Tuple[str, ...]`):
|
||||
cross_attention_block_identifiers (`tuple[str, ...]`):
|
||||
The identifiers to match against the layer names to determine if the layer is a cross-attention layer.
|
||||
"""
|
||||
|
||||
@@ -75,13 +75,13 @@ class PyramidAttentionBroadcastConfig:
|
||||
temporal_attention_block_skip_range: Optional[int] = None
|
||||
cross_attention_block_skip_range: Optional[int] = None
|
||||
|
||||
spatial_attention_timestep_skip_range: Tuple[int, int] = (100, 800)
|
||||
temporal_attention_timestep_skip_range: Tuple[int, int] = (100, 800)
|
||||
cross_attention_timestep_skip_range: Tuple[int, int] = (100, 800)
|
||||
spatial_attention_timestep_skip_range: tuple[int, int] = (100, 800)
|
||||
temporal_attention_timestep_skip_range: tuple[int, int] = (100, 800)
|
||||
cross_attention_timestep_skip_range: tuple[int, int] = (100, 800)
|
||||
|
||||
spatial_attention_block_identifiers: Tuple[str, ...] = _SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS
|
||||
temporal_attention_block_identifiers: Tuple[str, ...] = _TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS
|
||||
cross_attention_block_identifiers: Tuple[str, ...] = _CROSS_TRANSFORMER_BLOCK_IDENTIFIERS
|
||||
spatial_attention_block_identifiers: tuple[str, ...] = _SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS
|
||||
temporal_attention_block_identifiers: tuple[str, ...] = _TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS
|
||||
cross_attention_block_identifiers: tuple[str, ...] = _CROSS_TRANSFORMER_BLOCK_IDENTIFIERS
|
||||
|
||||
current_timestep_callback: Callable[[], int] = None
|
||||
|
||||
@@ -141,7 +141,7 @@ class PyramidAttentionBroadcastHook(ModelHook):
|
||||
_is_stateful = True
|
||||
|
||||
def __init__(
|
||||
self, timestep_skip_range: Tuple[int, int], block_skip_range: int, current_timestep_callback: Callable[[], int]
|
||||
self, timestep_skip_range: tuple[int, int], block_skip_range: int, current_timestep_callback: Callable[[], int]
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -288,8 +288,8 @@ def _apply_pyramid_attention_broadcast_on_attention_class(
|
||||
|
||||
|
||||
def _apply_pyramid_attention_broadcast_hook(
|
||||
module: Union[Attention, MochiAttention],
|
||||
timestep_skip_range: Tuple[int, int],
|
||||
module: Attention | MochiAttention,
|
||||
timestep_skip_range: tuple[int, int],
|
||||
block_skip_range: int,
|
||||
current_timestep_callback: Callable[[], int],
|
||||
):
|
||||
@@ -299,7 +299,7 @@ def _apply_pyramid_attention_broadcast_hook(
|
||||
Args:
|
||||
module (`torch.nn.Module`):
|
||||
The module to apply Pyramid Attention Broadcast to.
|
||||
timestep_skip_range (`Tuple[int, int]`):
|
||||
timestep_skip_range (`tuple[int, int]`):
|
||||
The range of timesteps to skip in the attention layer. The attention computations will be conditionally
|
||||
skipped if the current timestep is within the specified range.
|
||||
block_skip_range (`int`):
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
|
||||
import math
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -35,21 +35,21 @@ class SmoothedEnergyGuidanceConfig:
|
||||
Configuration for skipping internal transformer blocks when executing a transformer model.
|
||||
|
||||
Args:
|
||||
indices (`List[int]`):
|
||||
indices (`list[int]`):
|
||||
The indices of the layer to skip. This is typically the first layer in the transformer block.
|
||||
fqn (`str`, defaults to `"auto"`):
|
||||
The fully qualified name identifying the stack of transformer blocks. Typically, this is
|
||||
`transformer_blocks`, `single_transformer_blocks`, `blocks`, `layers`, or `temporal_transformer_blocks`.
|
||||
For automatic detection, set this to `"auto"`. "auto" only works on DiT models. For UNet models, you must
|
||||
provide the correct fqn.
|
||||
_query_proj_identifiers (`List[str]`, defaults to `None`):
|
||||
_query_proj_identifiers (`list[str]`, defaults to `None`):
|
||||
The identifiers for the query projection layers. Typically, these are `to_q`, `query`, or `q_proj`. If
|
||||
`None`, `to_q` is used by default.
|
||||
"""
|
||||
|
||||
indices: List[int]
|
||||
indices: list[int]
|
||||
fqn: str = "auto"
|
||||
_query_proj_identifiers: List[str] = None
|
||||
_query_proj_identifiers: list[str] = None
|
||||
|
||||
def to_dict(self):
|
||||
return asdict(self)
|
||||
|
||||
@@ -21,8 +21,8 @@ def _get_identifiable_transformer_blocks_in_module(module: torch.nn.Module):
|
||||
module_list_with_transformer_blocks = []
|
||||
for name, submodule in module.named_modules():
|
||||
name_endswith_identifier = any(name.endswith(identifier) for identifier in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS)
|
||||
is_modulelist = isinstance(submodule, torch.nn.ModuleList)
|
||||
if name_endswith_identifier and is_modulelist:
|
||||
is_ModuleList = isinstance(submodule, torch.nn.ModuleList)
|
||||
if name_endswith_identifier and is_ModuleList:
|
||||
module_list_with_transformer_blocks.append((name, submodule))
|
||||
return module_list_with_transformer_blocks
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
|
||||
import math
|
||||
import warnings
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
@@ -26,14 +26,9 @@ from .configuration_utils import ConfigMixin, register_to_config
|
||||
from .utils import CONFIG_NAME, PIL_INTERPOLATION, deprecate
|
||||
|
||||
|
||||
PipelineImageInput = Union[
|
||||
PIL.Image.Image,
|
||||
np.ndarray,
|
||||
torch.Tensor,
|
||||
List[PIL.Image.Image],
|
||||
List[np.ndarray],
|
||||
List[torch.Tensor],
|
||||
]
|
||||
PipelineImageInput = (
|
||||
PIL.Image.Image | np.ndarray | torch.Tensor | list[PIL.Image.Image] | list[np.ndarray] | list[torch.Tensor]
|
||||
)
|
||||
|
||||
PipelineDepthInput = PipelineImageInput
|
||||
|
||||
@@ -68,7 +63,7 @@ def is_valid_image_imagelist(images):
|
||||
- A list of valid images.
|
||||
|
||||
Args:
|
||||
images (`Union[np.ndarray, torch.Tensor, PIL.Image.Image, List]`):
|
||||
images (`Union[np.ndarray, torch.Tensor, PIL.Image.Image, list]`):
|
||||
The image(s) to check. Can be a batch of images (4D tensor/array), a single image, or a list of valid
|
||||
images.
|
||||
|
||||
@@ -131,7 +126,7 @@ class VaeImageProcessor(ConfigMixin):
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def numpy_to_pil(images: np.ndarray) -> List[PIL.Image.Image]:
|
||||
def numpy_to_pil(images: np.ndarray) -> list[PIL.Image.Image]:
|
||||
r"""
|
||||
Convert a numpy image or a batch of images to a PIL image.
|
||||
|
||||
@@ -140,7 +135,7 @@ class VaeImageProcessor(ConfigMixin):
|
||||
The image array to convert to PIL format.
|
||||
|
||||
Returns:
|
||||
`List[PIL.Image.Image]`:
|
||||
`list[PIL.Image.Image]`:
|
||||
A list of PIL images.
|
||||
"""
|
||||
if images.ndim == 3:
|
||||
@@ -155,12 +150,12 @@ class VaeImageProcessor(ConfigMixin):
|
||||
return pil_images
|
||||
|
||||
@staticmethod
|
||||
def pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray:
|
||||
def pil_to_numpy(images: list[PIL.Image.Image] | PIL.Image.Image) -> np.ndarray:
|
||||
r"""
|
||||
Convert a PIL image or a list of PIL images to NumPy arrays.
|
||||
|
||||
Args:
|
||||
images (`PIL.Image.Image` or `List[PIL.Image.Image]`):
|
||||
images (`PIL.Image.Image` or `list[PIL.Image.Image]`):
|
||||
The PIL image or list of images to convert to NumPy format.
|
||||
|
||||
Returns:
|
||||
@@ -210,7 +205,7 @@ class VaeImageProcessor(ConfigMixin):
|
||||
return images
|
||||
|
||||
@staticmethod
|
||||
def normalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
|
||||
def normalize(images: np.ndarray | torch.Tensor) -> np.ndarray | torch.Tensor:
|
||||
r"""
|
||||
Normalize an image array to [-1,1].
|
||||
|
||||
@@ -225,7 +220,7 @@ class VaeImageProcessor(ConfigMixin):
|
||||
return 2.0 * images - 1.0
|
||||
|
||||
@staticmethod
|
||||
def denormalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
|
||||
def denormalize(images: np.ndarray | torch.Tensor) -> np.ndarray | torch.Tensor:
|
||||
r"""
|
||||
Denormalize an image array to [0,1].
|
||||
|
||||
@@ -467,11 +462,11 @@ class VaeImageProcessor(ConfigMixin):
|
||||
|
||||
def resize(
|
||||
self,
|
||||
image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
|
||||
image: PIL.Image.Image | np.ndarray | torch.Tensor,
|
||||
height: int,
|
||||
width: int,
|
||||
resize_mode: str = "default", # "default", "fill", "crop"
|
||||
) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]:
|
||||
) -> PIL.Image.Image | np.ndarray | torch.Tensor:
|
||||
"""
|
||||
Resize image.
|
||||
|
||||
@@ -544,7 +539,7 @@ class VaeImageProcessor(ConfigMixin):
|
||||
return image
|
||||
|
||||
def _denormalize_conditionally(
|
||||
self, images: torch.Tensor, do_denormalize: Optional[List[bool]] = None
|
||||
self, images: torch.Tensor, do_denormalize: Optional[list[bool]] = None
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
Denormalize a batch of images based on a condition list.
|
||||
@@ -552,7 +547,7 @@ class VaeImageProcessor(ConfigMixin):
|
||||
Args:
|
||||
images (`torch.Tensor`):
|
||||
The input image tensor.
|
||||
do_denormalize (`Optional[List[bool]`, *optional*, defaults to `None`):
|
||||
do_denormalize (`Optional[list[bool]`, *optional*, defaults to `None`):
|
||||
A list of booleans indicating whether to denormalize each image in the batch. If `None`, will use the
|
||||
value of `do_normalize` in the `VaeImageProcessor` config.
|
||||
"""
|
||||
@@ -565,10 +560,10 @@ class VaeImageProcessor(ConfigMixin):
|
||||
|
||||
def get_default_height_width(
|
||||
self,
|
||||
image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
|
||||
image: PIL.Image.Image | np.ndarray | torch.Tensor,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
) -> Tuple[int, int]:
|
||||
) -> tuple[int, int]:
|
||||
r"""
|
||||
Returns the height and width of the image, downscaled to the next integer multiple of `vae_scale_factor`.
|
||||
|
||||
@@ -583,7 +578,7 @@ class VaeImageProcessor(ConfigMixin):
|
||||
The width of the preprocessed image. If `None`, the width of the `image` input will be used.
|
||||
|
||||
Returns:
|
||||
`Tuple[int, int]`:
|
||||
`tuple[int, int]`:
|
||||
A tuple containing the height and width, both resized to the nearest integer multiple of
|
||||
`vae_scale_factor`.
|
||||
"""
|
||||
@@ -616,7 +611,7 @@ class VaeImageProcessor(ConfigMixin):
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
resize_mode: str = "default", # "default", "fill", "crop"
|
||||
crops_coords: Optional[Tuple[int, int, int, int]] = None,
|
||||
crops_coords: Optional[tuple[int, int, int, int]] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Preprocess the image input.
|
||||
@@ -638,7 +633,7 @@ class VaeImageProcessor(ConfigMixin):
|
||||
image to fit within the specified width and height, maintaining the aspect ratio, and then center the
|
||||
image within the dimensions, cropping the excess. Note that resize_mode `fill` and `crop` are only
|
||||
supported for PIL image input.
|
||||
crops_coords (`List[Tuple[int, int, int, int]]`, *optional*, defaults to `None`):
|
||||
crops_coords (`list[tuple[int, int, int, int]]`, *optional*, defaults to `None`):
|
||||
The crop coordinates for each image in the batch. If `None`, will not crop the image.
|
||||
|
||||
Returns:
|
||||
@@ -745,8 +740,8 @@ class VaeImageProcessor(ConfigMixin):
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
output_type: str = "pil",
|
||||
do_denormalize: Optional[List[bool]] = None,
|
||||
) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]:
|
||||
do_denormalize: Optional[list[bool]] = None,
|
||||
) -> PIL.Image.Image | np.ndarray | torch.Tensor:
|
||||
"""
|
||||
Postprocess the image output from tensor to `output_type`.
|
||||
|
||||
@@ -755,7 +750,7 @@ class VaeImageProcessor(ConfigMixin):
|
||||
The image input, should be a pytorch tensor with shape `B x C x H x W`.
|
||||
output_type (`str`, *optional*, defaults to `pil`):
|
||||
The output type of the image, can be one of `pil`, `np`, `pt`, `latent`.
|
||||
do_denormalize (`List[bool]`, *optional*, defaults to `None`):
|
||||
do_denormalize (`list[bool]`, *optional*, defaults to `None`):
|
||||
Whether to denormalize the image to [0,1]. If `None`, will use the value of `do_normalize` in the
|
||||
`VaeImageProcessor` config.
|
||||
|
||||
@@ -796,7 +791,7 @@ class VaeImageProcessor(ConfigMixin):
|
||||
mask: PIL.Image.Image,
|
||||
init_image: PIL.Image.Image,
|
||||
image: PIL.Image.Image,
|
||||
crop_coords: Optional[Tuple[int, int, int, int]] = None,
|
||||
crop_coords: Optional[tuple[int, int, int, int]] = None,
|
||||
) -> PIL.Image.Image:
|
||||
r"""
|
||||
Applies an overlay of the mask and the inpainted image on the original image.
|
||||
@@ -808,7 +803,7 @@ class VaeImageProcessor(ConfigMixin):
|
||||
The original image to which the overlay is applied.
|
||||
image (`PIL.Image.Image`):
|
||||
The image to overlay onto the original.
|
||||
crop_coords (`Tuple[int, int, int, int]`, *optional*):
|
||||
crop_coords (`tuple[int, int, int, int]`, *optional*):
|
||||
Coordinates to crop the image. If provided, the image will be cropped accordingly.
|
||||
|
||||
Returns:
|
||||
@@ -891,7 +886,7 @@ class InpaintProcessor(ConfigMixin):
|
||||
height: int = None,
|
||||
width: int = None,
|
||||
padding_mask_crop: Optional[int] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Preprocess the image and mask.
|
||||
"""
|
||||
@@ -946,8 +941,8 @@ class InpaintProcessor(ConfigMixin):
|
||||
output_type: str = "pil",
|
||||
original_image: Optional[PIL.Image.Image] = None,
|
||||
original_mask: Optional[PIL.Image.Image] = None,
|
||||
crops_coords: Optional[Tuple[int, int, int, int]] = None,
|
||||
) -> Tuple[PIL.Image.Image, PIL.Image.Image]:
|
||||
crops_coords: Optional[tuple[int, int, int, int]] = None,
|
||||
) -> tuple[PIL.Image.Image, PIL.Image.Image]:
|
||||
"""
|
||||
Postprocess the image, optionally apply mask overlay
|
||||
"""
|
||||
@@ -998,7 +993,7 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
|
||||
super().__init__()
|
||||
|
||||
@staticmethod
|
||||
def numpy_to_pil(images: np.ndarray) -> List[PIL.Image.Image]:
|
||||
def numpy_to_pil(images: np.ndarray) -> list[PIL.Image.Image]:
|
||||
r"""
|
||||
Convert a NumPy image or a batch of images to a list of PIL images.
|
||||
|
||||
@@ -1007,7 +1002,7 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
|
||||
The input NumPy array of images, which can be a single image or a batch.
|
||||
|
||||
Returns:
|
||||
`List[PIL.Image.Image]`:
|
||||
`list[PIL.Image.Image]`:
|
||||
A list of PIL images converted from the input NumPy array.
|
||||
"""
|
||||
if images.ndim == 3:
|
||||
@@ -1022,12 +1017,12 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
|
||||
return pil_images
|
||||
|
||||
@staticmethod
|
||||
def depth_pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray:
|
||||
def depth_pil_to_numpy(images: list[PIL.Image.Image] | PIL.Image.Image) -> np.ndarray:
|
||||
r"""
|
||||
Convert a PIL image or a list of PIL images to NumPy arrays.
|
||||
|
||||
Args:
|
||||
images (`Union[List[PIL.Image.Image], PIL.Image.Image]`):
|
||||
images (`Union[list[PIL.Image.Image], PIL.Image.Image]`):
|
||||
The input image or list of images to be converted.
|
||||
|
||||
Returns:
|
||||
@@ -1042,7 +1037,7 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
|
||||
return images
|
||||
|
||||
@staticmethod
|
||||
def rgblike_to_depthmap(image: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
|
||||
def rgblike_to_depthmap(image: np.ndarray | torch.Tensor) -> np.ndarray | torch.Tensor:
|
||||
r"""
|
||||
Convert an RGB-like depth image to a depth map.
|
||||
|
||||
@@ -1056,7 +1051,7 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
|
||||
"""
|
||||
return image[:, :, 1] * 2**8 + image[:, :, 2]
|
||||
|
||||
def numpy_to_depth(self, images: np.ndarray) -> List[PIL.Image.Image]:
|
||||
def numpy_to_depth(self, images: np.ndarray) -> list[PIL.Image.Image]:
|
||||
r"""
|
||||
Convert a NumPy depth image or a batch of images to a list of PIL images.
|
||||
|
||||
@@ -1065,7 +1060,7 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
|
||||
The input NumPy array of depth images, which can be a single image or a batch.
|
||||
|
||||
Returns:
|
||||
`List[PIL.Image.Image]`:
|
||||
`list[PIL.Image.Image]`:
|
||||
A list of PIL images converted from the input NumPy depth images.
|
||||
"""
|
||||
if images.ndim == 3:
|
||||
@@ -1088,8 +1083,8 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
output_type: str = "pil",
|
||||
do_denormalize: Optional[List[bool]] = None,
|
||||
) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]:
|
||||
do_denormalize: Optional[list[bool]] = None,
|
||||
) -> PIL.Image.Image | np.ndarray | torch.Tensor:
|
||||
"""
|
||||
Postprocess the image output from tensor to `output_type`.
|
||||
|
||||
@@ -1098,7 +1093,7 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
|
||||
The image input, should be a pytorch tensor with shape `B x C x H x W`.
|
||||
output_type (`str`, *optional*, defaults to `pil`):
|
||||
The output type of the image, can be one of `pil`, `np`, `pt`, `latent`.
|
||||
do_denormalize (`List[bool]`, *optional*, defaults to `None`):
|
||||
do_denormalize (`list[bool]`, *optional*, defaults to `None`):
|
||||
Whether to denormalize the image to [0,1]. If `None`, will use the value of `do_normalize` in the
|
||||
`VaeImageProcessor` config.
|
||||
|
||||
@@ -1136,8 +1131,8 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
rgb: Union[torch.Tensor, PIL.Image.Image, np.ndarray],
|
||||
depth: Union[torch.Tensor, PIL.Image.Image, np.ndarray],
|
||||
rgb: torch.Tensor | PIL.Image.Image | np.ndarray,
|
||||
depth: torch.Tensor | PIL.Image.Image | np.ndarray,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
target_res: Optional[int] = None,
|
||||
@@ -1158,7 +1153,7 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
|
||||
Target resolution for resizing the images. If specified, overrides height and width.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, torch.Tensor]`:
|
||||
`tuple[torch.Tensor, torch.Tensor]`:
|
||||
A tuple containing the processed RGB and depth images as PyTorch tensors.
|
||||
"""
|
||||
supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor)
|
||||
@@ -1396,7 +1391,7 @@ class PixArtImageProcessor(VaeImageProcessor):
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def classify_height_width_bin(height: int, width: int, ratios: dict) -> Tuple[int, int]:
|
||||
def classify_height_width_bin(height: int, width: int, ratios: dict) -> tuple[int, int]:
|
||||
r"""
|
||||
Returns the binned height and width based on the aspect ratio.
|
||||
|
||||
@@ -1406,7 +1401,7 @@ class PixArtImageProcessor(VaeImageProcessor):
|
||||
ratios (`dict`): A dictionary where keys are aspect ratios and values are tuples of (height, width).
|
||||
|
||||
Returns:
|
||||
`Tuple[int, int]`: The closest binned height and width.
|
||||
`tuple[int, int]`: The closest binned height and width.
|
||||
"""
|
||||
ar = float(height / width)
|
||||
closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - ar))
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Union
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -57,15 +57,15 @@ class IPAdapterMixin:
|
||||
@validate_hf_hub_args
|
||||
def load_ip_adapter(
|
||||
self,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, List[str], Dict[str, torch.Tensor]],
|
||||
subfolder: Union[str, List[str]],
|
||||
weight_name: Union[str, List[str]],
|
||||
pretrained_model_name_or_path_or_dict: str | list[str] | dict[str, torch.Tensor],
|
||||
subfolder: str | list[str],
|
||||
weight_name: str | list[str],
|
||||
image_encoder_folder: Optional[str] = "image_encoder",
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
pretrained_model_name_or_path_or_dict (`str` or `List[str]` or `os.PathLike` or `List[os.PathLike]` or `dict` or `List[dict]`):
|
||||
pretrained_model_name_or_path_or_dict (`str` or `list[str]` or `os.PathLike` or `list[os.PathLike]` or `dict` or `list[dict]`):
|
||||
Can be either:
|
||||
|
||||
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
||||
@@ -74,10 +74,10 @@ class IPAdapterMixin:
|
||||
with [`ModelMixin.save_pretrained`].
|
||||
- A [torch state
|
||||
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
||||
subfolder (`str` or `List[str]`):
|
||||
subfolder (`str` or `list[str]`):
|
||||
The subfolder location of a model file within a larger model repository on the Hub or locally. If a
|
||||
list is passed, it should have the same length as `weight_name`.
|
||||
weight_name (`str` or `List[str]`):
|
||||
weight_name (`str` or `list[str]`):
|
||||
The name of the weight file to load. If a list is passed, it should have the same length as
|
||||
`subfolder`.
|
||||
image_encoder_folder (`str`, *optional*, defaults to `image_encoder`):
|
||||
@@ -94,7 +94,7 @@ class IPAdapterMixin:
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
proxies (`dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
@@ -358,14 +358,14 @@ class ModularIPAdapterMixin:
|
||||
@validate_hf_hub_args
|
||||
def load_ip_adapter(
|
||||
self,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, List[str], Dict[str, torch.Tensor]],
|
||||
subfolder: Union[str, List[str]],
|
||||
weight_name: Union[str, List[str]],
|
||||
pretrained_model_name_or_path_or_dict: str | list[str] | dict[str, torch.Tensor],
|
||||
subfolder: str | list[str],
|
||||
weight_name: str | list[str],
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
pretrained_model_name_or_path_or_dict (`str` or `List[str]` or `os.PathLike` or `List[os.PathLike]` or `dict` or `List[dict]`):
|
||||
pretrained_model_name_or_path_or_dict (`str` or `list[str]` or `os.PathLike` or `list[os.PathLike]` or `dict` or `list[dict]`):
|
||||
Can be either:
|
||||
|
||||
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
||||
@@ -374,10 +374,10 @@ class ModularIPAdapterMixin:
|
||||
with [`ModelMixin.save_pretrained`].
|
||||
- A [torch state
|
||||
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
||||
subfolder (`str` or `List[str]`):
|
||||
subfolder (`str` or `list[str]`):
|
||||
The subfolder location of a model file within a larger model repository on the Hub or locally. If a
|
||||
list is passed, it should have the same length as `weight_name`.
|
||||
weight_name (`str` or `List[str]`):
|
||||
weight_name (`str` or `list[str]`):
|
||||
The name of the weight file to load. If a list is passed, it should have the same length as
|
||||
`subfolder`.
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
@@ -387,7 +387,7 @@ class ModularIPAdapterMixin:
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
proxies (`dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
@@ -608,9 +608,9 @@ class FluxIPAdapterMixin:
|
||||
@validate_hf_hub_args
|
||||
def load_ip_adapter(
|
||||
self,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, List[str], Dict[str, torch.Tensor]],
|
||||
weight_name: Union[str, List[str]],
|
||||
subfolder: Optional[Union[str, List[str]]] = "",
|
||||
pretrained_model_name_or_path_or_dict: str | list[str] | dict[str, torch.Tensor],
|
||||
weight_name: str | list[str],
|
||||
subfolder: Optional[str | list[str]] = "",
|
||||
image_encoder_pretrained_model_name_or_path: Optional[str] = "image_encoder",
|
||||
image_encoder_subfolder: Optional[str] = "",
|
||||
image_encoder_dtype: torch.dtype = torch.float16,
|
||||
@@ -618,7 +618,7 @@ class FluxIPAdapterMixin:
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
pretrained_model_name_or_path_or_dict (`str` or `List[str]` or `os.PathLike` or `List[os.PathLike]` or `dict` or `List[dict]`):
|
||||
pretrained_model_name_or_path_or_dict (`str` or `list[str]` or `os.PathLike` or `list[os.PathLike]` or `dict` or `list[dict]`):
|
||||
Can be either:
|
||||
|
||||
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
||||
@@ -627,10 +627,10 @@ class FluxIPAdapterMixin:
|
||||
with [`ModelMixin.save_pretrained`].
|
||||
- A [torch state
|
||||
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
||||
subfolder (`str` or `List[str]`):
|
||||
subfolder (`str` or `list[str]`):
|
||||
The subfolder location of a model file within a larger model repository on the Hub or locally. If a
|
||||
list is passed, it should have the same length as `weight_name`.
|
||||
weight_name (`str` or `List[str]`):
|
||||
weight_name (`str` or `list[str]`):
|
||||
The name of the weight file to load. If a list is passed, it should have the same length as
|
||||
`weight_name`.
|
||||
image_encoder_pretrained_model_name_or_path (`str`, *optional*, defaults to `./image_encoder`):
|
||||
@@ -647,7 +647,7 @@ class FluxIPAdapterMixin:
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
proxies (`dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
@@ -797,13 +797,13 @@ class FluxIPAdapterMixin:
|
||||
# load ip-adapter into transformer
|
||||
self.transformer._load_ip_adapter_weights(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
|
||||
|
||||
def set_ip_adapter_scale(self, scale: Union[float, List[float], List[List[float]]]):
|
||||
def set_ip_adapter_scale(self, scale: float | list[float] | list[list[float]]):
|
||||
"""
|
||||
Set IP-Adapter scales per-transformer block. Input `scale` could be a single config or a list of configs for
|
||||
granular control over each IP-Adapter behavior. A config can be a float or a list.
|
||||
|
||||
`float` is converted to list and repeated for the number of blocks and the number of IP adapters. `List[float]`
|
||||
length match the number of blocks, it is repeated for each IP adapter. `List[List[float]]` must match the
|
||||
`float` is converted to list and repeated for the number of blocks and the number of IP adapters. `list[float]`
|
||||
length match the number of blocks, it is repeated for each IP adapter. `list[list[float]]` must match the
|
||||
number of IP adapters and each must match the number of blocks.
|
||||
|
||||
Example:
|
||||
@@ -823,18 +823,18 @@ class FluxIPAdapterMixin:
|
||||
```
|
||||
"""
|
||||
|
||||
scale_type = Union[int, float]
|
||||
scale_type = int | float
|
||||
num_ip_adapters = self.transformer.encoder_hid_proj.num_ip_adapters
|
||||
num_layers = self.transformer.config.num_layers
|
||||
|
||||
# Single value for all layers of all IP-Adapters
|
||||
if isinstance(scale, scale_type):
|
||||
scale = [scale for _ in range(num_ip_adapters)]
|
||||
# List of per-layer scales for a single IP-Adapter
|
||||
elif _is_valid_type(scale, List[scale_type]) and num_ip_adapters == 1:
|
||||
# list of per-layer scales for a single IP-Adapter
|
||||
elif _is_valid_type(scale, list[scale_type]) and num_ip_adapters == 1:
|
||||
scale = [scale]
|
||||
# Invalid scale type
|
||||
elif not _is_valid_type(scale, List[Union[scale_type, List[scale_type]]]):
|
||||
elif not _is_valid_type(scale, list[scale_type | list[scale_type]]):
|
||||
raise TypeError(f"Unexpected type {_get_detailed_type(scale)} for scale.")
|
||||
|
||||
if len(scale) != num_ip_adapters:
|
||||
@@ -918,7 +918,7 @@ class SD3IPAdapterMixin:
|
||||
@validate_hf_hub_args
|
||||
def load_ip_adapter(
|
||||
self,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
|
||||
weight_name: str = "ip-adapter.safetensors",
|
||||
subfolder: Optional[str] = None,
|
||||
image_encoder_folder: Optional[str] = "image_encoder",
|
||||
@@ -953,7 +953,7 @@ class SD3IPAdapterMixin:
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
proxies (`dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
|
||||
@@ -17,7 +17,7 @@ import inspect
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
from typing import Callable, Dict, Optional
|
||||
|
||||
import safetensors
|
||||
import torch
|
||||
@@ -77,7 +77,7 @@ def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adap
|
||||
Controls how much to influence the outputs with the LoRA parameters.
|
||||
safe_fusing (`bool`, defaults to `False`):
|
||||
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
|
||||
adapter_names (`List[str]` or `str`):
|
||||
adapter_names (`list[str]` or `str`):
|
||||
The names of the adapters to use.
|
||||
"""
|
||||
merge_kwargs = {"safe_merge": safe_fusing}
|
||||
@@ -116,20 +116,20 @@ def unfuse_text_encoder_lora(text_encoder):
|
||||
|
||||
|
||||
def set_adapters_for_text_encoder(
|
||||
adapter_names: Union[List[str], str],
|
||||
adapter_names: list[str] | str,
|
||||
text_encoder: Optional["PreTrainedModel"] = None, # noqa: F821
|
||||
text_encoder_weights: Optional[Union[float, List[float], List[None]]] = None,
|
||||
text_encoder_weights: Optional[float | list[float] | list[None]] = None,
|
||||
):
|
||||
"""
|
||||
Sets the adapter layers for the text encoder.
|
||||
|
||||
Args:
|
||||
adapter_names (`List[str]` or `str`):
|
||||
adapter_names (`list[str]` or `str`):
|
||||
The names of the adapters to use.
|
||||
text_encoder (`torch.nn.Module`, *optional*):
|
||||
The text encoder module to set the adapter layers for. If `None`, it will try to get the `text_encoder`
|
||||
attribute.
|
||||
text_encoder_weights (`List[float]`, *optional*):
|
||||
text_encoder_weights (`list[float]`, *optional*):
|
||||
The weights to use for the text encoder. If `None`, the weights are set to `1.0` for all the adapters.
|
||||
"""
|
||||
if text_encoder is None:
|
||||
@@ -535,10 +535,10 @@ class LoraBaseMixin:
|
||||
|
||||
def fuse_lora(
|
||||
self,
|
||||
components: List[str] = [],
|
||||
components: list[str] = [],
|
||||
lora_scale: float = 1.0,
|
||||
safe_fusing: bool = False,
|
||||
adapter_names: Optional[List[str]] = None,
|
||||
adapter_names: Optional[list[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
@@ -547,12 +547,12 @@ class LoraBaseMixin:
|
||||
> [!WARNING] > This is an experimental API.
|
||||
|
||||
Args:
|
||||
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
|
||||
components: (`list[str]`): list of LoRA-injectable components to fuse the LoRAs into.
|
||||
lora_scale (`float`, defaults to 1.0):
|
||||
Controls how much to influence the outputs with the LoRA parameters.
|
||||
safe_fusing (`bool`, defaults to `False`):
|
||||
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
|
||||
adapter_names (`List[str]`, *optional*):
|
||||
adapter_names (`list[str]`, *optional*):
|
||||
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
|
||||
|
||||
Example:
|
||||
@@ -619,7 +619,7 @@ class LoraBaseMixin:
|
||||
|
||||
self._merged_adapters = self._merged_adapters | merged_adapter_names
|
||||
|
||||
def unfuse_lora(self, components: List[str] = [], **kwargs):
|
||||
def unfuse_lora(self, components: list[str] = [], **kwargs):
|
||||
r"""
|
||||
Reverses the effect of
|
||||
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
|
||||
@@ -627,7 +627,7 @@ class LoraBaseMixin:
|
||||
> [!WARNING] > This is an experimental API.
|
||||
|
||||
Args:
|
||||
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
||||
components (`list[str]`): list of LoRA-injectable components to unfuse LoRA from.
|
||||
unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
||||
unfuse_text_encoder (`bool`, defaults to `True`):
|
||||
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
|
||||
@@ -674,16 +674,16 @@ class LoraBaseMixin:
|
||||
|
||||
def set_adapters(
|
||||
self,
|
||||
adapter_names: Union[List[str], str],
|
||||
adapter_weights: Optional[Union[float, Dict, List[float], List[Dict]]] = None,
|
||||
adapter_names: list[str] | str,
|
||||
adapter_weights: Optional[float | Dict | list[float] | list[Dict]] = None,
|
||||
):
|
||||
"""
|
||||
Set the currently active adapters for use in the pipeline.
|
||||
|
||||
Args:
|
||||
adapter_names (`List[str]` or `str`):
|
||||
adapter_names (`list[str]` or `str`):
|
||||
The names of the adapters to use.
|
||||
adapter_weights (`Union[List[float], float]`, *optional*):
|
||||
adapter_weights (`Union[list[float], float]`, *optional*):
|
||||
The adapter(s) weights to use with the UNet. If `None`, the weights are set to `1.0` for all the
|
||||
adapters.
|
||||
|
||||
@@ -835,12 +835,12 @@ class LoraBaseMixin:
|
||||
elif issubclass(model.__class__, PreTrainedModel):
|
||||
enable_lora_for_text_encoder(model)
|
||||
|
||||
def delete_adapters(self, adapter_names: Union[List[str], str]):
|
||||
def delete_adapters(self, adapter_names: list[str] | str):
|
||||
"""
|
||||
Delete an adapter's LoRA layers from the pipeline.
|
||||
|
||||
Args:
|
||||
adapter_names (`Union[List[str], str]`):
|
||||
adapter_names (`Union[list[str], str]`):
|
||||
The names of the adapters to delete.
|
||||
|
||||
Example:
|
||||
@@ -873,7 +873,7 @@ class LoraBaseMixin:
|
||||
for adapter_name in adapter_names:
|
||||
delete_adapter_layers(model, adapter_name)
|
||||
|
||||
def get_active_adapters(self) -> List[str]:
|
||||
def get_active_adapters(self) -> list[str]:
|
||||
"""
|
||||
Gets the list of the current active adapters.
|
||||
|
||||
@@ -906,7 +906,7 @@ class LoraBaseMixin:
|
||||
|
||||
return active_adapters
|
||||
|
||||
def get_list_adapters(self) -> Dict[str, List[str]]:
|
||||
def get_list_adapters(self) -> dict[str, list[str]]:
|
||||
"""
|
||||
Gets the current list of all available adapters in the pipeline.
|
||||
"""
|
||||
@@ -928,7 +928,7 @@ class LoraBaseMixin:
|
||||
|
||||
return set_adapters
|
||||
|
||||
def set_lora_device(self, adapter_names: List[str], device: Union[torch.device, str, int]) -> None:
|
||||
def set_lora_device(self, adapter_names: list[str], device: torch.device | str | int) -> None:
|
||||
"""
|
||||
Moves the LoRAs listed in `adapter_names` to a target device. Useful for offloading the LoRA to the CPU in case
|
||||
you want to load multiple adapters and free some GPU memory.
|
||||
@@ -955,8 +955,8 @@ class LoraBaseMixin:
|
||||
```
|
||||
|
||||
Args:
|
||||
adapter_names (`List[str]`):
|
||||
List of adapters to send device to.
|
||||
adapter_names (`list[str]`):
|
||||
list of adapters to send device to.
|
||||
device (`Union[torch.device, str, int]`):
|
||||
Device to send the adapters to. Can be either a torch device, a str or an integer.
|
||||
"""
|
||||
@@ -1007,7 +1007,7 @@ class LoraBaseMixin:
|
||||
|
||||
@staticmethod
|
||||
def write_lora_layers(
|
||||
state_dict: Dict[str, torch.Tensor],
|
||||
state_dict: dict[str, torch.Tensor],
|
||||
save_directory: str,
|
||||
is_main_process: bool,
|
||||
weight_name: str,
|
||||
@@ -1059,9 +1059,9 @@ class LoraBaseMixin:
|
||||
@classmethod
|
||||
def _save_lora_weights(
|
||||
cls,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
lora_layers: Dict[str, Dict[str, Union[torch.nn.Module, torch.Tensor]]],
|
||||
lora_metadata: Dict[str, Optional[dict]],
|
||||
save_directory: str | os.PathLike,
|
||||
lora_layers: dict[str, dict[str, torch.nn.Module | torch.Tensor]],
|
||||
lora_metadata: dict[str, Optional[dict]],
|
||||
is_main_process: bool = True,
|
||||
weight_name: str = None,
|
||||
save_function: Callable = None,
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
@@ -1021,7 +1020,7 @@ def _convert_xlabs_flux_lora_to_diffusers(old_state_dict):
|
||||
return new_state_dict
|
||||
|
||||
|
||||
def _custom_replace(key: str, substrings: List[str]) -> str:
|
||||
def _custom_replace(key: str, substrings: list[str]) -> str:
|
||||
# Replaces the "."s with "_"s upto the `substrings`.
|
||||
# Example:
|
||||
# lora_unet.foo.bar.lora_A.weight -> lora_unet_foo_bar.lora_A.weight
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
@@ -137,7 +137,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
def load_lora_weights(
|
||||
self,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
|
||||
adapter_name: Optional[str] = None,
|
||||
hotswap: bool = False,
|
||||
**kwargs,
|
||||
@@ -240,7 +240,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
||||
@validate_hf_hub_args
|
||||
def lora_state_dict(
|
||||
cls,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
@@ -267,7 +267,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
proxies (`dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
@@ -367,7 +367,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
||||
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
|
||||
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
|
||||
encoder lora layers.
|
||||
network_alphas (`Dict[str, float]`):
|
||||
network_alphas (`dict[str, float]`):
|
||||
The value of the network alpha used for stable learning and preventing underflow. This value has the
|
||||
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
|
||||
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
|
||||
@@ -429,7 +429,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
||||
state_dict (`dict`):
|
||||
A standard state dict containing the lora layer parameters. The key should be prefixed with an
|
||||
additional `text_encoder` to distinguish between unet lora layers.
|
||||
network_alphas (`Dict[str, float]`):
|
||||
network_alphas (`dict[str, float]`):
|
||||
The value of the network alpha used for stable learning and preventing underflow. This value has the
|
||||
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
|
||||
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
|
||||
@@ -469,9 +469,9 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
||||
@classmethod
|
||||
def save_lora_weights(
|
||||
cls,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||
text_encoder_lora_layers: Dict[str, torch.nn.Module] = None,
|
||||
save_directory: str | os.PathLike,
|
||||
unet_lora_layers: dict[str, torch.nn.Module | torch.Tensor] = None,
|
||||
text_encoder_lora_layers: dict[str, torch.nn.Module] = None,
|
||||
is_main_process: bool = True,
|
||||
weight_name: str = None,
|
||||
save_function: Callable = None,
|
||||
@@ -485,9 +485,9 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
||||
Arguments:
|
||||
save_directory (`str` or `os.PathLike`):
|
||||
Directory to save LoRA parameters to. Will be created if it doesn't exist.
|
||||
unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
||||
unet_lora_layers (`dict[str, torch.nn.Module]` or `dict[str, torch.Tensor]`):
|
||||
State dict of the LoRA layers corresponding to the `unet`.
|
||||
text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
||||
text_encoder_lora_layers (`dict[str, torch.nn.Module]` or `dict[str, torch.Tensor]`):
|
||||
State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
|
||||
encoder LoRA state dict because it comes from 🤗 Transformers.
|
||||
is_main_process (`bool`, *optional*, defaults to `True`):
|
||||
@@ -531,10 +531,10 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
def fuse_lora(
|
||||
self,
|
||||
components: List[str] = ["unet", "text_encoder"],
|
||||
components: list[str] = ["unet", "text_encoder"],
|
||||
lora_scale: float = 1.0,
|
||||
safe_fusing: bool = False,
|
||||
adapter_names: Optional[List[str]] = None,
|
||||
adapter_names: Optional[list[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
@@ -543,12 +543,12 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
||||
> [!WARNING] > This is an experimental API.
|
||||
|
||||
Args:
|
||||
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
|
||||
components: (`list[str]`): list of LoRA-injectable components to fuse the LoRAs into.
|
||||
lora_scale (`float`, defaults to 1.0):
|
||||
Controls how much to influence the outputs with the LoRA parameters.
|
||||
safe_fusing (`bool`, defaults to `False`):
|
||||
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
|
||||
adapter_names (`List[str]`, *optional*):
|
||||
adapter_names (`list[str]`, *optional*):
|
||||
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
|
||||
|
||||
Example:
|
||||
@@ -572,7 +572,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def unfuse_lora(self, components: List[str] = ["unet", "text_encoder"], **kwargs):
|
||||
def unfuse_lora(self, components: list[str] = ["unet", "text_encoder"], **kwargs):
|
||||
r"""
|
||||
Reverses the effect of
|
||||
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
|
||||
@@ -580,7 +580,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
||||
> [!WARNING] > This is an experimental API.
|
||||
|
||||
Args:
|
||||
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
||||
components (`list[str]`): list of LoRA-injectable components to unfuse LoRA from.
|
||||
unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
||||
unfuse_text_encoder (`bool`, defaults to `True`):
|
||||
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
|
||||
@@ -602,7 +602,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
def load_lora_weights(
|
||||
self,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
|
||||
adapter_name: Optional[str] = None,
|
||||
hotswap: bool = False,
|
||||
**kwargs,
|
||||
@@ -679,7 +679,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
||||
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.lora_state_dict
|
||||
def lora_state_dict(
|
||||
cls,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
@@ -706,7 +706,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
proxies (`dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
@@ -807,7 +807,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
||||
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
|
||||
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
|
||||
encoder lora layers.
|
||||
network_alphas (`Dict[str, float]`):
|
||||
network_alphas (`dict[str, float]`):
|
||||
The value of the network alpha used for stable learning and preventing underflow. This value has the
|
||||
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
|
||||
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
|
||||
@@ -870,7 +870,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
||||
state_dict (`dict`):
|
||||
A standard state dict containing the lora layer parameters. The key should be prefixed with an
|
||||
additional `text_encoder` to distinguish between unet lora layers.
|
||||
network_alphas (`Dict[str, float]`):
|
||||
network_alphas (`dict[str, float]`):
|
||||
The value of the network alpha used for stable learning and preventing underflow. This value has the
|
||||
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
|
||||
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
|
||||
@@ -910,10 +910,10 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
||||
@classmethod
|
||||
def save_lora_weights(
|
||||
cls,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||
text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||
text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||
save_directory: str | os.PathLike,
|
||||
unet_lora_layers: dict[str, torch.nn.Module | torch.Tensor] = None,
|
||||
text_encoder_lora_layers: dict[str, torch.nn.Module | torch.Tensor] = None,
|
||||
text_encoder_2_lora_layers: dict[str, torch.nn.Module | torch.Tensor] = None,
|
||||
is_main_process: bool = True,
|
||||
weight_name: str = None,
|
||||
save_function: Callable = None,
|
||||
@@ -957,10 +957,10 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
def fuse_lora(
|
||||
self,
|
||||
components: List[str] = ["unet", "text_encoder", "text_encoder_2"],
|
||||
components: list[str] = ["unet", "text_encoder", "text_encoder_2"],
|
||||
lora_scale: float = 1.0,
|
||||
safe_fusing: bool = False,
|
||||
adapter_names: Optional[List[str]] = None,
|
||||
adapter_names: Optional[list[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
@@ -974,7 +974,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def unfuse_lora(self, components: List[str] = ["unet", "text_encoder", "text_encoder_2"], **kwargs):
|
||||
def unfuse_lora(self, components: list[str] = ["unet", "text_encoder", "text_encoder_2"], **kwargs):
|
||||
r"""
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
|
||||
"""
|
||||
@@ -998,7 +998,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
||||
@validate_hf_hub_args
|
||||
def lora_state_dict(
|
||||
cls,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
@@ -1050,7 +1050,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
def load_lora_weights(
|
||||
self,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
|
||||
adapter_name=None,
|
||||
hotswap: bool = False,
|
||||
**kwargs,
|
||||
@@ -1166,7 +1166,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
||||
state_dict (`dict`):
|
||||
A standard state dict containing the lora layer parameters. The key should be prefixed with an
|
||||
additional `text_encoder` to distinguish between unet lora layers.
|
||||
network_alphas (`Dict[str, float]`):
|
||||
network_alphas (`dict[str, float]`):
|
||||
The value of the network alpha used for stable learning and preventing underflow. This value has the
|
||||
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
|
||||
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
|
||||
@@ -1207,10 +1207,10 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
||||
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.save_lora_weights with unet->transformer
|
||||
def save_lora_weights(
|
||||
cls,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||
text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||
text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||
save_directory: str | os.PathLike,
|
||||
transformer_lora_layers: dict[str, torch.nn.Module | torch.Tensor] = None,
|
||||
text_encoder_lora_layers: dict[str, torch.nn.Module | torch.Tensor] = None,
|
||||
text_encoder_2_lora_layers: dict[str, torch.nn.Module | torch.Tensor] = None,
|
||||
is_main_process: bool = True,
|
||||
weight_name: str = None,
|
||||
save_function: Callable = None,
|
||||
@@ -1255,10 +1255,10 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
||||
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.fuse_lora with unet->transformer
|
||||
def fuse_lora(
|
||||
self,
|
||||
components: List[str] = ["transformer", "text_encoder", "text_encoder_2"],
|
||||
components: list[str] = ["transformer", "text_encoder", "text_encoder_2"],
|
||||
lora_scale: float = 1.0,
|
||||
safe_fusing: bool = False,
|
||||
adapter_names: Optional[List[str]] = None,
|
||||
adapter_names: Optional[list[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
@@ -1273,7 +1273,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.unfuse_lora with unet->transformer
|
||||
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder", "text_encoder_2"], **kwargs):
|
||||
def unfuse_lora(self, components: list[str] = ["transformer", "text_encoder", "text_encoder_2"], **kwargs):
|
||||
r"""
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
|
||||
"""
|
||||
@@ -1293,7 +1293,7 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin):
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict
|
||||
def lora_state_dict(
|
||||
cls,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
@@ -1346,7 +1346,7 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin):
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
|
||||
def load_lora_weights(
|
||||
self,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
|
||||
adapter_name: Optional[str] = None,
|
||||
hotswap: bool = False,
|
||||
**kwargs,
|
||||
@@ -1421,8 +1421,8 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin):
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
|
||||
def save_lora_weights(
|
||||
cls,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||
save_directory: str | os.PathLike,
|
||||
transformer_lora_layers: dict[str, torch.nn.Module | torch.Tensor] = None,
|
||||
is_main_process: bool = True,
|
||||
weight_name: str = None,
|
||||
save_function: Callable = None,
|
||||
@@ -1455,10 +1455,10 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin):
|
||||
# Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora
|
||||
def fuse_lora(
|
||||
self,
|
||||
components: List[str] = ["transformer"],
|
||||
components: list[str] = ["transformer"],
|
||||
lora_scale: float = 1.0,
|
||||
safe_fusing: bool = False,
|
||||
adapter_names: Optional[List[str]] = None,
|
||||
adapter_names: Optional[list[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
@@ -1473,7 +1473,7 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin):
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.unfuse_lora
|
||||
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
|
||||
def unfuse_lora(self, components: list[str] = ["transformer", "text_encoder"], **kwargs):
|
||||
r"""
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
|
||||
"""
|
||||
@@ -1497,7 +1497,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
@validate_hf_hub_args
|
||||
def lora_state_dict(
|
||||
cls,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
|
||||
return_alphas: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
@@ -1620,7 +1620,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
def load_lora_weights(
|
||||
self,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
|
||||
adapter_name: Optional[str] = None,
|
||||
hotswap: bool = False,
|
||||
**kwargs,
|
||||
@@ -1782,7 +1782,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
transformer,
|
||||
prefix=None,
|
||||
discard_original_layers=False,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
) -> dict[str, torch.Tensor]:
|
||||
# Remove prefix if present
|
||||
prefix = prefix or cls.transformer_name
|
||||
for key in list(state_dict.keys()):
|
||||
@@ -1851,7 +1851,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
state_dict (`dict`):
|
||||
A standard state dict containing the lora layer parameters. The key should be prefixed with an
|
||||
additional `text_encoder` to distinguish between unet lora layers.
|
||||
network_alphas (`Dict[str, float]`):
|
||||
network_alphas (`dict[str, float]`):
|
||||
The value of the network alpha used for stable learning and preventing underflow. This value has the
|
||||
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
|
||||
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
|
||||
@@ -1892,9 +1892,9 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.save_lora_weights with unet->transformer
|
||||
def save_lora_weights(
|
||||
cls,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||
text_encoder_lora_layers: Dict[str, torch.nn.Module] = None,
|
||||
save_directory: str | os.PathLike,
|
||||
transformer_lora_layers: dict[str, torch.nn.Module | torch.Tensor] = None,
|
||||
text_encoder_lora_layers: dict[str, torch.nn.Module] = None,
|
||||
is_main_process: bool = True,
|
||||
weight_name: str = None,
|
||||
save_function: Callable = None,
|
||||
@@ -1908,9 +1908,9 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
Arguments:
|
||||
save_directory (`str` or `os.PathLike`):
|
||||
Directory to save LoRA parameters to. Will be created if it doesn't exist.
|
||||
transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
||||
transformer_lora_layers (`dict[str, torch.nn.Module]` or `dict[str, torch.Tensor]`):
|
||||
State dict of the LoRA layers corresponding to the `transformer`.
|
||||
text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
||||
text_encoder_lora_layers (`dict[str, torch.nn.Module]` or `dict[str, torch.Tensor]`):
|
||||
State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
|
||||
encoder LoRA state dict because it comes from 🤗 Transformers.
|
||||
is_main_process (`bool`, *optional*, defaults to `True`):
|
||||
@@ -1954,10 +1954,10 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
def fuse_lora(
|
||||
self,
|
||||
components: List[str] = ["transformer"],
|
||||
components: list[str] = ["transformer"],
|
||||
lora_scale: float = 1.0,
|
||||
safe_fusing: bool = False,
|
||||
adapter_names: Optional[List[str]] = None,
|
||||
adapter_names: Optional[list[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
@@ -1984,7 +1984,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
|
||||
def unfuse_lora(self, components: list[str] = ["transformer", "text_encoder"], **kwargs):
|
||||
r"""
|
||||
Reverses the effect of
|
||||
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
|
||||
@@ -1992,7 +1992,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
> [!WARNING] > This is an experimental API.
|
||||
|
||||
Args:
|
||||
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
||||
components (`list[str]`): list of LoRA-injectable components to unfuse LoRA from.
|
||||
"""
|
||||
transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
|
||||
if hasattr(transformer, "_transformer_norm_layers") and transformer._transformer_norm_layers:
|
||||
@@ -2341,7 +2341,7 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
|
||||
state_dict (`dict`):
|
||||
A standard state dict containing the lora layer parameters. The key should be prefixed with an
|
||||
additional `text_encoder` to distinguish between unet lora layers.
|
||||
network_alphas (`Dict[str, float]`):
|
||||
network_alphas (`dict[str, float]`):
|
||||
The value of the network alpha used for stable learning and preventing underflow. This value has the
|
||||
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
|
||||
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
|
||||
@@ -2381,9 +2381,9 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
|
||||
@classmethod
|
||||
def save_lora_weights(
|
||||
cls,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
text_encoder_lora_layers: Dict[str, torch.nn.Module] = None,
|
||||
transformer_lora_layers: Dict[str, torch.nn.Module] = None,
|
||||
save_directory: str | os.PathLike,
|
||||
text_encoder_lora_layers: dict[str, torch.nn.Module] = None,
|
||||
transformer_lora_layers: dict[str, torch.nn.Module] = None,
|
||||
is_main_process: bool = True,
|
||||
weight_name: str = None,
|
||||
save_function: Callable = None,
|
||||
@@ -2395,9 +2395,9 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
|
||||
Arguments:
|
||||
save_directory (`str` or `os.PathLike`):
|
||||
Directory to save LoRA parameters to. Will be created if it doesn't exist.
|
||||
unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
||||
unet_lora_layers (`dict[str, torch.nn.Module]` or `dict[str, torch.Tensor]`):
|
||||
State dict of the LoRA layers corresponding to the `unet`.
|
||||
text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
||||
text_encoder_lora_layers (`dict[str, torch.nn.Module]` or `dict[str, torch.Tensor]`):
|
||||
State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
|
||||
encoder LoRA state dict because it comes from 🤗 Transformers.
|
||||
is_main_process (`bool`, *optional*, defaults to `True`):
|
||||
@@ -2446,7 +2446,7 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
|
||||
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
|
||||
def lora_state_dict(
|
||||
cls,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
@@ -2498,7 +2498,7 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
def load_lora_weights(
|
||||
self,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
|
||||
adapter_name: Optional[str] = None,
|
||||
hotswap: bool = False,
|
||||
**kwargs,
|
||||
@@ -2572,8 +2572,8 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
|
||||
@classmethod
|
||||
def save_lora_weights(
|
||||
cls,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||
save_directory: str | os.PathLike,
|
||||
transformer_lora_layers: dict[str, torch.nn.Module | torch.Tensor] = None,
|
||||
is_main_process: bool = True,
|
||||
weight_name: str = None,
|
||||
save_function: Callable = None,
|
||||
@@ -2605,10 +2605,10 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
def fuse_lora(
|
||||
self,
|
||||
components: List[str] = ["transformer"],
|
||||
components: list[str] = ["transformer"],
|
||||
lora_scale: float = 1.0,
|
||||
safe_fusing: bool = False,
|
||||
adapter_names: Optional[List[str]] = None,
|
||||
adapter_names: Optional[list[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
@@ -2622,7 +2622,7 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
|
||||
def unfuse_lora(self, components: list[str] = ["transformer"], **kwargs):
|
||||
r"""
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
|
||||
"""
|
||||
@@ -2642,7 +2642,7 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
|
||||
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
|
||||
def lora_state_dict(
|
||||
cls,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
@@ -2695,7 +2695,7 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
|
||||
def load_lora_weights(
|
||||
self,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
|
||||
adapter_name: Optional[str] = None,
|
||||
hotswap: bool = False,
|
||||
**kwargs,
|
||||
@@ -2770,8 +2770,8 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
|
||||
def save_lora_weights(
|
||||
cls,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||
save_directory: str | os.PathLike,
|
||||
transformer_lora_layers: dict[str, torch.nn.Module | torch.Tensor] = None,
|
||||
is_main_process: bool = True,
|
||||
weight_name: str = None,
|
||||
save_function: Callable = None,
|
||||
@@ -2804,10 +2804,10 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
|
||||
def fuse_lora(
|
||||
self,
|
||||
components: List[str] = ["transformer"],
|
||||
components: list[str] = ["transformer"],
|
||||
lora_scale: float = 1.0,
|
||||
safe_fusing: bool = False,
|
||||
adapter_names: Optional[List[str]] = None,
|
||||
adapter_names: Optional[list[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
@@ -2822,7 +2822,7 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
|
||||
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
|
||||
def unfuse_lora(self, components: list[str] = ["transformer"], **kwargs):
|
||||
r"""
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
|
||||
"""
|
||||
@@ -2841,7 +2841,7 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
|
||||
@validate_hf_hub_args
|
||||
def lora_state_dict(
|
||||
cls,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
@@ -2898,7 +2898,7 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
|
||||
def load_lora_weights(
|
||||
self,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
|
||||
adapter_name: Optional[str] = None,
|
||||
hotswap: bool = False,
|
||||
**kwargs,
|
||||
@@ -2973,8 +2973,8 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
|
||||
def save_lora_weights(
|
||||
cls,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||
save_directory: str | os.PathLike,
|
||||
transformer_lora_layers: dict[str, torch.nn.Module | torch.Tensor] = None,
|
||||
is_main_process: bool = True,
|
||||
weight_name: str = None,
|
||||
save_function: Callable = None,
|
||||
@@ -3007,10 +3007,10 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
|
||||
def fuse_lora(
|
||||
self,
|
||||
components: List[str] = ["transformer"],
|
||||
components: list[str] = ["transformer"],
|
||||
lora_scale: float = 1.0,
|
||||
safe_fusing: bool = False,
|
||||
adapter_names: Optional[List[str]] = None,
|
||||
adapter_names: Optional[list[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
@@ -3025,7 +3025,7 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
|
||||
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
|
||||
def unfuse_lora(self, components: list[str] = ["transformer"], **kwargs):
|
||||
r"""
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
|
||||
"""
|
||||
@@ -3045,7 +3045,7 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
|
||||
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
|
||||
def lora_state_dict(
|
||||
cls,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
@@ -3098,7 +3098,7 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
|
||||
def load_lora_weights(
|
||||
self,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
|
||||
adapter_name: Optional[str] = None,
|
||||
hotswap: bool = False,
|
||||
**kwargs,
|
||||
@@ -3173,8 +3173,8 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
|
||||
def save_lora_weights(
|
||||
cls,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||
save_directory: str | os.PathLike,
|
||||
transformer_lora_layers: dict[str, torch.nn.Module | torch.Tensor] = None,
|
||||
is_main_process: bool = True,
|
||||
weight_name: str = None,
|
||||
save_function: Callable = None,
|
||||
@@ -3207,10 +3207,10 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
|
||||
def fuse_lora(
|
||||
self,
|
||||
components: List[str] = ["transformer"],
|
||||
components: list[str] = ["transformer"],
|
||||
lora_scale: float = 1.0,
|
||||
safe_fusing: bool = False,
|
||||
adapter_names: Optional[List[str]] = None,
|
||||
adapter_names: Optional[list[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
@@ -3225,7 +3225,7 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
|
||||
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
|
||||
def unfuse_lora(self, components: list[str] = ["transformer"], **kwargs):
|
||||
r"""
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
|
||||
"""
|
||||
@@ -3244,7 +3244,7 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
|
||||
@validate_hf_hub_args
|
||||
def lora_state_dict(
|
||||
cls,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
@@ -3301,7 +3301,7 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
|
||||
def load_lora_weights(
|
||||
self,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
|
||||
adapter_name: Optional[str] = None,
|
||||
hotswap: bool = False,
|
||||
**kwargs,
|
||||
@@ -3376,8 +3376,8 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
|
||||
def save_lora_weights(
|
||||
cls,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||
save_directory: str | os.PathLike,
|
||||
transformer_lora_layers: dict[str, torch.nn.Module | torch.Tensor] = None,
|
||||
is_main_process: bool = True,
|
||||
weight_name: str = None,
|
||||
save_function: Callable = None,
|
||||
@@ -3410,10 +3410,10 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
|
||||
def fuse_lora(
|
||||
self,
|
||||
components: List[str] = ["transformer"],
|
||||
components: list[str] = ["transformer"],
|
||||
lora_scale: float = 1.0,
|
||||
safe_fusing: bool = False,
|
||||
adapter_names: Optional[List[str]] = None,
|
||||
adapter_names: Optional[list[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
@@ -3428,7 +3428,7 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
|
||||
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
|
||||
def unfuse_lora(self, components: list[str] = ["transformer"], **kwargs):
|
||||
r"""
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
|
||||
"""
|
||||
@@ -3447,7 +3447,7 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
|
||||
@validate_hf_hub_args
|
||||
def lora_state_dict(
|
||||
cls,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
@@ -3505,7 +3505,7 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
|
||||
def load_lora_weights(
|
||||
self,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
|
||||
adapter_name: Optional[str] = None,
|
||||
hotswap: bool = False,
|
||||
**kwargs,
|
||||
@@ -3580,8 +3580,8 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
|
||||
def save_lora_weights(
|
||||
cls,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||
save_directory: str | os.PathLike,
|
||||
transformer_lora_layers: dict[str, torch.nn.Module | torch.Tensor] = None,
|
||||
is_main_process: bool = True,
|
||||
weight_name: str = None,
|
||||
save_function: Callable = None,
|
||||
@@ -3614,10 +3614,10 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
|
||||
# Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora
|
||||
def fuse_lora(
|
||||
self,
|
||||
components: List[str] = ["transformer"],
|
||||
components: list[str] = ["transformer"],
|
||||
lora_scale: float = 1.0,
|
||||
safe_fusing: bool = False,
|
||||
adapter_names: Optional[List[str]] = None,
|
||||
adapter_names: Optional[list[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
@@ -3632,7 +3632,7 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.unfuse_lora
|
||||
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
|
||||
def unfuse_lora(self, components: list[str] = ["transformer"], **kwargs):
|
||||
r"""
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
|
||||
"""
|
||||
@@ -3651,7 +3651,7 @@ class KandinskyLoraLoaderMixin(LoraBaseMixin):
|
||||
@validate_hf_hub_args
|
||||
def lora_state_dict(
|
||||
cls,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
@@ -3669,7 +3669,7 @@ class KandinskyLoraLoaderMixin(LoraBaseMixin):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
proxies (`dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether to only load local model weights and configuration files.
|
||||
@@ -3731,7 +3731,7 @@ class KandinskyLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
def load_lora_weights(
|
||||
self,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
|
||||
adapter_name: Optional[str] = None,
|
||||
hotswap: bool = False,
|
||||
**kwargs,
|
||||
@@ -3832,8 +3832,8 @@ class KandinskyLoraLoaderMixin(LoraBaseMixin):
|
||||
@classmethod
|
||||
def save_lora_weights(
|
||||
cls,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||
save_directory: str | os.PathLike,
|
||||
transformer_lora_layers: dict[str, torch.nn.Module | torch.Tensor] = None,
|
||||
is_main_process: bool = True,
|
||||
weight_name: str = None,
|
||||
save_function: Callable = None,
|
||||
@@ -3846,7 +3846,7 @@ class KandinskyLoraLoaderMixin(LoraBaseMixin):
|
||||
Arguments:
|
||||
save_directory (`str` or `os.PathLike`):
|
||||
Directory to save LoRA parameters to.
|
||||
transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
||||
transformer_lora_layers (`dict[str, torch.nn.Module]` or `dict[str, torch.Tensor]`):
|
||||
State dict of the LoRA layers corresponding to the `transformer`.
|
||||
is_main_process (`bool`, *optional*, defaults to `True`):
|
||||
Whether the process calling this is the main process.
|
||||
@@ -3879,22 +3879,22 @@ class KandinskyLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
def fuse_lora(
|
||||
self,
|
||||
components: List[str] = ["transformer"],
|
||||
components: list[str] = ["transformer"],
|
||||
lora_scale: float = 1.0,
|
||||
safe_fusing: bool = False,
|
||||
adapter_names: Optional[List[str]] = None,
|
||||
adapter_names: Optional[list[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
|
||||
|
||||
Args:
|
||||
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
|
||||
components: (`list[str]`): list of LoRA-injectable components to fuse the LoRAs into.
|
||||
lora_scale (`float`, defaults to 1.0):
|
||||
Controls how much to influence the outputs with the LoRA parameters.
|
||||
safe_fusing (`bool`, defaults to `False`):
|
||||
Whether to check fused weights for NaN values before fusing.
|
||||
adapter_names (`List[str]`, *optional*):
|
||||
adapter_names (`list[str]`, *optional*):
|
||||
Adapter names to be used for fusing.
|
||||
|
||||
Example:
|
||||
@@ -3914,12 +3914,12 @@ class KandinskyLoraLoaderMixin(LoraBaseMixin):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
|
||||
def unfuse_lora(self, components: list[str] = ["transformer"], **kwargs):
|
||||
r"""
|
||||
Reverses the effect of [`pipe.fuse_lora()`].
|
||||
|
||||
Args:
|
||||
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
||||
components (`list[str]`): list of LoRA-injectable components to unfuse LoRA from.
|
||||
"""
|
||||
super().unfuse_lora(components=components, **kwargs)
|
||||
|
||||
@@ -3936,7 +3936,7 @@ class WanLoraLoaderMixin(LoraBaseMixin):
|
||||
@validate_hf_hub_args
|
||||
def lora_state_dict(
|
||||
cls,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
@@ -4040,7 +4040,7 @@ class WanLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
def load_lora_weights(
|
||||
self,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
|
||||
adapter_name: Optional[str] = None,
|
||||
hotswap: bool = False,
|
||||
**kwargs,
|
||||
@@ -4139,8 +4139,8 @@ class WanLoraLoaderMixin(LoraBaseMixin):
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
|
||||
def save_lora_weights(
|
||||
cls,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||
save_directory: str | os.PathLike,
|
||||
transformer_lora_layers: dict[str, torch.nn.Module | torch.Tensor] = None,
|
||||
is_main_process: bool = True,
|
||||
weight_name: str = None,
|
||||
save_function: Callable = None,
|
||||
@@ -4173,10 +4173,10 @@ class WanLoraLoaderMixin(LoraBaseMixin):
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
|
||||
def fuse_lora(
|
||||
self,
|
||||
components: List[str] = ["transformer"],
|
||||
components: list[str] = ["transformer"],
|
||||
lora_scale: float = 1.0,
|
||||
safe_fusing: bool = False,
|
||||
adapter_names: Optional[List[str]] = None,
|
||||
adapter_names: Optional[list[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
@@ -4191,7 +4191,7 @@ class WanLoraLoaderMixin(LoraBaseMixin):
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
|
||||
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
|
||||
def unfuse_lora(self, components: list[str] = ["transformer"], **kwargs):
|
||||
r"""
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
|
||||
"""
|
||||
@@ -4211,7 +4211,7 @@ class SkyReelsV2LoraLoaderMixin(LoraBaseMixin):
|
||||
# Copied from diffusers.loaders.lora_pipeline.WanLoraLoaderMixin.lora_state_dict
|
||||
def lora_state_dict(
|
||||
cls,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
@@ -4317,7 +4317,7 @@ class SkyReelsV2LoraLoaderMixin(LoraBaseMixin):
|
||||
# Copied from diffusers.loaders.lora_pipeline.WanLoraLoaderMixin.load_lora_weights
|
||||
def load_lora_weights(
|
||||
self,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
|
||||
adapter_name: Optional[str] = None,
|
||||
hotswap: bool = False,
|
||||
**kwargs,
|
||||
@@ -4416,8 +4416,8 @@ class SkyReelsV2LoraLoaderMixin(LoraBaseMixin):
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
|
||||
def save_lora_weights(
|
||||
cls,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||
save_directory: str | os.PathLike,
|
||||
transformer_lora_layers: dict[str, torch.nn.Module | torch.Tensor] = None,
|
||||
is_main_process: bool = True,
|
||||
weight_name: str = None,
|
||||
save_function: Callable = None,
|
||||
@@ -4450,10 +4450,10 @@ class SkyReelsV2LoraLoaderMixin(LoraBaseMixin):
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
|
||||
def fuse_lora(
|
||||
self,
|
||||
components: List[str] = ["transformer"],
|
||||
components: list[str] = ["transformer"],
|
||||
lora_scale: float = 1.0,
|
||||
safe_fusing: bool = False,
|
||||
adapter_names: Optional[List[str]] = None,
|
||||
adapter_names: Optional[list[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
@@ -4468,7 +4468,7 @@ class SkyReelsV2LoraLoaderMixin(LoraBaseMixin):
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
|
||||
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
|
||||
def unfuse_lora(self, components: list[str] = ["transformer"], **kwargs):
|
||||
r"""
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
|
||||
"""
|
||||
@@ -4488,7 +4488,7 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict
|
||||
def lora_state_dict(
|
||||
cls,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
@@ -4541,7 +4541,7 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
|
||||
def load_lora_weights(
|
||||
self,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
|
||||
adapter_name: Optional[str] = None,
|
||||
hotswap: bool = False,
|
||||
**kwargs,
|
||||
@@ -4616,8 +4616,8 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
|
||||
def save_lora_weights(
|
||||
cls,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||
save_directory: str | os.PathLike,
|
||||
transformer_lora_layers: dict[str, torch.nn.Module | torch.Tensor] = None,
|
||||
is_main_process: bool = True,
|
||||
weight_name: str = None,
|
||||
save_function: Callable = None,
|
||||
@@ -4650,10 +4650,10 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
|
||||
def fuse_lora(
|
||||
self,
|
||||
components: List[str] = ["transformer"],
|
||||
components: list[str] = ["transformer"],
|
||||
lora_scale: float = 1.0,
|
||||
safe_fusing: bool = False,
|
||||
adapter_names: Optional[List[str]] = None,
|
||||
adapter_names: Optional[list[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
@@ -4668,7 +4668,7 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
|
||||
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
|
||||
def unfuse_lora(self, components: list[str] = ["transformer"], **kwargs):
|
||||
r"""
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
|
||||
"""
|
||||
@@ -4687,7 +4687,7 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
|
||||
@validate_hf_hub_args
|
||||
def lora_state_dict(
|
||||
cls,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
@@ -4744,7 +4744,7 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
|
||||
def load_lora_weights(
|
||||
self,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
|
||||
adapter_name: Optional[str] = None,
|
||||
hotswap: bool = False,
|
||||
**kwargs,
|
||||
@@ -4819,8 +4819,8 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
|
||||
def save_lora_weights(
|
||||
cls,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||
save_directory: str | os.PathLike,
|
||||
transformer_lora_layers: dict[str, torch.nn.Module | torch.Tensor] = None,
|
||||
is_main_process: bool = True,
|
||||
weight_name: str = None,
|
||||
save_function: Callable = None,
|
||||
@@ -4853,10 +4853,10 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
|
||||
# Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora
|
||||
def fuse_lora(
|
||||
self,
|
||||
components: List[str] = ["transformer"],
|
||||
components: list[str] = ["transformer"],
|
||||
lora_scale: float = 1.0,
|
||||
safe_fusing: bool = False,
|
||||
adapter_names: Optional[List[str]] = None,
|
||||
adapter_names: Optional[list[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
@@ -4871,7 +4871,7 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.unfuse_lora
|
||||
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
|
||||
def unfuse_lora(self, components: list[str] = ["transformer"], **kwargs):
|
||||
r"""
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
|
||||
"""
|
||||
@@ -4890,7 +4890,7 @@ class QwenImageLoraLoaderMixin(LoraBaseMixin):
|
||||
@validate_hf_hub_args
|
||||
def lora_state_dict(
|
||||
cls,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
@@ -4949,7 +4949,7 @@ class QwenImageLoraLoaderMixin(LoraBaseMixin):
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
|
||||
def load_lora_weights(
|
||||
self,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
|
||||
adapter_name: Optional[str] = None,
|
||||
hotswap: bool = False,
|
||||
**kwargs,
|
||||
@@ -5024,8 +5024,8 @@ class QwenImageLoraLoaderMixin(LoraBaseMixin):
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
|
||||
def save_lora_weights(
|
||||
cls,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||
save_directory: str | os.PathLike,
|
||||
transformer_lora_layers: dict[str, torch.nn.Module | torch.Tensor] = None,
|
||||
is_main_process: bool = True,
|
||||
weight_name: str = None,
|
||||
save_function: Callable = None,
|
||||
@@ -5058,10 +5058,10 @@ class QwenImageLoraLoaderMixin(LoraBaseMixin):
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
|
||||
def fuse_lora(
|
||||
self,
|
||||
components: List[str] = ["transformer"],
|
||||
components: list[str] = ["transformer"],
|
||||
lora_scale: float = 1.0,
|
||||
safe_fusing: bool = False,
|
||||
adapter_names: Optional[List[str]] = None,
|
||||
adapter_names: Optional[list[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
@@ -5076,7 +5076,7 @@ class QwenImageLoraLoaderMixin(LoraBaseMixin):
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
|
||||
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
|
||||
def unfuse_lora(self, components: list[str] = ["transformer"], **kwargs):
|
||||
r"""
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
|
||||
"""
|
||||
|
||||
@@ -17,7 +17,7 @@ import json
|
||||
import os
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Literal, Optional, Union
|
||||
from typing import Dict, Literal, Optional
|
||||
|
||||
import safetensors
|
||||
import torch
|
||||
@@ -113,7 +113,7 @@ class PeftAdapterMixin:
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
proxies (`dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
@@ -127,7 +127,7 @@ class PeftAdapterMixin:
|
||||
allowed by Git.
|
||||
subfolder (`str`, *optional*, defaults to `""`):
|
||||
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
||||
network_alphas (`Dict[str, float]`):
|
||||
network_alphas (`dict[str, float]`):
|
||||
The value of the network alpha used for stable learning and preventing underflow. This value has the
|
||||
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
|
||||
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
|
||||
@@ -447,16 +447,16 @@ class PeftAdapterMixin:
|
||||
|
||||
def set_adapters(
|
||||
self,
|
||||
adapter_names: Union[List[str], str],
|
||||
weights: Optional[Union[float, Dict, List[float], List[Dict], List[None]]] = None,
|
||||
adapter_names: list[str] | str,
|
||||
weights: Optional[float | Dict | list[float] | list[Dict] | list[None]] = None,
|
||||
):
|
||||
"""
|
||||
Set the currently active adapters for use in the diffusion network (e.g. unet, transformer, etc.).
|
||||
|
||||
Args:
|
||||
adapter_names (`List[str]` or `str`):
|
||||
adapter_names (`list[str]` or `str`):
|
||||
The names of the adapters to use.
|
||||
adapter_weights (`Union[List[float], float]`, *optional*):
|
||||
adapter_weights (`Union[list[float], float]`, *optional*):
|
||||
The adapter(s) weights to use with the UNet. If `None`, the weights are set to `1.0` for all the
|
||||
adapters.
|
||||
|
||||
@@ -539,7 +539,7 @@ class PeftAdapterMixin:
|
||||
inject_adapter_in_model(adapter_config, self, adapter_name)
|
||||
self.set_adapter(adapter_name)
|
||||
|
||||
def set_adapter(self, adapter_name: Union[str, List[str]]) -> None:
|
||||
def set_adapter(self, adapter_name: str | list[str]) -> None:
|
||||
"""
|
||||
Sets a specific adapter by forcing the model to only use that adapter and disables the other adapters.
|
||||
|
||||
@@ -547,7 +547,7 @@ class PeftAdapterMixin:
|
||||
[documentation](https://huggingface.co/docs/peft).
|
||||
|
||||
Args:
|
||||
adapter_name (Union[str, List[str]])):
|
||||
adapter_name (Union[str, list[str]])):
|
||||
The list of adapters to set or the adapter name in the case of a single adapter.
|
||||
"""
|
||||
check_peft_version(min_version=MIN_PEFT_VERSION)
|
||||
@@ -633,7 +633,7 @@ class PeftAdapterMixin:
|
||||
# support for older PEFT versions
|
||||
module.disable_adapters = False
|
||||
|
||||
def active_adapters(self) -> List[str]:
|
||||
def active_adapters(self) -> list[str]:
|
||||
"""
|
||||
Gets the current list of active adapters of the model.
|
||||
|
||||
@@ -756,12 +756,12 @@ class PeftAdapterMixin:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
set_adapter_layers(self, enabled=True)
|
||||
|
||||
def delete_adapters(self, adapter_names: Union[List[str], str]):
|
||||
def delete_adapters(self, adapter_names: list[str] | str):
|
||||
"""
|
||||
Delete an adapter's LoRA layers from the underlying model.
|
||||
|
||||
Args:
|
||||
adapter_names (`Union[List[str], str]`):
|
||||
adapter_names (`Union[list[str], str]`):
|
||||
The names (single string or list of strings) of the adapter to delete.
|
||||
|
||||
Example:
|
||||
|
||||
@@ -290,7 +290,7 @@ class FromSingleFileMixin:
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
proxies (`dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
|
||||
@@ -229,7 +229,7 @@ class FromOriginalModelMixin:
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
proxies (`dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Dict, List, Optional, Union
|
||||
from typing import Optional
|
||||
|
||||
import safetensors
|
||||
import torch
|
||||
@@ -112,7 +112,7 @@ class TextualInversionLoaderMixin:
|
||||
Load Textual Inversion tokens and embeddings to the tokenizer and text encoder.
|
||||
"""
|
||||
|
||||
def maybe_convert_prompt(self, prompt: Union[str, List[str]], tokenizer: "PreTrainedTokenizer"): # noqa: F821
|
||||
def maybe_convert_prompt(self, prompt: str | list[str], tokenizer: "PreTrainedTokenizer"): # noqa: F821
|
||||
r"""
|
||||
Processes prompts that include a special token corresponding to a multi-vector textual inversion embedding to
|
||||
be replaced with multiple special tokens each corresponding to one of the vectors. If the prompt has no textual
|
||||
@@ -127,14 +127,14 @@ class TextualInversionLoaderMixin:
|
||||
Returns:
|
||||
`str` or list of `str`: The converted prompt
|
||||
"""
|
||||
if not isinstance(prompt, List):
|
||||
if not isinstance(prompt, list):
|
||||
prompts = [prompt]
|
||||
else:
|
||||
prompts = prompt
|
||||
|
||||
prompts = [self._maybe_convert_prompt(p, tokenizer) for p in prompts]
|
||||
|
||||
if not isinstance(prompt, List):
|
||||
if not isinstance(prompt, list):
|
||||
return prompts[0]
|
||||
|
||||
return prompts
|
||||
@@ -263,8 +263,8 @@ class TextualInversionLoaderMixin:
|
||||
@validate_hf_hub_args
|
||||
def load_textual_inversion(
|
||||
self,
|
||||
pretrained_model_name_or_path: Union[str, List[str], Dict[str, torch.Tensor], List[Dict[str, torch.Tensor]]],
|
||||
token: Optional[Union[str, List[str]]] = None,
|
||||
pretrained_model_name_or_path: str | list[str] | dict[str, torch.Tensor] | list[dict[str, torch.Tensor]],
|
||||
token: Optional[str | list[str]] = None,
|
||||
tokenizer: Optional["PreTrainedTokenizer"] = None, # noqa: F821
|
||||
text_encoder: Optional["PreTrainedModel"] = None, # noqa: F821
|
||||
**kwargs,
|
||||
@@ -274,7 +274,7 @@ class TextualInversionLoaderMixin:
|
||||
Automatic1111 formats are supported).
|
||||
|
||||
Parameters:
|
||||
pretrained_model_name_or_path (`str` or `os.PathLike` or `List[str or os.PathLike]` or `Dict` or `List[Dict]`):
|
||||
pretrained_model_name_or_path (`str` or `os.PathLike` or `list[str or os.PathLike]` or `Dict` or `list[Dict]`):
|
||||
Can be either one of the following or a list of them:
|
||||
|
||||
- A string, the *model id* (for example `sd-concepts-library/low-poly-hd-logos-icons`) of a
|
||||
@@ -285,7 +285,7 @@ class TextualInversionLoaderMixin:
|
||||
- A [torch state
|
||||
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
||||
|
||||
token (`str` or `List[str]`, *optional*):
|
||||
token (`str` or `list[str]`, *optional*):
|
||||
Override the token to use for the textual inversion weights. If `pretrained_model_name_or_path` is a
|
||||
list, then `token` must also be a list of equal length.
|
||||
text_encoder ([`~transformers.CLIPTextModel`], *optional*):
|
||||
@@ -306,7 +306,7 @@ class TextualInversionLoaderMixin:
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
proxies (`dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
@@ -458,7 +458,7 @@ class TextualInversionLoaderMixin:
|
||||
|
||||
def unload_textual_inversion(
|
||||
self,
|
||||
tokens: Optional[Union[str, List[str]]] = None,
|
||||
tokens: Optional[str | list[str]] = None,
|
||||
tokenizer: Optional["PreTrainedTokenizer"] = None,
|
||||
text_encoder: Optional["PreTrainedModel"] = None,
|
||||
):
|
||||
|
||||
@@ -15,7 +15,7 @@ import os
|
||||
from collections import defaultdict
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, Union
|
||||
from typing import Callable
|
||||
|
||||
import safetensors
|
||||
import torch
|
||||
@@ -66,7 +66,7 @@ class UNet2DConditionLoadersMixin:
|
||||
unet_name = UNET_NAME
|
||||
|
||||
@validate_hf_hub_args
|
||||
def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
|
||||
def load_attn_procs(self, pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor], **kwargs):
|
||||
r"""
|
||||
Load pretrained attention processor layers into [`UNet2DConditionModel`]. Attention processor layers have to be
|
||||
defined in
|
||||
@@ -92,7 +92,7 @@ class UNet2DConditionLoadersMixin:
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
proxies (`dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
@@ -106,7 +106,7 @@ class UNet2DConditionLoadersMixin:
|
||||
allowed by Git.
|
||||
subfolder (`str`, *optional*, defaults to `""`):
|
||||
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
||||
network_alphas (`Dict[str, float]`):
|
||||
network_alphas (`dict[str, float]`):
|
||||
The value of the network alpha used for stable learning and preventing underflow. This value has the
|
||||
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
|
||||
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
|
||||
@@ -412,7 +412,7 @@ class UNet2DConditionLoadersMixin:
|
||||
|
||||
def save_attn_procs(
|
||||
self,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
save_directory: str | os.PathLike,
|
||||
is_main_process: bool = True,
|
||||
weight_name: str = None,
|
||||
save_function: Callable = None,
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import copy
|
||||
from typing import TYPE_CHECKING, Dict, List, Union
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
|
||||
from torch import nn
|
||||
|
||||
@@ -40,9 +40,7 @@ def _translate_into_actual_layer_name(name):
|
||||
return ".".join((updown, block, attn))
|
||||
|
||||
|
||||
def _maybe_expand_lora_scales(
|
||||
unet: "UNet2DConditionModel", weight_scales: List[Union[float, Dict]], default_scale=1.0
|
||||
):
|
||||
def _maybe_expand_lora_scales(unet: "UNet2DConditionModel", weight_scales: list[float | Dict], default_scale=1.0):
|
||||
blocks_with_transformer = {
|
||||
"down": [i for i, block in enumerate(unet.down_blocks) if hasattr(block, "attentions")],
|
||||
"up": [i for i, block in enumerate(unet.up_blocks) if hasattr(block, "attentions")],
|
||||
@@ -64,9 +62,9 @@ def _maybe_expand_lora_scales(
|
||||
|
||||
|
||||
def _maybe_expand_lora_scales_for_one_adapter(
|
||||
scales: Union[float, Dict],
|
||||
blocks_with_transformer: Dict[str, int],
|
||||
transformer_per_block: Dict[str, int],
|
||||
scales: float | Dict,
|
||||
blocks_with_transformer: dict[str, int],
|
||||
transformer_per_block: dict[str, int],
|
||||
model: nn.Module,
|
||||
default_scale: float = 1.0,
|
||||
):
|
||||
@@ -76,9 +74,9 @@ def _maybe_expand_lora_scales_for_one_adapter(
|
||||
Parameters:
|
||||
scales (`Union[float, Dict]`):
|
||||
Scales dict to expand.
|
||||
blocks_with_transformer (`Dict[str, int]`):
|
||||
blocks_with_transformer (`dict[str, int]`):
|
||||
Dict with keys 'up' and 'down', showing which blocks have transformer layers
|
||||
transformer_per_block (`Dict[str, int]`):
|
||||
transformer_per_block (`dict[str, int]`):
|
||||
Dict with keys 'up' and 'down', showing how many transformer layers each block has
|
||||
|
||||
E.g. turns
|
||||
|
||||
@@ -12,13 +12,12 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class AttnProcsLayers(torch.nn.Module):
|
||||
def __init__(self, state_dict: Dict[str, torch.Tensor]):
|
||||
def __init__(self, state_dict: dict[str, torch.Tensor]):
|
||||
super().__init__()
|
||||
self.layers = torch.nn.ModuleList(state_dict.values())
|
||||
self.mapping = dict(enumerate(state_dict.keys()))
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Literal, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -187,19 +187,17 @@ class ContextParallelOutput:
|
||||
# If the key is a string, it denotes the name of the parameter in the forward function.
|
||||
# If the key is an integer, split_output must be set to True, and it denotes the index of the output
|
||||
# to be split across context parallel region.
|
||||
ContextParallelInputType = Dict[
|
||||
Union[str, int], Union[ContextParallelInput, List[ContextParallelInput], Tuple[ContextParallelInput, ...]]
|
||||
ContextParallelInputType = dict[
|
||||
str | int, ContextParallelInput | list[ContextParallelInput] | tuple[ContextParallelInput, ...]
|
||||
]
|
||||
|
||||
# A dictionary where keys denote the output to be gathered across context parallel region, and the
|
||||
# value denotes the gathering configuration.
|
||||
ContextParallelOutputType = Union[
|
||||
ContextParallelOutput, List[ContextParallelOutput], Tuple[ContextParallelOutput, ...]
|
||||
]
|
||||
ContextParallelOutputType = ContextParallelOutput | list[ContextParallelOutput] | tuple[ContextParallelOutput, ...]
|
||||
|
||||
# A dictionary where keys denote the module id, and the value denotes how the inputs/outputs of
|
||||
# the module should be split/gathered across context parallel region.
|
||||
ContextParallelModelPlan = Dict[str, Union[ContextParallelInputType, ContextParallelOutputType]]
|
||||
ContextParallelModelPlan = dict[str, ContextParallelInputType | ContextParallelOutputType]
|
||||
|
||||
|
||||
# Example of a ContextParallelModelPlan (QwenImageTransformer2DModel):
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
from typing import Callable, List, Optional, Union
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -34,11 +34,11 @@ class MultiAdapter(ModelMixin):
|
||||
or saving.
|
||||
|
||||
Args:
|
||||
adapters (`List[T2IAdapter]`, *optional*, defaults to None):
|
||||
adapters (`list[T2IAdapter]`, *optional*, defaults to None):
|
||||
A list of `T2IAdapter` model instances.
|
||||
"""
|
||||
|
||||
def __init__(self, adapters: List["T2IAdapter"]):
|
||||
def __init__(self, adapters: list["T2IAdapter"]):
|
||||
super(MultiAdapter, self).__init__()
|
||||
|
||||
self.num_adapter = len(adapters)
|
||||
@@ -73,7 +73,7 @@ class MultiAdapter(ModelMixin):
|
||||
self.total_downscale_factor = first_adapter_total_downscale_factor
|
||||
self.downscale_factor = first_adapter_downscale_factor
|
||||
|
||||
def forward(self, xs: torch.Tensor, adapter_weights: Optional[List[float]] = None) -> List[torch.Tensor]:
|
||||
def forward(self, xs: torch.Tensor, adapter_weights: Optional[list[float]] = None) -> list[torch.Tensor]:
|
||||
r"""
|
||||
Args:
|
||||
xs (`torch.Tensor`):
|
||||
@@ -81,7 +81,7 @@ class MultiAdapter(ModelMixin):
|
||||
models, concatenated along dimension 1(channel dimension). The `channel` dimension should be equal to
|
||||
`num_adapter` * number of channel per image.
|
||||
|
||||
adapter_weights (`List[float]`, *optional*, defaults to None):
|
||||
adapter_weights (`list[float]`, *optional*, defaults to None):
|
||||
A list of floats representing the weights which will be multiplied by each adapter's output before
|
||||
summing them together. If `None`, equal weights will be used for all adapters.
|
||||
"""
|
||||
@@ -104,7 +104,7 @@ class MultiAdapter(ModelMixin):
|
||||
|
||||
def save_pretrained(
|
||||
self,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
save_directory: str | os.PathLike,
|
||||
is_main_process: bool = True,
|
||||
save_function: Callable = None,
|
||||
safe_serialization: bool = True,
|
||||
@@ -145,7 +145,7 @@ class MultiAdapter(ModelMixin):
|
||||
model_path_to_save = model_path_to_save + f"_{idx}"
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_path: Optional[Union[str, os.PathLike]], **kwargs):
|
||||
def from_pretrained(cls, pretrained_model_path: Optional[str | os.PathLike], **kwargs):
|
||||
r"""
|
||||
Instantiate a pretrained `MultiAdapter` model from multiple pre-trained adapter models.
|
||||
|
||||
@@ -165,7 +165,7 @@ class MultiAdapter(ModelMixin):
|
||||
Override the default `torch.dtype` and load the model under this dtype.
|
||||
output_loading_info(`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
||||
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
|
||||
device_map (`str` or `dict[str, Union[int, str, torch.device]]`, *optional*):
|
||||
A map that specifies where each submodule should go. It doesn't need to be refined to each
|
||||
parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
|
||||
same device.
|
||||
@@ -229,7 +229,7 @@ class T2IAdapter(ModelMixin, ConfigMixin):
|
||||
in_channels (`int`, *optional*, defaults to `3`):
|
||||
The number of channels in the adapter's input (*control image*). Set it to 1 if you're using a gray scale
|
||||
image.
|
||||
channels (`List[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
|
||||
channels (`list[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
|
||||
The number of channels in each downsample block's output hidden state. The `len(block_out_channels)`
|
||||
determines the number of downsample blocks in the adapter.
|
||||
num_res_blocks (`int`, *optional*, defaults to `2`):
|
||||
@@ -244,7 +244,7 @@ class T2IAdapter(ModelMixin, ConfigMixin):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
channels: List[int] = [320, 640, 1280, 1280],
|
||||
channels: list[int] = [320, 640, 1280, 1280],
|
||||
num_res_blocks: int = 2,
|
||||
downscale_factor: int = 8,
|
||||
adapter_type: str = "full_adapter",
|
||||
@@ -263,7 +263,7 @@ class T2IAdapter(ModelMixin, ConfigMixin):
|
||||
"'full_adapter_xl' or 'light_adapter'."
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
|
||||
def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
|
||||
r"""
|
||||
This function processes the input tensor `x` through the adapter model and returns a list of feature tensors,
|
||||
each representing information extracted at a different scale from the input. The length of the list is
|
||||
@@ -295,7 +295,7 @@ class FullAdapter(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
channels: List[int] = [320, 640, 1280, 1280],
|
||||
channels: list[int] = [320, 640, 1280, 1280],
|
||||
num_res_blocks: int = 2,
|
||||
downscale_factor: int = 8,
|
||||
):
|
||||
@@ -318,7 +318,7 @@ class FullAdapter(nn.Module):
|
||||
|
||||
self.total_downscale_factor = downscale_factor * 2 ** (len(channels) - 1)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
|
||||
def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
|
||||
r"""
|
||||
This method processes the input tensor `x` through the FullAdapter model and performs operations including
|
||||
pixel unshuffling, convolution, and a stack of AdapterBlocks. It returns a list of feature tensors, each
|
||||
@@ -345,7 +345,7 @@ class FullAdapterXL(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
channels: List[int] = [320, 640, 1280, 1280],
|
||||
channels: list[int] = [320, 640, 1280, 1280],
|
||||
num_res_blocks: int = 2,
|
||||
downscale_factor: int = 16,
|
||||
):
|
||||
@@ -370,7 +370,7 @@ class FullAdapterXL(nn.Module):
|
||||
# XL has only one downsampling AdapterBlock.
|
||||
self.total_downscale_factor = downscale_factor * 2
|
||||
|
||||
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
|
||||
def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
|
||||
r"""
|
||||
This method takes the tensor x as input and processes it through FullAdapterXL model. It consists of operations
|
||||
including unshuffling pixels, applying convolution layer and appending each block into list of feature tensors.
|
||||
@@ -473,7 +473,7 @@ class LightAdapter(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
channels: List[int] = [320, 640, 1280],
|
||||
channels: list[int] = [320, 640, 1280],
|
||||
num_res_blocks: int = 4,
|
||||
downscale_factor: int = 8,
|
||||
):
|
||||
@@ -496,7 +496,7 @@ class LightAdapter(nn.Module):
|
||||
|
||||
self.total_downscale_factor = downscale_factor * (2 ** len(channels))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
|
||||
def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
|
||||
r"""
|
||||
This method takes the input tensor x and performs downscaling and appends it in list of feature tensors. Each
|
||||
feature tensor corresponds to a different level of processing within the LightAdapter.
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -38,7 +38,7 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
class AttentionMixin:
|
||||
@property
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
def attn_processors(self) -> dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
@@ -47,7 +47,7 @@ class AttentionMixin:
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
@@ -61,7 +61,7 @@ class AttentionMixin:
|
||||
|
||||
return processors
|
||||
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
def set_attn_processor(self, processor: AttentionProcessor | dict[str, AttentionProcessor]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
@@ -184,7 +184,7 @@ class AttentionModuleMixin:
|
||||
def set_use_xla_flash_attention(
|
||||
self,
|
||||
use_xla_flash_attention: bool,
|
||||
partition_spec: Optional[Tuple[Optional[str], ...]] = None,
|
||||
partition_spec: Optional[tuple[Optional[str], ...]] = None,
|
||||
is_flux=False,
|
||||
) -> None:
|
||||
"""
|
||||
@@ -193,7 +193,7 @@ class AttentionModuleMixin:
|
||||
Args:
|
||||
use_xla_flash_attention (`bool`):
|
||||
Whether to use pallas flash attention kernel from `torch_xla` or not.
|
||||
partition_spec (`Tuple[]`, *optional*):
|
||||
partition_spec (`tuple[]`, *optional*):
|
||||
Specify the partition specification if using SPMD. Otherwise None.
|
||||
is_flux (`bool`, *optional*, defaults to `False`):
|
||||
Whether the model is a Flux model.
|
||||
@@ -669,8 +669,8 @@ class JointTransformerBlock(nn.Module):
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor,
|
||||
temb: torch.FloatTensor,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
joint_attention_kwargs: Optional[dict[str, Any]] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
joint_attention_kwargs = joint_attention_kwargs or {}
|
||||
if self.use_dual_attention:
|
||||
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1(
|
||||
@@ -950,9 +950,9 @@ class BasicTransformerBlock(nn.Module):
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
timestep: Optional[torch.LongTensor] = None,
|
||||
cross_attention_kwargs: Dict[str, Any] = None,
|
||||
cross_attention_kwargs: dict[str, Any] = None,
|
||||
class_labels: Optional[torch.LongTensor] = None,
|
||||
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
||||
added_cond_kwargs: Optional[dict[str, torch.Tensor]] = None,
|
||||
) -> torch.Tensor:
|
||||
if cross_attention_kwargs is not None:
|
||||
if cross_attention_kwargs.get("scale", None) is not None:
|
||||
@@ -1487,7 +1487,7 @@ class FreeNoiseTransformerBlock(nn.Module):
|
||||
self._chunk_size = None
|
||||
self._chunk_dim = 0
|
||||
|
||||
def _get_frame_indices(self, num_frames: int) -> List[Tuple[int, int]]:
|
||||
def _get_frame_indices(self, num_frames: int) -> list[tuple[int, int]]:
|
||||
frame_indices = []
|
||||
for i in range(0, num_frames - self.context_length + 1, self.context_stride):
|
||||
window_start = i
|
||||
@@ -1495,7 +1495,7 @@ class FreeNoiseTransformerBlock(nn.Module):
|
||||
frame_indices.append((window_start, window_end))
|
||||
return frame_indices
|
||||
|
||||
def _get_frame_weights(self, num_frames: int, weighting_scheme: str = "pyramid") -> List[float]:
|
||||
def _get_frame_weights(self, num_frames: int, weighting_scheme: str = "pyramid") -> list[float]:
|
||||
if weighting_scheme == "flat":
|
||||
weights = [1.0] * num_frames
|
||||
|
||||
@@ -1545,7 +1545,7 @@ class FreeNoiseTransformerBlock(nn.Module):
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
cross_attention_kwargs: Dict[str, Any] = None,
|
||||
cross_attention_kwargs: dict[str, Any] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
|
||||
@@ -12,12 +12,14 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import functools
|
||||
import inspect
|
||||
import math
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -228,7 +230,7 @@ class _AttentionBackendRegistry:
|
||||
def register(
|
||||
cls,
|
||||
backend: AttentionBackendName,
|
||||
constraints: Optional[List[Callable]] = None,
|
||||
constraints: Optional[list[Callable]] = None,
|
||||
supports_context_parallel: bool = False,
|
||||
):
|
||||
logger.debug(f"Registering attention backend: {backend} with constraints: {constraints}")
|
||||
@@ -263,7 +265,7 @@ class _AttentionBackendRegistry:
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def attention_backend(backend: Union[str, AttentionBackendName] = AttentionBackendName.NATIVE):
|
||||
def attention_backend(backend: str | AttentionBackendName = AttentionBackendName.NATIVE):
|
||||
"""
|
||||
Context manager to set the active attention backend.
|
||||
"""
|
||||
@@ -291,7 +293,7 @@ def dispatch_attention_fn(
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
enable_gqa: bool = False,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
attention_kwargs: Optional[dict[str, Any]] = None,
|
||||
*,
|
||||
backend: Optional[AttentionBackendName] = None,
|
||||
parallel_config: Optional["ParallelConfig"] = None,
|
||||
@@ -595,7 +597,7 @@ def _wrapped_flash_attn_3(
|
||||
pack_gqa: Optional[bool] = None,
|
||||
deterministic: bool = False,
|
||||
sm_margin: int = 0,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# Hardcoded for now because pytorch does not support tuple/int type hints
|
||||
window_size = (-1, -1)
|
||||
out, lse, *_ = flash_attn_3_func(
|
||||
@@ -637,7 +639,7 @@ def _(
|
||||
pack_gqa: Optional[bool] = None,
|
||||
deterministic: bool = False,
|
||||
sm_margin: int = 0,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
window_size = (-1, -1) # noqa: F841
|
||||
# A lot of the parameters here are not yet used in any way within diffusers.
|
||||
# We can safely ignore for now and keep the fake op shape propagation simple.
|
||||
@@ -1335,7 +1337,7 @@ def _flash_attention_3_hub(
|
||||
value: torch.Tensor,
|
||||
scale: Optional[float] = None,
|
||||
is_causal: bool = False,
|
||||
window_size: Tuple[int, int] = (-1, -1),
|
||||
window_size: tuple[int, int] = (-1, -1),
|
||||
softcap: float = 0.0,
|
||||
deterministic: bool = False,
|
||||
return_attn_probs: bool = False,
|
||||
@@ -1465,7 +1467,7 @@ def _native_flex_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[Union[torch.Tensor, "flex_attention.BlockMask"]] = None,
|
||||
attn_mask: Optional[torch.Tensor | "flex_attention.BlockMask"] = None,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
enable_gqa: bool = False,
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
import inspect
|
||||
import math
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -309,7 +309,7 @@ class Attention(nn.Module):
|
||||
def set_use_xla_flash_attention(
|
||||
self,
|
||||
use_xla_flash_attention: bool,
|
||||
partition_spec: Optional[Tuple[Optional[str], ...]] = None,
|
||||
partition_spec: Optional[tuple[Optional[str], ...]] = None,
|
||||
is_flux=False,
|
||||
) -> None:
|
||||
r"""
|
||||
@@ -318,7 +318,7 @@ class Attention(nn.Module):
|
||||
Args:
|
||||
use_xla_flash_attention (`bool`):
|
||||
Whether to use pallas flash attention kernel from `torch_xla` or not.
|
||||
partition_spec (`Tuple[]`, *optional*):
|
||||
partition_spec (`tuple[]`, *optional*):
|
||||
Specify the partition specification if using SPMD. Otherwise None.
|
||||
"""
|
||||
if use_xla_flash_attention:
|
||||
@@ -872,7 +872,7 @@ class SanaMultiscaleLinearAttention(nn.Module):
|
||||
attention_head_dim: int = 8,
|
||||
mult: float = 1.0,
|
||||
norm_type: str = "batch_norm",
|
||||
kernel_sizes: Tuple[int, ...] = (5,),
|
||||
kernel_sizes: tuple[int, ...] = (5,),
|
||||
eps: float = 1e-15,
|
||||
residual_connection: bool = False,
|
||||
):
|
||||
@@ -2790,7 +2790,7 @@ class XLAFlashAttnProcessor2_0:
|
||||
Processor for implementing scaled dot-product attention with pallas flash attention kernel if using `torch_xla`.
|
||||
"""
|
||||
|
||||
def __init__(self, partition_spec: Optional[Tuple[Optional[str], ...]] = None):
|
||||
def __init__(self, partition_spec: Optional[tuple[Optional[str], ...]] = None):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(
|
||||
"XLAFlashAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
||||
@@ -3001,7 +3001,7 @@ class StableAudioAttnProcessor2_0:
|
||||
def apply_partial_rotary_emb(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
freqs_cis: Tuple[torch.Tensor],
|
||||
freqs_cis: tuple[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
from .embeddings import apply_rotary_emb
|
||||
|
||||
@@ -4212,9 +4212,9 @@ class IPAdapterAttnProcessor(nn.Module):
|
||||
The hidden size of the attention layer.
|
||||
cross_attention_dim (`int`):
|
||||
The number of channels in the `encoder_hidden_states`.
|
||||
num_tokens (`int`, `Tuple[int]` or `List[int]`, defaults to `(4,)`):
|
||||
num_tokens (`int`, `tuple[int]` or `list[int]`, defaults to `(4,)`):
|
||||
The context length of the image features.
|
||||
scale (`float` or List[`float`], defaults to 1.0):
|
||||
scale (`float` or list[`float`], defaults to 1.0):
|
||||
the weight scale of image prompt.
|
||||
"""
|
||||
|
||||
@@ -4305,7 +4305,7 @@ class IPAdapterAttnProcessor(nn.Module):
|
||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||
|
||||
if ip_adapter_masks is not None:
|
||||
if not isinstance(ip_adapter_masks, List):
|
||||
if not isinstance(ip_adapter_masks, list):
|
||||
# for backward compatibility, we accept `ip_adapter_mask` as a tensor of shape [num_ip_adapter, 1, height, width]
|
||||
ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1))
|
||||
if not (len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states)):
|
||||
@@ -4412,9 +4412,9 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module):
|
||||
The hidden size of the attention layer.
|
||||
cross_attention_dim (`int`):
|
||||
The number of channels in the `encoder_hidden_states`.
|
||||
num_tokens (`int`, `Tuple[int]` or `List[int]`, defaults to `(4,)`):
|
||||
num_tokens (`int`, `tuple[int]` or `list[int]`, defaults to `(4,)`):
|
||||
The context length of the image features.
|
||||
scale (`float` or `List[float]`, defaults to 1.0):
|
||||
scale (`float` or `list[float]`, defaults to 1.0):
|
||||
the weight scale of image prompt.
|
||||
"""
|
||||
|
||||
@@ -4524,7 +4524,7 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module):
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
if ip_adapter_masks is not None:
|
||||
if not isinstance(ip_adapter_masks, List):
|
||||
if not isinstance(ip_adapter_masks, list):
|
||||
# for backward compatibility, we accept `ip_adapter_mask` as a tensor of shape [num_ip_adapter, 1, height, width]
|
||||
ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1))
|
||||
if not (len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states)):
|
||||
@@ -4644,9 +4644,9 @@ class IPAdapterXFormersAttnProcessor(torch.nn.Module):
|
||||
The hidden size of the attention layer.
|
||||
cross_attention_dim (`int`):
|
||||
The number of channels in the `encoder_hidden_states`.
|
||||
num_tokens (`int`, `Tuple[int]` or `List[int]`, defaults to `(4,)`):
|
||||
num_tokens (`int`, `tuple[int]` or `list[int]`, defaults to `(4,)`):
|
||||
The context length of the image features.
|
||||
scale (`float` or `List[float]`, defaults to 1.0):
|
||||
scale (`float` or `list[float]`, defaults to 1.0):
|
||||
the weight scale of image prompt.
|
||||
attention_op (`Callable`, *optional*, defaults to `None`):
|
||||
The base
|
||||
@@ -4763,7 +4763,7 @@ class IPAdapterXFormersAttnProcessor(torch.nn.Module):
|
||||
|
||||
if ip_hidden_states:
|
||||
if ip_adapter_masks is not None:
|
||||
if not isinstance(ip_adapter_masks, List):
|
||||
if not isinstance(ip_adapter_masks, list):
|
||||
# for backward compatibility, we accept `ip_adapter_mask` as a tensor of shape [num_ip_adapter, 1, height, width]
|
||||
ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1))
|
||||
if not (len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states)):
|
||||
@@ -5622,56 +5622,56 @@ CROSS_ATTENTION_PROCESSORS = (
|
||||
FluxIPAdapterJointAttnProcessor2_0,
|
||||
)
|
||||
|
||||
AttentionProcessor = Union[
|
||||
AttnProcessor,
|
||||
CustomDiffusionAttnProcessor,
|
||||
AttnAddedKVProcessor,
|
||||
AttnAddedKVProcessor2_0,
|
||||
JointAttnProcessor2_0,
|
||||
PAGJointAttnProcessor2_0,
|
||||
PAGCFGJointAttnProcessor2_0,
|
||||
FusedJointAttnProcessor2_0,
|
||||
AllegroAttnProcessor2_0,
|
||||
AuraFlowAttnProcessor2_0,
|
||||
FusedAuraFlowAttnProcessor2_0,
|
||||
FluxAttnProcessor2_0,
|
||||
FluxAttnProcessor2_0_NPU,
|
||||
FusedFluxAttnProcessor2_0,
|
||||
FusedFluxAttnProcessor2_0_NPU,
|
||||
CogVideoXAttnProcessor2_0,
|
||||
FusedCogVideoXAttnProcessor2_0,
|
||||
XFormersAttnAddedKVProcessor,
|
||||
XFormersAttnProcessor,
|
||||
XLAFlashAttnProcessor2_0,
|
||||
AttnProcessorNPU,
|
||||
AttnProcessor2_0,
|
||||
MochiVaeAttnProcessor2_0,
|
||||
MochiAttnProcessor2_0,
|
||||
StableAudioAttnProcessor2_0,
|
||||
HunyuanAttnProcessor2_0,
|
||||
FusedHunyuanAttnProcessor2_0,
|
||||
PAGHunyuanAttnProcessor2_0,
|
||||
PAGCFGHunyuanAttnProcessor2_0,
|
||||
LuminaAttnProcessor2_0,
|
||||
FusedAttnProcessor2_0,
|
||||
CustomDiffusionXFormersAttnProcessor,
|
||||
CustomDiffusionAttnProcessor2_0,
|
||||
SlicedAttnProcessor,
|
||||
SlicedAttnAddedKVProcessor,
|
||||
SanaLinearAttnProcessor2_0,
|
||||
PAGCFGSanaLinearAttnProcessor2_0,
|
||||
PAGIdentitySanaLinearAttnProcessor2_0,
|
||||
SanaMultiscaleLinearAttention,
|
||||
SanaMultiscaleAttnProcessor2_0,
|
||||
SanaMultiscaleAttentionProjection,
|
||||
IPAdapterAttnProcessor,
|
||||
IPAdapterAttnProcessor2_0,
|
||||
IPAdapterXFormersAttnProcessor,
|
||||
SD3IPAdapterJointAttnProcessor2_0,
|
||||
PAGIdentitySelfAttnProcessor2_0,
|
||||
PAGCFGIdentitySelfAttnProcessor2_0,
|
||||
LoRAAttnProcessor,
|
||||
LoRAAttnProcessor2_0,
|
||||
LoRAXFormersAttnProcessor,
|
||||
LoRAAttnAddedKVProcessor,
|
||||
]
|
||||
AttentionProcessor = (
|
||||
AttnProcessor
|
||||
| CustomDiffusionAttnProcessor
|
||||
| AttnAddedKVProcessor
|
||||
| AttnAddedKVProcessor2_0
|
||||
| JointAttnProcessor2_0
|
||||
| PAGJointAttnProcessor2_0
|
||||
| PAGCFGJointAttnProcessor2_0
|
||||
| FusedJointAttnProcessor2_0
|
||||
| AllegroAttnProcessor2_0
|
||||
| AuraFlowAttnProcessor2_0
|
||||
| FusedAuraFlowAttnProcessor2_0
|
||||
| FluxAttnProcessor2_0
|
||||
| FluxAttnProcessor2_0_NPU
|
||||
| FusedFluxAttnProcessor2_0
|
||||
| FusedFluxAttnProcessor2_0_NPU
|
||||
| CogVideoXAttnProcessor2_0
|
||||
| FusedCogVideoXAttnProcessor2_0
|
||||
| XFormersAttnAddedKVProcessor
|
||||
| XFormersAttnProcessor
|
||||
| XLAFlashAttnProcessor2_0
|
||||
| AttnProcessorNPU
|
||||
| AttnProcessor2_0
|
||||
| MochiVaeAttnProcessor2_0
|
||||
| MochiAttnProcessor2_0
|
||||
| StableAudioAttnProcessor2_0
|
||||
| HunyuanAttnProcessor2_0
|
||||
| FusedHunyuanAttnProcessor2_0
|
||||
| PAGHunyuanAttnProcessor2_0
|
||||
| PAGCFGHunyuanAttnProcessor2_0
|
||||
| LuminaAttnProcessor2_0
|
||||
| FusedAttnProcessor2_0
|
||||
| CustomDiffusionXFormersAttnProcessor
|
||||
| CustomDiffusionAttnProcessor2_0
|
||||
| SlicedAttnProcessor
|
||||
| SlicedAttnAddedKVProcessor
|
||||
| SanaLinearAttnProcessor2_0
|
||||
| PAGCFGSanaLinearAttnProcessor2_0
|
||||
| PAGIdentitySanaLinearAttnProcessor2_0
|
||||
| SanaMultiscaleLinearAttention
|
||||
| SanaMultiscaleAttnProcessor2_0
|
||||
| SanaMultiscaleAttentionProjection
|
||||
| IPAdapterAttnProcessor
|
||||
| IPAdapterAttnProcessor2_0
|
||||
| IPAdapterXFormersAttnProcessor
|
||||
| SD3IPAdapterJointAttnProcessor2_0
|
||||
| PAGIdentitySelfAttnProcessor2_0
|
||||
| PAGCFGIdentitySelfAttnProcessor2_0
|
||||
| LoRAAttnProcessor
|
||||
| LoRAAttnProcessor2_0
|
||||
| LoRAXFormersAttnProcessor
|
||||
| LoRAAttnAddedKVProcessor
|
||||
)
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from typing import Optional, Union
|
||||
from typing import Optional
|
||||
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
|
||||
@@ -37,7 +37,7 @@ class AutoModel(ConfigMixin):
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
def from_pretrained(cls, pretrained_model_or_path: Optional[Union[str, os.PathLike]] = None, **kwargs):
|
||||
def from_pretrained(cls, pretrained_model_or_path: Optional[str | os.PathLike] = None, **kwargs):
|
||||
r"""
|
||||
Instantiate a pretrained PyTorch model from a pretrained model configuration.
|
||||
|
||||
@@ -61,7 +61,7 @@ class AutoModel(ConfigMixin):
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
proxies (`dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
output_loading_info (`bool`, *optional*, defaults to `False`):
|
||||
@@ -83,7 +83,7 @@ class AutoModel(ConfigMixin):
|
||||
Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
|
||||
guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
|
||||
information.
|
||||
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
|
||||
device_map (`str` or `dict[str, Union[int, str, torch.device]]`, *optional*):
|
||||
A map that specifies where each submodule should go. It doesn't need to be defined for each
|
||||
parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
|
||||
same device. Defaults to `None`, meaning that the model will be loaded on CPU.
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -34,16 +34,16 @@ class AsymmetricAutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
Parameters:
|
||||
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
|
||||
out_channels (int, *optional*, defaults to 3): Number of channels in the output.
|
||||
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
|
||||
Tuple of downsample block types.
|
||||
down_block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
|
||||
Tuple of down block output channels.
|
||||
down_block_types (`tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
|
||||
tuple of downsample block types.
|
||||
down_block_out_channels (`tuple[int]`, *optional*, defaults to `(64,)`):
|
||||
tuple of down block output channels.
|
||||
layers_per_down_block (`int`, *optional*, defaults to `1`):
|
||||
Number layers for down block.
|
||||
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
|
||||
Tuple of upsample block types.
|
||||
up_block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
|
||||
Tuple of up block output channels.
|
||||
up_block_types (`tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
|
||||
tuple of upsample block types.
|
||||
up_block_out_channels (`tuple[int]`, *optional*, defaults to `(64,)`):
|
||||
tuple of up block output channels.
|
||||
layers_per_up_block (`int`, *optional*, defaults to `1`):
|
||||
Number layers for up block.
|
||||
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
||||
@@ -67,11 +67,11 @@ class AsymmetricAutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",),
|
||||
down_block_out_channels: Tuple[int, ...] = (64,),
|
||||
down_block_types: tuple[str, ...] = ("DownEncoderBlock2D",),
|
||||
down_block_out_channels: tuple[int, ...] = (64,),
|
||||
layers_per_down_block: int = 1,
|
||||
up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",),
|
||||
up_block_out_channels: Tuple[int, ...] = (64,),
|
||||
up_block_types: tuple[str, ...] = ("UpDecoderBlock2D",),
|
||||
up_block_out_channels: tuple[int, ...] = (64,),
|
||||
layers_per_up_block: int = 1,
|
||||
act_fn: str = "silu",
|
||||
latent_channels: int = 4,
|
||||
@@ -111,7 +111,7 @@ class AsymmetricAutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
self.register_to_config(force_upcast=False)
|
||||
|
||||
@apply_forward_hook
|
||||
def encode(self, x: torch.Tensor, return_dict: bool = True) -> Union[AutoencoderKLOutput, Tuple[torch.Tensor]]:
|
||||
def encode(self, x: torch.Tensor, return_dict: bool = True) -> AutoencoderKLOutput | tuple[torch.Tensor]:
|
||||
h = self.encoder(x)
|
||||
moments = self.quant_conv(h)
|
||||
posterior = DiagonalGaussianDistribution(moments)
|
||||
@@ -127,7 +127,7 @@ class AsymmetricAutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
image: Optional[torch.Tensor] = None,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
|
||||
) -> DecoderOutput | tuple[torch.Tensor]:
|
||||
z = self.post_quant_conv(z)
|
||||
dec = self.decoder(z, image, mask)
|
||||
|
||||
@@ -144,7 +144,7 @@ class AsymmetricAutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
image: Optional[torch.Tensor] = None,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
|
||||
) -> DecoderOutput | tuple[torch.Tensor]:
|
||||
decoded = self._decode(z, image, mask).sample
|
||||
|
||||
if not return_dict:
|
||||
@@ -159,7 +159,7 @@ class AsymmetricAutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
sample_posterior: bool = False,
|
||||
return_dict: bool = True,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
|
||||
) -> DecoderOutput | tuple[torch.Tensor]:
|
||||
r"""
|
||||
Args:
|
||||
sample (`torch.Tensor`): Input sample.
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -68,7 +68,7 @@ class EfficientViTBlock(nn.Module):
|
||||
in_channels: int,
|
||||
mult: float = 1.0,
|
||||
attention_head_dim: int = 32,
|
||||
qkv_multiscales: Tuple[int, ...] = (5,),
|
||||
qkv_multiscales: tuple[int, ...] = (5,),
|
||||
norm_type: str = "batch_norm",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@@ -102,7 +102,7 @@ def get_block(
|
||||
attention_head_dim: int,
|
||||
norm_type: str,
|
||||
act_fn: str,
|
||||
qkv_mutliscales: Tuple[int] = (),
|
||||
qkv_mutliscales: tuple[int] = (),
|
||||
):
|
||||
if block_type == "ResBlock":
|
||||
block = ResBlock(in_channels, out_channels, norm_type, act_fn)
|
||||
@@ -205,10 +205,10 @@ class Encoder(nn.Module):
|
||||
in_channels: int,
|
||||
latent_channels: int,
|
||||
attention_head_dim: int = 32,
|
||||
block_type: Union[str, Tuple[str]] = "ResBlock",
|
||||
block_out_channels: Tuple[int] = (128, 256, 512, 512, 1024, 1024),
|
||||
layers_per_block: Tuple[int] = (2, 2, 2, 2, 2, 2),
|
||||
qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)),
|
||||
block_type: str | tuple[str] = "ResBlock",
|
||||
block_out_channels: tuple[int] = (128, 256, 512, 512, 1024, 1024),
|
||||
layers_per_block: tuple[int] = (2, 2, 2, 2, 2, 2),
|
||||
qkv_multiscales: tuple[tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)),
|
||||
downsample_block_type: str = "pixel_unshuffle",
|
||||
out_shortcut: bool = True,
|
||||
):
|
||||
@@ -291,12 +291,12 @@ class Decoder(nn.Module):
|
||||
in_channels: int,
|
||||
latent_channels: int,
|
||||
attention_head_dim: int = 32,
|
||||
block_type: Union[str, Tuple[str]] = "ResBlock",
|
||||
block_out_channels: Tuple[int] = (128, 256, 512, 512, 1024, 1024),
|
||||
layers_per_block: Tuple[int] = (2, 2, 2, 2, 2, 2),
|
||||
qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)),
|
||||
norm_type: Union[str, Tuple[str]] = "rms_norm",
|
||||
act_fn: Union[str, Tuple[str]] = "silu",
|
||||
block_type: str | tuple[str] = "ResBlock",
|
||||
block_out_channels: tuple[int] = (128, 256, 512, 512, 1024, 1024),
|
||||
layers_per_block: tuple[int] = (2, 2, 2, 2, 2, 2),
|
||||
qkv_multiscales: tuple[tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)),
|
||||
norm_type: str | tuple[str] = "rms_norm",
|
||||
act_fn: str | tuple[str] = "silu",
|
||||
upsample_block_type: str = "pixel_shuffle",
|
||||
in_shortcut: bool = True,
|
||||
conv_act_fn: str = "relu",
|
||||
@@ -391,29 +391,29 @@ class AutoencoderDC(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModel
|
||||
The number of input channels in samples.
|
||||
latent_channels (`int`, defaults to `32`):
|
||||
The number of channels in the latent space representation.
|
||||
encoder_block_types (`Union[str, Tuple[str]]`, defaults to `"ResBlock"`):
|
||||
encoder_block_types (`Union[str, tuple[str]]`, defaults to `"ResBlock"`):
|
||||
The type(s) of block to use in the encoder.
|
||||
decoder_block_types (`Union[str, Tuple[str]]`, defaults to `"ResBlock"`):
|
||||
decoder_block_types (`Union[str, tuple[str]]`, defaults to `"ResBlock"`):
|
||||
The type(s) of block to use in the decoder.
|
||||
encoder_block_out_channels (`Tuple[int, ...]`, defaults to `(128, 256, 512, 512, 1024, 1024)`):
|
||||
encoder_block_out_channels (`tuple[int, ...]`, defaults to `(128, 256, 512, 512, 1024, 1024)`):
|
||||
The number of output channels for each block in the encoder.
|
||||
decoder_block_out_channels (`Tuple[int, ...]`, defaults to `(128, 256, 512, 512, 1024, 1024)`):
|
||||
decoder_block_out_channels (`tuple[int, ...]`, defaults to `(128, 256, 512, 512, 1024, 1024)`):
|
||||
The number of output channels for each block in the decoder.
|
||||
encoder_layers_per_block (`Tuple[int]`, defaults to `(2, 2, 2, 3, 3, 3)`):
|
||||
encoder_layers_per_block (`tuple[int]`, defaults to `(2, 2, 2, 3, 3, 3)`):
|
||||
The number of layers per block in the encoder.
|
||||
decoder_layers_per_block (`Tuple[int]`, defaults to `(3, 3, 3, 3, 3, 3)`):
|
||||
decoder_layers_per_block (`tuple[int]`, defaults to `(3, 3, 3, 3, 3, 3)`):
|
||||
The number of layers per block in the decoder.
|
||||
encoder_qkv_multiscales (`Tuple[Tuple[int, ...], ...]`, defaults to `((), (), (), (5,), (5,), (5,))`):
|
||||
encoder_qkv_multiscales (`tuple[tuple[int, ...], ...]`, defaults to `((), (), (), (5,), (5,), (5,))`):
|
||||
Multi-scale configurations for the encoder's QKV (query-key-value) transformations.
|
||||
decoder_qkv_multiscales (`Tuple[Tuple[int, ...], ...]`, defaults to `((), (), (), (5,), (5,), (5,))`):
|
||||
decoder_qkv_multiscales (`tuple[tuple[int, ...], ...]`, defaults to `((), (), (), (5,), (5,), (5,))`):
|
||||
Multi-scale configurations for the decoder's QKV (query-key-value) transformations.
|
||||
upsample_block_type (`str`, defaults to `"pixel_shuffle"`):
|
||||
The type of block to use for upsampling in the decoder.
|
||||
downsample_block_type (`str`, defaults to `"pixel_unshuffle"`):
|
||||
The type of block to use for downsampling in the encoder.
|
||||
decoder_norm_types (`Union[str, Tuple[str]]`, defaults to `"rms_norm"`):
|
||||
decoder_norm_types (`Union[str, tuple[str]]`, defaults to `"rms_norm"`):
|
||||
The normalization type(s) to use in the decoder.
|
||||
decoder_act_fns (`Union[str, Tuple[str]]`, defaults to `"silu"`):
|
||||
decoder_act_fns (`Union[str, tuple[str]]`, defaults to `"silu"`):
|
||||
The activation function(s) to use in the decoder.
|
||||
encoder_out_shortcut (`bool`, defaults to `True`):
|
||||
Whether to use shortcut at the end of the encoder.
|
||||
@@ -436,18 +436,18 @@ class AutoencoderDC(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModel
|
||||
in_channels: int = 3,
|
||||
latent_channels: int = 32,
|
||||
attention_head_dim: int = 32,
|
||||
encoder_block_types: Union[str, Tuple[str]] = "ResBlock",
|
||||
decoder_block_types: Union[str, Tuple[str]] = "ResBlock",
|
||||
encoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512, 1024, 1024),
|
||||
decoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512, 1024, 1024),
|
||||
encoder_layers_per_block: Tuple[int] = (2, 2, 2, 3, 3, 3),
|
||||
decoder_layers_per_block: Tuple[int] = (3, 3, 3, 3, 3, 3),
|
||||
encoder_qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)),
|
||||
decoder_qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)),
|
||||
encoder_block_types: str | tuple[str] = "ResBlock",
|
||||
decoder_block_types: str | tuple[str] = "ResBlock",
|
||||
encoder_block_out_channels: tuple[int, ...] = (128, 256, 512, 512, 1024, 1024),
|
||||
decoder_block_out_channels: tuple[int, ...] = (128, 256, 512, 512, 1024, 1024),
|
||||
encoder_layers_per_block: tuple[int] = (2, 2, 2, 3, 3, 3),
|
||||
decoder_layers_per_block: tuple[int] = (3, 3, 3, 3, 3, 3),
|
||||
encoder_qkv_multiscales: tuple[tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)),
|
||||
decoder_qkv_multiscales: tuple[tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)),
|
||||
upsample_block_type: str = "pixel_shuffle",
|
||||
downsample_block_type: str = "pixel_unshuffle",
|
||||
decoder_norm_types: Union[str, Tuple[str]] = "rms_norm",
|
||||
decoder_act_fns: Union[str, Tuple[str]] = "silu",
|
||||
decoder_norm_types: str | tuple[str] = "rms_norm",
|
||||
decoder_act_fns: str | tuple[str] = "silu",
|
||||
encoder_out_shortcut: bool = True,
|
||||
decoder_in_shortcut: bool = True,
|
||||
decoder_conv_act_fn: str = "relu",
|
||||
@@ -547,7 +547,7 @@ class AutoencoderDC(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModel
|
||||
return encoded
|
||||
|
||||
@apply_forward_hook
|
||||
def encode(self, x: torch.Tensor, return_dict: bool = True) -> Union[EncoderOutput, Tuple[torch.Tensor]]:
|
||||
def encode(self, x: torch.Tensor, return_dict: bool = True) -> EncoderOutput | tuple[torch.Tensor]:
|
||||
r"""
|
||||
Encode a batch of images into latents.
|
||||
|
||||
@@ -581,7 +581,7 @@ class AutoencoderDC(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModel
|
||||
return decoded
|
||||
|
||||
@apply_forward_hook
|
||||
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
|
||||
def decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | tuple[torch.Tensor]:
|
||||
r"""
|
||||
Decode a batch of images.
|
||||
|
||||
@@ -665,7 +665,7 @@ class AutoencoderDC(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModel
|
||||
return (encoded,)
|
||||
return EncoderOutput(latent=encoded)
|
||||
|
||||
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor:
|
||||
batch_size, num_channels, height, width = z.shape
|
||||
|
||||
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -45,12 +45,12 @@ class AutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModel
|
||||
Parameters:
|
||||
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
|
||||
out_channels (int, *optional*, defaults to 3): Number of channels in the output.
|
||||
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
|
||||
Tuple of downsample block types.
|
||||
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
|
||||
Tuple of upsample block types.
|
||||
block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
|
||||
Tuple of block output channels.
|
||||
down_block_types (`tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
|
||||
tuple of downsample block types.
|
||||
up_block_types (`tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
|
||||
tuple of upsample block types.
|
||||
block_out_channels (`tuple[int]`, *optional*, defaults to `(64,)`):
|
||||
tuple of block output channels.
|
||||
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
||||
latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
|
||||
sample_size (`int`, *optional*, defaults to `32`): Sample input size.
|
||||
@@ -78,9 +78,9 @@ class AutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModel
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
|
||||
up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
|
||||
block_out_channels: Tuple[int] = (64,),
|
||||
down_block_types: tuple[str] = ("DownEncoderBlock2D",),
|
||||
up_block_types: tuple[str] = ("UpDecoderBlock2D",),
|
||||
block_out_channels: tuple[int] = (64,),
|
||||
layers_per_block: int = 1,
|
||||
act_fn: str = "silu",
|
||||
latent_channels: int = 4,
|
||||
@@ -88,8 +88,8 @@ class AutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModel
|
||||
sample_size: int = 32,
|
||||
scaling_factor: float = 0.18215,
|
||||
shift_factor: Optional[float] = None,
|
||||
latents_mean: Optional[Tuple[float]] = None,
|
||||
latents_std: Optional[Tuple[float]] = None,
|
||||
latents_mean: Optional[tuple[float]] = None,
|
||||
latents_std: Optional[tuple[float]] = None,
|
||||
force_upcast: bool = True,
|
||||
use_quant_conv: bool = True,
|
||||
use_post_quant_conv: bool = True,
|
||||
@@ -140,7 +140,7 @@ class AutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModel
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
def attn_processors(self) -> dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
@@ -149,7 +149,7 @@ class AutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModel
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
@@ -164,7 +164,7 @@ class AutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModel
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
def set_attn_processor(self, processor: AttentionProcessor | dict[str, AttentionProcessor]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
@@ -229,7 +229,7 @@ class AutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModel
|
||||
@apply_forward_hook
|
||||
def encode(
|
||||
self, x: torch.Tensor, return_dict: bool = True
|
||||
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
|
||||
) -> AutoencoderKLOutput | tuple[DiagonalGaussianDistribution]:
|
||||
"""
|
||||
Encode a batch of images into latents.
|
||||
|
||||
@@ -255,7 +255,7 @@ class AutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModel
|
||||
|
||||
return AutoencoderKLOutput(latent_dist=posterior)
|
||||
|
||||
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor:
|
||||
if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
|
||||
return self.tiled_decode(z, return_dict=return_dict)
|
||||
|
||||
@@ -272,7 +272,7 @@ class AutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModel
|
||||
@apply_forward_hook
|
||||
def decode(
|
||||
self, z: torch.FloatTensor, return_dict: bool = True, generator=None
|
||||
) -> Union[DecoderOutput, torch.FloatTensor]:
|
||||
) -> DecoderOutput | torch.FloatTensor:
|
||||
"""
|
||||
Decode a batch of images.
|
||||
|
||||
@@ -420,7 +420,7 @@ class AutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModel
|
||||
|
||||
return AutoencoderKLOutput(latent_dist=posterior)
|
||||
|
||||
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor:
|
||||
r"""
|
||||
Decode a batch of images using a tiled decoder.
|
||||
|
||||
@@ -475,7 +475,7 @@ class AutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModel
|
||||
sample_posterior: bool = False,
|
||||
return_dict: bool = True,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
) -> Union[DecoderOutput, torch.Tensor]:
|
||||
) -> DecoderOutput | torch.Tensor:
|
||||
r"""
|
||||
Args:
|
||||
sample (`torch.Tensor`): Input sample.
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -417,14 +417,14 @@ class AllegroEncoder3D(nn.Module):
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
down_block_types: tuple[str, ...] = (
|
||||
"AllegroDownBlock3D",
|
||||
"AllegroDownBlock3D",
|
||||
"AllegroDownBlock3D",
|
||||
"AllegroDownBlock3D",
|
||||
),
|
||||
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
|
||||
temporal_downsample_blocks: Tuple[bool, ...] = [True, True, False, False],
|
||||
block_out_channels: tuple[int, ...] = (128, 256, 512, 512),
|
||||
temporal_downsample_blocks: tuple[bool, ...] = [True, True, False, False],
|
||||
layers_per_block: int = 2,
|
||||
norm_num_groups: int = 32,
|
||||
act_fn: str = "silu",
|
||||
@@ -544,14 +544,14 @@ class AllegroDecoder3D(nn.Module):
|
||||
self,
|
||||
in_channels: int = 4,
|
||||
out_channels: int = 3,
|
||||
up_block_types: Tuple[str, ...] = (
|
||||
up_block_types: tuple[str, ...] = (
|
||||
"AllegroUpBlock3D",
|
||||
"AllegroUpBlock3D",
|
||||
"AllegroUpBlock3D",
|
||||
"AllegroUpBlock3D",
|
||||
),
|
||||
temporal_upsample_blocks: Tuple[bool, ...] = [False, True, True, False],
|
||||
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
|
||||
temporal_upsample_blocks: tuple[bool, ...] = [False, True, True, False],
|
||||
block_out_channels: tuple[int, ...] = (128, 256, 512, 512),
|
||||
layers_per_block: int = 2,
|
||||
norm_num_groups: int = 32,
|
||||
act_fn: str = "silu",
|
||||
@@ -687,14 +687,14 @@ class AutoencoderKLAllegro(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
Number of channels in the input image.
|
||||
out_channels (int, defaults to `3`):
|
||||
Number of channels in the output.
|
||||
down_block_types (`Tuple[str, ...]`, defaults to `("AllegroDownBlock3D", "AllegroDownBlock3D", "AllegroDownBlock3D", "AllegroDownBlock3D")`):
|
||||
Tuple of strings denoting which types of down blocks to use.
|
||||
up_block_types (`Tuple[str, ...]`, defaults to `("AllegroUpBlock3D", "AllegroUpBlock3D", "AllegroUpBlock3D", "AllegroUpBlock3D")`):
|
||||
Tuple of strings denoting which types of up blocks to use.
|
||||
block_out_channels (`Tuple[int, ...]`, defaults to `(128, 256, 512, 512)`):
|
||||
Tuple of integers denoting number of output channels in each block.
|
||||
temporal_downsample_blocks (`Tuple[bool, ...]`, defaults to `(True, True, False, False)`):
|
||||
Tuple of booleans denoting which blocks to enable temporal downsampling in.
|
||||
down_block_types (`tuple[str, ...]`, defaults to `("AllegroDownBlock3D", "AllegroDownBlock3D", "AllegroDownBlock3D", "AllegroDownBlock3D")`):
|
||||
tuple of strings denoting which types of down blocks to use.
|
||||
up_block_types (`tuple[str, ...]`, defaults to `("AllegroUpBlock3D", "AllegroUpBlock3D", "AllegroUpBlock3D", "AllegroUpBlock3D")`):
|
||||
tuple of strings denoting which types of up blocks to use.
|
||||
block_out_channels (`tuple[int, ...]`, defaults to `(128, 256, 512, 512)`):
|
||||
tuple of integers denoting number of output channels in each block.
|
||||
temporal_downsample_blocks (`tuple[bool, ...]`, defaults to `(True, True, False, False)`):
|
||||
tuple of booleans denoting which blocks to enable temporal downsampling in.
|
||||
latent_channels (`int`, defaults to `4`):
|
||||
Number of channels in latents.
|
||||
layers_per_block (`int`, defaults to `2`):
|
||||
@@ -727,21 +727,21 @@ class AutoencoderKLAllegro(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
down_block_types: tuple[str, ...] = (
|
||||
"AllegroDownBlock3D",
|
||||
"AllegroDownBlock3D",
|
||||
"AllegroDownBlock3D",
|
||||
"AllegroDownBlock3D",
|
||||
),
|
||||
up_block_types: Tuple[str, ...] = (
|
||||
up_block_types: tuple[str, ...] = (
|
||||
"AllegroUpBlock3D",
|
||||
"AllegroUpBlock3D",
|
||||
"AllegroUpBlock3D",
|
||||
"AllegroUpBlock3D",
|
||||
),
|
||||
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
|
||||
temporal_downsample_blocks: Tuple[bool, ...] = (True, True, False, False),
|
||||
temporal_upsample_blocks: Tuple[bool, ...] = (False, True, True, False),
|
||||
block_out_channels: tuple[int, ...] = (128, 256, 512, 512),
|
||||
temporal_downsample_blocks: tuple[bool, ...] = (True, True, False, False),
|
||||
temporal_upsample_blocks: tuple[bool, ...] = (False, True, True, False),
|
||||
latent_channels: int = 4,
|
||||
layers_per_block: int = 2,
|
||||
act_fn: str = "silu",
|
||||
@@ -807,7 +807,7 @@ class AutoencoderKLAllegro(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
@apply_forward_hook
|
||||
def encode(
|
||||
self, x: torch.Tensor, return_dict: bool = True
|
||||
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
|
||||
) -> AutoencoderKLOutput | tuple[DiagonalGaussianDistribution]:
|
||||
r"""
|
||||
Encode a batch of videos into latents.
|
||||
|
||||
@@ -842,7 +842,7 @@ class AutoencoderKLAllegro(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
raise NotImplementedError("Decoding without tiling has not been implemented yet.")
|
||||
|
||||
@apply_forward_hook
|
||||
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
def decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor:
|
||||
"""
|
||||
Decode a batch of videos.
|
||||
|
||||
@@ -1045,7 +1045,7 @@ class AutoencoderKLAllegro(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
sample_posterior: bool = False,
|
||||
return_dict: bool = True,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
) -> Union[DecoderOutput, torch.Tensor]:
|
||||
) -> DecoderOutput | torch.Tensor:
|
||||
r"""
|
||||
Args:
|
||||
sample (`torch.Tensor`): Input sample.
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -72,7 +72,7 @@ class CogVideoXCausalConv3d(nn.Module):
|
||||
Args:
|
||||
in_channels (`int`): Number of channels in the input tensor.
|
||||
out_channels (`int`): Number of output channels produced by the convolution.
|
||||
kernel_size (`int` or `Tuple[int, int, int]`): Kernel size of the convolutional kernel.
|
||||
kernel_size (`int` or `tuple[int, int, int]`): Kernel size of the convolutional kernel.
|
||||
stride (`int`, defaults to `1`): Stride of the convolution.
|
||||
dilation (`int`, defaults to `1`): Dilation rate of the convolution.
|
||||
pad_mode (`str`, defaults to `"constant"`): Padding mode.
|
||||
@@ -82,7 +82,7 @@ class CogVideoXCausalConv3d(nn.Module):
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: Union[int, Tuple[int, int, int]],
|
||||
kernel_size: int | tuple[int, int, int],
|
||||
stride: int = 1,
|
||||
dilation: int = 1,
|
||||
pad_mode: str = "constant",
|
||||
@@ -174,7 +174,7 @@ class CogVideoXSpatialNorm3D(nn.Module):
|
||||
self.conv_b = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
|
||||
|
||||
def forward(
|
||||
self, f: torch.Tensor, zq: torch.Tensor, conv_cache: Optional[Dict[str, torch.Tensor]] = None
|
||||
self, f: torch.Tensor, zq: torch.Tensor, conv_cache: Optional[dict[str, torch.Tensor]] = None
|
||||
) -> torch.Tensor:
|
||||
new_conv_cache = {}
|
||||
conv_cache = conv_cache or {}
|
||||
@@ -289,7 +289,7 @@ class CogVideoXResnetBlock3D(nn.Module):
|
||||
inputs: torch.Tensor,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
zq: Optional[torch.Tensor] = None,
|
||||
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
|
||||
conv_cache: Optional[dict[str, torch.Tensor]] = None,
|
||||
) -> torch.Tensor:
|
||||
new_conv_cache = {}
|
||||
conv_cache = conv_cache or {}
|
||||
@@ -411,7 +411,7 @@ class CogVideoXDownBlock3D(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
zq: Optional[torch.Tensor] = None,
|
||||
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
|
||||
conv_cache: Optional[dict[str, torch.Tensor]] = None,
|
||||
) -> torch.Tensor:
|
||||
r"""Forward method of the `CogVideoXDownBlock3D` class."""
|
||||
|
||||
@@ -506,7 +506,7 @@ class CogVideoXMidBlock3D(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
zq: Optional[torch.Tensor] = None,
|
||||
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
|
||||
conv_cache: Optional[dict[str, torch.Tensor]] = None,
|
||||
) -> torch.Tensor:
|
||||
r"""Forward method of the `CogVideoXMidBlock3D` class."""
|
||||
|
||||
@@ -613,7 +613,7 @@ class CogVideoXUpBlock3D(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
zq: Optional[torch.Tensor] = None,
|
||||
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
|
||||
conv_cache: Optional[dict[str, torch.Tensor]] = None,
|
||||
) -> torch.Tensor:
|
||||
r"""Forward method of the `CogVideoXUpBlock3D` class."""
|
||||
|
||||
@@ -652,10 +652,10 @@ class CogVideoXEncoder3D(nn.Module):
|
||||
The number of input channels.
|
||||
out_channels (`int`, *optional*, defaults to 3):
|
||||
The number of output channels.
|
||||
down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
|
||||
down_block_types (`tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
|
||||
The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available
|
||||
options.
|
||||
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
|
||||
block_out_channels (`tuple[int, ...]`, *optional*, defaults to `(64,)`):
|
||||
The number of output channels for each block.
|
||||
act_fn (`str`, *optional*, defaults to `"silu"`):
|
||||
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
|
||||
@@ -671,13 +671,13 @@ class CogVideoXEncoder3D(nn.Module):
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 16,
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
down_block_types: tuple[str, ...] = (
|
||||
"CogVideoXDownBlock3D",
|
||||
"CogVideoXDownBlock3D",
|
||||
"CogVideoXDownBlock3D",
|
||||
"CogVideoXDownBlock3D",
|
||||
),
|
||||
block_out_channels: Tuple[int, ...] = (128, 256, 256, 512),
|
||||
block_out_channels: tuple[int, ...] = (128, 256, 256, 512),
|
||||
layers_per_block: int = 3,
|
||||
act_fn: str = "silu",
|
||||
norm_eps: float = 1e-6,
|
||||
@@ -744,7 +744,7 @@ class CogVideoXEncoder3D(nn.Module):
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
|
||||
conv_cache: Optional[dict[str, torch.Tensor]] = None,
|
||||
) -> torch.Tensor:
|
||||
r"""The forward method of the `CogVideoXEncoder3D` class."""
|
||||
|
||||
@@ -805,9 +805,9 @@ class CogVideoXDecoder3D(nn.Module):
|
||||
The number of input channels.
|
||||
out_channels (`int`, *optional*, defaults to 3):
|
||||
The number of output channels.
|
||||
up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
|
||||
up_block_types (`tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
|
||||
The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options.
|
||||
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
|
||||
block_out_channels (`tuple[int, ...]`, *optional*, defaults to `(64,)`):
|
||||
The number of output channels for each block.
|
||||
act_fn (`str`, *optional*, defaults to `"silu"`):
|
||||
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
|
||||
@@ -823,13 +823,13 @@ class CogVideoXDecoder3D(nn.Module):
|
||||
self,
|
||||
in_channels: int = 16,
|
||||
out_channels: int = 3,
|
||||
up_block_types: Tuple[str, ...] = (
|
||||
up_block_types: tuple[str, ...] = (
|
||||
"CogVideoXUpBlock3D",
|
||||
"CogVideoXUpBlock3D",
|
||||
"CogVideoXUpBlock3D",
|
||||
"CogVideoXUpBlock3D",
|
||||
),
|
||||
block_out_channels: Tuple[int, ...] = (128, 256, 256, 512),
|
||||
block_out_channels: tuple[int, ...] = (128, 256, 256, 512),
|
||||
layers_per_block: int = 3,
|
||||
act_fn: str = "silu",
|
||||
norm_eps: float = 1e-6,
|
||||
@@ -903,7 +903,7 @@ class CogVideoXDecoder3D(nn.Module):
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
|
||||
conv_cache: Optional[dict[str, torch.Tensor]] = None,
|
||||
) -> torch.Tensor:
|
||||
r"""The forward method of the `CogVideoXDecoder3D` class."""
|
||||
|
||||
@@ -966,12 +966,12 @@ class AutoencoderKLCogVideoX(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
|
||||
Parameters:
|
||||
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
|
||||
out_channels (int, *optional*, defaults to 3): Number of channels in the output.
|
||||
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
|
||||
Tuple of downsample block types.
|
||||
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
|
||||
Tuple of upsample block types.
|
||||
block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
|
||||
Tuple of block output channels.
|
||||
down_block_types (`tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
|
||||
tuple of downsample block types.
|
||||
up_block_types (`tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
|
||||
tuple of upsample block types.
|
||||
block_out_channels (`tuple[int]`, *optional*, defaults to `(64,)`):
|
||||
tuple of block output channels.
|
||||
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
||||
sample_size (`int`, *optional*, defaults to `32`): Sample input size.
|
||||
scaling_factor (`float`, *optional*, defaults to `1.15258426`):
|
||||
@@ -995,19 +995,19 @@ class AutoencoderKLCogVideoX(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
down_block_types: Tuple[str] = (
|
||||
down_block_types: tuple[str] = (
|
||||
"CogVideoXDownBlock3D",
|
||||
"CogVideoXDownBlock3D",
|
||||
"CogVideoXDownBlock3D",
|
||||
"CogVideoXDownBlock3D",
|
||||
),
|
||||
up_block_types: Tuple[str] = (
|
||||
up_block_types: tuple[str] = (
|
||||
"CogVideoXUpBlock3D",
|
||||
"CogVideoXUpBlock3D",
|
||||
"CogVideoXUpBlock3D",
|
||||
"CogVideoXUpBlock3D",
|
||||
),
|
||||
block_out_channels: Tuple[int] = (128, 256, 256, 512),
|
||||
block_out_channels: tuple[int] = (128, 256, 256, 512),
|
||||
latent_channels: int = 16,
|
||||
layers_per_block: int = 3,
|
||||
act_fn: str = "silu",
|
||||
@@ -1018,8 +1018,8 @@ class AutoencoderKLCogVideoX(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
|
||||
sample_width: int = 720,
|
||||
scaling_factor: float = 1.15258426,
|
||||
shift_factor: Optional[float] = None,
|
||||
latents_mean: Optional[Tuple[float]] = None,
|
||||
latents_std: Optional[Tuple[float]] = None,
|
||||
latents_mean: Optional[tuple[float]] = None,
|
||||
latents_std: Optional[tuple[float]] = None,
|
||||
force_upcast: float = True,
|
||||
use_quant_conv: bool = False,
|
||||
use_post_quant_conv: bool = False,
|
||||
@@ -1153,7 +1153,7 @@ class AutoencoderKLCogVideoX(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
|
||||
@apply_forward_hook
|
||||
def encode(
|
||||
self, x: torch.Tensor, return_dict: bool = True
|
||||
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
|
||||
) -> AutoencoderKLOutput | tuple[DiagonalGaussianDistribution]:
|
||||
"""
|
||||
Encode a batch of images into latents.
|
||||
|
||||
@@ -1178,7 +1178,7 @@ class AutoencoderKLCogVideoX(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
|
||||
return (posterior,)
|
||||
return AutoencoderKLOutput(latent_dist=posterior)
|
||||
|
||||
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor:
|
||||
batch_size, num_channels, num_frames, height, width = z.shape
|
||||
|
||||
if self.use_tiling and (width > self.tile_latent_min_width or height > self.tile_latent_min_height):
|
||||
@@ -1207,7 +1207,7 @@ class AutoencoderKLCogVideoX(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
|
||||
return DecoderOutput(sample=dec)
|
||||
|
||||
@apply_forward_hook
|
||||
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
def decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor:
|
||||
"""
|
||||
Decode a batch of images.
|
||||
|
||||
@@ -1321,7 +1321,7 @@ class AutoencoderKLCogVideoX(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
|
||||
enc = torch.cat(result_rows, dim=3)
|
||||
return enc
|
||||
|
||||
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor:
|
||||
r"""
|
||||
Decode a batch of images using a tiled decoder.
|
||||
|
||||
@@ -1410,7 +1410,7 @@ class AutoencoderKLCogVideoX(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
|
||||
sample_posterior: bool = False,
|
||||
return_dict: bool = True,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
) -> Union[torch.Tensor, torch.Tensor]:
|
||||
) -> torch.Tensor | torch.Tensor:
|
||||
x = sample
|
||||
posterior = self.encode(x).latent_dist
|
||||
if sample_posterior:
|
||||
|
||||
@@ -12,8 +12,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -47,9 +49,9 @@ class CosmosCausalConv3d(nn.Conv3d):
|
||||
self,
|
||||
in_channels: int = 1,
|
||||
out_channels: int = 1,
|
||||
kernel_size: Union[int, Tuple[int, int, int]] = (3, 3, 3),
|
||||
dilation: Union[int, Tuple[int, int, int]] = (1, 1, 1),
|
||||
stride: Union[int, Tuple[int, int, int]] = (1, 1, 1),
|
||||
kernel_size: int | tuple[int, int, int] = (3, 3, 3),
|
||||
dilation: int | tuple[int, int, int] = (1, 1, 1),
|
||||
stride: int | tuple[int, int, int] = (1, 1, 1),
|
||||
padding: int = 1,
|
||||
pad_mode: str = "constant",
|
||||
) -> None:
|
||||
@@ -419,7 +421,7 @@ class CosmosCausalAttention(nn.Module):
|
||||
attention_head_dim: int,
|
||||
num_groups: int = 1,
|
||||
dropout: float = 0.0,
|
||||
processor: Union["CosmosSpatialAttentionProcessor2_0", "CosmosTemporalAttentionProcessor2_0"] = None,
|
||||
processor: "CosmosSpatialAttentionProcessor2_0" | "CosmosTemporalAttentionProcessor2_0" = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.num_attention_heads = num_attention_heads
|
||||
@@ -711,9 +713,9 @@ class CosmosEncoder3d(nn.Module):
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 16,
|
||||
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
|
||||
block_out_channels: tuple[int, ...] = (128, 256, 512, 512),
|
||||
num_resnet_blocks: int = 2,
|
||||
attention_resolutions: Tuple[int, ...] = (32,),
|
||||
attention_resolutions: tuple[int, ...] = (32,),
|
||||
resolution: int = 1024,
|
||||
patch_size: int = 4,
|
||||
patch_type: str = "haar",
|
||||
@@ -795,9 +797,9 @@ class CosmosDecoder3d(nn.Module):
|
||||
self,
|
||||
in_channels: int = 16,
|
||||
out_channels: int = 3,
|
||||
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
|
||||
block_out_channels: tuple[int, ...] = (128, 256, 512, 512),
|
||||
num_resnet_blocks: int = 2,
|
||||
attention_resolutions: Tuple[int, ...] = (32,),
|
||||
attention_resolutions: tuple[int, ...] = (32,),
|
||||
resolution: int = 1024,
|
||||
patch_size: int = 4,
|
||||
patch_type: str = "haar",
|
||||
@@ -886,12 +888,12 @@ class AutoencoderKLCosmos(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
Number of output channels.
|
||||
latent_channels (`int`, defaults to `16`):
|
||||
Number of latent channels.
|
||||
encoder_block_out_channels (`Tuple[int, ...]`, defaults to `(128, 256, 512, 512)`):
|
||||
encoder_block_out_channels (`tuple[int, ...]`, defaults to `(128, 256, 512, 512)`):
|
||||
Number of output channels for each encoder down block.
|
||||
decode_block_out_channels (`Tuple[int, ...]`, defaults to `(256, 512, 512, 512)`):
|
||||
decode_block_out_channels (`tuple[int, ...]`, defaults to `(256, 512, 512, 512)`):
|
||||
Number of output channels for each decoder up block.
|
||||
attention_resolutions (`Tuple[int, ...]`, defaults to `(32,)`):
|
||||
List of image/video resolutions at which to apply attention.
|
||||
attention_resolutions (`tuple[int, ...]`, defaults to `(32,)`):
|
||||
list of image/video resolutions at which to apply attention.
|
||||
resolution (`int`, defaults to `1024`):
|
||||
Base image/video resolution used for computing whether a block should have attention layers.
|
||||
num_layers (`int`, defaults to `2`):
|
||||
@@ -924,9 +926,9 @@ class AutoencoderKLCosmos(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
latent_channels: int = 16,
|
||||
encoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
|
||||
decode_block_out_channels: Tuple[int, ...] = (256, 512, 512, 512),
|
||||
attention_resolutions: Tuple[int, ...] = (32,),
|
||||
encoder_block_out_channels: tuple[int, ...] = (128, 256, 512, 512),
|
||||
decode_block_out_channels: tuple[int, ...] = (256, 512, 512, 512),
|
||||
attention_resolutions: tuple[int, ...] = (32,),
|
||||
resolution: int = 1024,
|
||||
num_layers: int = 2,
|
||||
patch_size: int = 4,
|
||||
@@ -934,8 +936,8 @@ class AutoencoderKLCosmos(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
scaling_factor: float = 1.0,
|
||||
spatial_compression_ratio: int = 8,
|
||||
temporal_compression_ratio: int = 8,
|
||||
latents_mean: Optional[List[float]] = LATENTS_MEAN,
|
||||
latents_std: Optional[List[float]] = LATENTS_STD,
|
||||
latents_mean: Optional[list[float]] = LATENTS_MEAN,
|
||||
latents_std: Optional[list[float]] = LATENTS_STD,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -1050,7 +1052,7 @@ class AutoencoderKLCosmos(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
return (posterior,)
|
||||
return AutoencoderKLOutput(latent_dist=posterior)
|
||||
|
||||
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
|
||||
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | tuple[torch.Tensor]:
|
||||
z = self.post_quant_conv(z)
|
||||
dec = self.decoder(z)
|
||||
|
||||
@@ -1059,7 +1061,7 @@ class AutoencoderKLCosmos(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
return DecoderOutput(sample=dec)
|
||||
|
||||
@apply_forward_hook
|
||||
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
|
||||
def decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | tuple[torch.Tensor]:
|
||||
if self.use_slicing and z.shape[0] > 1:
|
||||
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
|
||||
decoded = torch.cat(decoded_slices)
|
||||
@@ -1076,7 +1078,7 @@ class AutoencoderKLCosmos(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
sample_posterior: bool = False,
|
||||
return_dict: bool = True,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
) -> Union[Tuple[torch.Tensor], DecoderOutput]:
|
||||
) -> tuple[torch.Tensor] | DecoderOutput:
|
||||
x = sample
|
||||
posterior = self.encode(x).latent_dist
|
||||
if sample_posterior:
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -50,10 +50,10 @@ class HunyuanVideoCausalConv3d(nn.Module):
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: Union[int, Tuple[int, int, int]] = 3,
|
||||
stride: Union[int, Tuple[int, int, int]] = 1,
|
||||
padding: Union[int, Tuple[int, int, int]] = 0,
|
||||
dilation: Union[int, Tuple[int, int, int]] = 1,
|
||||
kernel_size: int | tuple[int, int, int] = 3,
|
||||
stride: int | tuple[int, int, int] = 1,
|
||||
padding: int | tuple[int, int, int] = 0,
|
||||
dilation: int | tuple[int, int, int] = 1,
|
||||
bias: bool = True,
|
||||
pad_mode: str = "replicate",
|
||||
) -> None:
|
||||
@@ -86,7 +86,7 @@ class HunyuanVideoUpsampleCausal3D(nn.Module):
|
||||
kernel_size: int = 3,
|
||||
stride: int = 1,
|
||||
bias: bool = True,
|
||||
upsample_factor: Tuple[float, float, float] = (2, 2, 2),
|
||||
upsample_factor: tuple[float, float, float] = (2, 2, 2),
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -357,7 +357,7 @@ class HunyuanVideoUpBlock3D(nn.Module):
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_groups: int = 32,
|
||||
add_upsample: bool = True,
|
||||
upsample_scale_factor: Tuple[int, int, int] = (2, 2, 2),
|
||||
upsample_scale_factor: tuple[int, int, int] = (2, 2, 2),
|
||||
) -> None:
|
||||
super().__init__()
|
||||
resnets = []
|
||||
@@ -418,13 +418,13 @@ class HunyuanVideoEncoder3D(nn.Module):
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
down_block_types: tuple[str, ...] = (
|
||||
"HunyuanVideoDownBlock3D",
|
||||
"HunyuanVideoDownBlock3D",
|
||||
"HunyuanVideoDownBlock3D",
|
||||
"HunyuanVideoDownBlock3D",
|
||||
),
|
||||
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
|
||||
block_out_channels: tuple[int, ...] = (128, 256, 512, 512),
|
||||
layers_per_block: int = 2,
|
||||
norm_num_groups: int = 32,
|
||||
act_fn: str = "silu",
|
||||
@@ -526,13 +526,13 @@ class HunyuanVideoDecoder3D(nn.Module):
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
up_block_types: Tuple[str, ...] = (
|
||||
up_block_types: tuple[str, ...] = (
|
||||
"HunyuanVideoUpBlock3D",
|
||||
"HunyuanVideoUpBlock3D",
|
||||
"HunyuanVideoUpBlock3D",
|
||||
"HunyuanVideoUpBlock3D",
|
||||
),
|
||||
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
|
||||
block_out_channels: tuple[int, ...] = (128, 256, 512, 512),
|
||||
layers_per_block: int = 2,
|
||||
norm_num_groups: int = 32,
|
||||
act_fn: str = "silu",
|
||||
@@ -641,19 +641,19 @@ class AutoencoderKLHunyuanVideo(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
latent_channels: int = 16,
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
down_block_types: tuple[str, ...] = (
|
||||
"HunyuanVideoDownBlock3D",
|
||||
"HunyuanVideoDownBlock3D",
|
||||
"HunyuanVideoDownBlock3D",
|
||||
"HunyuanVideoDownBlock3D",
|
||||
),
|
||||
up_block_types: Tuple[str, ...] = (
|
||||
up_block_types: tuple[str, ...] = (
|
||||
"HunyuanVideoUpBlock3D",
|
||||
"HunyuanVideoUpBlock3D",
|
||||
"HunyuanVideoUpBlock3D",
|
||||
"HunyuanVideoUpBlock3D",
|
||||
),
|
||||
block_out_channels: Tuple[int] = (128, 256, 512, 512),
|
||||
block_out_channels: tuple[int] = (128, 256, 512, 512),
|
||||
layers_per_block: int = 2,
|
||||
act_fn: str = "silu",
|
||||
norm_num_groups: int = 32,
|
||||
@@ -779,7 +779,7 @@ class AutoencoderKLHunyuanVideo(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
@apply_forward_hook
|
||||
def encode(
|
||||
self, x: torch.Tensor, return_dict: bool = True
|
||||
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
|
||||
) -> AutoencoderKLOutput | tuple[DiagonalGaussianDistribution]:
|
||||
r"""
|
||||
Encode a batch of images into latents.
|
||||
|
||||
@@ -804,7 +804,7 @@ class AutoencoderKLHunyuanVideo(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
return (posterior,)
|
||||
return AutoencoderKLOutput(latent_dist=posterior)
|
||||
|
||||
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor:
|
||||
batch_size, num_channels, num_frames, height, width = z.shape
|
||||
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
|
||||
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
|
||||
@@ -825,7 +825,7 @@ class AutoencoderKLHunyuanVideo(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
return DecoderOutput(sample=dec)
|
||||
|
||||
@apply_forward_hook
|
||||
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
def decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor:
|
||||
r"""
|
||||
Decode a batch of images.
|
||||
|
||||
@@ -924,7 +924,7 @@ class AutoencoderKLHunyuanVideo(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width]
|
||||
return enc
|
||||
|
||||
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor:
|
||||
r"""
|
||||
Decode a batch of images using a tiled decoder.
|
||||
|
||||
@@ -1013,7 +1013,7 @@ class AutoencoderKLHunyuanVideo(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
enc = torch.cat(result_row, dim=2)[:, :, :latent_num_frames]
|
||||
return enc
|
||||
|
||||
def _temporal_tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
def _temporal_tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor:
|
||||
batch_size, num_channels, num_frames, height, width = z.shape
|
||||
num_sample_frames = (num_frames - 1) * self.temporal_compression_ratio + 1
|
||||
|
||||
@@ -1055,7 +1055,7 @@ class AutoencoderKLHunyuanVideo(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
sample_posterior: bool = False,
|
||||
return_dict: bool = True,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
) -> Union[DecoderOutput, torch.Tensor]:
|
||||
) -> DecoderOutput | torch.Tensor:
|
||||
r"""
|
||||
Args:
|
||||
sample (`torch.Tensor`): Input sample.
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -34,9 +34,9 @@ class LTXVideoCausalConv3d(nn.Module):
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: Union[int, Tuple[int, int, int]] = 3,
|
||||
stride: Union[int, Tuple[int, int, int]] = 1,
|
||||
dilation: Union[int, Tuple[int, int, int]] = 1,
|
||||
kernel_size: int | tuple[int, int, int] = 3,
|
||||
stride: int | tuple[int, int, int] = 1,
|
||||
dilation: int | tuple[int, int, int] = 1,
|
||||
groups: int = 1,
|
||||
padding_mode: str = "zeros",
|
||||
is_causal: bool = True,
|
||||
@@ -201,7 +201,7 @@ class LTXVideoDownsampler3d(nn.Module):
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
stride: Union[int, Tuple[int, int, int]] = 1,
|
||||
stride: int | tuple[int, int, int] = 1,
|
||||
is_causal: bool = True,
|
||||
padding_mode: str = "zeros",
|
||||
) -> None:
|
||||
@@ -249,7 +249,7 @@ class LTXVideoUpsampler3d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
stride: Union[int, Tuple[int, int, int]] = 1,
|
||||
stride: int | tuple[int, int, int] = 1,
|
||||
is_causal: bool = True,
|
||||
residual: bool = False,
|
||||
upscale_factor: int = 1,
|
||||
@@ -735,11 +735,11 @@ class LTXVideoEncoder3d(nn.Module):
|
||||
Number of input channels.
|
||||
out_channels (`int`, defaults to 128):
|
||||
Number of latent channels.
|
||||
block_out_channels (`Tuple[int, ...]`, defaults to `(128, 256, 512, 512)`):
|
||||
block_out_channels (`tuple[int, ...]`, defaults to `(128, 256, 512, 512)`):
|
||||
The number of output channels for each block.
|
||||
spatio_temporal_scaling (`Tuple[bool, ...], defaults to `(True, True, True, False)`:
|
||||
spatio_temporal_scaling (`tuple[bool, ...], defaults to `(True, True, True, False)`:
|
||||
Whether a block should contain spatio-temporal downscaling layers or not.
|
||||
layers_per_block (`Tuple[int, ...]`, defaults to `(4, 3, 3, 3, 4)`):
|
||||
layers_per_block (`tuple[int, ...]`, defaults to `(4, 3, 3, 3, 4)`):
|
||||
The number of layers per block.
|
||||
patch_size (`int`, defaults to `4`):
|
||||
The size of spatial patches.
|
||||
@@ -755,16 +755,16 @@ class LTXVideoEncoder3d(nn.Module):
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 128,
|
||||
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
block_out_channels: tuple[int, ...] = (128, 256, 512, 512),
|
||||
down_block_types: tuple[str, ...] = (
|
||||
"LTXVideoDownBlock3D",
|
||||
"LTXVideoDownBlock3D",
|
||||
"LTXVideoDownBlock3D",
|
||||
"LTXVideoDownBlock3D",
|
||||
),
|
||||
spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False),
|
||||
layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4),
|
||||
downsample_type: Tuple[str, ...] = ("conv", "conv", "conv", "conv"),
|
||||
spatio_temporal_scaling: tuple[bool, ...] = (True, True, True, False),
|
||||
layers_per_block: tuple[int, ...] = (4, 3, 3, 3, 4),
|
||||
downsample_type: tuple[str, ...] = ("conv", "conv", "conv", "conv"),
|
||||
patch_size: int = 4,
|
||||
patch_size_t: int = 1,
|
||||
resnet_norm_eps: float = 1e-6,
|
||||
@@ -888,11 +888,11 @@ class LTXVideoDecoder3d(nn.Module):
|
||||
Number of latent channels.
|
||||
out_channels (`int`, defaults to 3):
|
||||
Number of output channels.
|
||||
block_out_channels (`Tuple[int, ...]`, defaults to `(128, 256, 512, 512)`):
|
||||
block_out_channels (`tuple[int, ...]`, defaults to `(128, 256, 512, 512)`):
|
||||
The number of output channels for each block.
|
||||
spatio_temporal_scaling (`Tuple[bool, ...], defaults to `(True, True, True, False)`:
|
||||
spatio_temporal_scaling (`tuple[bool, ...], defaults to `(True, True, True, False)`:
|
||||
Whether a block should contain spatio-temporal upscaling layers or not.
|
||||
layers_per_block (`Tuple[int, ...]`, defaults to `(4, 3, 3, 3, 4)`):
|
||||
layers_per_block (`tuple[int, ...]`, defaults to `(4, 3, 3, 3, 4)`):
|
||||
The number of layers per block.
|
||||
patch_size (`int`, defaults to `4`):
|
||||
The size of spatial patches.
|
||||
@@ -910,17 +910,17 @@ class LTXVideoDecoder3d(nn.Module):
|
||||
self,
|
||||
in_channels: int = 128,
|
||||
out_channels: int = 3,
|
||||
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
|
||||
spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False),
|
||||
layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4),
|
||||
block_out_channels: tuple[int, ...] = (128, 256, 512, 512),
|
||||
spatio_temporal_scaling: tuple[bool, ...] = (True, True, True, False),
|
||||
layers_per_block: tuple[int, ...] = (4, 3, 3, 3, 4),
|
||||
patch_size: int = 4,
|
||||
patch_size_t: int = 1,
|
||||
resnet_norm_eps: float = 1e-6,
|
||||
is_causal: bool = False,
|
||||
inject_noise: Tuple[bool, ...] = (False, False, False, False),
|
||||
inject_noise: tuple[bool, ...] = (False, False, False, False),
|
||||
timestep_conditioning: bool = False,
|
||||
upsample_residual: Tuple[bool, ...] = (False, False, False, False),
|
||||
upsample_factor: Tuple[bool, ...] = (1, 1, 1, 1),
|
||||
upsample_residual: tuple[bool, ...] = (False, False, False, False),
|
||||
upsample_factor: tuple[bool, ...] = (1, 1, 1, 1),
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -1049,11 +1049,11 @@ class AutoencoderKLLTXVideo(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrigi
|
||||
Number of output channels.
|
||||
latent_channels (`int`, defaults to `128`):
|
||||
Number of latent channels.
|
||||
block_out_channels (`Tuple[int, ...]`, defaults to `(128, 256, 512, 512)`):
|
||||
block_out_channels (`tuple[int, ...]`, defaults to `(128, 256, 512, 512)`):
|
||||
The number of output channels for each block.
|
||||
spatio_temporal_scaling (`Tuple[bool, ...], defaults to `(True, True, True, False)`:
|
||||
spatio_temporal_scaling (`tuple[bool, ...], defaults to `(True, True, True, False)`:
|
||||
Whether a block should contain spatio-temporal downscaling or not.
|
||||
layers_per_block (`Tuple[int, ...]`, defaults to `(4, 3, 3, 3, 4)`):
|
||||
layers_per_block (`tuple[int, ...]`, defaults to `(4, 3, 3, 3, 4)`):
|
||||
The number of layers per block.
|
||||
patch_size (`int`, defaults to `4`):
|
||||
The size of spatial patches.
|
||||
@@ -1082,22 +1082,22 @@ class AutoencoderKLLTXVideo(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrigi
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
latent_channels: int = 128,
|
||||
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
block_out_channels: tuple[int, ...] = (128, 256, 512, 512),
|
||||
down_block_types: tuple[str, ...] = (
|
||||
"LTXVideoDownBlock3D",
|
||||
"LTXVideoDownBlock3D",
|
||||
"LTXVideoDownBlock3D",
|
||||
"LTXVideoDownBlock3D",
|
||||
),
|
||||
decoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
|
||||
layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4),
|
||||
decoder_layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4),
|
||||
spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False),
|
||||
decoder_spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False),
|
||||
decoder_inject_noise: Tuple[bool, ...] = (False, False, False, False, False),
|
||||
downsample_type: Tuple[str, ...] = ("conv", "conv", "conv", "conv"),
|
||||
upsample_residual: Tuple[bool, ...] = (False, False, False, False),
|
||||
upsample_factor: Tuple[int, ...] = (1, 1, 1, 1),
|
||||
decoder_block_out_channels: tuple[int, ...] = (128, 256, 512, 512),
|
||||
layers_per_block: tuple[int, ...] = (4, 3, 3, 3, 4),
|
||||
decoder_layers_per_block: tuple[int, ...] = (4, 3, 3, 3, 4),
|
||||
spatio_temporal_scaling: tuple[bool, ...] = (True, True, True, False),
|
||||
decoder_spatio_temporal_scaling: tuple[bool, ...] = (True, True, True, False),
|
||||
decoder_inject_noise: tuple[bool, ...] = (False, False, False, False, False),
|
||||
downsample_type: tuple[str, ...] = ("conv", "conv", "conv", "conv"),
|
||||
upsample_residual: tuple[bool, ...] = (False, False, False, False),
|
||||
upsample_factor: tuple[int, ...] = (1, 1, 1, 1),
|
||||
timestep_conditioning: bool = False,
|
||||
patch_size: int = 4,
|
||||
patch_size_t: int = 1,
|
||||
@@ -1235,7 +1235,7 @@ class AutoencoderKLLTXVideo(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrigi
|
||||
@apply_forward_hook
|
||||
def encode(
|
||||
self, x: torch.Tensor, return_dict: bool = True
|
||||
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
|
||||
) -> AutoencoderKLOutput | tuple[DiagonalGaussianDistribution]:
|
||||
"""
|
||||
Encode a batch of images into latents.
|
||||
|
||||
@@ -1261,7 +1261,7 @@ class AutoencoderKLLTXVideo(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrigi
|
||||
|
||||
def _decode(
|
||||
self, z: torch.Tensor, temb: Optional[torch.Tensor] = None, return_dict: bool = True
|
||||
) -> Union[DecoderOutput, torch.Tensor]:
|
||||
) -> DecoderOutput | torch.Tensor:
|
||||
batch_size, num_channels, num_frames, height, width = z.shape
|
||||
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
|
||||
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
|
||||
@@ -1283,7 +1283,7 @@ class AutoencoderKLLTXVideo(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrigi
|
||||
@apply_forward_hook
|
||||
def decode(
|
||||
self, z: torch.Tensor, temb: Optional[torch.Tensor] = None, return_dict: bool = True
|
||||
) -> Union[DecoderOutput, torch.Tensor]:
|
||||
) -> DecoderOutput | torch.Tensor:
|
||||
"""
|
||||
Decode a batch of images.
|
||||
|
||||
@@ -1390,7 +1390,7 @@ class AutoencoderKLLTXVideo(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrigi
|
||||
|
||||
def tiled_decode(
|
||||
self, z: torch.Tensor, temb: Optional[torch.Tensor], return_dict: bool = True
|
||||
) -> Union[DecoderOutput, torch.Tensor]:
|
||||
) -> DecoderOutput | torch.Tensor:
|
||||
r"""
|
||||
Decode a batch of images using a tiled decoder.
|
||||
|
||||
@@ -1480,7 +1480,7 @@ class AutoencoderKLLTXVideo(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrigi
|
||||
|
||||
def _temporal_tiled_decode(
|
||||
self, z: torch.Tensor, temb: Optional[torch.Tensor], return_dict: bool = True
|
||||
) -> Union[DecoderOutput, torch.Tensor]:
|
||||
) -> DecoderOutput | torch.Tensor:
|
||||
batch_size, num_channels, num_frames, height, width = z.shape
|
||||
num_sample_frames = (num_frames - 1) * self.temporal_compression_ratio + 1
|
||||
|
||||
@@ -1523,7 +1523,7 @@ class AutoencoderKLLTXVideo(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrigi
|
||||
sample_posterior: bool = False,
|
||||
return_dict: bool = True,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
) -> Union[torch.Tensor, torch.Tensor]:
|
||||
) -> torch.Tensor | torch.Tensor:
|
||||
x = sample
|
||||
posterior = self.encode(x).latent_dist
|
||||
if sample_posterior:
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -37,10 +37,10 @@ class EasyAnimateCausalConv3d(nn.Conv3d):
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: Union[int, Tuple[int, ...]] = 3,
|
||||
stride: Union[int, Tuple[int, ...]] = 1,
|
||||
padding: Union[int, Tuple[int, ...]] = 1,
|
||||
dilation: Union[int, Tuple[int, ...]] = 1,
|
||||
kernel_size: int | tuple[int, ...] = 3,
|
||||
stride: int | tuple[int, ...] = 1,
|
||||
padding: int | tuple[int, ...] = 1,
|
||||
dilation: int | tuple[int, ...] = 1,
|
||||
groups: int = 1,
|
||||
bias: bool = True,
|
||||
padding_mode: str = "zeros",
|
||||
@@ -437,13 +437,13 @@ class EasyAnimateEncoder(nn.Module):
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 8,
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
down_block_types: tuple[str, ...] = (
|
||||
"SpatialDownBlock3D",
|
||||
"SpatialTemporalDownBlock3D",
|
||||
"SpatialTemporalDownBlock3D",
|
||||
"SpatialTemporalDownBlock3D",
|
||||
),
|
||||
block_out_channels: Tuple[int, ...] = [128, 256, 512, 512],
|
||||
block_out_channels: tuple[int, ...] = [128, 256, 512, 512],
|
||||
layers_per_block: int = 2,
|
||||
norm_num_groups: int = 32,
|
||||
act_fn: str = "silu",
|
||||
@@ -553,13 +553,13 @@ class EasyAnimateDecoder(nn.Module):
|
||||
self,
|
||||
in_channels: int = 8,
|
||||
out_channels: int = 3,
|
||||
up_block_types: Tuple[str, ...] = (
|
||||
up_block_types: tuple[str, ...] = (
|
||||
"SpatialUpBlock3D",
|
||||
"SpatialTemporalUpBlock3D",
|
||||
"SpatialTemporalUpBlock3D",
|
||||
"SpatialTemporalUpBlock3D",
|
||||
),
|
||||
block_out_channels: Tuple[int, ...] = [128, 256, 512, 512],
|
||||
block_out_channels: tuple[int, ...] = [128, 256, 512, 512],
|
||||
layers_per_block: int = 2,
|
||||
norm_num_groups: int = 32,
|
||||
act_fn: str = "silu",
|
||||
@@ -680,14 +680,14 @@ class AutoencoderKLMagvit(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
in_channels: int = 3,
|
||||
latent_channels: int = 16,
|
||||
out_channels: int = 3,
|
||||
block_out_channels: Tuple[int, ...] = [128, 256, 512, 512],
|
||||
down_block_types: Tuple[str, ...] = [
|
||||
block_out_channels: tuple[int, ...] = [128, 256, 512, 512],
|
||||
down_block_types: tuple[str, ...] = [
|
||||
"SpatialDownBlock3D",
|
||||
"SpatialTemporalDownBlock3D",
|
||||
"SpatialTemporalDownBlock3D",
|
||||
"SpatialTemporalDownBlock3D",
|
||||
],
|
||||
up_block_types: Tuple[str, ...] = [
|
||||
up_block_types: tuple[str, ...] = [
|
||||
"SpatialUpBlock3D",
|
||||
"SpatialTemporalUpBlock3D",
|
||||
"SpatialTemporalUpBlock3D",
|
||||
@@ -808,7 +808,7 @@ class AutoencoderKLMagvit(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
@apply_forward_hook
|
||||
def _encode(
|
||||
self, x: torch.Tensor, return_dict: bool = True
|
||||
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
|
||||
) -> AutoencoderKLOutput | tuple[DiagonalGaussianDistribution]:
|
||||
"""
|
||||
Encode a batch of images into latents.
|
||||
|
||||
@@ -838,7 +838,7 @@ class AutoencoderKLMagvit(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
@apply_forward_hook
|
||||
def encode(
|
||||
self, x: torch.Tensor, return_dict: bool = True
|
||||
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
|
||||
) -> AutoencoderKLOutput | tuple[DiagonalGaussianDistribution]:
|
||||
"""
|
||||
Encode a batch of images into latents.
|
||||
|
||||
@@ -863,7 +863,7 @@ class AutoencoderKLMagvit(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
return (posterior,)
|
||||
return AutoencoderKLOutput(latent_dist=posterior)
|
||||
|
||||
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor:
|
||||
batch_size, num_channels, num_frames, height, width = z.shape
|
||||
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
|
||||
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
|
||||
@@ -890,7 +890,7 @@ class AutoencoderKLMagvit(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
return DecoderOutput(sample=dec)
|
||||
|
||||
@apply_forward_hook
|
||||
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
def decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor:
|
||||
"""
|
||||
Decode a batch of images.
|
||||
|
||||
@@ -983,7 +983,7 @@ class AutoencoderKLMagvit(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
moments = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width]
|
||||
return moments
|
||||
|
||||
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor:
|
||||
batch_size, num_channels, num_frames, height, width = z.shape
|
||||
sample_height = height * self.spatial_compression_ratio
|
||||
sample_width = width * self.spatial_compression_ratio
|
||||
@@ -1050,7 +1050,7 @@ class AutoencoderKLMagvit(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
sample_posterior: bool = False,
|
||||
return_dict: bool = True,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
) -> Union[DecoderOutput, torch.Tensor]:
|
||||
) -> DecoderOutput | torch.Tensor:
|
||||
r"""
|
||||
Args:
|
||||
sample (`torch.Tensor`): Input sample.
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import functools
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -106,7 +106,7 @@ class MochiResnetBlock3D(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
inputs: torch.Tensor,
|
||||
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
|
||||
conv_cache: Optional[dict[str, torch.Tensor]] = None,
|
||||
) -> torch.Tensor:
|
||||
new_conv_cache = {}
|
||||
conv_cache = conv_cache or {}
|
||||
@@ -193,7 +193,7 @@ class MochiDownBlock3D(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
|
||||
conv_cache: Optional[dict[str, torch.Tensor]] = None,
|
||||
chunk_size: int = 2**15,
|
||||
) -> torch.Tensor:
|
||||
r"""Forward method of the `MochiUpBlock3D` class."""
|
||||
@@ -294,7 +294,7 @@ class MochiMidBlock3D(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
|
||||
conv_cache: Optional[dict[str, torch.Tensor]] = None,
|
||||
) -> torch.Tensor:
|
||||
r"""Forward method of the `MochiMidBlock3D` class."""
|
||||
|
||||
@@ -368,7 +368,7 @@ class MochiUpBlock3D(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
|
||||
conv_cache: Optional[dict[str, torch.Tensor]] = None,
|
||||
) -> torch.Tensor:
|
||||
r"""Forward method of the `MochiUpBlock3D` class."""
|
||||
|
||||
@@ -445,13 +445,13 @@ class MochiEncoder3D(nn.Module):
|
||||
The number of input channels.
|
||||
out_channels (`int`, *optional*):
|
||||
The number of output channels.
|
||||
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(128, 256, 512, 768)`):
|
||||
block_out_channels (`tuple[int, ...]`, *optional*, defaults to `(128, 256, 512, 768)`):
|
||||
The number of output channels for each block.
|
||||
layers_per_block (`Tuple[int, ...]`, *optional*, defaults to `(3, 3, 4, 6, 3)`):
|
||||
layers_per_block (`tuple[int, ...]`, *optional*, defaults to `(3, 3, 4, 6, 3)`):
|
||||
The number of resnet blocks for each block.
|
||||
temporal_expansions (`Tuple[int, ...]`, *optional*, defaults to `(1, 2, 3)`):
|
||||
temporal_expansions (`tuple[int, ...]`, *optional*, defaults to `(1, 2, 3)`):
|
||||
The temporal expansion factor for each of the up blocks.
|
||||
spatial_expansions (`Tuple[int, ...]`, *optional*, defaults to `(2, 2, 2)`):
|
||||
spatial_expansions (`tuple[int, ...]`, *optional*, defaults to `(2, 2, 2)`):
|
||||
The spatial expansion factor for each of the up blocks.
|
||||
non_linearity (`str`, *optional*, defaults to `"swish"`):
|
||||
The non-linearity to use in the decoder.
|
||||
@@ -461,11 +461,11 @@ class MochiEncoder3D(nn.Module):
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
block_out_channels: Tuple[int, ...] = (128, 256, 512, 768),
|
||||
layers_per_block: Tuple[int, ...] = (3, 3, 4, 6, 3),
|
||||
temporal_expansions: Tuple[int, ...] = (1, 2, 3),
|
||||
spatial_expansions: Tuple[int, ...] = (2, 2, 2),
|
||||
add_attention_block: Tuple[bool, ...] = (False, True, True, True, True),
|
||||
block_out_channels: tuple[int, ...] = (128, 256, 512, 768),
|
||||
layers_per_block: tuple[int, ...] = (3, 3, 4, 6, 3),
|
||||
temporal_expansions: tuple[int, ...] = (1, 2, 3),
|
||||
spatial_expansions: tuple[int, ...] = (2, 2, 2),
|
||||
add_attention_block: tuple[bool, ...] = (False, True, True, True, True),
|
||||
act_fn: str = "swish",
|
||||
):
|
||||
super().__init__()
|
||||
@@ -500,7 +500,7 @@ class MochiEncoder3D(nn.Module):
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, conv_cache: Optional[Dict[str, torch.Tensor]] = None
|
||||
self, hidden_states: torch.Tensor, conv_cache: Optional[dict[str, torch.Tensor]] = None
|
||||
) -> torch.Tensor:
|
||||
r"""Forward method of the `MochiEncoder3D` class."""
|
||||
|
||||
@@ -558,13 +558,13 @@ class MochiDecoder3D(nn.Module):
|
||||
The number of input channels.
|
||||
out_channels (`int`, *optional*):
|
||||
The number of output channels.
|
||||
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(128, 256, 512, 768)`):
|
||||
block_out_channels (`tuple[int, ...]`, *optional*, defaults to `(128, 256, 512, 768)`):
|
||||
The number of output channels for each block.
|
||||
layers_per_block (`Tuple[int, ...]`, *optional*, defaults to `(3, 3, 4, 6, 3)`):
|
||||
layers_per_block (`tuple[int, ...]`, *optional*, defaults to `(3, 3, 4, 6, 3)`):
|
||||
The number of resnet blocks for each block.
|
||||
temporal_expansions (`Tuple[int, ...]`, *optional*, defaults to `(1, 2, 3)`):
|
||||
temporal_expansions (`tuple[int, ...]`, *optional*, defaults to `(1, 2, 3)`):
|
||||
The temporal expansion factor for each of the up blocks.
|
||||
spatial_expansions (`Tuple[int, ...]`, *optional*, defaults to `(2, 2, 2)`):
|
||||
spatial_expansions (`tuple[int, ...]`, *optional*, defaults to `(2, 2, 2)`):
|
||||
The spatial expansion factor for each of the up blocks.
|
||||
non_linearity (`str`, *optional*, defaults to `"swish"`):
|
||||
The non-linearity to use in the decoder.
|
||||
@@ -574,10 +574,10 @@ class MochiDecoder3D(nn.Module):
|
||||
self,
|
||||
in_channels: int, # 12
|
||||
out_channels: int, # 3
|
||||
block_out_channels: Tuple[int, ...] = (128, 256, 512, 768),
|
||||
layers_per_block: Tuple[int, ...] = (3, 3, 4, 6, 3),
|
||||
temporal_expansions: Tuple[int, ...] = (1, 2, 3),
|
||||
spatial_expansions: Tuple[int, ...] = (2, 2, 2),
|
||||
block_out_channels: tuple[int, ...] = (128, 256, 512, 768),
|
||||
layers_per_block: tuple[int, ...] = (3, 3, 4, 6, 3),
|
||||
temporal_expansions: tuple[int, ...] = (1, 2, 3),
|
||||
spatial_expansions: tuple[int, ...] = (2, 2, 2),
|
||||
act_fn: str = "swish",
|
||||
):
|
||||
super().__init__()
|
||||
@@ -613,7 +613,7 @@ class MochiDecoder3D(nn.Module):
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, conv_cache: Optional[Dict[str, torch.Tensor]] = None
|
||||
self, hidden_states: torch.Tensor, conv_cache: Optional[dict[str, torch.Tensor]] = None
|
||||
) -> torch.Tensor:
|
||||
r"""Forward method of the `MochiDecoder3D` class."""
|
||||
|
||||
@@ -668,8 +668,8 @@ class AutoencoderKLMochi(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
Parameters:
|
||||
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
|
||||
out_channels (int, *optional*, defaults to 3): Number of channels in the output.
|
||||
block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
|
||||
Tuple of block output channels.
|
||||
block_out_channels (`tuple[int]`, *optional*, defaults to `(64,)`):
|
||||
tuple of block output channels.
|
||||
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
||||
scaling_factor (`float`, *optional*, defaults to `1.15258426`):
|
||||
The component-wise standard deviation of the trained latent space computed using the first batch of the
|
||||
@@ -688,15 +688,15 @@ class AutoencoderKLMochi(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
self,
|
||||
in_channels: int = 15,
|
||||
out_channels: int = 3,
|
||||
encoder_block_out_channels: Tuple[int] = (64, 128, 256, 384),
|
||||
decoder_block_out_channels: Tuple[int] = (128, 256, 512, 768),
|
||||
encoder_block_out_channels: tuple[int] = (64, 128, 256, 384),
|
||||
decoder_block_out_channels: tuple[int] = (128, 256, 512, 768),
|
||||
latent_channels: int = 12,
|
||||
layers_per_block: Tuple[int, ...] = (3, 3, 4, 6, 3),
|
||||
layers_per_block: tuple[int, ...] = (3, 3, 4, 6, 3),
|
||||
act_fn: str = "silu",
|
||||
temporal_expansions: Tuple[int, ...] = (1, 2, 3),
|
||||
spatial_expansions: Tuple[int, ...] = (2, 2, 2),
|
||||
add_attention_block: Tuple[bool, ...] = (False, True, True, True, True),
|
||||
latents_mean: Tuple[float, ...] = (
|
||||
temporal_expansions: tuple[int, ...] = (1, 2, 3),
|
||||
spatial_expansions: tuple[int, ...] = (2, 2, 2),
|
||||
add_attention_block: tuple[bool, ...] = (False, True, True, True, True),
|
||||
latents_mean: tuple[float, ...] = (
|
||||
-0.06730895953510081,
|
||||
-0.038011381506090416,
|
||||
-0.07477820912866141,
|
||||
@@ -710,7 +710,7 @@ class AutoencoderKLMochi(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
-0.011931556316503654,
|
||||
-0.0321993391887285,
|
||||
),
|
||||
latents_std: Tuple[float, ...] = (
|
||||
latents_std: tuple[float, ...] = (
|
||||
0.9263795028493863,
|
||||
0.9248894543193766,
|
||||
0.9393059390890617,
|
||||
@@ -860,7 +860,7 @@ class AutoencoderKLMochi(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
@apply_forward_hook
|
||||
def encode(
|
||||
self, x: torch.Tensor, return_dict: bool = True
|
||||
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
|
||||
) -> AutoencoderKLOutput | tuple[DiagonalGaussianDistribution]:
|
||||
"""
|
||||
Encode a batch of images into latents.
|
||||
|
||||
@@ -885,7 +885,7 @@ class AutoencoderKLMochi(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
return (posterior,)
|
||||
return AutoencoderKLOutput(latent_dist=posterior)
|
||||
|
||||
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor:
|
||||
batch_size, num_channels, num_frames, height, width = z.shape
|
||||
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
|
||||
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
|
||||
@@ -915,7 +915,7 @@ class AutoencoderKLMochi(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
return DecoderOutput(sample=dec)
|
||||
|
||||
@apply_forward_hook
|
||||
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
def decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor:
|
||||
"""
|
||||
Decode a batch of images.
|
||||
|
||||
@@ -1013,7 +1013,7 @@ class AutoencoderKLMochi(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width]
|
||||
return enc
|
||||
|
||||
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor:
|
||||
r"""
|
||||
Decode a batch of images using a tiled decoder.
|
||||
|
||||
@@ -1097,7 +1097,7 @@ class AutoencoderKLMochi(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
sample_posterior: bool = False,
|
||||
return_dict: bool = True,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
) -> Union[torch.Tensor, torch.Tensor]:
|
||||
) -> torch.Tensor | torch.Tensor:
|
||||
x = sample
|
||||
posterior = self.encode(x).latent_dist
|
||||
if sample_posterior:
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
# - GitHub: https://github.com/Wan-Video/Wan2.1
|
||||
# - arXiv: https://arxiv.org/abs/2503.20314
|
||||
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -58,9 +58,9 @@ class QwenImageCausalConv3d(nn.Conv3d):
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: Union[int, Tuple[int, int, int]],
|
||||
stride: Union[int, Tuple[int, int, int]] = 1,
|
||||
padding: Union[int, Tuple[int, int, int]] = 0,
|
||||
kernel_size: int | tuple[int, int, int],
|
||||
stride: int | tuple[int, int, int] = 1,
|
||||
padding: int | tuple[int, int, int] = 0,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
in_channels=in_channels,
|
||||
@@ -679,13 +679,13 @@ class AutoencoderKLQwenImage(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
|
||||
self,
|
||||
base_dim: int = 96,
|
||||
z_dim: int = 16,
|
||||
dim_mult: Tuple[int] = [1, 2, 4, 4],
|
||||
dim_mult: tuple[int] = [1, 2, 4, 4],
|
||||
num_res_blocks: int = 2,
|
||||
attn_scales: List[float] = [],
|
||||
temperal_downsample: List[bool] = [False, True, True],
|
||||
attn_scales: list[float] = [],
|
||||
temperal_downsample: list[bool] = [False, True, True],
|
||||
dropout: float = 0.0,
|
||||
latents_mean: List[float] = [-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508, 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921],
|
||||
latents_std: List[float] = [2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743, 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160],
|
||||
latents_mean: list[float] = [-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508, 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921],
|
||||
latents_std: list[float] = [2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743, 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160],
|
||||
) -> None:
|
||||
# fmt: on
|
||||
super().__init__()
|
||||
@@ -806,7 +806,7 @@ class AutoencoderKLQwenImage(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
|
||||
@apply_forward_hook
|
||||
def encode(
|
||||
self, x: torch.Tensor, return_dict: bool = True
|
||||
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
|
||||
) -> AutoencoderKLOutput | tuple[DiagonalGaussianDistribution]:
|
||||
r"""
|
||||
Encode a batch of images into latents.
|
||||
|
||||
@@ -856,7 +856,7 @@ class AutoencoderKLQwenImage(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
|
||||
return DecoderOutput(sample=out)
|
||||
|
||||
@apply_forward_hook
|
||||
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
def decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor:
|
||||
r"""
|
||||
Decode a batch of images.
|
||||
|
||||
@@ -962,7 +962,7 @@ class AutoencoderKLQwenImage(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
|
||||
enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width]
|
||||
return enc
|
||||
|
||||
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor:
|
||||
r"""
|
||||
Decode a batch of images using a tiled decoder.
|
||||
|
||||
@@ -1031,7 +1031,7 @@ class AutoencoderKLQwenImage(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
|
||||
sample_posterior: bool = False,
|
||||
return_dict: bool = True,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
) -> Union[DecoderOutput, torch.Tensor]:
|
||||
) -> DecoderOutput | torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
sample (`torch.Tensor`): Input sample.
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import itertools
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -31,7 +31,7 @@ class TemporalDecoder(nn.Module):
|
||||
self,
|
||||
in_channels: int = 4,
|
||||
out_channels: int = 3,
|
||||
block_out_channels: Tuple[int] = (128, 256, 512, 512),
|
||||
block_out_channels: tuple[int] = (128, 256, 512, 512),
|
||||
layers_per_block: int = 2,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -145,10 +145,10 @@ class AutoencoderKLTemporalDecoder(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
Parameters:
|
||||
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
|
||||
out_channels (int, *optional*, defaults to 3): Number of channels in the output.
|
||||
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
|
||||
Tuple of downsample block types.
|
||||
block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
|
||||
Tuple of block output channels.
|
||||
down_block_types (`tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
|
||||
tuple of downsample block types.
|
||||
block_out_channels (`tuple[int]`, *optional*, defaults to `(64,)`):
|
||||
tuple of block output channels.
|
||||
layers_per_block: (`int`, *optional*, defaults to 1): Number of layers per block.
|
||||
latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
|
||||
sample_size (`int`, *optional*, defaults to `32`): Sample input size.
|
||||
@@ -172,8 +172,8 @@ class AutoencoderKLTemporalDecoder(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
|
||||
block_out_channels: Tuple[int] = (64,),
|
||||
down_block_types: tuple[str] = ("DownEncoderBlock2D",),
|
||||
block_out_channels: tuple[int] = (64,),
|
||||
layers_per_block: int = 1,
|
||||
latent_channels: int = 4,
|
||||
sample_size: int = 32,
|
||||
@@ -204,7 +204,7 @@ class AutoencoderKLTemporalDecoder(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
def attn_processors(self) -> dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
@@ -213,7 +213,7 @@ class AutoencoderKLTemporalDecoder(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
@@ -228,7 +228,7 @@ class AutoencoderKLTemporalDecoder(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
def set_attn_processor(self, processor: AttentionProcessor | dict[str, AttentionProcessor]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
@@ -278,7 +278,7 @@ class AutoencoderKLTemporalDecoder(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
@apply_forward_hook
|
||||
def encode(
|
||||
self, x: torch.Tensor, return_dict: bool = True
|
||||
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
|
||||
) -> AutoencoderKLOutput | tuple[DiagonalGaussianDistribution]:
|
||||
"""
|
||||
Encode a batch of images into latents.
|
||||
|
||||
@@ -308,7 +308,7 @@ class AutoencoderKLTemporalDecoder(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
z: torch.Tensor,
|
||||
num_frames: int,
|
||||
return_dict: bool = True,
|
||||
) -> Union[DecoderOutput, torch.Tensor]:
|
||||
) -> DecoderOutput | torch.Tensor:
|
||||
"""
|
||||
Decode a batch of images.
|
||||
|
||||
@@ -339,7 +339,7 @@ class AutoencoderKLTemporalDecoder(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
return_dict: bool = True,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
num_frames: int = 1,
|
||||
) -> Union[DecoderOutput, torch.Tensor]:
|
||||
) -> DecoderOutput | torch.Tensor:
|
||||
r"""
|
||||
Args:
|
||||
sample (`torch.Tensor`): Input sample.
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -149,9 +149,9 @@ class WanCausalConv3d(nn.Conv3d):
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: Union[int, Tuple[int, int, int]],
|
||||
stride: Union[int, Tuple[int, int, int]] = 1,
|
||||
padding: Union[int, Tuple[int, int, int]] = 0,
|
||||
kernel_size: int | tuple[int, int, int],
|
||||
stride: int | tuple[int, int, int] = 1,
|
||||
padding: int | tuple[int, int, int] = 0,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
in_channels=in_channels,
|
||||
@@ -971,12 +971,12 @@ class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMo
|
||||
base_dim: int = 96,
|
||||
decoder_base_dim: Optional[int] = None,
|
||||
z_dim: int = 16,
|
||||
dim_mult: Tuple[int] = [1, 2, 4, 4],
|
||||
dim_mult: tuple[int] = [1, 2, 4, 4],
|
||||
num_res_blocks: int = 2,
|
||||
attn_scales: List[float] = [],
|
||||
temperal_downsample: List[bool] = [False, True, True],
|
||||
attn_scales: list[float] = [],
|
||||
temperal_downsample: list[bool] = [False, True, True],
|
||||
dropout: float = 0.0,
|
||||
latents_mean: List[float] = [
|
||||
latents_mean: list[float] = [
|
||||
-0.7571,
|
||||
-0.7089,
|
||||
-0.9113,
|
||||
@@ -994,7 +994,7 @@ class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMo
|
||||
0.2503,
|
||||
-0.2921,
|
||||
],
|
||||
latents_std: List[float] = [
|
||||
latents_std: list[float] = [
|
||||
2.8184,
|
||||
1.4541,
|
||||
2.3275,
|
||||
@@ -1153,7 +1153,7 @@ class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMo
|
||||
@apply_forward_hook
|
||||
def encode(
|
||||
self, x: torch.Tensor, return_dict: bool = True
|
||||
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
|
||||
) -> AutoencoderKLOutput | tuple[DiagonalGaussianDistribution]:
|
||||
r"""
|
||||
Encode a batch of images into latents.
|
||||
|
||||
@@ -1209,7 +1209,7 @@ class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMo
|
||||
return DecoderOutput(sample=out)
|
||||
|
||||
@apply_forward_hook
|
||||
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
def decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor:
|
||||
r"""
|
||||
Decode a batch of images.
|
||||
|
||||
@@ -1315,7 +1315,7 @@ class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMo
|
||||
enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width]
|
||||
return enc
|
||||
|
||||
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor:
|
||||
r"""
|
||||
Decode a batch of images using a tiled decoder.
|
||||
|
||||
@@ -1399,7 +1399,7 @@ class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMo
|
||||
sample_posterior: bool = False,
|
||||
return_dict: bool = True,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
) -> Union[DecoderOutput, torch.Tensor]:
|
||||
) -> DecoderOutput | torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
sample (`torch.Tensor`): Input sample.
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -303,9 +303,9 @@ class AutoencoderOobleck(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
Parameters:
|
||||
encoder_hidden_size (`int`, *optional*, defaults to 128):
|
||||
Intermediate representation dimension for the encoder.
|
||||
downsampling_ratios (`List[int]`, *optional*, defaults to `[2, 4, 4, 8, 8]`):
|
||||
downsampling_ratios (`list[int]`, *optional*, defaults to `[2, 4, 4, 8, 8]`):
|
||||
Ratios for downsampling in the encoder. These are used in reverse order for upsampling in the decoder.
|
||||
channel_multiples (`List[int]`, *optional*, defaults to `[1, 2, 4, 8, 16]`):
|
||||
channel_multiples (`list[int]`, *optional*, defaults to `[1, 2, 4, 8, 16]`):
|
||||
Multiples used to determine the hidden sizes of the hidden layers.
|
||||
decoder_channels (`int`, *optional*, defaults to 128):
|
||||
Intermediate representation dimension for the decoder.
|
||||
@@ -360,7 +360,7 @@ class AutoencoderOobleck(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
@apply_forward_hook
|
||||
def encode(
|
||||
self, x: torch.Tensor, return_dict: bool = True
|
||||
) -> Union[AutoencoderOobleckOutput, Tuple[OobleckDiagonalGaussianDistribution]]:
|
||||
) -> AutoencoderOobleckOutput | tuple[OobleckDiagonalGaussianDistribution]:
|
||||
"""
|
||||
Encode a batch of images into latents.
|
||||
|
||||
@@ -386,7 +386,7 @@ class AutoencoderOobleck(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
|
||||
return AutoencoderOobleckOutput(latent_dist=posterior)
|
||||
|
||||
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[OobleckDecoderOutput, torch.Tensor]:
|
||||
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> OobleckDecoderOutput | torch.Tensor:
|
||||
dec = self.decoder(z)
|
||||
|
||||
if not return_dict:
|
||||
@@ -397,7 +397,7 @@ class AutoencoderOobleck(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
@apply_forward_hook
|
||||
def decode(
|
||||
self, z: torch.FloatTensor, return_dict: bool = True, generator=None
|
||||
) -> Union[OobleckDecoderOutput, torch.FloatTensor]:
|
||||
) -> OobleckDecoderOutput | torch.FloatTensor:
|
||||
"""
|
||||
Decode a batch of images.
|
||||
|
||||
@@ -429,7 +429,7 @@ class AutoencoderOobleck(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
sample_posterior: bool = False,
|
||||
return_dict: bool = True,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
) -> Union[OobleckDecoderOutput, torch.Tensor]:
|
||||
) -> OobleckDecoderOutput | torch.Tensor:
|
||||
r"""
|
||||
Args:
|
||||
sample (`torch.Tensor`): Input sample.
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -50,11 +50,11 @@ class AutoencoderTiny(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
Parameters:
|
||||
in_channels (`int`, *optional*, defaults to 3): Number of channels in the input image.
|
||||
out_channels (`int`, *optional*, defaults to 3): Number of channels in the output.
|
||||
encoder_block_out_channels (`Tuple[int]`, *optional*, defaults to `(64, 64, 64, 64)`):
|
||||
Tuple of integers representing the number of output channels for each encoder block. The length of the
|
||||
encoder_block_out_channels (`tuple[int]`, *optional*, defaults to `(64, 64, 64, 64)`):
|
||||
tuple of integers representing the number of output channels for each encoder block. The length of the
|
||||
tuple should be equal to the number of encoder blocks.
|
||||
decoder_block_out_channels (`Tuple[int]`, *optional*, defaults to `(64, 64, 64, 64)`):
|
||||
Tuple of integers representing the number of output channels for each decoder block. The length of the
|
||||
decoder_block_out_channels (`tuple[int]`, *optional*, defaults to `(64, 64, 64, 64)`):
|
||||
tuple of integers representing the number of output channels for each decoder block. The length of the
|
||||
tuple should be equal to the number of decoder blocks.
|
||||
act_fn (`str`, *optional*, defaults to `"relu"`):
|
||||
Activation function to be used throughout the model.
|
||||
@@ -64,12 +64,12 @@ class AutoencoderTiny(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
upsampling_scaling_factor (`int`, *optional*, defaults to 2):
|
||||
Scaling factor for upsampling in the decoder. It determines the size of the output image during the
|
||||
upsampling process.
|
||||
num_encoder_blocks (`Tuple[int]`, *optional*, defaults to `(1, 3, 3, 3)`):
|
||||
Tuple of integers representing the number of encoder blocks at each stage of the encoding process. The
|
||||
num_encoder_blocks (`tuple[int]`, *optional*, defaults to `(1, 3, 3, 3)`):
|
||||
tuple of integers representing the number of encoder blocks at each stage of the encoding process. The
|
||||
length of the tuple should be equal to the number of stages in the encoder. Each stage has a different
|
||||
number of encoder blocks.
|
||||
num_decoder_blocks (`Tuple[int]`, *optional*, defaults to `(3, 3, 3, 1)`):
|
||||
Tuple of integers representing the number of decoder blocks at each stage of the decoding process. The
|
||||
num_decoder_blocks (`tuple[int]`, *optional*, defaults to `(3, 3, 3, 1)`):
|
||||
tuple of integers representing the number of decoder blocks at each stage of the decoding process. The
|
||||
length of the tuple should be equal to the number of stages in the decoder. Each stage has a different
|
||||
number of decoder blocks.
|
||||
latent_magnitude (`float`, *optional*, defaults to 3.0):
|
||||
@@ -99,14 +99,14 @@ class AutoencoderTiny(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
encoder_block_out_channels: Tuple[int, ...] = (64, 64, 64, 64),
|
||||
decoder_block_out_channels: Tuple[int, ...] = (64, 64, 64, 64),
|
||||
encoder_block_out_channels: tuple[int, ...] = (64, 64, 64, 64),
|
||||
decoder_block_out_channels: tuple[int, ...] = (64, 64, 64, 64),
|
||||
act_fn: str = "relu",
|
||||
upsample_fn: str = "nearest",
|
||||
latent_channels: int = 4,
|
||||
upsampling_scaling_factor: int = 2,
|
||||
num_encoder_blocks: Tuple[int, ...] = (1, 3, 3, 3),
|
||||
num_decoder_blocks: Tuple[int, ...] = (3, 3, 3, 1),
|
||||
num_encoder_blocks: tuple[int, ...] = (1, 3, 3, 3),
|
||||
num_decoder_blocks: tuple[int, ...] = (3, 3, 3, 1),
|
||||
latent_magnitude: int = 3,
|
||||
latent_shift: float = 0.5,
|
||||
force_upcast: bool = False,
|
||||
@@ -258,7 +258,7 @@ class AutoencoderTiny(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
return out
|
||||
|
||||
@apply_forward_hook
|
||||
def encode(self, x: torch.Tensor, return_dict: bool = True) -> Union[AutoencoderTinyOutput, Tuple[torch.Tensor]]:
|
||||
def encode(self, x: torch.Tensor, return_dict: bool = True) -> AutoencoderTinyOutput | tuple[torch.Tensor]:
|
||||
if self.use_slicing and x.shape[0] > 1:
|
||||
output = [
|
||||
self._tiled_encode(x_slice) if self.use_tiling else self.encoder(x_slice) for x_slice in x.split(1)
|
||||
@@ -275,7 +275,7 @@ class AutoencoderTiny(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
@apply_forward_hook
|
||||
def decode(
|
||||
self, x: torch.Tensor, generator: Optional[torch.Generator] = None, return_dict: bool = True
|
||||
) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
|
||||
) -> DecoderOutput | tuple[torch.Tensor]:
|
||||
if self.use_slicing and x.shape[0] > 1:
|
||||
output = [
|
||||
self._tiled_decode(x_slice) if self.use_tiling else self.decoder(x_slice) for x_slice in x.split(1)
|
||||
@@ -293,7 +293,7 @@ class AutoencoderTiny(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
return_dict: bool = True,
|
||||
) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
|
||||
) -> DecoderOutput | tuple[torch.Tensor]:
|
||||
r"""
|
||||
Args:
|
||||
sample (`torch.Tensor`): Input sample.
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -77,9 +77,9 @@ class ConsistencyDecoderVAE(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
latent_channels: int = 4,
|
||||
sample_size: int = 32,
|
||||
encoder_act_fn: str = "silu",
|
||||
encoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
|
||||
encoder_block_out_channels: tuple[int, ...] = (128, 256, 512, 512),
|
||||
encoder_double_z: bool = True,
|
||||
encoder_down_block_types: Tuple[str, ...] = (
|
||||
encoder_down_block_types: tuple[str, ...] = (
|
||||
"DownEncoderBlock2D",
|
||||
"DownEncoderBlock2D",
|
||||
"DownEncoderBlock2D",
|
||||
@@ -90,8 +90,8 @@ class ConsistencyDecoderVAE(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
encoder_norm_num_groups: int = 32,
|
||||
encoder_out_channels: int = 4,
|
||||
decoder_add_attention: bool = False,
|
||||
decoder_block_out_channels: Tuple[int, ...] = (320, 640, 1024, 1024),
|
||||
decoder_down_block_types: Tuple[str, ...] = (
|
||||
decoder_block_out_channels: tuple[int, ...] = (320, 640, 1024, 1024),
|
||||
decoder_down_block_types: tuple[str, ...] = (
|
||||
"ResnetDownsampleBlock2D",
|
||||
"ResnetDownsampleBlock2D",
|
||||
"ResnetDownsampleBlock2D",
|
||||
@@ -106,7 +106,7 @@ class ConsistencyDecoderVAE(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
decoder_out_channels: int = 6,
|
||||
decoder_resnet_time_scale_shift: str = "scale_shift",
|
||||
decoder_time_embedding_type: str = "learned",
|
||||
decoder_up_block_types: Tuple[str, ...] = (
|
||||
decoder_up_block_types: tuple[str, ...] = (
|
||||
"ResnetUpsampleBlock2D",
|
||||
"ResnetUpsampleBlock2D",
|
||||
"ResnetUpsampleBlock2D",
|
||||
@@ -169,7 +169,7 @@ class ConsistencyDecoderVAE(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
def attn_processors(self) -> dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
@@ -178,7 +178,7 @@ class ConsistencyDecoderVAE(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
@@ -193,7 +193,7 @@ class ConsistencyDecoderVAE(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
def set_attn_processor(self, processor: AttentionProcessor | dict[str, AttentionProcessor]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
@@ -246,7 +246,7 @@ class ConsistencyDecoderVAE(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
@apply_forward_hook
|
||||
def encode(
|
||||
self, x: torch.Tensor, return_dict: bool = True
|
||||
) -> Union[ConsistencyDecoderVAEOutput, Tuple[DiagonalGaussianDistribution]]:
|
||||
) -> ConsistencyDecoderVAEOutput | tuple[DiagonalGaussianDistribution]:
|
||||
"""
|
||||
Encode a batch of images into latents.
|
||||
|
||||
@@ -285,7 +285,7 @@ class ConsistencyDecoderVAE(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
generator: Optional[torch.Generator] = None,
|
||||
return_dict: bool = True,
|
||||
num_inference_steps: int = 2,
|
||||
) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
|
||||
) -> DecoderOutput | tuple[torch.Tensor]:
|
||||
"""
|
||||
Decodes the input latent vector `z` using the consistency decoder VAE model.
|
||||
|
||||
@@ -296,7 +296,7 @@ class ConsistencyDecoderVAE(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
num_inference_steps (int): The number of inference steps. Default is 2.
|
||||
|
||||
Returns:
|
||||
Union[DecoderOutput, Tuple[torch.Tensor]]: The decoded output.
|
||||
Union[DecoderOutput, tuple[torch.Tensor]]: The decoded output.
|
||||
|
||||
"""
|
||||
z = (z * self.config.scaling_factor - self.means) / self.stds
|
||||
@@ -339,7 +339,7 @@ class ConsistencyDecoderVAE(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
|
||||
return b
|
||||
|
||||
def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> Union[ConsistencyDecoderVAEOutput, Tuple]:
|
||||
def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> ConsistencyDecoderVAEOutput | tuple:
|
||||
r"""Encode a batch of images using a tiled encoder.
|
||||
|
||||
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
|
||||
@@ -400,7 +400,7 @@ class ConsistencyDecoderVAE(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
sample_posterior: bool = False,
|
||||
return_dict: bool = True,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
|
||||
) -> DecoderOutput | tuple[torch.Tensor]:
|
||||
r"""
|
||||
Args:
|
||||
sample (`torch.Tensor`): Input sample.
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -66,10 +66,10 @@ class Encoder(nn.Module):
|
||||
The number of input channels.
|
||||
out_channels (`int`, *optional*, defaults to 3):
|
||||
The number of output channels.
|
||||
down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
|
||||
down_block_types (`tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
|
||||
The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available
|
||||
options.
|
||||
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
|
||||
block_out_channels (`tuple[int, ...]`, *optional*, defaults to `(64,)`):
|
||||
The number of output channels for each block.
|
||||
layers_per_block (`int`, *optional*, defaults to 2):
|
||||
The number of layers per block.
|
||||
@@ -85,8 +85,8 @@ class Encoder(nn.Module):
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",),
|
||||
block_out_channels: Tuple[int, ...] = (64,),
|
||||
down_block_types: tuple[str, ...] = ("DownEncoderBlock2D",),
|
||||
block_out_channels: tuple[int, ...] = (64,),
|
||||
layers_per_block: int = 2,
|
||||
norm_num_groups: int = 32,
|
||||
act_fn: str = "silu",
|
||||
@@ -187,9 +187,9 @@ class Decoder(nn.Module):
|
||||
The number of input channels.
|
||||
out_channels (`int`, *optional*, defaults to 3):
|
||||
The number of output channels.
|
||||
up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
|
||||
up_block_types (`tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
|
||||
The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options.
|
||||
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
|
||||
block_out_channels (`tuple[int, ...]`, *optional*, defaults to `(64,)`):
|
||||
The number of output channels for each block.
|
||||
layers_per_block (`int`, *optional*, defaults to 2):
|
||||
The number of layers per block.
|
||||
@@ -205,8 +205,8 @@ class Decoder(nn.Module):
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",),
|
||||
block_out_channels: Tuple[int, ...] = (64,),
|
||||
up_block_types: tuple[str, ...] = ("UpDecoderBlock2D",),
|
||||
block_out_channels: tuple[int, ...] = (64,),
|
||||
layers_per_block: int = 2,
|
||||
norm_num_groups: int = 32,
|
||||
act_fn: str = "silu",
|
||||
@@ -286,11 +286,9 @@ class Decoder(nn.Module):
|
||||
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
# middle
|
||||
sample = self._gradient_checkpointing_func(self.mid_block, sample, latent_embeds)
|
||||
sample = sample.to(upscale_dtype)
|
||||
|
||||
# up
|
||||
for up_block in self.up_blocks:
|
||||
@@ -298,7 +296,6 @@ class Decoder(nn.Module):
|
||||
else:
|
||||
# middle
|
||||
sample = self.mid_block(sample, latent_embeds)
|
||||
sample = sample.to(upscale_dtype)
|
||||
|
||||
# up
|
||||
for up_block in self.up_blocks:
|
||||
@@ -405,9 +402,9 @@ class MaskConditionDecoder(nn.Module):
|
||||
The number of input channels.
|
||||
out_channels (`int`, *optional*, defaults to 3):
|
||||
The number of output channels.
|
||||
up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
|
||||
up_block_types (`tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
|
||||
The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options.
|
||||
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
|
||||
block_out_channels (`tuple[int, ...]`, *optional*, defaults to `(64,)`):
|
||||
The number of output channels for each block.
|
||||
layers_per_block (`int`, *optional*, defaults to 2):
|
||||
The number of layers per block.
|
||||
@@ -423,8 +420,8 @@ class MaskConditionDecoder(nn.Module):
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",),
|
||||
block_out_channels: Tuple[int, ...] = (64,),
|
||||
up_block_types: tuple[str, ...] = ("UpDecoderBlock2D",),
|
||||
block_out_channels: tuple[int, ...] = (64,),
|
||||
layers_per_block: int = 2,
|
||||
norm_num_groups: int = 32,
|
||||
act_fn: str = "silu",
|
||||
@@ -636,7 +633,7 @@ class VectorQuantizer(nn.Module):
|
||||
back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
|
||||
return back.reshape(ishape)
|
||||
|
||||
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, Tuple]:
|
||||
def forward(self, z: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, tuple]:
|
||||
# reshape z -> (batch, height, width, channel) and flatten
|
||||
z = z.permute(0, 2, 3, 1).contiguous()
|
||||
z_flattened = z.view(-1, self.vq_embed_dim)
|
||||
@@ -670,7 +667,7 @@ class VectorQuantizer(nn.Module):
|
||||
|
||||
return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
|
||||
|
||||
def get_codebook_entry(self, indices: torch.LongTensor, shape: Tuple[int, ...]) -> torch.Tensor:
|
||||
def get_codebook_entry(self, indices: torch.LongTensor, shape: tuple[int, ...]) -> torch.Tensor:
|
||||
# shape specifying (batch, height, width, channel)
|
||||
if self.remap is not None:
|
||||
indices = indices.reshape(shape[0], -1) # add batch axis
|
||||
@@ -731,7 +728,7 @@ class DiagonalGaussianDistribution(object):
|
||||
dim=[1, 2, 3],
|
||||
)
|
||||
|
||||
def nll(self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]) -> torch.Tensor:
|
||||
def nll(self, sample: torch.Tensor, dims: tuple[int, ...] = [1, 2, 3]) -> torch.Tensor:
|
||||
if self.deterministic:
|
||||
return torch.Tensor([0.0])
|
||||
logtwopi = np.log(2.0 * np.pi)
|
||||
@@ -764,10 +761,10 @@ class EncoderTiny(nn.Module):
|
||||
The number of input channels.
|
||||
out_channels (`int`):
|
||||
The number of output channels.
|
||||
num_blocks (`Tuple[int, ...]`):
|
||||
num_blocks (`tuple[int, ...]`):
|
||||
Each value of the tuple represents a Conv2d layer followed by `value` number of `AutoencoderTinyBlock`'s to
|
||||
use.
|
||||
block_out_channels (`Tuple[int, ...]`):
|
||||
block_out_channels (`tuple[int, ...]`):
|
||||
The number of output channels for each block.
|
||||
act_fn (`str`):
|
||||
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
|
||||
@@ -777,8 +774,8 @@ class EncoderTiny(nn.Module):
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
num_blocks: Tuple[int, ...],
|
||||
block_out_channels: Tuple[int, ...],
|
||||
num_blocks: tuple[int, ...],
|
||||
block_out_channels: tuple[int, ...],
|
||||
act_fn: str,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -830,10 +827,10 @@ class DecoderTiny(nn.Module):
|
||||
The number of input channels.
|
||||
out_channels (`int`):
|
||||
The number of output channels.
|
||||
num_blocks (`Tuple[int, ...]`):
|
||||
num_blocks (`tuple[int, ...]`):
|
||||
Each value of the tuple represents a Conv2d layer followed by `value` number of `AutoencoderTinyBlock`'s to
|
||||
use.
|
||||
block_out_channels (`Tuple[int, ...]`):
|
||||
block_out_channels (`tuple[int, ...]`):
|
||||
The number of output channels for each block.
|
||||
upsampling_scaling_factor (`int`):
|
||||
The scaling factor to use for upsampling.
|
||||
@@ -845,8 +842,8 @@ class DecoderTiny(nn.Module):
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
num_blocks: Tuple[int, ...],
|
||||
block_out_channels: Tuple[int, ...],
|
||||
num_blocks: tuple[int, ...],
|
||||
block_out_channels: tuple[int, ...],
|
||||
upsampling_scaling_factor: int,
|
||||
act_fn: str,
|
||||
upsample_fn: str,
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -48,12 +48,12 @@ class VQModel(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
Parameters:
|
||||
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
|
||||
out_channels (int, *optional*, defaults to 3): Number of channels in the output.
|
||||
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
|
||||
Tuple of downsample block types.
|
||||
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
|
||||
Tuple of upsample block types.
|
||||
block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
|
||||
Tuple of block output channels.
|
||||
down_block_types (`tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
|
||||
tuple of downsample block types.
|
||||
up_block_types (`tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
|
||||
tuple of upsample block types.
|
||||
block_out_channels (`tuple[int]`, *optional*, defaults to `(64,)`):
|
||||
tuple of block output channels.
|
||||
layers_per_block (`int`, *optional*, defaults to `1`): Number of layers per block.
|
||||
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
||||
latent_channels (`int`, *optional*, defaults to `3`): Number of channels in the latent space.
|
||||
@@ -80,9 +80,9 @@ class VQModel(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",),
|
||||
up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",),
|
||||
block_out_channels: Tuple[int, ...] = (64,),
|
||||
down_block_types: tuple[str, ...] = ("DownEncoderBlock2D",),
|
||||
up_block_types: tuple[str, ...] = ("UpDecoderBlock2D",),
|
||||
block_out_channels: tuple[int, ...] = (64,),
|
||||
layers_per_block: int = 1,
|
||||
act_fn: str = "silu",
|
||||
latent_channels: int = 3,
|
||||
@@ -143,7 +143,7 @@ class VQModel(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
@apply_forward_hook
|
||||
def decode(
|
||||
self, h: torch.Tensor, force_not_quantize: bool = False, return_dict: bool = True, shape=None
|
||||
) -> Union[DecoderOutput, torch.Tensor]:
|
||||
) -> DecoderOutput | torch.Tensor:
|
||||
# also go through quantization layer
|
||||
if not force_not_quantize:
|
||||
quant, commit_loss, _ = self.quantize(h)
|
||||
@@ -161,9 +161,7 @@ class VQModel(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
|
||||
return DecoderOutput(sample=dec, commit_loss=commit_loss)
|
||||
|
||||
def forward(
|
||||
self, sample: torch.Tensor, return_dict: bool = True
|
||||
) -> Union[DecoderOutput, Tuple[torch.Tensor, ...]]:
|
||||
def forward(self, sample: torch.Tensor, return_dict: bool = True) -> DecoderOutput | tuple[torch.Tensor, ...]:
|
||||
r"""
|
||||
The [`VQModel`] forward method.
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import Optional
|
||||
|
||||
from ..utils import deprecate
|
||||
from .controlnets.controlnet import ( # noqa
|
||||
@@ -36,15 +36,15 @@ class ControlNetModel(ControlNetModel):
|
||||
conditioning_channels: int = 3,
|
||||
flip_sin_to_cos: bool = True,
|
||||
freq_shift: int = 0,
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
down_block_types: tuple[str, ...] = (
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"DownBlock2D",
|
||||
),
|
||||
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
|
||||
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
||||
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
||||
only_cross_attention: bool | tuple[bool] = False,
|
||||
block_out_channels: tuple[int, ...] = (320, 640, 1280, 1280),
|
||||
layers_per_block: int = 2,
|
||||
downsample_padding: int = 1,
|
||||
mid_block_scale_factor: float = 1,
|
||||
@@ -52,11 +52,11 @@ class ControlNetModel(ControlNetModel):
|
||||
norm_num_groups: Optional[int] = 32,
|
||||
norm_eps: float = 1e-5,
|
||||
cross_attention_dim: int = 1280,
|
||||
transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
|
||||
transformer_layers_per_block: int | tuple[int, ...] = 1,
|
||||
encoder_hid_dim: Optional[int] = None,
|
||||
encoder_hid_dim_type: Optional[str] = None,
|
||||
attention_head_dim: Union[int, Tuple[int, ...]] = 8,
|
||||
num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
attention_head_dim: int | tuple[int, ...] = 8,
|
||||
num_attention_heads: Optional[int | tuple[int, ...]] = None,
|
||||
use_linear_projection: bool = False,
|
||||
class_embed_type: Optional[str] = None,
|
||||
addition_embed_type: Optional[str] = None,
|
||||
@@ -66,7 +66,7 @@ class ControlNetModel(ControlNetModel):
|
||||
resnet_time_scale_shift: str = "default",
|
||||
projection_class_embeddings_input_dim: Optional[int] = None,
|
||||
controlnet_conditioning_channel_order: str = "rgb",
|
||||
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
|
||||
conditioning_embedding_out_channels: Optional[tuple[int, ...]] = (16, 32, 96, 256),
|
||||
global_pool_conditions: bool = False,
|
||||
addition_embed_type_num_heads: int = 64,
|
||||
):
|
||||
|
||||
@@ -13,8 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import List
|
||||
|
||||
from ..utils import deprecate, logging
|
||||
from .controlnets.controlnet_flux import FluxControlNetModel, FluxControlNetOutput, FluxMultiControlNetModel
|
||||
|
||||
@@ -41,7 +39,7 @@ class FluxControlNetModel(FluxControlNetModel):
|
||||
joint_attention_dim: int = 4096,
|
||||
pooled_projection_dim: int = 768,
|
||||
guidance_embeds: bool = False,
|
||||
axes_dims_rope: List[int] = [16, 56, 56],
|
||||
axes_dims_rope: list[int] = [16, 56, 56],
|
||||
num_mode: int = None,
|
||||
conditioning_embedding_channels: int = None,
|
||||
):
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import Optional
|
||||
|
||||
from ..utils import deprecate, logging
|
||||
from .controlnets.controlnet_sparsectrl import ( # noqa
|
||||
@@ -50,14 +50,14 @@ class SparseControlNetModel(SparseControlNetModel):
|
||||
conditioning_channels: int = 4,
|
||||
flip_sin_to_cos: bool = True,
|
||||
freq_shift: int = 0,
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
down_block_types: tuple[str, ...] = (
|
||||
"CrossAttnDownBlockMotion",
|
||||
"CrossAttnDownBlockMotion",
|
||||
"CrossAttnDownBlockMotion",
|
||||
"DownBlockMotion",
|
||||
),
|
||||
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
||||
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
||||
only_cross_attention: bool | tuple[bool] = False,
|
||||
block_out_channels: tuple[int, ...] = (320, 640, 1280, 1280),
|
||||
layers_per_block: int = 2,
|
||||
downsample_padding: int = 1,
|
||||
mid_block_scale_factor: float = 1,
|
||||
@@ -65,15 +65,15 @@ class SparseControlNetModel(SparseControlNetModel):
|
||||
norm_num_groups: Optional[int] = 32,
|
||||
norm_eps: float = 1e-5,
|
||||
cross_attention_dim: int = 768,
|
||||
transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
|
||||
transformer_layers_per_mid_block: Optional[Union[int, Tuple[int]]] = None,
|
||||
temporal_transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
|
||||
attention_head_dim: Union[int, Tuple[int, ...]] = 8,
|
||||
num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
transformer_layers_per_block: int | tuple[int, ...] = 1,
|
||||
transformer_layers_per_mid_block: Optional[int | tuple[int]] = None,
|
||||
temporal_transformer_layers_per_block: int | tuple[int, ...] = 1,
|
||||
attention_head_dim: int | tuple[int, ...] = 8,
|
||||
num_attention_heads: Optional[int | tuple[int, ...]] = None,
|
||||
use_linear_projection: bool = False,
|
||||
upcast_attention: bool = False,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
|
||||
conditioning_embedding_out_channels: Optional[tuple[int, ...]] = (16, 32, 96, 256),
|
||||
global_pool_conditions: bool = False,
|
||||
controlnet_conditioning_channel_order: str = "rgb",
|
||||
motion_max_seq_length: int = 32,
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -57,7 +57,7 @@ class ControlNetOutput(BaseOutput):
|
||||
Output can be used to condition the original UNet's middle block activation.
|
||||
"""
|
||||
|
||||
down_block_res_samples: Tuple[torch.Tensor]
|
||||
down_block_res_samples: tuple[torch.Tensor]
|
||||
mid_block_res_sample: torch.Tensor
|
||||
|
||||
|
||||
@@ -75,7 +75,7 @@ class ControlNetConditioningEmbedding(nn.Module):
|
||||
self,
|
||||
conditioning_embedding_channels: int,
|
||||
conditioning_channels: int = 3,
|
||||
block_out_channels: Tuple[int, ...] = (16, 32, 96, 256),
|
||||
block_out_channels: tuple[int, ...] = (16, 32, 96, 256),
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -119,7 +119,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
The frequency shift to apply to the time embedding.
|
||||
down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
|
||||
The tuple of downsample blocks to use.
|
||||
only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
|
||||
only_cross_attention (`Union[bool, tuple[bool]]`, defaults to `False`):
|
||||
block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
|
||||
The tuple of output channels for each block.
|
||||
layers_per_block (`int`, defaults to 2):
|
||||
@@ -137,7 +137,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
The epsilon to use for the normalization.
|
||||
cross_attention_dim (`int`, defaults to 1280):
|
||||
The dimension of the cross attention features.
|
||||
transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
|
||||
transformer_layers_per_block (`int` or `tuple[int]`, *optional*, defaults to 1):
|
||||
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
|
||||
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
|
||||
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
|
||||
@@ -147,7 +147,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
|
||||
If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
|
||||
embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
|
||||
attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
|
||||
attention_head_dim (`Union[int, tuple[int]]`, defaults to 8):
|
||||
The dimension of the attention heads.
|
||||
use_linear_projection (`bool`, defaults to `False`):
|
||||
class_embed_type (`str`, *optional*, defaults to `None`):
|
||||
@@ -184,15 +184,15 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
conditioning_channels: int = 3,
|
||||
flip_sin_to_cos: bool = True,
|
||||
freq_shift: int = 0,
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
down_block_types: tuple[str, ...] = (
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"DownBlock2D",
|
||||
),
|
||||
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
|
||||
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
||||
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
||||
only_cross_attention: bool | tuple[bool] = False,
|
||||
block_out_channels: tuple[int, ...] = (320, 640, 1280, 1280),
|
||||
layers_per_block: int = 2,
|
||||
downsample_padding: int = 1,
|
||||
mid_block_scale_factor: float = 1,
|
||||
@@ -200,11 +200,11 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
norm_num_groups: Optional[int] = 32,
|
||||
norm_eps: float = 1e-5,
|
||||
cross_attention_dim: int = 1280,
|
||||
transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
|
||||
transformer_layers_per_block: int | tuple[int, ...] = 1,
|
||||
encoder_hid_dim: Optional[int] = None,
|
||||
encoder_hid_dim_type: Optional[str] = None,
|
||||
attention_head_dim: Union[int, Tuple[int, ...]] = 8,
|
||||
num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
attention_head_dim: int | tuple[int, ...] = 8,
|
||||
num_attention_heads: Optional[int | tuple[int, ...]] = None,
|
||||
use_linear_projection: bool = False,
|
||||
class_embed_type: Optional[str] = None,
|
||||
addition_embed_type: Optional[str] = None,
|
||||
@@ -214,7 +214,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
resnet_time_scale_shift: str = "default",
|
||||
projection_class_embeddings_input_dim: Optional[int] = None,
|
||||
controlnet_conditioning_channel_order: str = "rgb",
|
||||
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
|
||||
conditioning_embedding_out_channels: Optional[tuple[int, ...]] = (16, 32, 96, 256),
|
||||
global_pool_conditions: bool = False,
|
||||
addition_embed_type_num_heads: int = 64,
|
||||
):
|
||||
@@ -444,7 +444,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
cls,
|
||||
unet: UNet2DConditionModel,
|
||||
controlnet_conditioning_channel_order: str = "rgb",
|
||||
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
|
||||
conditioning_embedding_out_channels: Optional[tuple[int, ...]] = (16, 32, 96, 256),
|
||||
load_weights_from_unet: bool = True,
|
||||
conditioning_channels: int = 3,
|
||||
):
|
||||
@@ -517,7 +517,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
def attn_processors(self) -> dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
@@ -526,7 +526,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
@@ -541,7 +541,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
def set_attn_processor(self, processor: AttentionProcessor | dict[str, AttentionProcessor]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
@@ -592,7 +592,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
self.set_attn_processor(processor)
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice
|
||||
def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None:
|
||||
def set_attention_slice(self, slice_size: str | int | list[int]) -> None:
|
||||
r"""
|
||||
Enable sliced attention computation.
|
||||
|
||||
@@ -646,7 +646,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
# Recursively walk through all the children.
|
||||
# Any children which exposes the set_attention_slice method
|
||||
# gets the message
|
||||
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
|
||||
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: list[int]):
|
||||
if hasattr(module, "set_attention_slice"):
|
||||
module.set_attention_slice(slice_size.pop())
|
||||
|
||||
@@ -660,18 +660,18 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
timestep: torch.Tensor | float | int,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
controlnet_cond: torch.Tensor,
|
||||
conditioning_scale: float = 1.0,
|
||||
class_labels: Optional[torch.Tensor] = None,
|
||||
timestep_cond: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
added_cond_kwargs: Optional[dict[str, torch.Tensor]] = None,
|
||||
cross_attention_kwargs: Optional[dict[str, Any]] = None,
|
||||
guess_mode: bool = False,
|
||||
return_dict: bool = True,
|
||||
) -> Union[ControlNetOutput, Tuple[Tuple[torch.Tensor, ...], torch.Tensor]]:
|
||||
) -> ControlNetOutput | tuple[tuple[torch.Tensor, ...], torch.Tensor]:
|
||||
"""
|
||||
The [`ControlNetModel`] forward method.
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import Optional
|
||||
|
||||
import flax
|
||||
import flax.linen as nn
|
||||
@@ -49,7 +49,7 @@ class FlaxControlNetOutput(BaseOutput):
|
||||
|
||||
class FlaxControlNetConditioningEmbedding(nn.Module):
|
||||
conditioning_embedding_channels: int
|
||||
block_out_channels: Tuple[int, ...] = (16, 32, 96, 256)
|
||||
block_out_channels: tuple[int, ...] = (16, 32, 96, 256)
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self) -> None:
|
||||
@@ -132,15 +132,15 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
The size of the input sample.
|
||||
in_channels (`int`, *optional*, defaults to 4):
|
||||
The number of channels in the input sample.
|
||||
down_block_types (`Tuple[str]`, *optional*, defaults to `("FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxDownBlock2D")`):
|
||||
down_block_types (`tuple[str]`, *optional*, defaults to `("FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxDownBlock2D")`):
|
||||
The tuple of downsample blocks to use.
|
||||
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
|
||||
block_out_channels (`tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
|
||||
The tuple of output channels for each block.
|
||||
layers_per_block (`int`, *optional*, defaults to 2):
|
||||
The number of layers per block.
|
||||
attention_head_dim (`int` or `Tuple[int]`, *optional*, defaults to 8):
|
||||
attention_head_dim (`int` or `tuple[int]`, *optional*, defaults to 8):
|
||||
The dimension of the attention heads.
|
||||
num_attention_heads (`int` or `Tuple[int]`, *optional*):
|
||||
num_attention_heads (`int` or `tuple[int]`, *optional*):
|
||||
The number of attention heads.
|
||||
cross_attention_dim (`int`, *optional*, defaults to 768):
|
||||
The dimension of the cross attention features.
|
||||
@@ -157,17 +157,17 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
|
||||
sample_size: int = 32
|
||||
in_channels: int = 4
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
down_block_types: tuple[str, ...] = (
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"DownBlock2D",
|
||||
)
|
||||
only_cross_attention: Union[bool, Tuple[bool, ...]] = False
|
||||
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280)
|
||||
only_cross_attention: bool | tuple[bool, ...] = False
|
||||
block_out_channels: tuple[int, ...] = (320, 640, 1280, 1280)
|
||||
layers_per_block: int = 2
|
||||
attention_head_dim: Union[int, Tuple[int, ...]] = 8
|
||||
num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None
|
||||
attention_head_dim: int | tuple[int, ...] = 8
|
||||
num_attention_heads: Optional[int | tuple[int, ...]] = None
|
||||
cross_attention_dim: int = 1280
|
||||
dropout: float = 0.0
|
||||
use_linear_projection: bool = False
|
||||
@@ -175,7 +175,7 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
flip_sin_to_cos: bool = True
|
||||
freq_shift: int = 0
|
||||
controlnet_conditioning_channel_order: str = "rgb"
|
||||
conditioning_embedding_out_channels: Tuple[int, ...] = (16, 32, 96, 256)
|
||||
conditioning_embedding_out_channels: tuple[int, ...] = (16, 32, 96, 256)
|
||||
|
||||
def init_weights(self, rng: jax.Array) -> FrozenDict:
|
||||
# init input tensors
|
||||
@@ -327,13 +327,13 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
def __call__(
|
||||
self,
|
||||
sample: jnp.ndarray,
|
||||
timesteps: Union[jnp.ndarray, float, int],
|
||||
timesteps: jnp.ndarray | float | int,
|
||||
encoder_hidden_states: jnp.ndarray,
|
||||
controlnet_cond: jnp.ndarray,
|
||||
conditioning_scale: float = 1.0,
|
||||
return_dict: bool = True,
|
||||
train: bool = False,
|
||||
) -> Union[FlaxControlNetOutput, Tuple[Tuple[jnp.ndarray, ...], jnp.ndarray]]:
|
||||
) -> FlaxControlNetOutput | tuple[tuple[jnp.ndarray, ...], jnp.ndarray]:
|
||||
r"""
|
||||
Args:
|
||||
sample (`jnp.ndarray`): (batch, channel, height, width) noisy inputs tensor
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -34,8 +34,8 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@dataclass
|
||||
class FluxControlNetOutput(BaseOutput):
|
||||
controlnet_block_samples: Tuple[torch.Tensor]
|
||||
controlnet_single_block_samples: Tuple[torch.Tensor]
|
||||
controlnet_block_samples: tuple[torch.Tensor]
|
||||
controlnet_single_block_samples: tuple[torch.Tensor]
|
||||
|
||||
|
||||
class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
@@ -53,7 +53,7 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
joint_attention_dim: int = 4096,
|
||||
pooled_projection_dim: int = 768,
|
||||
guidance_embeds: bool = False,
|
||||
axes_dims_rope: List[int] = [16, 56, 56],
|
||||
axes_dims_rope: list[int] = [16, 56, 56],
|
||||
num_mode: int = None,
|
||||
conditioning_embedding_channels: int = None,
|
||||
):
|
||||
@@ -129,7 +129,7 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
@@ -222,9 +222,9 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
img_ids: torch.Tensor = None,
|
||||
txt_ids: torch.Tensor = None,
|
||||
guidance: torch.Tensor = None,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
joint_attention_kwargs: Optional[dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
|
||||
) -> torch.FloatTensor | Transformer2DModelOutput:
|
||||
"""
|
||||
The [`FluxTransformer2DModel`] forward method.
|
||||
|
||||
@@ -404,7 +404,7 @@ class FluxMultiControlNetModel(ModelMixin):
|
||||
compatible with `FluxControlNetModel`.
|
||||
|
||||
Args:
|
||||
controlnets (`List[FluxControlNetModel]`):
|
||||
controlnets (`list[FluxControlNetModel]`):
|
||||
Provides additional conditioning to the unet during the denoising process. You must set multiple
|
||||
`FluxControlNetModel` as a list.
|
||||
"""
|
||||
@@ -416,18 +416,18 @@ class FluxMultiControlNetModel(ModelMixin):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
controlnet_cond: List[torch.tensor],
|
||||
controlnet_mode: List[torch.tensor],
|
||||
conditioning_scale: List[float],
|
||||
controlnet_cond: list[torch.tensor],
|
||||
controlnet_mode: list[torch.tensor],
|
||||
conditioning_scale: list[float],
|
||||
encoder_hidden_states: torch.Tensor = None,
|
||||
pooled_projections: torch.Tensor = None,
|
||||
timestep: torch.LongTensor = None,
|
||||
img_ids: torch.Tensor = None,
|
||||
txt_ids: torch.Tensor = None,
|
||||
guidance: torch.Tensor = None,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
joint_attention_kwargs: Optional[dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[FluxControlNetOutput, Tuple]:
|
||||
) -> FluxControlNetOutput | tuple:
|
||||
# ControlNet-Union with multiple conditions
|
||||
# only load one ControlNet for saving memories
|
||||
if len(self.nets) == 1:
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional, Union
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -27,7 +27,7 @@ from ..embeddings import (
|
||||
)
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..transformers.hunyuan_transformer_2d import HunyuanDiTBlock
|
||||
from .controlnet import Tuple, zero_module
|
||||
from .controlnet import zero_module
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -35,7 +35,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@dataclass
|
||||
class HunyuanControlNetOutput(BaseOutput):
|
||||
controlnet_block_samples: Tuple[torch.Tensor]
|
||||
controlnet_block_samples: tuple[torch.Tensor]
|
||||
|
||||
|
||||
class HunyuanDiT2DControlNetModel(ModelMixin, ConfigMixin):
|
||||
@@ -116,7 +116,7 @@ class HunyuanDiT2DControlNetModel(ModelMixin, ConfigMixin):
|
||||
self.controlnet_blocks.append(controlnet_block)
|
||||
|
||||
@property
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
def attn_processors(self) -> dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
@@ -125,7 +125,7 @@ class HunyuanDiT2DControlNetModel(ModelMixin, ConfigMixin):
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
|
||||
|
||||
@@ -139,7 +139,7 @@ class HunyuanDiT2DControlNetModel(ModelMixin, ConfigMixin):
|
||||
|
||||
return processors
|
||||
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
def set_attn_processor(self, processor: AttentionProcessor | dict[str, AttentionProcessor]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
@@ -317,7 +317,7 @@ class HunyuanDiT2DMultiControlNetModel(ModelMixin):
|
||||
designed to be compatible with `HunyuanDiT2DControlNetModel`.
|
||||
|
||||
Args:
|
||||
controlnets (`List[HunyuanDiT2DControlNetModel]`):
|
||||
controlnets (`list[HunyuanDiT2DControlNetModel]`):
|
||||
Provides additional conditioning to the unet during the denoising process. You must set multiple
|
||||
`HunyuanDiT2DControlNetModel` as a list.
|
||||
"""
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -39,7 +39,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@dataclass
|
||||
class QwenImageControlNetOutput(BaseOutput):
|
||||
controlnet_block_samples: Tuple[torch.Tensor]
|
||||
controlnet_block_samples: tuple[torch.Tensor]
|
||||
|
||||
|
||||
class QwenImageControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
|
||||
@@ -55,7 +55,7 @@ class QwenImageControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOr
|
||||
attention_head_dim: int = 128,
|
||||
num_attention_heads: int = 24,
|
||||
joint_attention_dim: int = 3584,
|
||||
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
|
||||
axes_dims_rope: tuple[int, int, int] = (16, 56, 56),
|
||||
extra_condition_channels: int = 0, # for controlnet-inpainting
|
||||
):
|
||||
super().__init__()
|
||||
@@ -103,7 +103,7 @@ class QwenImageControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOr
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
@@ -188,11 +188,11 @@ class QwenImageControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOr
|
||||
encoder_hidden_states: torch.Tensor = None,
|
||||
encoder_hidden_states_mask: torch.Tensor = None,
|
||||
timestep: torch.LongTensor = None,
|
||||
img_shapes: Optional[List[Tuple[int, int, int]]] = None,
|
||||
txt_seq_lens: Optional[List[int]] = None,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
img_shapes: Optional[list[tuple[int, int, int]]] = None,
|
||||
txt_seq_lens: Optional[list[int]] = None,
|
||||
joint_attention_kwargs: Optional[dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
|
||||
) -> torch.FloatTensor | Transformer2DModelOutput:
|
||||
"""
|
||||
The [`FluxTransformer2DModel`] forward method.
|
||||
|
||||
@@ -303,7 +303,7 @@ class QwenImageMultiControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, F
|
||||
to be compatible with `QwenImageControlNetModel`.
|
||||
|
||||
Args:
|
||||
controlnets (`List[QwenImageControlNetModel]`):
|
||||
controlnets (`list[QwenImageControlNetModel]`):
|
||||
Provides additional conditioning to the unet during the denoising process. You must set multiple
|
||||
`QwenImageControlNetModel` as a list.
|
||||
"""
|
||||
@@ -315,16 +315,16 @@ class QwenImageMultiControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, F
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
controlnet_cond: List[torch.tensor],
|
||||
conditioning_scale: List[float],
|
||||
controlnet_cond: list[torch.tensor],
|
||||
conditioning_scale: list[float],
|
||||
encoder_hidden_states: torch.Tensor = None,
|
||||
encoder_hidden_states_mask: torch.Tensor = None,
|
||||
timestep: torch.LongTensor = None,
|
||||
img_shapes: Optional[List[Tuple[int, int, int]]] = None,
|
||||
txt_seq_lens: Optional[List[int]] = None,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
img_shapes: Optional[list[tuple[int, int, int]]] = None,
|
||||
txt_seq_lens: Optional[list[int]] = None,
|
||||
joint_attention_kwargs: Optional[dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[QwenImageControlNetOutput, Tuple]:
|
||||
) -> QwenImageControlNetOutput | tuple:
|
||||
# ControlNet-Union with multiple conditions
|
||||
# only load one ControlNet for saving memories
|
||||
if len(self.nets) == 1:
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -35,7 +35,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@dataclass
|
||||
class SanaControlNetOutput(BaseOutput):
|
||||
controlnet_block_samples: Tuple[torch.Tensor]
|
||||
controlnet_block_samples: tuple[torch.Tensor]
|
||||
|
||||
|
||||
class SanaControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
@@ -119,7 +119,7 @@ class SanaControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
def attn_processors(self) -> dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
@@ -128,7 +128,7 @@ class SanaControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
@@ -143,7 +143,7 @@ class SanaControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
def set_attn_processor(self, processor: AttentionProcessor | dict[str, AttentionProcessor]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
@@ -186,9 +186,9 @@ class SanaControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
conditioning_scale: float = 1.0,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
attention_kwargs: Optional[dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[Tuple[torch.Tensor, ...], Transformer2DModelOutput]:
|
||||
) -> tuple[torch.Tensor, ...] | Transformer2DModelOutput:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -36,7 +36,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@dataclass
|
||||
class SD3ControlNetOutput(BaseOutput):
|
||||
controlnet_block_samples: Tuple[torch.Tensor]
|
||||
controlnet_block_samples: tuple[torch.Tensor]
|
||||
|
||||
|
||||
class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
|
||||
@@ -69,7 +69,7 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
|
||||
The maximum latent height/width of positional embeddings.
|
||||
extra_conditioning_channels (`int`, defaults to `0`):
|
||||
The number of extra channels to use for conditioning for patch embedding.
|
||||
dual_attention_layers (`Tuple[int, ...]`, defaults to `()`):
|
||||
dual_attention_layers (`tuple[int, ...]`, defaults to `()`):
|
||||
The number of dual-stream transformer blocks to use.
|
||||
qk_norm (`str`, *optional*, defaults to `None`):
|
||||
The normalization to use for query and key in the attention layer. If `None`, no normalization is used.
|
||||
@@ -99,7 +99,7 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
|
||||
out_channels: int = 16,
|
||||
pos_embed_max_size: int = 96,
|
||||
extra_conditioning_channels: int = 0,
|
||||
dual_attention_layers: Tuple[int, ...] = (),
|
||||
dual_attention_layers: tuple[int, ...] = (),
|
||||
qk_norm: Optional[str] = None,
|
||||
pos_embed_type: Optional[str] = "sincos",
|
||||
use_pos_embed: bool = True,
|
||||
@@ -206,7 +206,7 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
def attn_processors(self) -> dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
@@ -215,7 +215,7 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
@@ -230,7 +230,7 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
def set_attn_processor(self, processor: AttentionProcessor | dict[str, AttentionProcessor]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
@@ -337,9 +337,9 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
|
||||
encoder_hidden_states: torch.Tensor = None,
|
||||
pooled_projections: torch.Tensor = None,
|
||||
timestep: torch.LongTensor = None,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
joint_attention_kwargs: Optional[dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[torch.Tensor, Transformer2DModelOutput]:
|
||||
) -> torch.Tensor | Transformer2DModelOutput:
|
||||
"""
|
||||
The [`SD3Transformer2DModel`] forward method.
|
||||
|
||||
@@ -460,7 +460,7 @@ class SD3MultiControlNetModel(ModelMixin):
|
||||
compatible with `SD3ControlNetModel`.
|
||||
|
||||
Args:
|
||||
controlnets (`List[SD3ControlNetModel]`):
|
||||
controlnets (`list[SD3ControlNetModel]`):
|
||||
Provides additional conditioning to the unet during the denoising process. You must set multiple
|
||||
`SD3ControlNetModel` as a list.
|
||||
"""
|
||||
@@ -472,14 +472,14 @@ class SD3MultiControlNetModel(ModelMixin):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
controlnet_cond: List[torch.tensor],
|
||||
conditioning_scale: List[float],
|
||||
controlnet_cond: list[torch.tensor],
|
||||
conditioning_scale: list[float],
|
||||
pooled_projections: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor = None,
|
||||
timestep: torch.LongTensor = None,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
joint_attention_kwargs: Optional[dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[SD3ControlNetOutput, Tuple]:
|
||||
) -> SD3ControlNetOutput | tuple:
|
||||
for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)):
|
||||
block_samples = controlnet(
|
||||
hidden_states=hidden_states,
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -55,7 +55,7 @@ class SparseControlNetOutput(BaseOutput):
|
||||
Output can be used to condition the original UNet's middle block activation.
|
||||
"""
|
||||
|
||||
down_block_res_samples: Tuple[torch.Tensor]
|
||||
down_block_res_samples: tuple[torch.Tensor]
|
||||
mid_block_res_sample: torch.Tensor
|
||||
|
||||
|
||||
@@ -64,7 +64,7 @@ class SparseControlNetConditioningEmbedding(nn.Module):
|
||||
self,
|
||||
conditioning_embedding_channels: int,
|
||||
conditioning_channels: int = 3,
|
||||
block_out_channels: Tuple[int, ...] = (16, 32, 96, 256),
|
||||
block_out_channels: tuple[int, ...] = (16, 32, 96, 256),
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -110,7 +110,7 @@ class SparseControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
The frequency shift to apply to the time embedding.
|
||||
down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
|
||||
The tuple of downsample blocks to use.
|
||||
only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
|
||||
only_cross_attention (`Union[bool, tuple[bool]]`, defaults to `False`):
|
||||
block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
|
||||
The tuple of output channels for each block.
|
||||
layers_per_block (`int`, defaults to 2):
|
||||
@@ -128,28 +128,28 @@ class SparseControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
The epsilon to use for the normalization.
|
||||
cross_attention_dim (`int`, defaults to 1280):
|
||||
The dimension of the cross attention features.
|
||||
transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
|
||||
transformer_layers_per_block (`int` or `tuple[int]`, *optional*, defaults to 1):
|
||||
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
|
||||
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
|
||||
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
|
||||
transformer_layers_per_mid_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
|
||||
transformer_layers_per_mid_block (`int` or `tuple[int]`, *optional*, defaults to 1):
|
||||
The number of transformer layers to use in each layer in the middle block.
|
||||
attention_head_dim (`int` or `Tuple[int]`, defaults to 8):
|
||||
attention_head_dim (`int` or `tuple[int]`, defaults to 8):
|
||||
The dimension of the attention heads.
|
||||
num_attention_heads (`int` or `Tuple[int]`, *optional*):
|
||||
num_attention_heads (`int` or `tuple[int]`, *optional*):
|
||||
The number of heads to use for multi-head attention.
|
||||
use_linear_projection (`bool`, defaults to `False`):
|
||||
upcast_attention (`bool`, defaults to `False`):
|
||||
resnet_time_scale_shift (`str`, defaults to `"default"`):
|
||||
Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
|
||||
conditioning_embedding_out_channels (`Tuple[int]`, defaults to `(16, 32, 96, 256)`):
|
||||
conditioning_embedding_out_channels (`tuple[int]`, defaults to `(16, 32, 96, 256)`):
|
||||
The tuple of output channel for each block in the `conditioning_embedding` layer.
|
||||
global_pool_conditions (`bool`, defaults to `False`):
|
||||
TODO(Patrick) - unused parameter
|
||||
controlnet_conditioning_channel_order (`str`, defaults to `rgb`):
|
||||
motion_max_seq_length (`int`, defaults to `32`):
|
||||
The maximum sequence length to use in the motion module.
|
||||
motion_num_attention_heads (`int` or `Tuple[int]`, defaults to `8`):
|
||||
motion_num_attention_heads (`int` or `tuple[int]`, defaults to `8`):
|
||||
The number of heads to use in each attention layer of the motion module.
|
||||
concat_conditioning_mask (`bool`, defaults to `True`):
|
||||
use_simplified_condition_embedding (`bool`, defaults to `True`):
|
||||
@@ -164,14 +164,14 @@ class SparseControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
conditioning_channels: int = 4,
|
||||
flip_sin_to_cos: bool = True,
|
||||
freq_shift: int = 0,
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
down_block_types: tuple[str, ...] = (
|
||||
"CrossAttnDownBlockMotion",
|
||||
"CrossAttnDownBlockMotion",
|
||||
"CrossAttnDownBlockMotion",
|
||||
"DownBlockMotion",
|
||||
),
|
||||
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
||||
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
||||
only_cross_attention: bool | tuple[bool] = False,
|
||||
block_out_channels: tuple[int, ...] = (320, 640, 1280, 1280),
|
||||
layers_per_block: int = 2,
|
||||
downsample_padding: int = 1,
|
||||
mid_block_scale_factor: float = 1,
|
||||
@@ -179,15 +179,15 @@ class SparseControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
norm_num_groups: Optional[int] = 32,
|
||||
norm_eps: float = 1e-5,
|
||||
cross_attention_dim: int = 768,
|
||||
transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
|
||||
transformer_layers_per_mid_block: Optional[Union[int, Tuple[int]]] = None,
|
||||
temporal_transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
|
||||
attention_head_dim: Union[int, Tuple[int, ...]] = 8,
|
||||
num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
transformer_layers_per_block: int | tuple[int, ...] = 1,
|
||||
transformer_layers_per_mid_block: Optional[int | tuple[int]] = None,
|
||||
temporal_transformer_layers_per_block: int | tuple[int, ...] = 1,
|
||||
attention_head_dim: int | tuple[int, ...] = 8,
|
||||
num_attention_heads: Optional[int | tuple[int, ...]] = None,
|
||||
use_linear_projection: bool = False,
|
||||
upcast_attention: bool = False,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
|
||||
conditioning_embedding_out_channels: Optional[tuple[int, ...]] = (16, 32, 96, 256),
|
||||
global_pool_conditions: bool = False,
|
||||
controlnet_conditioning_channel_order: str = "rgb",
|
||||
motion_max_seq_length: int = 32,
|
||||
@@ -389,7 +389,7 @@ class SparseControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
cls,
|
||||
unet: UNet2DConditionModel,
|
||||
controlnet_conditioning_channel_order: str = "rgb",
|
||||
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
|
||||
conditioning_embedding_out_channels: Optional[tuple[int, ...]] = (16, 32, 96, 256),
|
||||
load_weights_from_unet: bool = True,
|
||||
conditioning_channels: int = 3,
|
||||
) -> "SparseControlNetModel":
|
||||
@@ -450,7 +450,7 @@ class SparseControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
def attn_processors(self) -> dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
@@ -459,7 +459,7 @@ class SparseControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
@@ -474,7 +474,7 @@ class SparseControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
def set_attn_processor(self, processor: AttentionProcessor | dict[str, AttentionProcessor]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
@@ -525,7 +525,7 @@ class SparseControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
self.set_attn_processor(processor)
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice
|
||||
def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None:
|
||||
def set_attention_slice(self, slice_size: str | int | list[int]) -> None:
|
||||
r"""
|
||||
Enable sliced attention computation.
|
||||
|
||||
@@ -579,7 +579,7 @@ class SparseControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
# Recursively walk through all the children.
|
||||
# Any children which exposes the set_attention_slice method
|
||||
# gets the message
|
||||
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
|
||||
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: list[int]):
|
||||
if hasattr(module, "set_attention_slice"):
|
||||
module.set_attention_slice(slice_size.pop())
|
||||
|
||||
@@ -593,17 +593,17 @@ class SparseControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
timestep: torch.Tensor | float | int,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
controlnet_cond: torch.Tensor,
|
||||
conditioning_scale: float = 1.0,
|
||||
timestep_cond: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
cross_attention_kwargs: Optional[dict[str, Any]] = None,
|
||||
conditioning_mask: Optional[torch.Tensor] = None,
|
||||
guess_mode: bool = False,
|
||||
return_dict: bool = True,
|
||||
) -> Union[SparseControlNetOutput, Tuple[Tuple[torch.Tensor, ...], torch.Tensor]]:
|
||||
) -> SparseControlNetOutput | tuple[tuple[torch.Tensor, ...], torch.Tensor]:
|
||||
"""
|
||||
The [`SparseControlNetModel`] forward method.
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -94,7 +94,7 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
The frequency shift to apply to the time embedding.
|
||||
down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
|
||||
The tuple of downsample blocks to use.
|
||||
only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
|
||||
only_cross_attention (`Union[bool, tuple[bool]]`, defaults to `False`):
|
||||
block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
|
||||
The tuple of output channels for each block.
|
||||
layers_per_block (`int`, defaults to 2):
|
||||
@@ -112,7 +112,7 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
The epsilon to use for the normalization.
|
||||
cross_attention_dim (`int`, defaults to 1280):
|
||||
The dimension of the cross attention features.
|
||||
transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
|
||||
transformer_layers_per_block (`int` or `tuple[int]`, *optional*, defaults to 1):
|
||||
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
|
||||
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
|
||||
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
|
||||
@@ -122,7 +122,7 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
|
||||
If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
|
||||
embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
|
||||
attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
|
||||
attention_head_dim (`Union[int, tuple[int]]`, defaults to 8):
|
||||
The dimension of the attention heads.
|
||||
use_linear_projection (`bool`, defaults to `False`):
|
||||
class_embed_type (`str`, *optional*, defaults to `None`):
|
||||
@@ -156,14 +156,14 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
conditioning_channels: int = 3,
|
||||
flip_sin_to_cos: bool = True,
|
||||
freq_shift: int = 0,
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
down_block_types: tuple[str, ...] = (
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"DownBlock2D",
|
||||
),
|
||||
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
||||
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
||||
only_cross_attention: bool | tuple[bool] = False,
|
||||
block_out_channels: tuple[int, ...] = (320, 640, 1280, 1280),
|
||||
layers_per_block: int = 2,
|
||||
downsample_padding: int = 1,
|
||||
mid_block_scale_factor: float = 1,
|
||||
@@ -171,11 +171,11 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
norm_num_groups: Optional[int] = 32,
|
||||
norm_eps: float = 1e-5,
|
||||
cross_attention_dim: int = 1280,
|
||||
transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
|
||||
transformer_layers_per_block: int | tuple[int, ...] = 1,
|
||||
encoder_hid_dim: Optional[int] = None,
|
||||
encoder_hid_dim_type: Optional[str] = None,
|
||||
attention_head_dim: Union[int, Tuple[int, ...]] = 8,
|
||||
num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
attention_head_dim: int | tuple[int, ...] = 8,
|
||||
num_attention_heads: Optional[int | tuple[int, ...]] = None,
|
||||
use_linear_projection: bool = False,
|
||||
class_embed_type: Optional[str] = None,
|
||||
addition_embed_type: Optional[str] = None,
|
||||
@@ -185,7 +185,7 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
resnet_time_scale_shift: str = "default",
|
||||
projection_class_embeddings_input_dim: Optional[int] = None,
|
||||
controlnet_conditioning_channel_order: str = "rgb",
|
||||
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (48, 96, 192, 384),
|
||||
conditioning_embedding_out_channels: Optional[tuple[int, ...]] = (48, 96, 192, 384),
|
||||
global_pool_conditions: bool = False,
|
||||
addition_embed_type_num_heads: int = 64,
|
||||
num_control_type: int = 6,
|
||||
@@ -390,7 +390,7 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
cls,
|
||||
unet: UNet2DConditionModel,
|
||||
controlnet_conditioning_channel_order: str = "rgb",
|
||||
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
|
||||
conditioning_embedding_out_channels: Optional[tuple[int, ...]] = (16, 32, 96, 256),
|
||||
load_weights_from_unet: bool = True,
|
||||
):
|
||||
r"""
|
||||
@@ -457,7 +457,7 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
def attn_processors(self) -> dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
@@ -466,7 +466,7 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
@@ -481,7 +481,7 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
def set_attn_processor(self, processor: AttentionProcessor | dict[str, AttentionProcessor]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
@@ -532,7 +532,7 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
self.set_attn_processor(processor)
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice
|
||||
def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None:
|
||||
def set_attention_slice(self, slice_size: str | int | list[int]) -> None:
|
||||
r"""
|
||||
Enable sliced attention computation.
|
||||
|
||||
@@ -586,7 +586,7 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
# Recursively walk through all the children.
|
||||
# Any children which exposes the set_attention_slice method
|
||||
# gets the message
|
||||
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
|
||||
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: list[int]):
|
||||
if hasattr(module, "set_attention_slice"):
|
||||
module.set_attention_slice(slice_size.pop())
|
||||
|
||||
@@ -600,21 +600,21 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
timestep: torch.Tensor | float | int,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
controlnet_cond: List[torch.Tensor],
|
||||
controlnet_cond: list[torch.Tensor],
|
||||
control_type: torch.Tensor,
|
||||
control_type_idx: List[int],
|
||||
conditioning_scale: Union[float, List[float]] = 1.0,
|
||||
control_type_idx: list[int],
|
||||
conditioning_scale: float | list[float] = 1.0,
|
||||
class_labels: Optional[torch.Tensor] = None,
|
||||
timestep_cond: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
added_cond_kwargs: Optional[dict[str, torch.Tensor]] = None,
|
||||
cross_attention_kwargs: Optional[dict[str, Any]] = None,
|
||||
from_multi: bool = False,
|
||||
guess_mode: bool = False,
|
||||
return_dict: bool = True,
|
||||
) -> Union[ControlNetOutput, Tuple[Tuple[torch.Tensor, ...], torch.Tensor]]:
|
||||
) -> ControlNetOutput | tuple[tuple[torch.Tensor, ...], torch.Tensor]:
|
||||
"""
|
||||
The [`ControlNetUnionModel`] forward method.
|
||||
|
||||
@@ -625,12 +625,12 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
The number of timesteps to denoise an input.
|
||||
encoder_hidden_states (`torch.Tensor`):
|
||||
The encoder hidden states.
|
||||
controlnet_cond (`List[torch.Tensor]`):
|
||||
controlnet_cond (`list[torch.Tensor]`):
|
||||
The conditional input tensors.
|
||||
control_type (`torch.Tensor`):
|
||||
A tensor of shape `(batch, num_control_type)` with values `0` or `1` depending on whether the control
|
||||
type is used.
|
||||
control_type_idx (`List[int]`):
|
||||
control_type_idx (`list[int]`):
|
||||
The indices of `control_type`.
|
||||
conditioning_scale (`float`, defaults to `1.0`):
|
||||
The scale factor for ControlNet outputs.
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass
|
||||
from math import gcd
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
@@ -109,7 +109,7 @@ def get_down_block_adapter(
|
||||
temb_channels: int,
|
||||
max_norm_num_groups: Optional[int] = 32,
|
||||
has_crossattn=True,
|
||||
transformer_layers_per_block: Optional[Union[int, Tuple[int]]] = 1,
|
||||
transformer_layers_per_block: Optional[int | tuple[int]] = 1,
|
||||
num_attention_heads: Optional[int] = 1,
|
||||
cross_attention_dim: Optional[int] = 1024,
|
||||
add_downsample: bool = True,
|
||||
@@ -230,7 +230,7 @@ def get_mid_block_adapter(
|
||||
def get_up_block_adapter(
|
||||
out_channels: int,
|
||||
prev_output_channel: int,
|
||||
ctrl_skip_channels: List[int],
|
||||
ctrl_skip_channels: list[int],
|
||||
):
|
||||
ctrl_to_base = []
|
||||
num_layers = 3 # only support sd + sdxl
|
||||
@@ -278,7 +278,7 @@ class ControlNetXSAdapter(ModelMixin, ConfigMixin):
|
||||
The tuple of downsample blocks to use.
|
||||
sample_size (`int`, defaults to 96):
|
||||
Height and width of input/output sample.
|
||||
transformer_layers_per_block (`Union[int, Tuple[int]]`, defaults to 1):
|
||||
transformer_layers_per_block (`Union[int, tuple[int]]`, defaults to 1):
|
||||
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
|
||||
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
|
||||
upcast_attention (`bool`, defaults to `True`):
|
||||
@@ -293,21 +293,21 @@ class ControlNetXSAdapter(ModelMixin, ConfigMixin):
|
||||
self,
|
||||
conditioning_channels: int = 3,
|
||||
conditioning_channel_order: str = "rgb",
|
||||
conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256),
|
||||
conditioning_embedding_out_channels: tuple[int] = (16, 32, 96, 256),
|
||||
time_embedding_mix: float = 1.0,
|
||||
learn_time_embedding: bool = False,
|
||||
num_attention_heads: Union[int, Tuple[int]] = 4,
|
||||
block_out_channels: Tuple[int] = (4, 8, 16, 16),
|
||||
base_block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
||||
num_attention_heads: int | tuple[int] = 4,
|
||||
block_out_channels: tuple[int] = (4, 8, 16, 16),
|
||||
base_block_out_channels: tuple[int] = (320, 640, 1280, 1280),
|
||||
cross_attention_dim: int = 1024,
|
||||
down_block_types: Tuple[str] = (
|
||||
down_block_types: tuple[str] = (
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"DownBlock2D",
|
||||
),
|
||||
sample_size: Optional[int] = 96,
|
||||
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
|
||||
transformer_layers_per_block: int | tuple[int] = 1,
|
||||
upcast_attention: bool = True,
|
||||
max_norm_num_groups: int = 32,
|
||||
use_linear_projection: bool = True,
|
||||
@@ -430,13 +430,13 @@ class ControlNetXSAdapter(ModelMixin, ConfigMixin):
|
||||
cls,
|
||||
unet: UNet2DConditionModel,
|
||||
size_ratio: Optional[float] = None,
|
||||
block_out_channels: Optional[List[int]] = None,
|
||||
num_attention_heads: Optional[List[int]] = None,
|
||||
block_out_channels: Optional[list[int]] = None,
|
||||
num_attention_heads: Optional[list[int]] = None,
|
||||
learn_time_embedding: bool = False,
|
||||
time_embedding_mix: int = 1.0,
|
||||
conditioning_channels: int = 3,
|
||||
conditioning_channel_order: str = "rgb",
|
||||
conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256),
|
||||
conditioning_embedding_out_channels: tuple[int] = (16, 32, 96, 256),
|
||||
):
|
||||
r"""
|
||||
Instantiate a [`ControlNetXSAdapter`] from a [`UNet2DConditionModel`].
|
||||
@@ -447,9 +447,9 @@ class ControlNetXSAdapter(ModelMixin, ConfigMixin):
|
||||
size_ratio (float, *optional*, defaults to `None`):
|
||||
When given, block_out_channels is set to a fraction of the base model's block_out_channels. Either this
|
||||
or `block_out_channels` must be given.
|
||||
block_out_channels (`List[int]`, *optional*, defaults to `None`):
|
||||
block_out_channels (`list[int]`, *optional*, defaults to `None`):
|
||||
Down blocks output channels in control model. Either this or `size_ratio` must be given.
|
||||
num_attention_heads (`List[int]`, *optional*, defaults to `None`):
|
||||
num_attention_heads (`list[int]`, *optional*, defaults to `None`):
|
||||
The dimension of the attention heads. The naming seems a bit confusing and it is, see
|
||||
https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 for why.
|
||||
learn_time_embedding (`bool`, defaults to `False`):
|
||||
@@ -461,7 +461,7 @@ class ControlNetXSAdapter(ModelMixin, ConfigMixin):
|
||||
Number of channels of conditioning input (e.g. an image)
|
||||
conditioning_channel_order (`str`, defaults to `"rgb"`):
|
||||
The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
|
||||
conditioning_embedding_out_channels (`Tuple[int]`, defaults to `(16, 32, 96, 256)`):
|
||||
conditioning_embedding_out_channels (`tuple[int]`, defaults to `(16, 32, 96, 256)`):
|
||||
The tuple of output channel for each block in the `controlnet_cond_embedding` layer.
|
||||
"""
|
||||
|
||||
@@ -529,18 +529,18 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
|
||||
self,
|
||||
# unet configs
|
||||
sample_size: Optional[int] = 96,
|
||||
down_block_types: Tuple[str] = (
|
||||
down_block_types: tuple[str] = (
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"DownBlock2D",
|
||||
),
|
||||
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
|
||||
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
||||
up_block_types: tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
|
||||
block_out_channels: tuple[int] = (320, 640, 1280, 1280),
|
||||
norm_num_groups: Optional[int] = 32,
|
||||
cross_attention_dim: Union[int, Tuple[int]] = 1024,
|
||||
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
|
||||
num_attention_heads: Union[int, Tuple[int]] = 8,
|
||||
cross_attention_dim: int | tuple[int] = 1024,
|
||||
transformer_layers_per_block: int | tuple[int] = 1,
|
||||
num_attention_heads: int | tuple[int] = 8,
|
||||
addition_embed_type: Optional[str] = None,
|
||||
addition_time_embed_dim: Optional[int] = None,
|
||||
upcast_attention: bool = True,
|
||||
@@ -550,11 +550,11 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
|
||||
# additional controlnet configs
|
||||
time_embedding_mix: float = 1.0,
|
||||
ctrl_conditioning_channels: int = 3,
|
||||
ctrl_conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256),
|
||||
ctrl_conditioning_embedding_out_channels: tuple[int] = (16, 32, 96, 256),
|
||||
ctrl_conditioning_channel_order: str = "rgb",
|
||||
ctrl_learn_time_embedding: bool = False,
|
||||
ctrl_block_out_channels: Tuple[int] = (4, 8, 16, 16),
|
||||
ctrl_num_attention_heads: Union[int, Tuple[int]] = 4,
|
||||
ctrl_block_out_channels: tuple[int] = (4, 8, 16, 16),
|
||||
ctrl_num_attention_heads: int | tuple[int] = 4,
|
||||
ctrl_max_norm_num_groups: int = 32,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -721,7 +721,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
|
||||
unet: UNet2DConditionModel,
|
||||
controlnet: Optional[ControlNetXSAdapter] = None,
|
||||
size_ratio: Optional[float] = None,
|
||||
ctrl_block_out_channels: Optional[List[float]] = None,
|
||||
ctrl_block_out_channels: Optional[list[float]] = None,
|
||||
time_embedding_mix: Optional[float] = None,
|
||||
ctrl_optional_kwargs: Optional[Dict] = None,
|
||||
):
|
||||
@@ -737,7 +737,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
|
||||
adapter will be created.
|
||||
size_ratio (float, *optional*, defaults to `None`):
|
||||
Used to construct the controlnet if none is given. See [`ControlNetXSAdapter.from_unet`] for details.
|
||||
ctrl_block_out_channels (`List[int]`, *optional*, defaults to `None`):
|
||||
ctrl_block_out_channels (`list[int]`, *optional*, defaults to `None`):
|
||||
Used to construct the controlnet if none is given. See [`ControlNetXSAdapter.from_unet`] for details,
|
||||
where this parameter is called `block_out_channels`.
|
||||
time_embedding_mix (`float`, *optional*, defaults to None):
|
||||
@@ -865,7 +865,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
def attn_processors(self) -> dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
@@ -874,7 +874,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
@@ -889,7 +889,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
def set_attn_processor(self, processor: AttentionProcessor | dict[str, AttentionProcessor]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
@@ -1008,18 +1008,18 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
|
||||
def forward(
|
||||
self,
|
||||
sample: Tensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
timestep: torch.Tensor | float | int,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
controlnet_cond: Optional[torch.Tensor] = None,
|
||||
conditioning_scale: Optional[float] = 1.0,
|
||||
class_labels: Optional[torch.Tensor] = None,
|
||||
timestep_cond: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
||||
cross_attention_kwargs: Optional[dict[str, Any]] = None,
|
||||
added_cond_kwargs: Optional[dict[str, torch.Tensor]] = None,
|
||||
return_dict: bool = True,
|
||||
apply_control: bool = True,
|
||||
) -> Union[ControlNetXSOutput, Tuple]:
|
||||
) -> ControlNetXSOutput | tuple:
|
||||
"""
|
||||
The [`ControlNetXSModel`] forward method.
|
||||
|
||||
@@ -1221,7 +1221,7 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module):
|
||||
norm_num_groups: int = 32,
|
||||
ctrl_max_norm_num_groups: int = 32,
|
||||
has_crossattn=True,
|
||||
transformer_layers_per_block: Optional[Union[int, Tuple[int]]] = 1,
|
||||
transformer_layers_per_block: Optional[int | tuple[int]] = 1,
|
||||
base_num_attention_heads: Optional[int] = 1,
|
||||
ctrl_num_attention_heads: Optional[int] = 1,
|
||||
cross_attention_dim: Optional[int] = 1024,
|
||||
@@ -1420,10 +1420,10 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module):
|
||||
hidden_states_ctrl: Optional[Tensor] = None,
|
||||
conditioning_scale: Optional[float] = 1.0,
|
||||
attention_mask: Optional[Tensor] = None,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
cross_attention_kwargs: Optional[dict[str, Any]] = None,
|
||||
encoder_attention_mask: Optional[Tensor] = None,
|
||||
apply_control: bool = True,
|
||||
) -> Tuple[Tensor, Tensor, Tuple[Tensor, ...], Tuple[Tensor, ...]]:
|
||||
) -> tuple[Tensor, Tensor, tuple[Tensor, ...], tuple[Tensor, ...]]:
|
||||
if cross_attention_kwargs is not None:
|
||||
if cross_attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
|
||||
@@ -1625,11 +1625,11 @@ class ControlNetXSCrossAttnMidBlock2D(nn.Module):
|
||||
encoder_hidden_states: Tensor,
|
||||
hidden_states_ctrl: Optional[Tensor] = None,
|
||||
conditioning_scale: Optional[float] = 1.0,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
cross_attention_kwargs: Optional[dict[str, Any]] = None,
|
||||
attention_mask: Optional[Tensor] = None,
|
||||
encoder_attention_mask: Optional[Tensor] = None,
|
||||
apply_control: bool = True,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
) -> tuple[Tensor, Tensor]:
|
||||
if cross_attention_kwargs is not None:
|
||||
if cross_attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
|
||||
@@ -1661,7 +1661,7 @@ class ControlNetXSCrossAttnUpBlock2D(nn.Module):
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
prev_output_channel: int,
|
||||
ctrl_skip_channels: List[int],
|
||||
ctrl_skip_channels: list[int],
|
||||
temb_channels: int,
|
||||
norm_num_groups: int = 32,
|
||||
resolution_idx: Optional[int] = None,
|
||||
@@ -1806,12 +1806,12 @@ class ControlNetXSCrossAttnUpBlock2D(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: Tensor,
|
||||
res_hidden_states_tuple_base: Tuple[Tensor, ...],
|
||||
res_hidden_states_tuple_ctrl: Tuple[Tensor, ...],
|
||||
res_hidden_states_tuple_base: tuple[Tensor, ...],
|
||||
res_hidden_states_tuple_ctrl: tuple[Tensor, ...],
|
||||
temb: Tensor,
|
||||
encoder_hidden_states: Optional[Tensor] = None,
|
||||
conditioning_scale: Optional[float] = 1.0,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
cross_attention_kwargs: Optional[dict[str, Any]] = None,
|
||||
attention_mask: Optional[Tensor] = None,
|
||||
upsample_size: Optional[int] = None,
|
||||
encoder_attention_mask: Optional[Tensor] = None,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import os
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -20,30 +20,30 @@ class MultiControlNetModel(ModelMixin):
|
||||
compatible with `ControlNetModel`.
|
||||
|
||||
Args:
|
||||
controlnets (`List[ControlNetModel]`):
|
||||
controlnets (`list[ControlNetModel]`):
|
||||
Provides additional conditioning to the unet during the denoising process. You must set multiple
|
||||
`ControlNetModel` as a list.
|
||||
"""
|
||||
|
||||
def __init__(self, controlnets: Union[List[ControlNetModel], Tuple[ControlNetModel]]):
|
||||
def __init__(self, controlnets: list[ControlNetModel] | tuple[ControlNetModel]):
|
||||
super().__init__()
|
||||
self.nets = nn.ModuleList(controlnets)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
timestep: torch.Tensor | float | int,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
controlnet_cond: List[torch.tensor],
|
||||
conditioning_scale: List[float],
|
||||
controlnet_cond: list[torch.tensor],
|
||||
conditioning_scale: list[float],
|
||||
class_labels: Optional[torch.Tensor] = None,
|
||||
timestep_cond: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
added_cond_kwargs: Optional[dict[str, torch.Tensor]] = None,
|
||||
cross_attention_kwargs: Optional[dict[str, Any]] = None,
|
||||
guess_mode: bool = False,
|
||||
return_dict: bool = True,
|
||||
) -> Union[ControlNetOutput, Tuple]:
|
||||
) -> ControlNetOutput | tuple:
|
||||
for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)):
|
||||
down_samples, mid_sample = controlnet(
|
||||
sample=sample,
|
||||
@@ -74,7 +74,7 @@ class MultiControlNetModel(ModelMixin):
|
||||
|
||||
def save_pretrained(
|
||||
self,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
save_directory: str | os.PathLike,
|
||||
is_main_process: bool = True,
|
||||
save_function: Callable = None,
|
||||
safe_serialization: bool = True,
|
||||
@@ -111,7 +111,7 @@ class MultiControlNetModel(ModelMixin):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_path: Optional[Union[str, os.PathLike]], **kwargs):
|
||||
def from_pretrained(cls, pretrained_model_path: Optional[str | os.PathLike], **kwargs):
|
||||
r"""
|
||||
Instantiate a pretrained MultiControlNet model from multiple pre-trained controlnet models.
|
||||
|
||||
@@ -134,7 +134,7 @@ class MultiControlNetModel(ModelMixin):
|
||||
Override the default `torch.dtype` and load the model under this dtype.
|
||||
output_loading_info(`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
||||
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
|
||||
device_map (`str` or `dict[str, Union[int, str, torch.device]]`, *optional*):
|
||||
A map that specifies where each submodule should go. It doesn't need to be refined to each
|
||||
parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
|
||||
same device.
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import os
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -21,32 +21,32 @@ class MultiControlNetUnionModel(ModelMixin):
|
||||
be compatible with `ControlNetUnionModel`.
|
||||
|
||||
Args:
|
||||
controlnets (`List[ControlNetUnionModel]`):
|
||||
controlnets (`list[ControlNetUnionModel]`):
|
||||
Provides additional conditioning to the unet during the denoising process. You must set multiple
|
||||
`ControlNetUnionModel` as a list.
|
||||
"""
|
||||
|
||||
def __init__(self, controlnets: Union[List[ControlNetUnionModel], Tuple[ControlNetUnionModel]]):
|
||||
def __init__(self, controlnets: list[ControlNetUnionModel] | tuple[ControlNetUnionModel]):
|
||||
super().__init__()
|
||||
self.nets = nn.ModuleList(controlnets)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
timestep: torch.Tensor | float | int,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
controlnet_cond: List[torch.tensor],
|
||||
control_type: List[torch.Tensor],
|
||||
control_type_idx: List[List[int]],
|
||||
conditioning_scale: List[float],
|
||||
controlnet_cond: list[torch.tensor],
|
||||
control_type: list[torch.Tensor],
|
||||
control_type_idx: list[list[int]],
|
||||
conditioning_scale: list[float],
|
||||
class_labels: Optional[torch.Tensor] = None,
|
||||
timestep_cond: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
added_cond_kwargs: Optional[dict[str, torch.Tensor]] = None,
|
||||
cross_attention_kwargs: Optional[dict[str, Any]] = None,
|
||||
guess_mode: bool = False,
|
||||
return_dict: bool = True,
|
||||
) -> Union[ControlNetOutput, Tuple]:
|
||||
) -> ControlNetOutput | tuple:
|
||||
down_block_res_samples, mid_block_res_sample = None, None
|
||||
for i, (image, ctype, ctype_idx, scale, controlnet) in enumerate(
|
||||
zip(controlnet_cond, control_type, control_type_idx, conditioning_scale, self.nets)
|
||||
@@ -86,7 +86,7 @@ class MultiControlNetUnionModel(ModelMixin):
|
||||
# Copied from diffusers.models.controlnets.multicontrolnet.MultiControlNetModel.save_pretrained with ControlNet->ControlNetUnion
|
||||
def save_pretrained(
|
||||
self,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
save_directory: str | os.PathLike,
|
||||
is_main_process: bool = True,
|
||||
save_function: Callable = None,
|
||||
safe_serialization: bool = True,
|
||||
@@ -124,7 +124,7 @@ class MultiControlNetUnionModel(ModelMixin):
|
||||
|
||||
@classmethod
|
||||
# Copied from diffusers.models.controlnets.multicontrolnet.MultiControlNetModel.from_pretrained with ControlNet->ControlNetUnion
|
||||
def from_pretrained(cls, pretrained_model_path: Optional[Union[str, os.PathLike]], **kwargs):
|
||||
def from_pretrained(cls, pretrained_model_path: Optional[str | os.PathLike], **kwargs):
|
||||
r"""
|
||||
Instantiate a pretrained MultiControlNetUnion model from multiple pre-trained controlnet models.
|
||||
|
||||
@@ -147,7 +147,7 @@ class MultiControlNetUnionModel(ModelMixin):
|
||||
Override the default `torch.dtype` and load the model under this dtype.
|
||||
output_loading_info(`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
||||
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
|
||||
device_map (`str` or `dict[str, Union[int, str, torch.device]]`, *optional*):
|
||||
A map that specifies where each submodule should go. It doesn't need to be refined to each
|
||||
parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
|
||||
same device.
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -168,7 +168,7 @@ class FirDownsample2D(nn.Module):
|
||||
channels: Optional[int] = None,
|
||||
out_channels: Optional[int] = None,
|
||||
use_conv: bool = False,
|
||||
fir_kernel: Tuple[int, int, int, int] = (1, 3, 3, 1),
|
||||
fir_kernel: tuple[int, int, int, int] = (1, 3, 3, 1),
|
||||
):
|
||||
super().__init__()
|
||||
out_channels = out_channels if out_channels else channels
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import math
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -80,7 +80,7 @@ def get_timestep_embedding(
|
||||
|
||||
def get_3d_sincos_pos_embed(
|
||||
embed_dim: int,
|
||||
spatial_size: Union[int, Tuple[int, int]],
|
||||
spatial_size: int | tuple[int, int],
|
||||
temporal_size: int,
|
||||
spatial_interpolation_scale: float = 1.0,
|
||||
temporal_interpolation_scale: float = 1.0,
|
||||
@@ -93,7 +93,7 @@ def get_3d_sincos_pos_embed(
|
||||
Args:
|
||||
embed_dim (`int`):
|
||||
The embedding dimension of inputs. It must be divisible by 16.
|
||||
spatial_size (`int` or `Tuple[int, int]`):
|
||||
spatial_size (`int` or `tuple[int, int]`):
|
||||
The spatial dimension of positional embeddings. If an integer is provided, the same size is applied to both
|
||||
spatial dimensions (height and width).
|
||||
temporal_size (`int`):
|
||||
@@ -154,7 +154,7 @@ def get_3d_sincos_pos_embed(
|
||||
|
||||
def _get_3d_sincos_pos_embed_np(
|
||||
embed_dim: int,
|
||||
spatial_size: Union[int, Tuple[int, int]],
|
||||
spatial_size: int | tuple[int, int],
|
||||
temporal_size: int,
|
||||
spatial_interpolation_scale: float = 1.0,
|
||||
temporal_interpolation_scale: float = 1.0,
|
||||
@@ -165,7 +165,7 @@ def _get_3d_sincos_pos_embed_np(
|
||||
Args:
|
||||
embed_dim (`int`):
|
||||
The embedding dimension of inputs. It must be divisible by 16.
|
||||
spatial_size (`int` or `Tuple[int, int]`):
|
||||
spatial_size (`int` or `tuple[int, int]`):
|
||||
The spatial dimension of positional embeddings. If an integer is provided, the same size is applied to both
|
||||
spatial dimensions (height and width).
|
||||
temporal_size (`int`):
|
||||
@@ -609,10 +609,10 @@ class LuminaPatchEmbed(nn.Module):
|
||||
Patchifies and embeds the input tensor(s).
|
||||
|
||||
Args:
|
||||
x (List[torch.Tensor] | torch.Tensor): The input tensor(s) to be patchified and embedded.
|
||||
x (list[torch.Tensor] | torch.Tensor): The input tensor(s) to be patchified and embedded.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], torch.Tensor]: A tuple containing the patchified
|
||||
tuple[torch.Tensor, torch.Tensor, list[tuple[int, int]], torch.Tensor]: A tuple containing the patchified
|
||||
and embedded tensor(s), the mask indicating the valid patches, the original image size(s), and the
|
||||
frequency tensor(s).
|
||||
"""
|
||||
@@ -836,18 +836,18 @@ def get_3d_rotary_pos_embed(
|
||||
theta: int = 10000,
|
||||
use_real: bool = True,
|
||||
grid_type: str = "linspace",
|
||||
max_size: Optional[Tuple[int, int]] = None,
|
||||
max_size: Optional[tuple[int, int]] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
RoPE for video tokens with 3D structure.
|
||||
|
||||
Args:
|
||||
embed_dim: (`int`):
|
||||
The embedding dimension size, corresponding to hidden_size_head.
|
||||
crops_coords (`Tuple[int]`):
|
||||
crops_coords (`tuple[int]`):
|
||||
The top-left and bottom-right coordinates of the crop.
|
||||
grid_size (`Tuple[int]`):
|
||||
grid_size (`tuple[int]`):
|
||||
The grid size of the spatial positional embedding (height, width).
|
||||
temporal_size (`int`):
|
||||
The size of the temporal dimension.
|
||||
@@ -934,10 +934,10 @@ def get_3d_rotary_pos_embed_allegro(
|
||||
crops_coords,
|
||||
grid_size,
|
||||
temporal_size,
|
||||
interpolation_scale: Tuple[float, float, float] = (1.0, 1.0, 1.0),
|
||||
interpolation_scale: tuple[float, float, float] = (1.0, 1.0, 1.0),
|
||||
theta: int = 10000,
|
||||
device: Optional[torch.device] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
# TODO(aryan): docs
|
||||
start, stop = crops_coords
|
||||
grid_size_h, grid_size_w = grid_size
|
||||
@@ -981,9 +981,9 @@ def get_2d_rotary_pos_embed(
|
||||
Args:
|
||||
embed_dim: (`int`):
|
||||
The embedding dimension size
|
||||
crops_coords (`Tuple[int]`)
|
||||
crops_coords (`tuple[int]`)
|
||||
The top-left and bottom-right coordinates of the crop.
|
||||
grid_size (`Tuple[int]`):
|
||||
grid_size (`tuple[int]`):
|
||||
The grid size of the positional embedding.
|
||||
use_real (`bool`):
|
||||
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
|
||||
@@ -1029,9 +1029,9 @@ def _get_2d_rotary_pos_embed_np(embed_dim, crops_coords, grid_size, use_real=Tru
|
||||
Args:
|
||||
embed_dim: (`int`):
|
||||
The embedding dimension size
|
||||
crops_coords (`Tuple[int]`)
|
||||
crops_coords (`tuple[int]`)
|
||||
The top-left and bottom-right coordinates of the crop.
|
||||
grid_size (`Tuple[int]`):
|
||||
grid_size (`tuple[int]`):
|
||||
The grid size of the positional embedding.
|
||||
use_real (`bool`):
|
||||
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
|
||||
@@ -1119,7 +1119,7 @@ def get_2d_rotary_pos_embed_lumina(embed_dim, len_h, len_w, linear_factor=1.0, n
|
||||
|
||||
def get_1d_rotary_pos_embed(
|
||||
dim: int,
|
||||
pos: Union[np.ndarray, int],
|
||||
pos: np.ndarray | int,
|
||||
theta: float = 10000.0,
|
||||
use_real=False,
|
||||
linear_factor=1.0,
|
||||
@@ -1186,11 +1186,11 @@ def get_1d_rotary_pos_embed(
|
||||
|
||||
def apply_rotary_emb(
|
||||
x: torch.Tensor,
|
||||
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
|
||||
freqs_cis: torch.Tensor | tuple[torch.Tensor],
|
||||
use_real: bool = True,
|
||||
use_real_unbind_dim: int = -1,
|
||||
sequence_dim: int = 2,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
|
||||
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
|
||||
@@ -1200,10 +1200,10 @@ def apply_rotary_emb(
|
||||
Args:
|
||||
x (`torch.Tensor`):
|
||||
Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
|
||||
freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
|
||||
freqs_cis (`tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
|
||||
tuple[torch.Tensor, torch.Tensor]: tuple of modified query tensor and key tensor with rotary embeddings.
|
||||
"""
|
||||
if use_real:
|
||||
cos, sin = freqs_cis # [S, D]
|
||||
@@ -2543,7 +2543,7 @@ class IPAdapterTimeImageProjection(nn.Module):
|
||||
self.time_proj = Timesteps(timestep_in_dim, timestep_flip_sin_to_cos, timestep_freq_shift)
|
||||
self.time_embedding = TimestepEmbedding(timestep_in_dim, hidden_dim, act_fn="silu")
|
||||
|
||||
def forward(self, x: torch.Tensor, timestep: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
def forward(self, x: torch.Tensor, timestep: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Forward pass.
|
||||
|
||||
Args:
|
||||
@@ -2552,7 +2552,7 @@ class IPAdapterTimeImageProjection(nn.Module):
|
||||
timestep (`torch.Tensor`):
|
||||
Timestep in denoising process.
|
||||
Returns:
|
||||
`Tuple`[`torch.Tensor`, `torch.Tensor`]: The pair (latents, timestep_emb).
|
||||
`tuple`[`torch.Tensor`, `torch.Tensor`]: The pair (latents, timestep_emb).
|
||||
"""
|
||||
timestep_emb = self.time_proj(timestep).to(dtype=x.dtype)
|
||||
timestep_emb = self.time_embedding(timestep_emb)
|
||||
@@ -2572,7 +2572,7 @@ class IPAdapterTimeImageProjection(nn.Module):
|
||||
|
||||
|
||||
class MultiIPAdapterImageProjection(nn.Module):
|
||||
def __init__(self, IPAdapterImageProjectionLayers: Union[List[nn.Module], Tuple[nn.Module]]):
|
||||
def __init__(self, IPAdapterImageProjectionLayers: list[nn.Module] | tuple[nn.Module]):
|
||||
super().__init__()
|
||||
self.image_projection_layers = nn.ModuleList(IPAdapterImageProjectionLayers)
|
||||
|
||||
@@ -2581,7 +2581,7 @@ class MultiIPAdapterImageProjection(nn.Module):
|
||||
"""Number of IP-Adapters loaded."""
|
||||
return len(self.image_projection_layers)
|
||||
|
||||
def forward(self, image_embeds: List[torch.Tensor]):
|
||||
def forward(self, image_embeds: list[torch.Tensor]):
|
||||
projected_image_embeds = []
|
||||
|
||||
# currently, we accept `image_embeds` as
|
||||
|
||||
@@ -21,7 +21,7 @@
|
||||
# ----------------------------------------------------------------#
|
||||
###################################################################
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -199,7 +199,7 @@ class LoRALinearLayer(nn.Module):
|
||||
out_features: int,
|
||||
rank: int = 4,
|
||||
network_alpha: Optional[float] = None,
|
||||
device: Optional[Union[torch.device, str]] = None,
|
||||
device: Optional[torch.device | str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -260,9 +260,9 @@ class LoRAConv2dLayer(nn.Module):
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
rank: int = 4,
|
||||
kernel_size: Union[int, Tuple[int, int]] = (1, 1),
|
||||
stride: Union[int, Tuple[int, int]] = (1, 1),
|
||||
padding: Union[int, Tuple[int, int], str] = 0,
|
||||
kernel_size: int | tuple[int, int] = (1, 1),
|
||||
stride: int | tuple[int, int] = (1, 1),
|
||||
padding: int | tuple[int, int] | str = 0,
|
||||
network_alpha: Optional[float] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -22,7 +22,7 @@ from array import array
|
||||
from collections import OrderedDict, defaultdict
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Union
|
||||
from typing import Dict, Optional
|
||||
from zipfile import is_zipfile
|
||||
|
||||
import safetensors
|
||||
@@ -135,7 +135,7 @@ def _fetch_remapped_cls_from_config(config, old_class):
|
||||
return old_class
|
||||
|
||||
|
||||
def _determine_param_device(param_name: str, device_map: Optional[Dict[str, Union[int, str, torch.device]]]):
|
||||
def _determine_param_device(param_name: str, device_map: Optional[dict[str, int | str | torch.device]]):
|
||||
"""
|
||||
Find the device of param_name from the device_map.
|
||||
"""
|
||||
@@ -153,10 +153,10 @@ def _determine_param_device(param_name: str, device_map: Optional[Dict[str, Unio
|
||||
|
||||
|
||||
def load_state_dict(
|
||||
checkpoint_file: Union[str, os.PathLike],
|
||||
dduf_entries: Optional[Dict[str, DDUFEntry]] = None,
|
||||
checkpoint_file: str | os.PathLike,
|
||||
dduf_entries: Optional[dict[str, DDUFEntry]] = None,
|
||||
disable_mmap: bool = False,
|
||||
map_location: Union[str, torch.device] = "cpu",
|
||||
map_location: str | torch.device = "cpu",
|
||||
):
|
||||
"""
|
||||
Reads a checkpoint file, returning properly formatted errors if they arise.
|
||||
@@ -213,17 +213,17 @@ def load_state_dict(
|
||||
def load_model_dict_into_meta(
|
||||
model,
|
||||
state_dict: OrderedDict,
|
||||
dtype: Optional[Union[str, torch.dtype]] = None,
|
||||
dtype: Optional[str | torch.dtype] = None,
|
||||
model_name_or_path: Optional[str] = None,
|
||||
hf_quantizer: Optional[DiffusersQuantizer] = None,
|
||||
keep_in_fp32_modules: Optional[List] = None,
|
||||
device_map: Optional[Dict[str, Union[int, str, torch.device]]] = None,
|
||||
unexpected_keys: Optional[List[str]] = None,
|
||||
offload_folder: Optional[Union[str, os.PathLike]] = None,
|
||||
keep_in_fp32_modules: Optional[list] = None,
|
||||
device_map: Optional[dict[str, int | str | torch.device]] = None,
|
||||
unexpected_keys: Optional[list[str]] = None,
|
||||
offload_folder: Optional[str | os.PathLike] = None,
|
||||
offload_index: Optional[Dict] = None,
|
||||
state_dict_index: Optional[Dict] = None,
|
||||
state_dict_folder: Optional[Union[str, os.PathLike]] = None,
|
||||
) -> List[str]:
|
||||
state_dict_folder: Optional[str | os.PathLike] = None,
|
||||
) -> list[str]:
|
||||
"""
|
||||
This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its
|
||||
params on a `meta` device. It replaces the model params with the data from the `state_dict`
|
||||
@@ -466,7 +466,7 @@ def _find_mismatched_keys(
|
||||
|
||||
def _load_state_dict_into_model(
|
||||
model_to_load, state_dict: OrderedDict, assign_to_params_buffers: bool = False
|
||||
) -> List[str]:
|
||||
) -> list[str]:
|
||||
# Convert old format to new format if needed from a PyTorch state_dict
|
||||
# copy state_dict so _load_from_state_dict can modify it
|
||||
state_dict = state_dict.copy()
|
||||
@@ -505,7 +505,7 @@ def _fetch_index_file(
|
||||
revision,
|
||||
user_agent,
|
||||
commit_hash,
|
||||
dduf_entries: Optional[Dict[str, DDUFEntry]] = None,
|
||||
dduf_entries: Optional[dict[str, DDUFEntry]] = None,
|
||||
):
|
||||
if is_local:
|
||||
index_file = Path(
|
||||
@@ -555,7 +555,7 @@ def _fetch_index_file_legacy(
|
||||
revision,
|
||||
user_agent,
|
||||
commit_hash,
|
||||
dduf_entries: Optional[Dict[str, DDUFEntry]] = None,
|
||||
dduf_entries: Optional[dict[str, DDUFEntry]] = None,
|
||||
):
|
||||
if is_local:
|
||||
index_file = Path(
|
||||
@@ -714,7 +714,7 @@ def _expand_device_map(device_map, param_names):
|
||||
|
||||
# Adapted from: https://github.com/huggingface/transformers/blob/0687d481e2c71544501ef9cb3eef795a6e79b1de/src/transformers/modeling_utils.py#L5859
|
||||
def _caching_allocator_warmup(
|
||||
model, expanded_device_map: Dict[str, torch.device], dtype: torch.dtype, hf_quantizer: Optional[DiffusersQuantizer]
|
||||
model, expanded_device_map: dict[str, torch.device], dtype: torch.dtype, hf_quantizer: Optional[DiffusersQuantizer]
|
||||
) -> None:
|
||||
"""
|
||||
This function warm-ups the caching allocator based on the size of the model tensors that will reside on each
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
|
||||
import os
|
||||
from pickle import UnpicklingError
|
||||
from typing import Any, Dict, Union
|
||||
from typing import Any, Dict
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
@@ -68,7 +68,7 @@ class FlaxModelMixin(PushToHubMixin):
|
||||
"""
|
||||
return cls(config, **kwargs)
|
||||
|
||||
def _cast_floating_to(self, params: Union[Dict, FrozenDict], dtype: jnp.dtype, mask: Any = None) -> Any:
|
||||
def _cast_floating_to(self, params: Dict | FrozenDict, dtype: jnp.dtype, mask: Any = None) -> Any:
|
||||
"""
|
||||
Helper method to cast floating-point values of given parameter `PyTree` to given `dtype`.
|
||||
"""
|
||||
@@ -92,7 +92,7 @@ class FlaxModelMixin(PushToHubMixin):
|
||||
|
||||
return unflatten_dict(flat_params)
|
||||
|
||||
def to_bf16(self, params: Union[Dict, FrozenDict], mask: Any = None):
|
||||
def to_bf16(self, params: Dict | FrozenDict, mask: Any = None):
|
||||
r"""
|
||||
Cast the floating-point `params` to `jax.numpy.bfloat16`. This returns a new `params` tree and does not cast
|
||||
the `params` in place.
|
||||
@@ -131,7 +131,7 @@ class FlaxModelMixin(PushToHubMixin):
|
||||
```"""
|
||||
return self._cast_floating_to(params, jnp.bfloat16, mask)
|
||||
|
||||
def to_fp32(self, params: Union[Dict, FrozenDict], mask: Any = None):
|
||||
def to_fp32(self, params: Dict | FrozenDict, mask: Any = None):
|
||||
r"""
|
||||
Cast the floating-point `params` to `jax.numpy.float32`. This method can be used to explicitly convert the
|
||||
model parameters to fp32 precision. This returns a new `params` tree and does not cast the `params` in place.
|
||||
@@ -158,7 +158,7 @@ class FlaxModelMixin(PushToHubMixin):
|
||||
```"""
|
||||
return self._cast_floating_to(params, jnp.float32, mask)
|
||||
|
||||
def to_fp16(self, params: Union[Dict, FrozenDict], mask: Any = None):
|
||||
def to_fp16(self, params: Dict | FrozenDict, mask: Any = None):
|
||||
r"""
|
||||
Cast the floating-point `params` to `jax.numpy.float16`. This returns a new `params` tree and does not cast the
|
||||
`params` in place.
|
||||
@@ -204,7 +204,7 @@ class FlaxModelMixin(PushToHubMixin):
|
||||
@validate_hf_hub_args
|
||||
def from_pretrained(
|
||||
cls,
|
||||
pretrained_model_name_or_path: Union[str, os.PathLike],
|
||||
pretrained_model_name_or_path: str | os.PathLike,
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
*model_args,
|
||||
**kwargs,
|
||||
@@ -240,7 +240,7 @@ class FlaxModelMixin(PushToHubMixin):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
proxies (`dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only(`bool`, *optional*, defaults to `False`):
|
||||
@@ -493,8 +493,8 @@ class FlaxModelMixin(PushToHubMixin):
|
||||
|
||||
def save_pretrained(
|
||||
self,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
params: Union[Dict, FrozenDict],
|
||||
save_directory: str | os.PathLike,
|
||||
params: Dict | FrozenDict,
|
||||
is_main_process: bool = True,
|
||||
push_to_hub: bool = False,
|
||||
**kwargs,
|
||||
@@ -516,7 +516,7 @@ class FlaxModelMixin(PushToHubMixin):
|
||||
Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
|
||||
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
|
||||
namespace).
|
||||
kwargs (`Dict[str, Any]`, *optional*):
|
||||
kwargs (`dict[str, Any]`, *optional*):
|
||||
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
||||
"""
|
||||
if os.path.isfile(save_directory):
|
||||
|
||||
@@ -27,7 +27,7 @@ from collections import OrderedDict
|
||||
from contextlib import ExitStack, contextmanager
|
||||
from functools import wraps
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, ContextManager, Dict, List, Optional, Tuple, Type, Union
|
||||
from typing import Any, Callable, ContextManager, Optional, Type
|
||||
|
||||
import safetensors
|
||||
import torch
|
||||
@@ -84,7 +84,7 @@ class ContextManagers:
|
||||
in the `fastcore` library.
|
||||
"""
|
||||
|
||||
def __init__(self, context_managers: List[ContextManager]):
|
||||
def __init__(self, context_managers: list[ContextManager]):
|
||||
self.context_managers = context_managers
|
||||
self.stack = ExitStack()
|
||||
|
||||
@@ -146,7 +146,7 @@ def get_parameter_device(parameter: torch.nn.Module) -> torch.device:
|
||||
except StopIteration:
|
||||
# For torch.nn.DataParallel compatibility in PyTorch 1.5
|
||||
|
||||
def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
|
||||
def find_tensor_attributes(module: torch.nn.Module) -> list[tuple[str, Tensor]]:
|
||||
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
|
||||
return tuples
|
||||
|
||||
@@ -194,7 +194,7 @@ def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype:
|
||||
return last_dtype
|
||||
|
||||
# For nn.DataParallel compatibility in PyTorch > 1.5
|
||||
def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
|
||||
def find_tensor_attributes(module: nn.Module) -> list[tuple[str, Tensor]]:
|
||||
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
|
||||
return tuples
|
||||
|
||||
@@ -439,8 +439,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
self,
|
||||
storage_dtype: torch.dtype = torch.float8_e4m3fn,
|
||||
compute_dtype: Optional[torch.dtype] = None,
|
||||
skip_modules_pattern: Optional[Tuple[str, ...]] = None,
|
||||
skip_modules_classes: Optional[Tuple[Type[torch.nn.Module], ...]] = None,
|
||||
skip_modules_pattern: Optional[tuple[str, ...]] = None,
|
||||
skip_modules_classes: Optional[tuple[Type[torch.nn.Module], ...]] = None,
|
||||
non_blocking: bool = False,
|
||||
) -> None:
|
||||
r"""
|
||||
@@ -476,11 +476,11 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
The dtype to which the model should be cast for storage.
|
||||
compute_dtype (`torch.dtype`):
|
||||
The dtype to which the model weights should be cast during the forward pass.
|
||||
skip_modules_pattern (`Tuple[str, ...]`, *optional*):
|
||||
skip_modules_pattern (`tuple[str, ...]`, *optional*):
|
||||
A list of patterns to match the names of the modules to skip during the layerwise casting process. If
|
||||
set to `None`, default skip patterns are used to ignore certain internal layers of modules and PEFT
|
||||
layers.
|
||||
skip_modules_classes (`Tuple[Type[torch.nn.Module], ...]`, *optional*):
|
||||
skip_modules_classes (`tuple[Type[torch.nn.Module], ...]`, *optional*):
|
||||
A list of module classes to skip during the layerwise casting process.
|
||||
non_blocking (`bool`, *optional*, defaults to `False`):
|
||||
If `True`, the weight casting operations are non-blocking.
|
||||
@@ -639,12 +639,12 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
|
||||
def save_pretrained(
|
||||
self,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
save_directory: str | os.PathLike,
|
||||
is_main_process: bool = True,
|
||||
save_function: Optional[Callable] = None,
|
||||
safe_serialization: bool = True,
|
||||
variant: Optional[str] = None,
|
||||
max_shard_size: Union[int, str] = "10GB",
|
||||
max_shard_size: int | str = "10GB",
|
||||
push_to_hub: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
@@ -678,7 +678,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
|
||||
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
|
||||
namespace).
|
||||
kwargs (`Dict[str, Any]`, *optional*):
|
||||
kwargs (`dict[str, Any]`, *optional*):
|
||||
Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
||||
"""
|
||||
if os.path.isfile(save_directory):
|
||||
@@ -806,7 +806,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs) -> Self:
|
||||
def from_pretrained(cls, pretrained_model_name_or_path: Optional[str | os.PathLike], **kwargs) -> Self:
|
||||
r"""
|
||||
Instantiate a pretrained PyTorch model from a pretrained model configuration.
|
||||
|
||||
@@ -830,7 +830,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
proxies (`dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
output_loading_info (`bool`, *optional*, defaults to `False`):
|
||||
@@ -852,7 +852,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
|
||||
guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
|
||||
information.
|
||||
device_map (`Union[int, str, torch.device]` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
|
||||
device_map (`Union[int, str, torch.device]` or `dict[str, Union[int, str, torch.device]]`, *optional*):
|
||||
A map that specifies where each submodule should go. It doesn't need to be defined for each
|
||||
parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
|
||||
same device. Defaults to `None`, meaning that the model will be loaded on CPU.
|
||||
@@ -954,9 +954,9 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
variant = kwargs.pop("variant", None)
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
quantization_config = kwargs.pop("quantization_config", None)
|
||||
dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None)
|
||||
dduf_entries: Optional[dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None)
|
||||
disable_mmap = kwargs.pop("disable_mmap", False)
|
||||
parallel_config: Optional[Union[ParallelConfig, ContextParallelConfig]] = kwargs.pop("parallel_config", None)
|
||||
parallel_config: Optional[ParallelConfig | ContextParallelConfig] = kwargs.pop("parallel_config", None)
|
||||
|
||||
is_parallel_loading_enabled = HF_ENABLE_PARALLEL_LOADING
|
||||
if is_parallel_loading_enabled and not low_cpu_mem_usage:
|
||||
@@ -1481,8 +1481,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
def enable_parallelism(
|
||||
self,
|
||||
*,
|
||||
config: Union[ParallelConfig, ContextParallelConfig],
|
||||
cp_plan: Optional[Dict[str, ContextParallelModelPlan]] = None,
|
||||
config: ParallelConfig | ContextParallelConfig,
|
||||
cp_plan: Optional[dict[str, ContextParallelModelPlan]] = None,
|
||||
):
|
||||
from ..hooks.context_parallel import apply_context_parallel
|
||||
from .attention import AttentionModuleMixin
|
||||
@@ -1550,19 +1550,19 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
cls,
|
||||
model,
|
||||
state_dict: OrderedDict,
|
||||
resolved_model_file: List[str],
|
||||
pretrained_model_name_or_path: Union[str, os.PathLike],
|
||||
loaded_keys: List[str],
|
||||
resolved_model_file: list[str],
|
||||
pretrained_model_name_or_path: str | os.PathLike,
|
||||
loaded_keys: list[str],
|
||||
ignore_mismatched_sizes: bool = False,
|
||||
assign_to_params_buffers: bool = False,
|
||||
hf_quantizer: Optional[DiffusersQuantizer] = None,
|
||||
low_cpu_mem_usage: bool = True,
|
||||
dtype: Optional[Union[str, torch.dtype]] = None,
|
||||
keep_in_fp32_modules: Optional[List[str]] = None,
|
||||
device_map: Union[str, int, torch.device, Dict[str, Union[int, str, torch.device]]] = None,
|
||||
dtype: Optional[str | torch.dtype] = None,
|
||||
keep_in_fp32_modules: Optional[list[str]] = None,
|
||||
device_map: str | int | torch.device | dict[str, int | str | torch.device] = None,
|
||||
offload_state_dict: Optional[bool] = None,
|
||||
offload_folder: Optional[Union[str, os.PathLike]] = None,
|
||||
dduf_entries: Optional[Dict[str, DDUFEntry]] = None,
|
||||
offload_folder: Optional[str | os.PathLike] = None,
|
||||
dduf_entries: Optional[dict[str, DDUFEntry]] = None,
|
||||
is_parallel_loading_enabled: Optional[bool] = False,
|
||||
):
|
||||
model_state_dict = model.state_dict()
|
||||
@@ -1722,7 +1722,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
The device map value. Options are ["auto", "balanced", "balanced_low_0", "sequential"]
|
||||
|
||||
Returns:
|
||||
`List[str]`: List of modules that should not be split
|
||||
`list[str]`: list of modules that should not be split
|
||||
"""
|
||||
_no_split_modules = set()
|
||||
modules_to_check = [self]
|
||||
@@ -1943,7 +1943,7 @@ class LegacyModelMixin(ModelMixin):
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
|
||||
def from_pretrained(cls, pretrained_model_name_or_path: Optional[str | os.PathLike], **kwargs):
|
||||
# To prevent dependency import problem.
|
||||
from .model_loading_utils import _fetch_remapped_cls_from_config
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import numbers
|
||||
from typing import Dict, Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -117,7 +117,7 @@ class SD35AdaLayerNormZeroX(nn.Module):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
emb: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, ...]:
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
emb = self.linear(self.silu(emb))
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, shift_msa2, scale_msa2, gate_msa2 = emb.chunk(
|
||||
9, dim=1
|
||||
@@ -162,7 +162,7 @@ class AdaLayerNormZero(nn.Module):
|
||||
class_labels: Optional[torch.LongTensor] = None,
|
||||
hidden_dtype: Optional[torch.dtype] = None,
|
||||
emb: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
if self.emb is not None:
|
||||
emb = self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)
|
||||
emb = self.linear(self.silu(emb))
|
||||
@@ -196,7 +196,7 @@ class AdaLayerNormZeroSingle(nn.Module):
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
emb: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
emb = self.linear(self.silu(emb))
|
||||
shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=1)
|
||||
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
||||
@@ -225,7 +225,7 @@ class LuminaRMSNormZero(nn.Module):
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
emb: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
emb = self.linear(self.silu(emb))
|
||||
scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1)
|
||||
x = self.norm(x) * (1 + scale_msa[:, None])
|
||||
@@ -257,10 +257,10 @@ class AdaLayerNormSingle(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
timestep: torch.Tensor,
|
||||
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
||||
added_cond_kwargs: Optional[dict[str, torch.Tensor]] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
hidden_dtype: Optional[torch.dtype] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
# No modulation happening here.
|
||||
added_cond_kwargs = added_cond_kwargs or {"resolution": None, "aspect_ratio": None}
|
||||
embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype)
|
||||
@@ -423,7 +423,7 @@ class CogView3PlusAdaLayerNormZeroTextImage(nn.Module):
|
||||
x: torch.Tensor,
|
||||
context: torch.Tensor,
|
||||
emb: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
emb = self.linear(self.silu(emb))
|
||||
(
|
||||
shift_msa,
|
||||
@@ -463,7 +463,7 @@ class CogVideoXLayerNormZero(nn.Module):
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
shift, scale, gate, enc_shift, enc_scale, enc_gate = self.linear(self.silu(temb)).chunk(6, dim=1)
|
||||
hidden_states = self.norm(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :]
|
||||
encoder_hidden_states = self.norm(encoder_hidden_states) * (1 + enc_scale)[:, None, :] + enc_shift[:, None, :]
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from functools import partial
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -401,7 +401,7 @@ class Conv1dBlock(nn.Module):
|
||||
self,
|
||||
inp_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: Union[int, Tuple[int, int]],
|
||||
kernel_size: int | tuple[int, int],
|
||||
n_groups: int = 8,
|
||||
activation: str = "mish",
|
||||
):
|
||||
@@ -438,7 +438,7 @@ class ResidualTemporalBlock1D(nn.Module):
|
||||
inp_channels: int,
|
||||
out_channels: int,
|
||||
embed_dim: int,
|
||||
kernel_size: Union[int, Tuple[int, int]] = 5,
|
||||
kernel_size: int | tuple[int, int] = 5,
|
||||
activation: str = "mish",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -172,7 +172,7 @@ class AuraFlowSingleTransformerBlock(nn.Module):
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
temb: torch.FloatTensor,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
attention_kwargs: Optional[dict[str, Any]] = None,
|
||||
) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
attention_kwargs = attention_kwargs or {}
|
||||
@@ -241,8 +241,8 @@ class AuraFlowJointTransformerBlock(nn.Module):
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor,
|
||||
temb: torch.FloatTensor,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
attention_kwargs: Optional[dict[str, Any]] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
residual = hidden_states
|
||||
residual_context = encoder_hidden_states
|
||||
attention_kwargs = attention_kwargs or {}
|
||||
@@ -367,7 +367,7 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, From
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
def attn_processors(self) -> dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
@@ -376,7 +376,7 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, From
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
@@ -391,7 +391,7 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, From
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
def set_attn_processor(self, processor: AttentionProcessor | dict[str, AttentionProcessor]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
@@ -462,9 +462,9 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, From
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor = None,
|
||||
timestep: torch.LongTensor = None,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
attention_kwargs: Optional[dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
|
||||
) -> tuple[torch.Tensor] | Transformer2DModelOutput:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -120,9 +120,9 @@ class CogVideoXBlock(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
image_rotary_emb: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
attention_kwargs: Optional[dict[str, Any]] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
text_seq_length = encoder_hidden_states.size(1)
|
||||
attention_kwargs = attention_kwargs or {}
|
||||
|
||||
@@ -333,7 +333,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cac
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
def attn_processors(self) -> dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
@@ -342,7 +342,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cac
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
@@ -357,7 +357,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cac
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
def set_attn_processor(self, processor: AttentionProcessor | dict[str, AttentionProcessor]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
@@ -427,13 +427,13 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cac
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
timestep: Union[int, float, torch.LongTensor],
|
||||
timestep: int | float | torch.LongTensor,
|
||||
timestep_cond: Optional[torch.Tensor] = None,
|
||||
ofs: Optional[Union[int, float, torch.LongTensor]] = None,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
ofs: Optional[int | float | torch.LongTensor] = None,
|
||||
image_rotary_emb: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
attention_kwargs: Optional[dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
|
||||
) -> tuple[torch.Tensor] | Transformer2DModelOutput:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -152,7 +152,7 @@ class LocalFacialExtractor(nn.Module):
|
||||
nn.Linear(vit_dim, vit_dim * num_id_token),
|
||||
)
|
||||
|
||||
def forward(self, id_embeds: torch.Tensor, vit_hidden_states: List[torch.Tensor]) -> torch.Tensor:
|
||||
def forward(self, id_embeds: torch.Tensor, vit_hidden_states: list[torch.Tensor]) -> torch.Tensor:
|
||||
# Repeat latent queries for the batch size
|
||||
latents = self.latents.repeat(id_embeds.size(0), 1, 1)
|
||||
|
||||
@@ -314,8 +314,8 @@ class ConsisIDBlock(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
image_rotary_emb: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
text_seq_length = encoder_hidden_states.size(1)
|
||||
|
||||
# norm & modulate
|
||||
@@ -622,7 +622,7 @@ class ConsisIDTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
def attn_processors(self) -> dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
@@ -631,7 +631,7 @@ class ConsisIDTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
@@ -646,7 +646,7 @@ class ConsisIDTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
def set_attn_processor(self, processor: AttentionProcessor | dict[str, AttentionProcessor]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
@@ -684,14 +684,14 @@ class ConsisIDTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
timestep: Union[int, float, torch.LongTensor],
|
||||
timestep: int | float | torch.LongTensor,
|
||||
timestep_cond: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
image_rotary_emb: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
attention_kwargs: Optional[dict[str, Any]] = None,
|
||||
id_cond: Optional[torch.Tensor] = None,
|
||||
id_vit_hidden: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
|
||||
) -> tuple[torch.Tensor] | Transformer2DModelOutput:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -150,7 +150,7 @@ class DiTTransformer2DModel(ModelMixin, ConfigMixin):
|
||||
hidden_states: torch.Tensor,
|
||||
timestep: Optional[torch.LongTensor] = None,
|
||||
class_labels: Optional[torch.LongTensor] = None,
|
||||
cross_attention_kwargs: Dict[str, Any] = None,
|
||||
cross_attention_kwargs: dict[str, Any] = None,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
"""
|
||||
@@ -164,7 +164,7 @@ class DiTTransformer2DModel(ModelMixin, ConfigMixin):
|
||||
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
|
||||
Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
|
||||
`AdaLayerZeroNorm`.
|
||||
cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
|
||||
cross_attention_kwargs ( `dict[str, Any]`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Dict, Optional, Union
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -352,7 +352,7 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
def attn_processors(self) -> dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
@@ -361,7 +361,7 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
@@ -376,7 +376,7 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
def set_attn_processor(self, processor: AttentionProcessor | dict[str, AttentionProcessor]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -123,7 +123,7 @@ class LuminaNextDiTBlock(nn.Module):
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
encoder_mask: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
cross_attention_kwargs: Optional[dict[str, Any]] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Perform a forward pass through the LuminaNextDiTBlock.
|
||||
@@ -135,7 +135,7 @@ class LuminaNextDiTBlock(nn.Module):
|
||||
encoder_hidden_states: (`torch.Tensor`): The hidden_states of text prompt are processed by Gemma encoder.
|
||||
encoder_mask (`torch.Tensor`): The hidden_states of text prompt attention mask.
|
||||
temb (`torch.Tensor`): Timestep embedding with text prompt embedding.
|
||||
cross_attention_kwargs (`Dict[str, Any]`): kwargs for cross attention.
|
||||
cross_attention_kwargs (`dict[str, Any]`): kwargs for cross attention.
|
||||
"""
|
||||
residual = hidden_states
|
||||
|
||||
@@ -295,9 +295,9 @@ class LuminaNextDiT2DModel(ModelMixin, ConfigMixin):
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
encoder_mask: torch.Tensor,
|
||||
image_rotary_emb: torch.Tensor,
|
||||
cross_attention_kwargs: Dict[str, Any] = None,
|
||||
cross_attention_kwargs: dict[str, Any] = None,
|
||||
return_dict=True,
|
||||
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
|
||||
) -> tuple[torch.Tensor] | Transformer2DModelOutput:
|
||||
"""
|
||||
Forward pass of LuminaNextDiT.
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -186,7 +186,7 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
def attn_processors(self) -> dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
@@ -195,7 +195,7 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
@@ -210,7 +210,7 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
def set_attn_processor(self, processor: AttentionProcessor | dict[str, AttentionProcessor]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
@@ -289,8 +289,8 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
timestep: Optional[torch.LongTensor] = None,
|
||||
added_cond_kwargs: Dict[str, torch.Tensor] = None,
|
||||
cross_attention_kwargs: Dict[str, Any] = None,
|
||||
added_cond_kwargs: dict[str, torch.Tensor] = None,
|
||||
cross_attention_kwargs: dict[str, Any] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
@@ -306,8 +306,8 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
|
||||
self-attention.
|
||||
timestep (`torch.LongTensor`, *optional*):
|
||||
Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
|
||||
added_cond_kwargs: (`Dict[str, Any]`, *optional*): Additional conditions to be used as inputs.
|
||||
cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
|
||||
added_cond_kwargs: (`dict[str, Any]`, *optional*): Additional conditions to be used as inputs.
|
||||
cross_attention_kwargs ( `dict[str, Any]`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional, Union
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -168,7 +168,7 @@ class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Pef
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
def attn_processors(self) -> dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
@@ -177,7 +177,7 @@ class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Pef
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
@@ -192,7 +192,7 @@ class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Pef
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
def set_attn_processor(self, processor: AttentionProcessor | dict[str, AttentionProcessor]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
@@ -245,7 +245,7 @@ class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Pef
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
timestep: torch.Tensor | float | int,
|
||||
proj_embedding: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.BoolTensor] = None,
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -416,7 +416,7 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
def attn_processors(self) -> dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
@@ -425,7 +425,7 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
@@ -440,7 +440,7 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
def set_attn_processor(self, processor: AttentionProcessor | dict[str, AttentionProcessor]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
@@ -482,10 +482,10 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
|
||||
guidance: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
controlnet_block_samples: Optional[Tuple[torch.Tensor]] = None,
|
||||
attention_kwargs: Optional[dict[str, Any]] = None,
|
||||
controlnet_block_samples: Optional[tuple[torch.Tensor]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[Tuple[torch.Tensor, ...], Transformer2DModelOutput]:
|
||||
) -> tuple[torch.Tensor, ...] | Transformer2DModelOutput:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Dict, Optional, Union
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -276,7 +276,7 @@ class StableAudioDiTModel(ModelMixin, ConfigMixin):
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
def attn_processors(self) -> dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
@@ -285,7 +285,7 @@ class StableAudioDiTModel(ModelMixin, ConfigMixin):
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
@@ -300,7 +300,7 @@ class StableAudioDiTModel(ModelMixin, ConfigMixin):
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
def set_attn_processor(self, processor: AttentionProcessor | dict[str, AttentionProcessor]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
@@ -351,7 +351,7 @@ class StableAudioDiTModel(ModelMixin, ConfigMixin):
|
||||
return_dict: bool = True,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
encoder_attention_mask: Optional[torch.LongTensor] = None,
|
||||
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
|
||||
) -> torch.FloatTensor | Transformer2DModelOutput:
|
||||
"""
|
||||
The [`StableAudioDiTModel`] forward method.
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -201,7 +201,7 @@ class DecoderLayer(nn.Module):
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
encoder_decoder_position_bias=None,
|
||||
) -> Tuple[torch.Tensor]:
|
||||
) -> tuple[torch.Tensor]:
|
||||
hidden_states = self.layer[0](
|
||||
hidden_states,
|
||||
conditioning_emb=conditioning_emb,
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user