Fix: update type hints for Tuple parameters across multiple files to support variable-length tuples (#12544)

* Fix: update type hints for Tuple parameters across multiple files to support variable-length tuples

* Apply style fixes

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
Cesaryuan
2025-11-11 05:57:52 +08:00
committed by GitHub
parent 8f6328c4a4
commit 5a47442f92
28 changed files with 110 additions and 85 deletions
+1 -1
View File
@@ -45,7 +45,7 @@ def check_size(image, height, width):
raise ValueError(f"Image size should be {height}x{width}, but got {h}x{w}") raise ValueError(f"Image size should be {height}x{width}, but got {h}x{w}")
def overlay_inner_image(image, inner_image, paste_offset: Tuple[int] = (0, 0)): def overlay_inner_image(image, inner_image, paste_offset: Tuple[int, ...] = (0, 0)):
inner_image = inner_image.convert("RGBA") inner_image = inner_image.convert("RGBA")
image = image.convert("RGB") image = image.convert("RGB")
+11 -6
View File
@@ -1966,16 +1966,21 @@ class MatryoshkaUNet2DConditionModel(
center_input_sample: bool = False, center_input_sample: bool = False,
flip_sin_to_cos: bool = True, flip_sin_to_cos: bool = True,
freq_shift: int = 0, freq_shift: int = 0,
down_block_types: Tuple[str] = ( down_block_types: Tuple[str, ...] = (
"CrossAttnDownBlock2D", "CrossAttnDownBlock2D",
"CrossAttnDownBlock2D", "CrossAttnDownBlock2D",
"CrossAttnDownBlock2D", "CrossAttnDownBlock2D",
"DownBlock2D", "DownBlock2D",
), ),
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn", mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), up_block_types: Tuple[str, ...] = (
"UpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
),
only_cross_attention: Union[bool, Tuple[bool]] = False, only_cross_attention: Union[bool, Tuple[bool]] = False,
block_out_channels: Tuple[int] = (320, 640, 1280, 1280), block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
layers_per_block: Union[int, Tuple[int]] = 2, layers_per_block: Union[int, Tuple[int]] = 2,
downsample_padding: int = 1, downsample_padding: int = 1,
mid_block_scale_factor: float = 1, mid_block_scale_factor: float = 1,
@@ -2294,10 +2299,10 @@ class MatryoshkaUNet2DConditionModel(
def _check_config( def _check_config(
self, self,
down_block_types: Tuple[str], down_block_types: Tuple[str, ...],
up_block_types: Tuple[str], up_block_types: Tuple[str, ...],
only_cross_attention: Union[bool, Tuple[bool]], only_cross_attention: Union[bool, Tuple[bool]],
block_out_channels: Tuple[int], block_out_channels: Tuple[int, ...],
layers_per_block: Union[int, Tuple[int]], layers_per_block: Union[int, Tuple[int]],
cross_attention_dim: Union[int, Tuple[int]], cross_attention_dim: Union[int, Tuple[int]],
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple[int]]], transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple[int]]],
@@ -438,16 +438,21 @@ class UNet2DConditionModel(OriginalUNet2DConditionModel, ConfigMixin, UNet2DCond
center_input_sample: bool = False, center_input_sample: bool = False,
flip_sin_to_cos: bool = True, flip_sin_to_cos: bool = True,
freq_shift: int = 0, freq_shift: int = 0,
down_block_types: Tuple[str] = ( down_block_types: Tuple[str, ...] = (
"CrossAttnDownBlock2D", "CrossAttnDownBlock2D",
"CrossAttnDownBlock2D", "CrossAttnDownBlock2D",
"CrossAttnDownBlock2D", "CrossAttnDownBlock2D",
"DownBlock2D", "DownBlock2D",
), ),
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn", mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), up_block_types: Tuple[str, ...] = (
"UpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
),
only_cross_attention: Union[bool, Tuple[bool]] = False, only_cross_attention: Union[bool, Tuple[bool]] = False,
block_out_channels: Tuple[int] = (320, 640, 1280, 1280), block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
layers_per_block: Union[int, Tuple[int]] = 2, layers_per_block: Union[int, Tuple[int]] = 2,
downsample_padding: int = 1, downsample_padding: int = 1,
mid_block_scale_factor: float = 1, mid_block_scale_factor: float = 1,
@@ -102,7 +102,7 @@ def get_block(
attention_head_dim: int, attention_head_dim: int,
norm_type: str, norm_type: str,
act_fn: str, act_fn: str,
qkv_mutliscales: Tuple[int] = (), qkv_mutliscales: Tuple[int, ...] = (),
): ):
if block_type == "ResBlock": if block_type == "ResBlock":
block = ResBlock(in_channels, out_channels, norm_type, act_fn) block = ResBlock(in_channels, out_channels, norm_type, act_fn)
@@ -206,8 +206,8 @@ class Encoder(nn.Module):
latent_channels: int, latent_channels: int,
attention_head_dim: int = 32, attention_head_dim: int = 32,
block_type: Union[str, Tuple[str]] = "ResBlock", block_type: Union[str, Tuple[str]] = "ResBlock",
block_out_channels: Tuple[int] = (128, 256, 512, 512, 1024, 1024), block_out_channels: Tuple[int, ...] = (128, 256, 512, 512, 1024, 1024),
layers_per_block: Tuple[int] = (2, 2, 2, 2, 2, 2), layers_per_block: Tuple[int, ...] = (2, 2, 2, 2, 2, 2),
qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)), qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)),
downsample_block_type: str = "pixel_unshuffle", downsample_block_type: str = "pixel_unshuffle",
out_shortcut: bool = True, out_shortcut: bool = True,
@@ -292,8 +292,8 @@ class Decoder(nn.Module):
latent_channels: int, latent_channels: int,
attention_head_dim: int = 32, attention_head_dim: int = 32,
block_type: Union[str, Tuple[str]] = "ResBlock", block_type: Union[str, Tuple[str]] = "ResBlock",
block_out_channels: Tuple[int] = (128, 256, 512, 512, 1024, 1024), block_out_channels: Tuple[int, ...] = (128, 256, 512, 512, 1024, 1024),
layers_per_block: Tuple[int] = (2, 2, 2, 2, 2, 2), layers_per_block: Tuple[int, ...] = (2, 2, 2, 2, 2, 2),
qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)), qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)),
norm_type: Union[str, Tuple[str]] = "rms_norm", norm_type: Union[str, Tuple[str]] = "rms_norm",
act_fn: Union[str, Tuple[str]] = "silu", act_fn: Union[str, Tuple[str]] = "silu",
@@ -440,8 +440,8 @@ class AutoencoderDC(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModel
decoder_block_types: Union[str, Tuple[str]] = "ResBlock", decoder_block_types: Union[str, Tuple[str]] = "ResBlock",
encoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512, 1024, 1024), encoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512, 1024, 1024),
decoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512, 1024, 1024), decoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512, 1024, 1024),
encoder_layers_per_block: Tuple[int] = (2, 2, 2, 3, 3, 3), encoder_layers_per_block: Tuple[int, ...] = (2, 2, 2, 3, 3, 3),
decoder_layers_per_block: Tuple[int] = (3, 3, 3, 3, 3, 3), decoder_layers_per_block: Tuple[int, ...] = (3, 3, 3, 3, 3, 3),
encoder_qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)), encoder_qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)),
decoder_qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)), decoder_qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)),
upsample_block_type: str = "pixel_shuffle", upsample_block_type: str = "pixel_shuffle",
@@ -78,9 +78,9 @@ class AutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModel
self, self,
in_channels: int = 3, in_channels: int = 3,
out_channels: int = 3, out_channels: int = 3,
down_block_types: Tuple[str] = ("DownEncoderBlock2D",), down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",),
up_block_types: Tuple[str] = ("UpDecoderBlock2D",), up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",),
block_out_channels: Tuple[int] = (64,), block_out_channels: Tuple[int, ...] = (64,),
layers_per_block: int = 1, layers_per_block: int = 1,
act_fn: str = "silu", act_fn: str = "silu",
latent_channels: int = 4, latent_channels: int = 4,
@@ -995,19 +995,19 @@ class AutoencoderKLCogVideoX(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
self, self,
in_channels: int = 3, in_channels: int = 3,
out_channels: int = 3, out_channels: int = 3,
down_block_types: Tuple[str] = ( down_block_types: Tuple[str, ...] = (
"CogVideoXDownBlock3D", "CogVideoXDownBlock3D",
"CogVideoXDownBlock3D", "CogVideoXDownBlock3D",
"CogVideoXDownBlock3D", "CogVideoXDownBlock3D",
"CogVideoXDownBlock3D", "CogVideoXDownBlock3D",
), ),
up_block_types: Tuple[str] = ( up_block_types: Tuple[str, ...] = (
"CogVideoXUpBlock3D", "CogVideoXUpBlock3D",
"CogVideoXUpBlock3D", "CogVideoXUpBlock3D",
"CogVideoXUpBlock3D", "CogVideoXUpBlock3D",
"CogVideoXUpBlock3D", "CogVideoXUpBlock3D",
), ),
block_out_channels: Tuple[int] = (128, 256, 256, 512), block_out_channels: Tuple[int, ...] = (128, 256, 256, 512),
latent_channels: int = 16, latent_channels: int = 16,
layers_per_block: int = 3, layers_per_block: int = 3,
act_fn: str = "silu", act_fn: str = "silu",
@@ -653,7 +653,7 @@ class AutoencoderKLHunyuanVideo(ModelMixin, AutoencoderMixin, ConfigMixin):
"HunyuanVideoUpBlock3D", "HunyuanVideoUpBlock3D",
"HunyuanVideoUpBlock3D", "HunyuanVideoUpBlock3D",
), ),
block_out_channels: Tuple[int] = (128, 256, 512, 512), block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
layers_per_block: int = 2, layers_per_block: int = 2,
act_fn: str = "silu", act_fn: str = "silu",
norm_num_groups: int = 32, norm_num_groups: int = 32,
@@ -601,7 +601,7 @@ class AutoencoderKLHunyuanImageRefiner(ModelMixin, ConfigMixin):
in_channels: int = 3, in_channels: int = 3,
out_channels: int = 3, out_channels: int = 3,
latent_channels: int = 32, latent_channels: int = 32,
block_out_channels: Tuple[int] = (128, 256, 512, 1024, 1024), block_out_channels: Tuple[int, ...] = (128, 256, 512, 1024, 1024),
layers_per_block: int = 2, layers_per_block: int = 2,
spatial_compression_ratio: int = 16, spatial_compression_ratio: int = 16,
temporal_compression_ratio: int = 4, temporal_compression_ratio: int = 4,
@@ -688,8 +688,8 @@ class AutoencoderKLMochi(ModelMixin, AutoencoderMixin, ConfigMixin):
self, self,
in_channels: int = 15, in_channels: int = 15,
out_channels: int = 3, out_channels: int = 3,
encoder_block_out_channels: Tuple[int] = (64, 128, 256, 384), encoder_block_out_channels: Tuple[int, ...] = (64, 128, 256, 384),
decoder_block_out_channels: Tuple[int] = (128, 256, 512, 768), decoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 768),
latent_channels: int = 12, latent_channels: int = 12,
layers_per_block: Tuple[int, ...] = (3, 3, 4, 6, 3), layers_per_block: Tuple[int, ...] = (3, 3, 4, 6, 3),
act_fn: str = "silu", act_fn: str = "silu",
@@ -679,7 +679,7 @@ class AutoencoderKLQwenImage(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
self, self,
base_dim: int = 96, base_dim: int = 96,
z_dim: int = 16, z_dim: int = 16,
dim_mult: Tuple[int] = [1, 2, 4, 4], dim_mult: Tuple[int, ...] = (1, 2, 4, 4),
num_res_blocks: int = 2, num_res_blocks: int = 2,
attn_scales: List[float] = [], attn_scales: List[float] = [],
temperal_downsample: List[bool] = [False, True, True], temperal_downsample: List[bool] = [False, True, True],
@@ -31,7 +31,7 @@ class TemporalDecoder(nn.Module):
self, self,
in_channels: int = 4, in_channels: int = 4,
out_channels: int = 3, out_channels: int = 3,
block_out_channels: Tuple[int] = (128, 256, 512, 512), block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
layers_per_block: int = 2, layers_per_block: int = 2,
): ):
super().__init__() super().__init__()
@@ -172,8 +172,8 @@ class AutoencoderKLTemporalDecoder(ModelMixin, AutoencoderMixin, ConfigMixin):
self, self,
in_channels: int = 3, in_channels: int = 3,
out_channels: int = 3, out_channels: int = 3,
down_block_types: Tuple[str] = ("DownEncoderBlock2D",), down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",),
block_out_channels: Tuple[int] = (64,), block_out_channels: Tuple[int, ...] = (64,),
layers_per_block: int = 1, layers_per_block: int = 1,
latent_channels: int = 4, latent_channels: int = 4,
sample_size: int = 32, sample_size: int = 32,
@@ -971,7 +971,7 @@ class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMo
base_dim: int = 96, base_dim: int = 96,
decoder_base_dim: Optional[int] = None, decoder_base_dim: Optional[int] = None,
z_dim: int = 16, z_dim: int = 16,
dim_mult: Tuple[int] = [1, 2, 4, 4], dim_mult: Tuple[int, ...] = (1, 2, 4, 4),
num_res_blocks: int = 2, num_res_blocks: int = 2,
attn_scales: List[float] = [], attn_scales: List[float] = [],
temperal_downsample: List[bool] = [False, True, True], temperal_downsample: List[bool] = [False, True, True],
@@ -293,14 +293,14 @@ class ControlNetXSAdapter(ModelMixin, ConfigMixin):
self, self,
conditioning_channels: int = 3, conditioning_channels: int = 3,
conditioning_channel_order: str = "rgb", conditioning_channel_order: str = "rgb",
conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256), conditioning_embedding_out_channels: Tuple[int, ...] = (16, 32, 96, 256),
time_embedding_mix: float = 1.0, time_embedding_mix: float = 1.0,
learn_time_embedding: bool = False, learn_time_embedding: bool = False,
num_attention_heads: Union[int, Tuple[int]] = 4, num_attention_heads: Union[int, Tuple[int]] = 4,
block_out_channels: Tuple[int] = (4, 8, 16, 16), block_out_channels: Tuple[int, ...] = (4, 8, 16, 16),
base_block_out_channels: Tuple[int] = (320, 640, 1280, 1280), base_block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
cross_attention_dim: int = 1024, cross_attention_dim: int = 1024,
down_block_types: Tuple[str] = ( down_block_types: Tuple[str, ...] = (
"CrossAttnDownBlock2D", "CrossAttnDownBlock2D",
"CrossAttnDownBlock2D", "CrossAttnDownBlock2D",
"CrossAttnDownBlock2D", "CrossAttnDownBlock2D",
@@ -436,7 +436,7 @@ class ControlNetXSAdapter(ModelMixin, ConfigMixin):
time_embedding_mix: int = 1.0, time_embedding_mix: int = 1.0,
conditioning_channels: int = 3, conditioning_channels: int = 3,
conditioning_channel_order: str = "rgb", conditioning_channel_order: str = "rgb",
conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256), conditioning_embedding_out_channels: Tuple[int, ...] = (16, 32, 96, 256),
): ):
r""" r"""
Instantiate a [`ControlNetXSAdapter`] from a [`UNet2DConditionModel`]. Instantiate a [`ControlNetXSAdapter`] from a [`UNet2DConditionModel`].
@@ -529,14 +529,19 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
self, self,
# unet configs # unet configs
sample_size: Optional[int] = 96, sample_size: Optional[int] = 96,
down_block_types: Tuple[str] = ( down_block_types: Tuple[str, ...] = (
"CrossAttnDownBlock2D", "CrossAttnDownBlock2D",
"CrossAttnDownBlock2D", "CrossAttnDownBlock2D",
"CrossAttnDownBlock2D", "CrossAttnDownBlock2D",
"DownBlock2D", "DownBlock2D",
), ),
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), up_block_types: Tuple[str, ...] = (
block_out_channels: Tuple[int] = (320, 640, 1280, 1280), "UpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
),
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
norm_num_groups: Optional[int] = 32, norm_num_groups: Optional[int] = 32,
cross_attention_dim: Union[int, Tuple[int]] = 1024, cross_attention_dim: Union[int, Tuple[int]] = 1024,
transformer_layers_per_block: Union[int, Tuple[int]] = 1, transformer_layers_per_block: Union[int, Tuple[int]] = 1,
@@ -550,10 +555,10 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
# additional controlnet configs # additional controlnet configs
time_embedding_mix: float = 1.0, time_embedding_mix: float = 1.0,
ctrl_conditioning_channels: int = 3, ctrl_conditioning_channels: int = 3,
ctrl_conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256), ctrl_conditioning_embedding_out_channels: Tuple[int, ...] = (16, 32, 96, 256),
ctrl_conditioning_channel_order: str = "rgb", ctrl_conditioning_channel_order: str = "rgb",
ctrl_learn_time_embedding: bool = False, ctrl_learn_time_embedding: bool = False,
ctrl_block_out_channels: Tuple[int] = (4, 8, 16, 16), ctrl_block_out_channels: Tuple[int, ...] = (4, 8, 16, 16),
ctrl_num_attention_heads: Union[int, Tuple[int]] = 4, ctrl_num_attention_heads: Union[int, Tuple[int]] = 4,
ctrl_max_norm_num_groups: int = 32, ctrl_max_norm_num_groups: int = 32,
): ):
@@ -914,7 +914,7 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
text_embed_dim: int = 4096, text_embed_dim: int = 4096,
pooled_projection_dim: int = 768, pooled_projection_dim: int = 768,
rope_theta: float = 256.0, rope_theta: float = 256.0,
rope_axes_dim: Tuple[int] = (16, 56, 56), rope_axes_dim: Tuple[int, ...] = (16, 56, 56),
image_condition_type: Optional[str] = None, image_condition_type: Optional[str] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
@@ -139,7 +139,7 @@ class HunyuanVideoFramepackTransformer3DModel(
text_embed_dim: int = 4096, text_embed_dim: int = 4096,
pooled_projection_dim: int = 768, pooled_projection_dim: int = 768,
rope_theta: float = 256.0, rope_theta: float = 256.0,
rope_axes_dim: Tuple[int] = (16, 56, 56), rope_axes_dim: Tuple[int, ...] = (16, 56, 56),
image_condition_type: Optional[str] = None, image_condition_type: Optional[str] = None,
has_image_proj: int = False, has_image_proj: int = False,
image_proj_dim: int = 1152, image_proj_dim: int = 1152,
@@ -689,7 +689,7 @@ class HunyuanImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
text_embed_dim: int = 3584, text_embed_dim: int = 3584,
text_embed_2_dim: Optional[int] = None, text_embed_2_dim: Optional[int] = None,
rope_theta: float = 256.0, rope_theta: float = 256.0,
rope_axes_dim: Tuple[int] = (64, 64), rope_axes_dim: Tuple[int, ...] = (64, 64),
use_meanflow: bool = False, use_meanflow: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
@@ -570,7 +570,7 @@ class SkyReelsV2Transformer3DModel(
@register_to_config @register_to_config
def __init__( def __init__(
self, self,
patch_size: Tuple[int] = (1, 2, 2), patch_size: Tuple[int, ...] = (1, 2, 2),
num_attention_heads: int = 16, num_attention_heads: int = 16,
attention_head_dim: int = 128, attention_head_dim: int = 128,
in_channels: int = 16, in_channels: int = 16,
@@ -563,7 +563,7 @@ class WanTransformer3DModel(
@register_to_config @register_to_config
def __init__( def __init__(
self, self,
patch_size: Tuple[int] = (1, 2, 2), patch_size: Tuple[int, ...] = (1, 2, 2),
num_attention_heads: int = 40, num_attention_heads: int = 40,
attention_head_dim: int = 128, attention_head_dim: int = 128,
in_channels: int = 16, in_channels: int = 16,
@@ -182,7 +182,7 @@ class WanVACETransformer3DModel(
@register_to_config @register_to_config
def __init__( def __init__(
self, self,
patch_size: Tuple[int] = (1, 2, 2), patch_size: Tuple[int, ...] = (1, 2, 2),
num_attention_heads: int = 40, num_attention_heads: int = 40,
attention_head_dim: int = 128, attention_head_dim: int = 128,
in_channels: int = 16, in_channels: int = 16,
+4 -4
View File
@@ -86,11 +86,11 @@ class UNet1DModel(ModelMixin, ConfigMixin):
flip_sin_to_cos: bool = True, flip_sin_to_cos: bool = True,
use_timestep_embedding: bool = False, use_timestep_embedding: bool = False,
freq_shift: float = 0.0, freq_shift: float = 0.0,
down_block_types: Tuple[str] = ("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D"), down_block_types: Tuple[str, ...] = ("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D"),
up_block_types: Tuple[str] = ("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip"), up_block_types: Tuple[str, ...] = ("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip"),
mid_block_type: Tuple[str] = "UNetMidBlock1D", mid_block_type: str = "UNetMidBlock1D",
out_block_type: str = None, out_block_type: str = None,
block_out_channels: Tuple[int] = (32, 32, 64), block_out_channels: Tuple[int, ...] = (32, 32, 64),
act_fn: str = None, act_fn: str = None,
norm_num_groups: int = 8, norm_num_groups: int = 8,
layers_per_block: int = 1, layers_per_block: int = 1,
@@ -177,16 +177,21 @@ class UNet2DConditionModel(
center_input_sample: bool = False, center_input_sample: bool = False,
flip_sin_to_cos: bool = True, flip_sin_to_cos: bool = True,
freq_shift: int = 0, freq_shift: int = 0,
down_block_types: Tuple[str] = ( down_block_types: Tuple[str, ...] = (
"CrossAttnDownBlock2D", "CrossAttnDownBlock2D",
"CrossAttnDownBlock2D", "CrossAttnDownBlock2D",
"CrossAttnDownBlock2D", "CrossAttnDownBlock2D",
"DownBlock2D", "DownBlock2D",
), ),
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn", mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), up_block_types: Tuple[str, ...] = (
"UpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
),
only_cross_attention: Union[bool, Tuple[bool]] = False, only_cross_attention: Union[bool, Tuple[bool]] = False,
block_out_channels: Tuple[int] = (320, 640, 1280, 1280), block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
layers_per_block: Union[int, Tuple[int]] = 2, layers_per_block: Union[int, Tuple[int]] = 2,
downsample_padding: int = 1, downsample_padding: int = 1,
mid_block_scale_factor: float = 1, mid_block_scale_factor: float = 1,
@@ -486,10 +491,10 @@ class UNet2DConditionModel(
def _check_config( def _check_config(
self, self,
down_block_types: Tuple[str], down_block_types: Tuple[str, ...],
up_block_types: Tuple[str], up_block_types: Tuple[str, ...],
only_cross_attention: Union[bool, Tuple[bool]], only_cross_attention: Union[bool, Tuple[bool]],
block_out_channels: Tuple[int], block_out_channels: Tuple[int, ...],
layers_per_block: Union[int, Tuple[int]], layers_per_block: Union[int, Tuple[int]],
cross_attention_dim: Union[int, Tuple[int]], cross_attention_dim: Union[int, Tuple[int]],
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple[int]]], transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple[int]]],
@@ -54,7 +54,7 @@ class Kandinsky3UNet(ModelMixin, ConfigMixin):
groups: int = 32, groups: int = 32,
attention_head_dim: int = 64, attention_head_dim: int = 64,
layers_per_block: Union[int, Tuple[int]] = 3, layers_per_block: Union[int, Tuple[int]] = 3,
block_out_channels: Tuple[int] = (384, 768, 1536, 3072), block_out_channels: Tuple[int, ...] = (384, 768, 1536, 3072),
cross_attention_dim: Union[int, Tuple[int]] = 4096, cross_attention_dim: Union[int, Tuple[int]] = 4096,
encoder_hid_dim: int = 4096, encoder_hid_dim: int = 4096,
): ):
@@ -73,25 +73,25 @@ class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionL
sample_size: Optional[int] = None, sample_size: Optional[int] = None,
in_channels: int = 8, in_channels: int = 8,
out_channels: int = 4, out_channels: int = 4,
down_block_types: Tuple[str] = ( down_block_types: Tuple[str, ...] = (
"CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal",
"CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal",
"CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal",
"DownBlockSpatioTemporal", "DownBlockSpatioTemporal",
), ),
up_block_types: Tuple[str] = ( up_block_types: Tuple[str, ...] = (
"UpBlockSpatioTemporal", "UpBlockSpatioTemporal",
"CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal",
"CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal",
"CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal",
), ),
block_out_channels: Tuple[int] = (320, 640, 1280, 1280), block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
addition_time_embed_dim: int = 256, addition_time_embed_dim: int = 256,
projection_class_embeddings_input_dim: int = 768, projection_class_embeddings_input_dim: int = 768,
layers_per_block: Union[int, Tuple[int]] = 2, layers_per_block: Union[int, Tuple[int]] = 2,
cross_attention_dim: Union[int, Tuple[int]] = 1024, cross_attention_dim: Union[int, Tuple[int]] = 1024,
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
num_attention_heads: Union[int, Tuple[int]] = (5, 10, 20, 20), num_attention_heads: Union[int, Tuple[int, ...]] = (5, 10, 20, 20),
num_frames: int = 25, num_frames: int = 25,
): ):
super().__init__() super().__init__()
@@ -145,10 +145,10 @@ class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalModelMixin):
timestep_ratio_embedding_dim: int = 64, timestep_ratio_embedding_dim: int = 64,
patch_size: int = 1, patch_size: int = 1,
conditioning_dim: int = 2048, conditioning_dim: int = 2048,
block_out_channels: Tuple[int] = (2048, 2048), block_out_channels: Tuple[int, ...] = (2048, 2048),
num_attention_heads: Tuple[int] = (32, 32), num_attention_heads: Tuple[int, ...] = (32, 32),
down_num_layers_per_block: Tuple[int] = (8, 24), down_num_layers_per_block: Tuple[int, ...] = (8, 24),
up_num_layers_per_block: Tuple[int] = (24, 8), up_num_layers_per_block: Tuple[int, ...] = (24, 8),
down_blocks_repeat_mappers: Optional[Tuple[int]] = ( down_blocks_repeat_mappers: Optional[Tuple[int]] = (
1, 1,
1, 1,
@@ -167,7 +167,7 @@ class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalModelMixin):
kernel_size=3, kernel_size=3,
dropout: Union[float, Tuple[float]] = (0.1, 0.1), dropout: Union[float, Tuple[float]] = (0.1, 0.1),
self_attn: Union[bool, Tuple[bool]] = True, self_attn: Union[bool, Tuple[bool]] = True,
timestep_conditioning_type: Tuple[str] = ("sca", "crp"), timestep_conditioning_type: Tuple[str, ...] = ("sca", "crp"),
switch_level: Optional[Tuple[bool]] = None, switch_level: Optional[Tuple[bool]] = None,
): ):
""" """
+7 -7
View File
@@ -532,8 +532,8 @@ class FlaxEncoder(nn.Module):
in_channels: int = 3 in_channels: int = 3
out_channels: int = 3 out_channels: int = 3
down_block_types: Tuple[str] = ("DownEncoderBlock2D",) down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",)
block_out_channels: Tuple[int] = (64,) block_out_channels: Tuple[int, ...] = (64,)
layers_per_block: int = 2 layers_per_block: int = 2
norm_num_groups: int = 32 norm_num_groups: int = 32
act_fn: str = "silu" act_fn: str = "silu"
@@ -650,8 +650,8 @@ class FlaxDecoder(nn.Module):
in_channels: int = 3 in_channels: int = 3
out_channels: int = 3 out_channels: int = 3
up_block_types: Tuple[str] = ("UpDecoderBlock2D",) up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",)
block_out_channels: int = (64,) block_out_channels: Tuple[int, ...] = (64,)
layers_per_block: int = 2 layers_per_block: int = 2
norm_num_groups: int = 32 norm_num_groups: int = 32
act_fn: str = "silu" act_fn: str = "silu"
@@ -823,9 +823,9 @@ class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin):
in_channels: int = 3 in_channels: int = 3
out_channels: int = 3 out_channels: int = 3
down_block_types: Tuple[str] = ("DownEncoderBlock2D",) down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",)
up_block_types: Tuple[str] = ("UpDecoderBlock2D",) up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",)
block_out_channels: Tuple[int] = (64,) block_out_channels: Tuple[int, ...] = (64,)
layers_per_block: int = 1 layers_per_block: int = 1
act_fn: str = "silu" act_fn: str = "silu"
latent_channels: int = 4 latent_channels: int = 4
@@ -245,16 +245,21 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad
out_channels: int = 4, out_channels: int = 4,
flip_sin_to_cos: bool = True, flip_sin_to_cos: bool = True,
freq_shift: int = 0, freq_shift: int = 0,
down_block_types: Tuple[str] = ( down_block_types: Tuple[str, ...] = (
"CrossAttnDownBlock2D", "CrossAttnDownBlock2D",
"CrossAttnDownBlock2D", "CrossAttnDownBlock2D",
"CrossAttnDownBlock2D", "CrossAttnDownBlock2D",
"DownBlock2D", "DownBlock2D",
), ),
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn", mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), up_block_types: Tuple[str, ...] = (
"UpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
),
only_cross_attention: Union[bool, Tuple[bool]] = False, only_cross_attention: Union[bool, Tuple[bool]] = False,
block_out_channels: Tuple[int] = (320, 640, 1280, 1280), block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
layers_per_block: Union[int, Tuple[int]] = 2, layers_per_block: Union[int, Tuple[int]] = 2,
downsample_padding: int = 1, downsample_padding: int = 1,
mid_block_scale_factor: float = 1, mid_block_scale_factor: float = 1,
@@ -374,21 +374,21 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
center_input_sample: bool = False, center_input_sample: bool = False,
flip_sin_to_cos: bool = True, flip_sin_to_cos: bool = True,
freq_shift: int = 0, freq_shift: int = 0,
down_block_types: Tuple[str] = ( down_block_types: Tuple[str, ...] = (
"CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat",
"CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat",
"CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat",
"DownBlockFlat", "DownBlockFlat",
), ),
mid_block_type: Optional[str] = "UNetMidBlockFlatCrossAttn", mid_block_type: Optional[str] = "UNetMidBlockFlatCrossAttn",
up_block_types: Tuple[str] = ( up_block_types: Tuple[str, ...] = (
"UpBlockFlat", "UpBlockFlat",
"CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat",
"CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat",
"CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat",
), ),
only_cross_attention: Union[bool, Tuple[bool]] = False, only_cross_attention: Union[bool, Tuple[bool]] = False,
block_out_channels: Tuple[int] = (320, 640, 1280, 1280), block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
layers_per_block: Union[int, Tuple[int]] = 2, layers_per_block: Union[int, Tuple[int]] = 2,
downsample_padding: int = 1, downsample_padding: int = 1,
mid_block_scale_factor: float = 1, mid_block_scale_factor: float = 1,
+4 -4
View File
@@ -742,7 +742,7 @@ class ShapEParamsProjModel(ModelMixin, ConfigMixin):
def __init__( def __init__(
self, self,
*, *,
param_names: Tuple[str] = ( param_names: Tuple[str, ...] = (
"nerstf.mlp.0.weight", "nerstf.mlp.0.weight",
"nerstf.mlp.1.weight", "nerstf.mlp.1.weight",
"nerstf.mlp.2.weight", "nerstf.mlp.2.weight",
@@ -786,13 +786,13 @@ class ShapERenderer(ModelMixin, ConfigMixin):
def __init__( def __init__(
self, self,
*, *,
param_names: Tuple[str] = ( param_names: Tuple[str, ...] = (
"nerstf.mlp.0.weight", "nerstf.mlp.0.weight",
"nerstf.mlp.1.weight", "nerstf.mlp.1.weight",
"nerstf.mlp.2.weight", "nerstf.mlp.2.weight",
"nerstf.mlp.3.weight", "nerstf.mlp.3.weight",
), ),
param_shapes: Tuple[Tuple[int]] = ( param_shapes: Tuple[Tuple[int, int], ...] = (
(256, 93), (256, 93),
(256, 256), (256, 256),
(256, 256), (256, 256),
@@ -804,7 +804,7 @@ class ShapERenderer(ModelMixin, ConfigMixin):
n_hidden_layers: int = 6, n_hidden_layers: int = 6,
act_fn: str = "swish", act_fn: str = "swish",
insert_direction_at: int = 4, insert_direction_at: int = 4,
background: Tuple[float] = ( background: Tuple[float, ...] = (
255.0, 255.0,
255.0, 255.0,
255.0, 255.0,