Fix: update type hints for Tuple parameters across multiple files to support variable-length tuples (#12544)
* Fix: update type hints for Tuple parameters across multiple files to support variable-length tuples * Apply style fixes --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
@@ -45,7 +45,7 @@ def check_size(image, height, width):
|
|||||||
raise ValueError(f"Image size should be {height}x{width}, but got {h}x{w}")
|
raise ValueError(f"Image size should be {height}x{width}, but got {h}x{w}")
|
||||||
|
|
||||||
|
|
||||||
def overlay_inner_image(image, inner_image, paste_offset: Tuple[int] = (0, 0)):
|
def overlay_inner_image(image, inner_image, paste_offset: Tuple[int, ...] = (0, 0)):
|
||||||
inner_image = inner_image.convert("RGBA")
|
inner_image = inner_image.convert("RGBA")
|
||||||
image = image.convert("RGB")
|
image = image.convert("RGB")
|
||||||
|
|
||||||
|
|||||||
@@ -1966,16 +1966,21 @@ class MatryoshkaUNet2DConditionModel(
|
|||||||
center_input_sample: bool = False,
|
center_input_sample: bool = False,
|
||||||
flip_sin_to_cos: bool = True,
|
flip_sin_to_cos: bool = True,
|
||||||
freq_shift: int = 0,
|
freq_shift: int = 0,
|
||||||
down_block_types: Tuple[str] = (
|
down_block_types: Tuple[str, ...] = (
|
||||||
"CrossAttnDownBlock2D",
|
"CrossAttnDownBlock2D",
|
||||||
"CrossAttnDownBlock2D",
|
"CrossAttnDownBlock2D",
|
||||||
"CrossAttnDownBlock2D",
|
"CrossAttnDownBlock2D",
|
||||||
"DownBlock2D",
|
"DownBlock2D",
|
||||||
),
|
),
|
||||||
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
|
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
|
||||||
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
|
up_block_types: Tuple[str, ...] = (
|
||||||
|
"UpBlock2D",
|
||||||
|
"CrossAttnUpBlock2D",
|
||||||
|
"CrossAttnUpBlock2D",
|
||||||
|
"CrossAttnUpBlock2D",
|
||||||
|
),
|
||||||
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
||||||
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
||||||
layers_per_block: Union[int, Tuple[int]] = 2,
|
layers_per_block: Union[int, Tuple[int]] = 2,
|
||||||
downsample_padding: int = 1,
|
downsample_padding: int = 1,
|
||||||
mid_block_scale_factor: float = 1,
|
mid_block_scale_factor: float = 1,
|
||||||
@@ -2294,10 +2299,10 @@ class MatryoshkaUNet2DConditionModel(
|
|||||||
|
|
||||||
def _check_config(
|
def _check_config(
|
||||||
self,
|
self,
|
||||||
down_block_types: Tuple[str],
|
down_block_types: Tuple[str, ...],
|
||||||
up_block_types: Tuple[str],
|
up_block_types: Tuple[str, ...],
|
||||||
only_cross_attention: Union[bool, Tuple[bool]],
|
only_cross_attention: Union[bool, Tuple[bool]],
|
||||||
block_out_channels: Tuple[int],
|
block_out_channels: Tuple[int, ...],
|
||||||
layers_per_block: Union[int, Tuple[int]],
|
layers_per_block: Union[int, Tuple[int]],
|
||||||
cross_attention_dim: Union[int, Tuple[int]],
|
cross_attention_dim: Union[int, Tuple[int]],
|
||||||
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple[int]]],
|
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple[int]]],
|
||||||
|
|||||||
@@ -438,16 +438,21 @@ class UNet2DConditionModel(OriginalUNet2DConditionModel, ConfigMixin, UNet2DCond
|
|||||||
center_input_sample: bool = False,
|
center_input_sample: bool = False,
|
||||||
flip_sin_to_cos: bool = True,
|
flip_sin_to_cos: bool = True,
|
||||||
freq_shift: int = 0,
|
freq_shift: int = 0,
|
||||||
down_block_types: Tuple[str] = (
|
down_block_types: Tuple[str, ...] = (
|
||||||
"CrossAttnDownBlock2D",
|
"CrossAttnDownBlock2D",
|
||||||
"CrossAttnDownBlock2D",
|
"CrossAttnDownBlock2D",
|
||||||
"CrossAttnDownBlock2D",
|
"CrossAttnDownBlock2D",
|
||||||
"DownBlock2D",
|
"DownBlock2D",
|
||||||
),
|
),
|
||||||
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
|
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
|
||||||
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
|
up_block_types: Tuple[str, ...] = (
|
||||||
|
"UpBlock2D",
|
||||||
|
"CrossAttnUpBlock2D",
|
||||||
|
"CrossAttnUpBlock2D",
|
||||||
|
"CrossAttnUpBlock2D",
|
||||||
|
),
|
||||||
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
||||||
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
||||||
layers_per_block: Union[int, Tuple[int]] = 2,
|
layers_per_block: Union[int, Tuple[int]] = 2,
|
||||||
downsample_padding: int = 1,
|
downsample_padding: int = 1,
|
||||||
mid_block_scale_factor: float = 1,
|
mid_block_scale_factor: float = 1,
|
||||||
|
|||||||
@@ -102,7 +102,7 @@ def get_block(
|
|||||||
attention_head_dim: int,
|
attention_head_dim: int,
|
||||||
norm_type: str,
|
norm_type: str,
|
||||||
act_fn: str,
|
act_fn: str,
|
||||||
qkv_mutliscales: Tuple[int] = (),
|
qkv_mutliscales: Tuple[int, ...] = (),
|
||||||
):
|
):
|
||||||
if block_type == "ResBlock":
|
if block_type == "ResBlock":
|
||||||
block = ResBlock(in_channels, out_channels, norm_type, act_fn)
|
block = ResBlock(in_channels, out_channels, norm_type, act_fn)
|
||||||
@@ -206,8 +206,8 @@ class Encoder(nn.Module):
|
|||||||
latent_channels: int,
|
latent_channels: int,
|
||||||
attention_head_dim: int = 32,
|
attention_head_dim: int = 32,
|
||||||
block_type: Union[str, Tuple[str]] = "ResBlock",
|
block_type: Union[str, Tuple[str]] = "ResBlock",
|
||||||
block_out_channels: Tuple[int] = (128, 256, 512, 512, 1024, 1024),
|
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512, 1024, 1024),
|
||||||
layers_per_block: Tuple[int] = (2, 2, 2, 2, 2, 2),
|
layers_per_block: Tuple[int, ...] = (2, 2, 2, 2, 2, 2),
|
||||||
qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)),
|
qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)),
|
||||||
downsample_block_type: str = "pixel_unshuffle",
|
downsample_block_type: str = "pixel_unshuffle",
|
||||||
out_shortcut: bool = True,
|
out_shortcut: bool = True,
|
||||||
@@ -292,8 +292,8 @@ class Decoder(nn.Module):
|
|||||||
latent_channels: int,
|
latent_channels: int,
|
||||||
attention_head_dim: int = 32,
|
attention_head_dim: int = 32,
|
||||||
block_type: Union[str, Tuple[str]] = "ResBlock",
|
block_type: Union[str, Tuple[str]] = "ResBlock",
|
||||||
block_out_channels: Tuple[int] = (128, 256, 512, 512, 1024, 1024),
|
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512, 1024, 1024),
|
||||||
layers_per_block: Tuple[int] = (2, 2, 2, 2, 2, 2),
|
layers_per_block: Tuple[int, ...] = (2, 2, 2, 2, 2, 2),
|
||||||
qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)),
|
qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)),
|
||||||
norm_type: Union[str, Tuple[str]] = "rms_norm",
|
norm_type: Union[str, Tuple[str]] = "rms_norm",
|
||||||
act_fn: Union[str, Tuple[str]] = "silu",
|
act_fn: Union[str, Tuple[str]] = "silu",
|
||||||
@@ -440,8 +440,8 @@ class AutoencoderDC(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModel
|
|||||||
decoder_block_types: Union[str, Tuple[str]] = "ResBlock",
|
decoder_block_types: Union[str, Tuple[str]] = "ResBlock",
|
||||||
encoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512, 1024, 1024),
|
encoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512, 1024, 1024),
|
||||||
decoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512, 1024, 1024),
|
decoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512, 1024, 1024),
|
||||||
encoder_layers_per_block: Tuple[int] = (2, 2, 2, 3, 3, 3),
|
encoder_layers_per_block: Tuple[int, ...] = (2, 2, 2, 3, 3, 3),
|
||||||
decoder_layers_per_block: Tuple[int] = (3, 3, 3, 3, 3, 3),
|
decoder_layers_per_block: Tuple[int, ...] = (3, 3, 3, 3, 3, 3),
|
||||||
encoder_qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)),
|
encoder_qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)),
|
||||||
decoder_qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)),
|
decoder_qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)),
|
||||||
upsample_block_type: str = "pixel_shuffle",
|
upsample_block_type: str = "pixel_shuffle",
|
||||||
|
|||||||
@@ -78,9 +78,9 @@ class AutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModel
|
|||||||
self,
|
self,
|
||||||
in_channels: int = 3,
|
in_channels: int = 3,
|
||||||
out_channels: int = 3,
|
out_channels: int = 3,
|
||||||
down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
|
down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",),
|
||||||
up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
|
up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",),
|
||||||
block_out_channels: Tuple[int] = (64,),
|
block_out_channels: Tuple[int, ...] = (64,),
|
||||||
layers_per_block: int = 1,
|
layers_per_block: int = 1,
|
||||||
act_fn: str = "silu",
|
act_fn: str = "silu",
|
||||||
latent_channels: int = 4,
|
latent_channels: int = 4,
|
||||||
|
|||||||
@@ -995,19 +995,19 @@ class AutoencoderKLCogVideoX(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
|
|||||||
self,
|
self,
|
||||||
in_channels: int = 3,
|
in_channels: int = 3,
|
||||||
out_channels: int = 3,
|
out_channels: int = 3,
|
||||||
down_block_types: Tuple[str] = (
|
down_block_types: Tuple[str, ...] = (
|
||||||
"CogVideoXDownBlock3D",
|
"CogVideoXDownBlock3D",
|
||||||
"CogVideoXDownBlock3D",
|
"CogVideoXDownBlock3D",
|
||||||
"CogVideoXDownBlock3D",
|
"CogVideoXDownBlock3D",
|
||||||
"CogVideoXDownBlock3D",
|
"CogVideoXDownBlock3D",
|
||||||
),
|
),
|
||||||
up_block_types: Tuple[str] = (
|
up_block_types: Tuple[str, ...] = (
|
||||||
"CogVideoXUpBlock3D",
|
"CogVideoXUpBlock3D",
|
||||||
"CogVideoXUpBlock3D",
|
"CogVideoXUpBlock3D",
|
||||||
"CogVideoXUpBlock3D",
|
"CogVideoXUpBlock3D",
|
||||||
"CogVideoXUpBlock3D",
|
"CogVideoXUpBlock3D",
|
||||||
),
|
),
|
||||||
block_out_channels: Tuple[int] = (128, 256, 256, 512),
|
block_out_channels: Tuple[int, ...] = (128, 256, 256, 512),
|
||||||
latent_channels: int = 16,
|
latent_channels: int = 16,
|
||||||
layers_per_block: int = 3,
|
layers_per_block: int = 3,
|
||||||
act_fn: str = "silu",
|
act_fn: str = "silu",
|
||||||
|
|||||||
@@ -653,7 +653,7 @@ class AutoencoderKLHunyuanVideo(ModelMixin, AutoencoderMixin, ConfigMixin):
|
|||||||
"HunyuanVideoUpBlock3D",
|
"HunyuanVideoUpBlock3D",
|
||||||
"HunyuanVideoUpBlock3D",
|
"HunyuanVideoUpBlock3D",
|
||||||
),
|
),
|
||||||
block_out_channels: Tuple[int] = (128, 256, 512, 512),
|
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
|
||||||
layers_per_block: int = 2,
|
layers_per_block: int = 2,
|
||||||
act_fn: str = "silu",
|
act_fn: str = "silu",
|
||||||
norm_num_groups: int = 32,
|
norm_num_groups: int = 32,
|
||||||
|
|||||||
@@ -601,7 +601,7 @@ class AutoencoderKLHunyuanImageRefiner(ModelMixin, ConfigMixin):
|
|||||||
in_channels: int = 3,
|
in_channels: int = 3,
|
||||||
out_channels: int = 3,
|
out_channels: int = 3,
|
||||||
latent_channels: int = 32,
|
latent_channels: int = 32,
|
||||||
block_out_channels: Tuple[int] = (128, 256, 512, 1024, 1024),
|
block_out_channels: Tuple[int, ...] = (128, 256, 512, 1024, 1024),
|
||||||
layers_per_block: int = 2,
|
layers_per_block: int = 2,
|
||||||
spatial_compression_ratio: int = 16,
|
spatial_compression_ratio: int = 16,
|
||||||
temporal_compression_ratio: int = 4,
|
temporal_compression_ratio: int = 4,
|
||||||
|
|||||||
@@ -688,8 +688,8 @@ class AutoencoderKLMochi(ModelMixin, AutoencoderMixin, ConfigMixin):
|
|||||||
self,
|
self,
|
||||||
in_channels: int = 15,
|
in_channels: int = 15,
|
||||||
out_channels: int = 3,
|
out_channels: int = 3,
|
||||||
encoder_block_out_channels: Tuple[int] = (64, 128, 256, 384),
|
encoder_block_out_channels: Tuple[int, ...] = (64, 128, 256, 384),
|
||||||
decoder_block_out_channels: Tuple[int] = (128, 256, 512, 768),
|
decoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 768),
|
||||||
latent_channels: int = 12,
|
latent_channels: int = 12,
|
||||||
layers_per_block: Tuple[int, ...] = (3, 3, 4, 6, 3),
|
layers_per_block: Tuple[int, ...] = (3, 3, 4, 6, 3),
|
||||||
act_fn: str = "silu",
|
act_fn: str = "silu",
|
||||||
|
|||||||
@@ -679,7 +679,7 @@ class AutoencoderKLQwenImage(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
|
|||||||
self,
|
self,
|
||||||
base_dim: int = 96,
|
base_dim: int = 96,
|
||||||
z_dim: int = 16,
|
z_dim: int = 16,
|
||||||
dim_mult: Tuple[int] = [1, 2, 4, 4],
|
dim_mult: Tuple[int, ...] = (1, 2, 4, 4),
|
||||||
num_res_blocks: int = 2,
|
num_res_blocks: int = 2,
|
||||||
attn_scales: List[float] = [],
|
attn_scales: List[float] = [],
|
||||||
temperal_downsample: List[bool] = [False, True, True],
|
temperal_downsample: List[bool] = [False, True, True],
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ class TemporalDecoder(nn.Module):
|
|||||||
self,
|
self,
|
||||||
in_channels: int = 4,
|
in_channels: int = 4,
|
||||||
out_channels: int = 3,
|
out_channels: int = 3,
|
||||||
block_out_channels: Tuple[int] = (128, 256, 512, 512),
|
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
|
||||||
layers_per_block: int = 2,
|
layers_per_block: int = 2,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -172,8 +172,8 @@ class AutoencoderKLTemporalDecoder(ModelMixin, AutoencoderMixin, ConfigMixin):
|
|||||||
self,
|
self,
|
||||||
in_channels: int = 3,
|
in_channels: int = 3,
|
||||||
out_channels: int = 3,
|
out_channels: int = 3,
|
||||||
down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
|
down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",),
|
||||||
block_out_channels: Tuple[int] = (64,),
|
block_out_channels: Tuple[int, ...] = (64,),
|
||||||
layers_per_block: int = 1,
|
layers_per_block: int = 1,
|
||||||
latent_channels: int = 4,
|
latent_channels: int = 4,
|
||||||
sample_size: int = 32,
|
sample_size: int = 32,
|
||||||
|
|||||||
@@ -971,7 +971,7 @@ class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMo
|
|||||||
base_dim: int = 96,
|
base_dim: int = 96,
|
||||||
decoder_base_dim: Optional[int] = None,
|
decoder_base_dim: Optional[int] = None,
|
||||||
z_dim: int = 16,
|
z_dim: int = 16,
|
||||||
dim_mult: Tuple[int] = [1, 2, 4, 4],
|
dim_mult: Tuple[int, ...] = (1, 2, 4, 4),
|
||||||
num_res_blocks: int = 2,
|
num_res_blocks: int = 2,
|
||||||
attn_scales: List[float] = [],
|
attn_scales: List[float] = [],
|
||||||
temperal_downsample: List[bool] = [False, True, True],
|
temperal_downsample: List[bool] = [False, True, True],
|
||||||
|
|||||||
@@ -293,14 +293,14 @@ class ControlNetXSAdapter(ModelMixin, ConfigMixin):
|
|||||||
self,
|
self,
|
||||||
conditioning_channels: int = 3,
|
conditioning_channels: int = 3,
|
||||||
conditioning_channel_order: str = "rgb",
|
conditioning_channel_order: str = "rgb",
|
||||||
conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256),
|
conditioning_embedding_out_channels: Tuple[int, ...] = (16, 32, 96, 256),
|
||||||
time_embedding_mix: float = 1.0,
|
time_embedding_mix: float = 1.0,
|
||||||
learn_time_embedding: bool = False,
|
learn_time_embedding: bool = False,
|
||||||
num_attention_heads: Union[int, Tuple[int]] = 4,
|
num_attention_heads: Union[int, Tuple[int]] = 4,
|
||||||
block_out_channels: Tuple[int] = (4, 8, 16, 16),
|
block_out_channels: Tuple[int, ...] = (4, 8, 16, 16),
|
||||||
base_block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
base_block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
||||||
cross_attention_dim: int = 1024,
|
cross_attention_dim: int = 1024,
|
||||||
down_block_types: Tuple[str] = (
|
down_block_types: Tuple[str, ...] = (
|
||||||
"CrossAttnDownBlock2D",
|
"CrossAttnDownBlock2D",
|
||||||
"CrossAttnDownBlock2D",
|
"CrossAttnDownBlock2D",
|
||||||
"CrossAttnDownBlock2D",
|
"CrossAttnDownBlock2D",
|
||||||
@@ -436,7 +436,7 @@ class ControlNetXSAdapter(ModelMixin, ConfigMixin):
|
|||||||
time_embedding_mix: int = 1.0,
|
time_embedding_mix: int = 1.0,
|
||||||
conditioning_channels: int = 3,
|
conditioning_channels: int = 3,
|
||||||
conditioning_channel_order: str = "rgb",
|
conditioning_channel_order: str = "rgb",
|
||||||
conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256),
|
conditioning_embedding_out_channels: Tuple[int, ...] = (16, 32, 96, 256),
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Instantiate a [`ControlNetXSAdapter`] from a [`UNet2DConditionModel`].
|
Instantiate a [`ControlNetXSAdapter`] from a [`UNet2DConditionModel`].
|
||||||
@@ -529,14 +529,19 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
|
|||||||
self,
|
self,
|
||||||
# unet configs
|
# unet configs
|
||||||
sample_size: Optional[int] = 96,
|
sample_size: Optional[int] = 96,
|
||||||
down_block_types: Tuple[str] = (
|
down_block_types: Tuple[str, ...] = (
|
||||||
"CrossAttnDownBlock2D",
|
"CrossAttnDownBlock2D",
|
||||||
"CrossAttnDownBlock2D",
|
"CrossAttnDownBlock2D",
|
||||||
"CrossAttnDownBlock2D",
|
"CrossAttnDownBlock2D",
|
||||||
"DownBlock2D",
|
"DownBlock2D",
|
||||||
),
|
),
|
||||||
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
|
up_block_types: Tuple[str, ...] = (
|
||||||
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
"UpBlock2D",
|
||||||
|
"CrossAttnUpBlock2D",
|
||||||
|
"CrossAttnUpBlock2D",
|
||||||
|
"CrossAttnUpBlock2D",
|
||||||
|
),
|
||||||
|
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
||||||
norm_num_groups: Optional[int] = 32,
|
norm_num_groups: Optional[int] = 32,
|
||||||
cross_attention_dim: Union[int, Tuple[int]] = 1024,
|
cross_attention_dim: Union[int, Tuple[int]] = 1024,
|
||||||
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
|
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
|
||||||
@@ -550,10 +555,10 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
|
|||||||
# additional controlnet configs
|
# additional controlnet configs
|
||||||
time_embedding_mix: float = 1.0,
|
time_embedding_mix: float = 1.0,
|
||||||
ctrl_conditioning_channels: int = 3,
|
ctrl_conditioning_channels: int = 3,
|
||||||
ctrl_conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256),
|
ctrl_conditioning_embedding_out_channels: Tuple[int, ...] = (16, 32, 96, 256),
|
||||||
ctrl_conditioning_channel_order: str = "rgb",
|
ctrl_conditioning_channel_order: str = "rgb",
|
||||||
ctrl_learn_time_embedding: bool = False,
|
ctrl_learn_time_embedding: bool = False,
|
||||||
ctrl_block_out_channels: Tuple[int] = (4, 8, 16, 16),
|
ctrl_block_out_channels: Tuple[int, ...] = (4, 8, 16, 16),
|
||||||
ctrl_num_attention_heads: Union[int, Tuple[int]] = 4,
|
ctrl_num_attention_heads: Union[int, Tuple[int]] = 4,
|
||||||
ctrl_max_norm_num_groups: int = 32,
|
ctrl_max_norm_num_groups: int = 32,
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -914,7 +914,7 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
|
|||||||
text_embed_dim: int = 4096,
|
text_embed_dim: int = 4096,
|
||||||
pooled_projection_dim: int = 768,
|
pooled_projection_dim: int = 768,
|
||||||
rope_theta: float = 256.0,
|
rope_theta: float = 256.0,
|
||||||
rope_axes_dim: Tuple[int] = (16, 56, 56),
|
rope_axes_dim: Tuple[int, ...] = (16, 56, 56),
|
||||||
image_condition_type: Optional[str] = None,
|
image_condition_type: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|||||||
@@ -139,7 +139,7 @@ class HunyuanVideoFramepackTransformer3DModel(
|
|||||||
text_embed_dim: int = 4096,
|
text_embed_dim: int = 4096,
|
||||||
pooled_projection_dim: int = 768,
|
pooled_projection_dim: int = 768,
|
||||||
rope_theta: float = 256.0,
|
rope_theta: float = 256.0,
|
||||||
rope_axes_dim: Tuple[int] = (16, 56, 56),
|
rope_axes_dim: Tuple[int, ...] = (16, 56, 56),
|
||||||
image_condition_type: Optional[str] = None,
|
image_condition_type: Optional[str] = None,
|
||||||
has_image_proj: int = False,
|
has_image_proj: int = False,
|
||||||
image_proj_dim: int = 1152,
|
image_proj_dim: int = 1152,
|
||||||
|
|||||||
@@ -689,7 +689,7 @@ class HunyuanImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
|
|||||||
text_embed_dim: int = 3584,
|
text_embed_dim: int = 3584,
|
||||||
text_embed_2_dim: Optional[int] = None,
|
text_embed_2_dim: Optional[int] = None,
|
||||||
rope_theta: float = 256.0,
|
rope_theta: float = 256.0,
|
||||||
rope_axes_dim: Tuple[int] = (64, 64),
|
rope_axes_dim: Tuple[int, ...] = (64, 64),
|
||||||
use_meanflow: bool = False,
|
use_meanflow: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|||||||
@@ -570,7 +570,7 @@ class SkyReelsV2Transformer3DModel(
|
|||||||
@register_to_config
|
@register_to_config
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
patch_size: Tuple[int] = (1, 2, 2),
|
patch_size: Tuple[int, ...] = (1, 2, 2),
|
||||||
num_attention_heads: int = 16,
|
num_attention_heads: int = 16,
|
||||||
attention_head_dim: int = 128,
|
attention_head_dim: int = 128,
|
||||||
in_channels: int = 16,
|
in_channels: int = 16,
|
||||||
|
|||||||
@@ -563,7 +563,7 @@ class WanTransformer3DModel(
|
|||||||
@register_to_config
|
@register_to_config
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
patch_size: Tuple[int] = (1, 2, 2),
|
patch_size: Tuple[int, ...] = (1, 2, 2),
|
||||||
num_attention_heads: int = 40,
|
num_attention_heads: int = 40,
|
||||||
attention_head_dim: int = 128,
|
attention_head_dim: int = 128,
|
||||||
in_channels: int = 16,
|
in_channels: int = 16,
|
||||||
|
|||||||
@@ -182,7 +182,7 @@ class WanVACETransformer3DModel(
|
|||||||
@register_to_config
|
@register_to_config
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
patch_size: Tuple[int] = (1, 2, 2),
|
patch_size: Tuple[int, ...] = (1, 2, 2),
|
||||||
num_attention_heads: int = 40,
|
num_attention_heads: int = 40,
|
||||||
attention_head_dim: int = 128,
|
attention_head_dim: int = 128,
|
||||||
in_channels: int = 16,
|
in_channels: int = 16,
|
||||||
|
|||||||
@@ -86,11 +86,11 @@ class UNet1DModel(ModelMixin, ConfigMixin):
|
|||||||
flip_sin_to_cos: bool = True,
|
flip_sin_to_cos: bool = True,
|
||||||
use_timestep_embedding: bool = False,
|
use_timestep_embedding: bool = False,
|
||||||
freq_shift: float = 0.0,
|
freq_shift: float = 0.0,
|
||||||
down_block_types: Tuple[str] = ("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D"),
|
down_block_types: Tuple[str, ...] = ("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D"),
|
||||||
up_block_types: Tuple[str] = ("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip"),
|
up_block_types: Tuple[str, ...] = ("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip"),
|
||||||
mid_block_type: Tuple[str] = "UNetMidBlock1D",
|
mid_block_type: str = "UNetMidBlock1D",
|
||||||
out_block_type: str = None,
|
out_block_type: str = None,
|
||||||
block_out_channels: Tuple[int] = (32, 32, 64),
|
block_out_channels: Tuple[int, ...] = (32, 32, 64),
|
||||||
act_fn: str = None,
|
act_fn: str = None,
|
||||||
norm_num_groups: int = 8,
|
norm_num_groups: int = 8,
|
||||||
layers_per_block: int = 1,
|
layers_per_block: int = 1,
|
||||||
|
|||||||
@@ -177,16 +177,21 @@ class UNet2DConditionModel(
|
|||||||
center_input_sample: bool = False,
|
center_input_sample: bool = False,
|
||||||
flip_sin_to_cos: bool = True,
|
flip_sin_to_cos: bool = True,
|
||||||
freq_shift: int = 0,
|
freq_shift: int = 0,
|
||||||
down_block_types: Tuple[str] = (
|
down_block_types: Tuple[str, ...] = (
|
||||||
"CrossAttnDownBlock2D",
|
"CrossAttnDownBlock2D",
|
||||||
"CrossAttnDownBlock2D",
|
"CrossAttnDownBlock2D",
|
||||||
"CrossAttnDownBlock2D",
|
"CrossAttnDownBlock2D",
|
||||||
"DownBlock2D",
|
"DownBlock2D",
|
||||||
),
|
),
|
||||||
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
|
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
|
||||||
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
|
up_block_types: Tuple[str, ...] = (
|
||||||
|
"UpBlock2D",
|
||||||
|
"CrossAttnUpBlock2D",
|
||||||
|
"CrossAttnUpBlock2D",
|
||||||
|
"CrossAttnUpBlock2D",
|
||||||
|
),
|
||||||
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
||||||
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
||||||
layers_per_block: Union[int, Tuple[int]] = 2,
|
layers_per_block: Union[int, Tuple[int]] = 2,
|
||||||
downsample_padding: int = 1,
|
downsample_padding: int = 1,
|
||||||
mid_block_scale_factor: float = 1,
|
mid_block_scale_factor: float = 1,
|
||||||
@@ -486,10 +491,10 @@ class UNet2DConditionModel(
|
|||||||
|
|
||||||
def _check_config(
|
def _check_config(
|
||||||
self,
|
self,
|
||||||
down_block_types: Tuple[str],
|
down_block_types: Tuple[str, ...],
|
||||||
up_block_types: Tuple[str],
|
up_block_types: Tuple[str, ...],
|
||||||
only_cross_attention: Union[bool, Tuple[bool]],
|
only_cross_attention: Union[bool, Tuple[bool]],
|
||||||
block_out_channels: Tuple[int],
|
block_out_channels: Tuple[int, ...],
|
||||||
layers_per_block: Union[int, Tuple[int]],
|
layers_per_block: Union[int, Tuple[int]],
|
||||||
cross_attention_dim: Union[int, Tuple[int]],
|
cross_attention_dim: Union[int, Tuple[int]],
|
||||||
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple[int]]],
|
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple[int]]],
|
||||||
|
|||||||
@@ -54,7 +54,7 @@ class Kandinsky3UNet(ModelMixin, ConfigMixin):
|
|||||||
groups: int = 32,
|
groups: int = 32,
|
||||||
attention_head_dim: int = 64,
|
attention_head_dim: int = 64,
|
||||||
layers_per_block: Union[int, Tuple[int]] = 3,
|
layers_per_block: Union[int, Tuple[int]] = 3,
|
||||||
block_out_channels: Tuple[int] = (384, 768, 1536, 3072),
|
block_out_channels: Tuple[int, ...] = (384, 768, 1536, 3072),
|
||||||
cross_attention_dim: Union[int, Tuple[int]] = 4096,
|
cross_attention_dim: Union[int, Tuple[int]] = 4096,
|
||||||
encoder_hid_dim: int = 4096,
|
encoder_hid_dim: int = 4096,
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -73,25 +73,25 @@ class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionL
|
|||||||
sample_size: Optional[int] = None,
|
sample_size: Optional[int] = None,
|
||||||
in_channels: int = 8,
|
in_channels: int = 8,
|
||||||
out_channels: int = 4,
|
out_channels: int = 4,
|
||||||
down_block_types: Tuple[str] = (
|
down_block_types: Tuple[str, ...] = (
|
||||||
"CrossAttnDownBlockSpatioTemporal",
|
"CrossAttnDownBlockSpatioTemporal",
|
||||||
"CrossAttnDownBlockSpatioTemporal",
|
"CrossAttnDownBlockSpatioTemporal",
|
||||||
"CrossAttnDownBlockSpatioTemporal",
|
"CrossAttnDownBlockSpatioTemporal",
|
||||||
"DownBlockSpatioTemporal",
|
"DownBlockSpatioTemporal",
|
||||||
),
|
),
|
||||||
up_block_types: Tuple[str] = (
|
up_block_types: Tuple[str, ...] = (
|
||||||
"UpBlockSpatioTemporal",
|
"UpBlockSpatioTemporal",
|
||||||
"CrossAttnUpBlockSpatioTemporal",
|
"CrossAttnUpBlockSpatioTemporal",
|
||||||
"CrossAttnUpBlockSpatioTemporal",
|
"CrossAttnUpBlockSpatioTemporal",
|
||||||
"CrossAttnUpBlockSpatioTemporal",
|
"CrossAttnUpBlockSpatioTemporal",
|
||||||
),
|
),
|
||||||
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
||||||
addition_time_embed_dim: int = 256,
|
addition_time_embed_dim: int = 256,
|
||||||
projection_class_embeddings_input_dim: int = 768,
|
projection_class_embeddings_input_dim: int = 768,
|
||||||
layers_per_block: Union[int, Tuple[int]] = 2,
|
layers_per_block: Union[int, Tuple[int]] = 2,
|
||||||
cross_attention_dim: Union[int, Tuple[int]] = 1024,
|
cross_attention_dim: Union[int, Tuple[int]] = 1024,
|
||||||
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
|
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
|
||||||
num_attention_heads: Union[int, Tuple[int]] = (5, 10, 20, 20),
|
num_attention_heads: Union[int, Tuple[int, ...]] = (5, 10, 20, 20),
|
||||||
num_frames: int = 25,
|
num_frames: int = 25,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|||||||
@@ -145,10 +145,10 @@ class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|||||||
timestep_ratio_embedding_dim: int = 64,
|
timestep_ratio_embedding_dim: int = 64,
|
||||||
patch_size: int = 1,
|
patch_size: int = 1,
|
||||||
conditioning_dim: int = 2048,
|
conditioning_dim: int = 2048,
|
||||||
block_out_channels: Tuple[int] = (2048, 2048),
|
block_out_channels: Tuple[int, ...] = (2048, 2048),
|
||||||
num_attention_heads: Tuple[int] = (32, 32),
|
num_attention_heads: Tuple[int, ...] = (32, 32),
|
||||||
down_num_layers_per_block: Tuple[int] = (8, 24),
|
down_num_layers_per_block: Tuple[int, ...] = (8, 24),
|
||||||
up_num_layers_per_block: Tuple[int] = (24, 8),
|
up_num_layers_per_block: Tuple[int, ...] = (24, 8),
|
||||||
down_blocks_repeat_mappers: Optional[Tuple[int]] = (
|
down_blocks_repeat_mappers: Optional[Tuple[int]] = (
|
||||||
1,
|
1,
|
||||||
1,
|
1,
|
||||||
@@ -167,7 +167,7 @@ class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
dropout: Union[float, Tuple[float]] = (0.1, 0.1),
|
dropout: Union[float, Tuple[float]] = (0.1, 0.1),
|
||||||
self_attn: Union[bool, Tuple[bool]] = True,
|
self_attn: Union[bool, Tuple[bool]] = True,
|
||||||
timestep_conditioning_type: Tuple[str] = ("sca", "crp"),
|
timestep_conditioning_type: Tuple[str, ...] = ("sca", "crp"),
|
||||||
switch_level: Optional[Tuple[bool]] = None,
|
switch_level: Optional[Tuple[bool]] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -532,8 +532,8 @@ class FlaxEncoder(nn.Module):
|
|||||||
|
|
||||||
in_channels: int = 3
|
in_channels: int = 3
|
||||||
out_channels: int = 3
|
out_channels: int = 3
|
||||||
down_block_types: Tuple[str] = ("DownEncoderBlock2D",)
|
down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",)
|
||||||
block_out_channels: Tuple[int] = (64,)
|
block_out_channels: Tuple[int, ...] = (64,)
|
||||||
layers_per_block: int = 2
|
layers_per_block: int = 2
|
||||||
norm_num_groups: int = 32
|
norm_num_groups: int = 32
|
||||||
act_fn: str = "silu"
|
act_fn: str = "silu"
|
||||||
@@ -650,8 +650,8 @@ class FlaxDecoder(nn.Module):
|
|||||||
|
|
||||||
in_channels: int = 3
|
in_channels: int = 3
|
||||||
out_channels: int = 3
|
out_channels: int = 3
|
||||||
up_block_types: Tuple[str] = ("UpDecoderBlock2D",)
|
up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",)
|
||||||
block_out_channels: int = (64,)
|
block_out_channels: Tuple[int, ...] = (64,)
|
||||||
layers_per_block: int = 2
|
layers_per_block: int = 2
|
||||||
norm_num_groups: int = 32
|
norm_num_groups: int = 32
|
||||||
act_fn: str = "silu"
|
act_fn: str = "silu"
|
||||||
@@ -823,9 +823,9 @@ class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin):
|
|||||||
|
|
||||||
in_channels: int = 3
|
in_channels: int = 3
|
||||||
out_channels: int = 3
|
out_channels: int = 3
|
||||||
down_block_types: Tuple[str] = ("DownEncoderBlock2D",)
|
down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",)
|
||||||
up_block_types: Tuple[str] = ("UpDecoderBlock2D",)
|
up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",)
|
||||||
block_out_channels: Tuple[int] = (64,)
|
block_out_channels: Tuple[int, ...] = (64,)
|
||||||
layers_per_block: int = 1
|
layers_per_block: int = 1
|
||||||
act_fn: str = "silu"
|
act_fn: str = "silu"
|
||||||
latent_channels: int = 4
|
latent_channels: int = 4
|
||||||
|
|||||||
@@ -245,16 +245,21 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad
|
|||||||
out_channels: int = 4,
|
out_channels: int = 4,
|
||||||
flip_sin_to_cos: bool = True,
|
flip_sin_to_cos: bool = True,
|
||||||
freq_shift: int = 0,
|
freq_shift: int = 0,
|
||||||
down_block_types: Tuple[str] = (
|
down_block_types: Tuple[str, ...] = (
|
||||||
"CrossAttnDownBlock2D",
|
"CrossAttnDownBlock2D",
|
||||||
"CrossAttnDownBlock2D",
|
"CrossAttnDownBlock2D",
|
||||||
"CrossAttnDownBlock2D",
|
"CrossAttnDownBlock2D",
|
||||||
"DownBlock2D",
|
"DownBlock2D",
|
||||||
),
|
),
|
||||||
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
|
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
|
||||||
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
|
up_block_types: Tuple[str, ...] = (
|
||||||
|
"UpBlock2D",
|
||||||
|
"CrossAttnUpBlock2D",
|
||||||
|
"CrossAttnUpBlock2D",
|
||||||
|
"CrossAttnUpBlock2D",
|
||||||
|
),
|
||||||
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
||||||
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
||||||
layers_per_block: Union[int, Tuple[int]] = 2,
|
layers_per_block: Union[int, Tuple[int]] = 2,
|
||||||
downsample_padding: int = 1,
|
downsample_padding: int = 1,
|
||||||
mid_block_scale_factor: float = 1,
|
mid_block_scale_factor: float = 1,
|
||||||
|
|||||||
@@ -374,21 +374,21 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|||||||
center_input_sample: bool = False,
|
center_input_sample: bool = False,
|
||||||
flip_sin_to_cos: bool = True,
|
flip_sin_to_cos: bool = True,
|
||||||
freq_shift: int = 0,
|
freq_shift: int = 0,
|
||||||
down_block_types: Tuple[str] = (
|
down_block_types: Tuple[str, ...] = (
|
||||||
"CrossAttnDownBlockFlat",
|
"CrossAttnDownBlockFlat",
|
||||||
"CrossAttnDownBlockFlat",
|
"CrossAttnDownBlockFlat",
|
||||||
"CrossAttnDownBlockFlat",
|
"CrossAttnDownBlockFlat",
|
||||||
"DownBlockFlat",
|
"DownBlockFlat",
|
||||||
),
|
),
|
||||||
mid_block_type: Optional[str] = "UNetMidBlockFlatCrossAttn",
|
mid_block_type: Optional[str] = "UNetMidBlockFlatCrossAttn",
|
||||||
up_block_types: Tuple[str] = (
|
up_block_types: Tuple[str, ...] = (
|
||||||
"UpBlockFlat",
|
"UpBlockFlat",
|
||||||
"CrossAttnUpBlockFlat",
|
"CrossAttnUpBlockFlat",
|
||||||
"CrossAttnUpBlockFlat",
|
"CrossAttnUpBlockFlat",
|
||||||
"CrossAttnUpBlockFlat",
|
"CrossAttnUpBlockFlat",
|
||||||
),
|
),
|
||||||
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
||||||
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
||||||
layers_per_block: Union[int, Tuple[int]] = 2,
|
layers_per_block: Union[int, Tuple[int]] = 2,
|
||||||
downsample_padding: int = 1,
|
downsample_padding: int = 1,
|
||||||
mid_block_scale_factor: float = 1,
|
mid_block_scale_factor: float = 1,
|
||||||
|
|||||||
@@ -742,7 +742,7 @@ class ShapEParamsProjModel(ModelMixin, ConfigMixin):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
param_names: Tuple[str] = (
|
param_names: Tuple[str, ...] = (
|
||||||
"nerstf.mlp.0.weight",
|
"nerstf.mlp.0.weight",
|
||||||
"nerstf.mlp.1.weight",
|
"nerstf.mlp.1.weight",
|
||||||
"nerstf.mlp.2.weight",
|
"nerstf.mlp.2.weight",
|
||||||
@@ -786,13 +786,13 @@ class ShapERenderer(ModelMixin, ConfigMixin):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
param_names: Tuple[str] = (
|
param_names: Tuple[str, ...] = (
|
||||||
"nerstf.mlp.0.weight",
|
"nerstf.mlp.0.weight",
|
||||||
"nerstf.mlp.1.weight",
|
"nerstf.mlp.1.weight",
|
||||||
"nerstf.mlp.2.weight",
|
"nerstf.mlp.2.weight",
|
||||||
"nerstf.mlp.3.weight",
|
"nerstf.mlp.3.weight",
|
||||||
),
|
),
|
||||||
param_shapes: Tuple[Tuple[int]] = (
|
param_shapes: Tuple[Tuple[int, int], ...] = (
|
||||||
(256, 93),
|
(256, 93),
|
||||||
(256, 256),
|
(256, 256),
|
||||||
(256, 256),
|
(256, 256),
|
||||||
@@ -804,7 +804,7 @@ class ShapERenderer(ModelMixin, ConfigMixin):
|
|||||||
n_hidden_layers: int = 6,
|
n_hidden_layers: int = 6,
|
||||||
act_fn: str = "swish",
|
act_fn: str = "swish",
|
||||||
insert_direction_at: int = 4,
|
insert_direction_at: int = 4,
|
||||||
background: Tuple[float] = (
|
background: Tuple[float, ...] = (
|
||||||
255.0,
|
255.0,
|
||||||
255.0,
|
255.0,
|
||||||
255.0,
|
255.0,
|
||||||
|
|||||||
Reference in New Issue
Block a user