Flax: Trickle down norm_num_groups (#789)
* pass norm_num_groups param and add tests * set resnet_groups for FlaxUNetMidBlock2D * fixed docstrings * fixed typo * using is_flax_available util and created require_flax decorator
This commit is contained in:
@@ -119,6 +119,8 @@ class FlaxResnetBlock2D(nn.Module):
|
|||||||
Output channels
|
Output channels
|
||||||
dropout (:obj:`float`, *optional*, defaults to 0.0):
|
dropout (:obj:`float`, *optional*, defaults to 0.0):
|
||||||
Dropout rate
|
Dropout rate
|
||||||
|
groups (:obj:`int`, *optional*, defaults to `32`):
|
||||||
|
The number of groups to use for group norm.
|
||||||
use_nin_shortcut (:obj:`bool`, *optional*, defaults to `None`):
|
use_nin_shortcut (:obj:`bool`, *optional*, defaults to `None`):
|
||||||
Whether to use `nin_shortcut`. This activates a new layer inside ResNet block
|
Whether to use `nin_shortcut`. This activates a new layer inside ResNet block
|
||||||
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
||||||
@@ -128,13 +130,14 @@ class FlaxResnetBlock2D(nn.Module):
|
|||||||
in_channels: int
|
in_channels: int
|
||||||
out_channels: int = None
|
out_channels: int = None
|
||||||
dropout: float = 0.0
|
dropout: float = 0.0
|
||||||
|
groups: int = 32
|
||||||
use_nin_shortcut: bool = None
|
use_nin_shortcut: bool = None
|
||||||
dtype: jnp.dtype = jnp.float32
|
dtype: jnp.dtype = jnp.float32
|
||||||
|
|
||||||
def setup(self):
|
def setup(self):
|
||||||
out_channels = self.in_channels if self.out_channels is None else self.out_channels
|
out_channels = self.in_channels if self.out_channels is None else self.out_channels
|
||||||
|
|
||||||
self.norm1 = nn.GroupNorm(num_groups=32, epsilon=1e-6)
|
self.norm1 = nn.GroupNorm(num_groups=self.groups, epsilon=1e-6)
|
||||||
self.conv1 = nn.Conv(
|
self.conv1 = nn.Conv(
|
||||||
out_channels,
|
out_channels,
|
||||||
kernel_size=(3, 3),
|
kernel_size=(3, 3),
|
||||||
@@ -143,7 +146,7 @@ class FlaxResnetBlock2D(nn.Module):
|
|||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.norm2 = nn.GroupNorm(num_groups=32, epsilon=1e-6)
|
self.norm2 = nn.GroupNorm(num_groups=self.groups, epsilon=1e-6)
|
||||||
self.dropout_layer = nn.Dropout(self.dropout)
|
self.dropout_layer = nn.Dropout(self.dropout)
|
||||||
self.conv2 = nn.Conv(
|
self.conv2 = nn.Conv(
|
||||||
out_channels,
|
out_channels,
|
||||||
@@ -191,12 +194,15 @@ class FlaxAttentionBlock(nn.Module):
|
|||||||
Input channels
|
Input channels
|
||||||
num_head_channels (:obj:`int`, *optional*, defaults to `None`):
|
num_head_channels (:obj:`int`, *optional*, defaults to `None`):
|
||||||
Number of attention heads
|
Number of attention heads
|
||||||
|
num_groups (:obj:`int`, *optional*, defaults to `32`):
|
||||||
|
The number of groups to use for group norm
|
||||||
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
||||||
Parameters `dtype`
|
Parameters `dtype`
|
||||||
|
|
||||||
"""
|
"""
|
||||||
channels: int
|
channels: int
|
||||||
num_head_channels: int = None
|
num_head_channels: int = None
|
||||||
|
num_groups: int = 32
|
||||||
dtype: jnp.dtype = jnp.float32
|
dtype: jnp.dtype = jnp.float32
|
||||||
|
|
||||||
def setup(self):
|
def setup(self):
|
||||||
@@ -204,7 +210,7 @@ class FlaxAttentionBlock(nn.Module):
|
|||||||
|
|
||||||
dense = partial(nn.Dense, self.channels, dtype=self.dtype)
|
dense = partial(nn.Dense, self.channels, dtype=self.dtype)
|
||||||
|
|
||||||
self.group_norm = nn.GroupNorm(num_groups=32, epsilon=1e-6)
|
self.group_norm = nn.GroupNorm(num_groups=self.num_groups, epsilon=1e-6)
|
||||||
self.query, self.key, self.value = dense(), dense(), dense()
|
self.query, self.key, self.value = dense(), dense(), dense()
|
||||||
self.proj_attn = dense()
|
self.proj_attn = dense()
|
||||||
|
|
||||||
@@ -264,6 +270,8 @@ class FlaxDownEncoderBlock2D(nn.Module):
|
|||||||
Dropout rate
|
Dropout rate
|
||||||
num_layers (:obj:`int`, *optional*, defaults to 1):
|
num_layers (:obj:`int`, *optional*, defaults to 1):
|
||||||
Number of Resnet layer block
|
Number of Resnet layer block
|
||||||
|
resnet_groups (:obj:`int`, *optional*, defaults to `32`):
|
||||||
|
The number of groups to use for the Resnet block group norm
|
||||||
add_downsample (:obj:`bool`, *optional*, defaults to `True`):
|
add_downsample (:obj:`bool`, *optional*, defaults to `True`):
|
||||||
Whether to add downsample layer
|
Whether to add downsample layer
|
||||||
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
||||||
@@ -273,6 +281,7 @@ class FlaxDownEncoderBlock2D(nn.Module):
|
|||||||
out_channels: int
|
out_channels: int
|
||||||
dropout: float = 0.0
|
dropout: float = 0.0
|
||||||
num_layers: int = 1
|
num_layers: int = 1
|
||||||
|
resnet_groups: int = 32
|
||||||
add_downsample: bool = True
|
add_downsample: bool = True
|
||||||
dtype: jnp.dtype = jnp.float32
|
dtype: jnp.dtype = jnp.float32
|
||||||
|
|
||||||
@@ -285,6 +294,7 @@ class FlaxDownEncoderBlock2D(nn.Module):
|
|||||||
in_channels=in_channels,
|
in_channels=in_channels,
|
||||||
out_channels=self.out_channels,
|
out_channels=self.out_channels,
|
||||||
dropout=self.dropout,
|
dropout=self.dropout,
|
||||||
|
groups=self.resnet_groups,
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
)
|
)
|
||||||
resnets.append(res_block)
|
resnets.append(res_block)
|
||||||
@@ -303,9 +313,9 @@ class FlaxDownEncoderBlock2D(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class FlaxUpEncoderBlock2D(nn.Module):
|
class FlaxUpDecoderBlock2D(nn.Module):
|
||||||
r"""
|
r"""
|
||||||
Flax Resnet blocks-based Encoder block for diffusion-based VAE.
|
Flax Resnet blocks-based Decoder block for diffusion-based VAE.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
in_channels (:obj:`int`):
|
in_channels (:obj:`int`):
|
||||||
@@ -316,8 +326,10 @@ class FlaxUpEncoderBlock2D(nn.Module):
|
|||||||
Dropout rate
|
Dropout rate
|
||||||
num_layers (:obj:`int`, *optional*, defaults to 1):
|
num_layers (:obj:`int`, *optional*, defaults to 1):
|
||||||
Number of Resnet layer block
|
Number of Resnet layer block
|
||||||
add_downsample (:obj:`bool`, *optional*, defaults to `True`):
|
resnet_groups (:obj:`int`, *optional*, defaults to `32`):
|
||||||
Whether to add downsample layer
|
The number of groups to use for the Resnet block group norm
|
||||||
|
add_upsample (:obj:`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to add upsample layer
|
||||||
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
||||||
Parameters `dtype`
|
Parameters `dtype`
|
||||||
"""
|
"""
|
||||||
@@ -325,6 +337,7 @@ class FlaxUpEncoderBlock2D(nn.Module):
|
|||||||
out_channels: int
|
out_channels: int
|
||||||
dropout: float = 0.0
|
dropout: float = 0.0
|
||||||
num_layers: int = 1
|
num_layers: int = 1
|
||||||
|
resnet_groups: int = 32
|
||||||
add_upsample: bool = True
|
add_upsample: bool = True
|
||||||
dtype: jnp.dtype = jnp.float32
|
dtype: jnp.dtype = jnp.float32
|
||||||
|
|
||||||
@@ -336,6 +349,7 @@ class FlaxUpEncoderBlock2D(nn.Module):
|
|||||||
in_channels=in_channels,
|
in_channels=in_channels,
|
||||||
out_channels=self.out_channels,
|
out_channels=self.out_channels,
|
||||||
dropout=self.dropout,
|
dropout=self.dropout,
|
||||||
|
groups=self.resnet_groups,
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
)
|
)
|
||||||
resnets.append(res_block)
|
resnets.append(res_block)
|
||||||
@@ -366,6 +380,8 @@ class FlaxUNetMidBlock2D(nn.Module):
|
|||||||
Dropout rate
|
Dropout rate
|
||||||
num_layers (:obj:`int`, *optional*, defaults to 1):
|
num_layers (:obj:`int`, *optional*, defaults to 1):
|
||||||
Number of Resnet layer block
|
Number of Resnet layer block
|
||||||
|
resnet_groups (:obj:`int`, *optional*, defaults to `32`):
|
||||||
|
The number of groups to use for the Resnet and Attention block group norm
|
||||||
attn_num_head_channels (:obj:`int`, *optional*, defaults to `1`):
|
attn_num_head_channels (:obj:`int`, *optional*, defaults to `1`):
|
||||||
Number of attention heads for each attention block
|
Number of attention heads for each attention block
|
||||||
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
||||||
@@ -374,16 +390,20 @@ class FlaxUNetMidBlock2D(nn.Module):
|
|||||||
in_channels: int
|
in_channels: int
|
||||||
dropout: float = 0.0
|
dropout: float = 0.0
|
||||||
num_layers: int = 1
|
num_layers: int = 1
|
||||||
|
resnet_groups: int = 32
|
||||||
attn_num_head_channels: int = 1
|
attn_num_head_channels: int = 1
|
||||||
dtype: jnp.dtype = jnp.float32
|
dtype: jnp.dtype = jnp.float32
|
||||||
|
|
||||||
def setup(self):
|
def setup(self):
|
||||||
|
resnet_groups = self.resnet_groups if self.resnet_groups is not None else min(self.in_channels // 4, 32)
|
||||||
|
|
||||||
# there is always at least one resnet
|
# there is always at least one resnet
|
||||||
resnets = [
|
resnets = [
|
||||||
FlaxResnetBlock2D(
|
FlaxResnetBlock2D(
|
||||||
in_channels=self.in_channels,
|
in_channels=self.in_channels,
|
||||||
out_channels=self.in_channels,
|
out_channels=self.in_channels,
|
||||||
dropout=self.dropout,
|
dropout=self.dropout,
|
||||||
|
groups=resnet_groups,
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
@@ -392,7 +412,10 @@ class FlaxUNetMidBlock2D(nn.Module):
|
|||||||
|
|
||||||
for _ in range(self.num_layers):
|
for _ in range(self.num_layers):
|
||||||
attn_block = FlaxAttentionBlock(
|
attn_block = FlaxAttentionBlock(
|
||||||
channels=self.in_channels, num_head_channels=self.attn_num_head_channels, dtype=self.dtype
|
channels=self.in_channels,
|
||||||
|
num_head_channels=self.attn_num_head_channels,
|
||||||
|
num_groups=resnet_groups,
|
||||||
|
dtype=self.dtype,
|
||||||
)
|
)
|
||||||
attentions.append(attn_block)
|
attentions.append(attn_block)
|
||||||
|
|
||||||
@@ -400,6 +423,7 @@ class FlaxUNetMidBlock2D(nn.Module):
|
|||||||
in_channels=self.in_channels,
|
in_channels=self.in_channels,
|
||||||
out_channels=self.in_channels,
|
out_channels=self.in_channels,
|
||||||
dropout=self.dropout,
|
dropout=self.dropout,
|
||||||
|
groups=resnet_groups,
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
)
|
)
|
||||||
resnets.append(res_block)
|
resnets.append(res_block)
|
||||||
@@ -441,7 +465,7 @@ class FlaxEncoder(nn.Module):
|
|||||||
Tuple containing the number of output channels for each block
|
Tuple containing the number of output channels for each block
|
||||||
layers_per_block (:obj:`int`, *optional*, defaults to `2`):
|
layers_per_block (:obj:`int`, *optional*, defaults to `2`):
|
||||||
Number of Resnet layer for each block
|
Number of Resnet layer for each block
|
||||||
norm_num_groups (:obj:`int`, *optional*, defaults to `2`):
|
norm_num_groups (:obj:`int`, *optional*, defaults to `32`):
|
||||||
norm num group
|
norm num group
|
||||||
act_fn (:obj:`str`, *optional*, defaults to `silu`):
|
act_fn (:obj:`str`, *optional*, defaults to `silu`):
|
||||||
Activation function
|
Activation function
|
||||||
@@ -483,6 +507,7 @@ class FlaxEncoder(nn.Module):
|
|||||||
in_channels=input_channel,
|
in_channels=input_channel,
|
||||||
out_channels=output_channel,
|
out_channels=output_channel,
|
||||||
num_layers=self.layers_per_block,
|
num_layers=self.layers_per_block,
|
||||||
|
resnet_groups=self.norm_num_groups,
|
||||||
add_downsample=not is_final_block,
|
add_downsample=not is_final_block,
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
)
|
)
|
||||||
@@ -491,12 +516,15 @@ class FlaxEncoder(nn.Module):
|
|||||||
|
|
||||||
# middle
|
# middle
|
||||||
self.mid_block = FlaxUNetMidBlock2D(
|
self.mid_block = FlaxUNetMidBlock2D(
|
||||||
in_channels=block_out_channels[-1], attn_num_head_channels=None, dtype=self.dtype
|
in_channels=block_out_channels[-1],
|
||||||
|
resnet_groups=self.norm_num_groups,
|
||||||
|
attn_num_head_channels=None,
|
||||||
|
dtype=self.dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
# end
|
# end
|
||||||
conv_out_channels = 2 * self.out_channels if self.double_z else self.out_channels
|
conv_out_channels = 2 * self.out_channels if self.double_z else self.out_channels
|
||||||
self.conv_norm_out = nn.GroupNorm(num_groups=32, epsilon=1e-6)
|
self.conv_norm_out = nn.GroupNorm(num_groups=self.norm_num_groups, epsilon=1e-6)
|
||||||
self.conv_out = nn.Conv(
|
self.conv_out = nn.Conv(
|
||||||
conv_out_channels,
|
conv_out_channels,
|
||||||
kernel_size=(3, 3),
|
kernel_size=(3, 3),
|
||||||
@@ -581,7 +609,10 @@ class FlaxDecoder(nn.Module):
|
|||||||
|
|
||||||
# middle
|
# middle
|
||||||
self.mid_block = FlaxUNetMidBlock2D(
|
self.mid_block = FlaxUNetMidBlock2D(
|
||||||
in_channels=block_out_channels[-1], attn_num_head_channels=None, dtype=self.dtype
|
in_channels=block_out_channels[-1],
|
||||||
|
resnet_groups=self.norm_num_groups,
|
||||||
|
attn_num_head_channels=None,
|
||||||
|
dtype=self.dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
# upsampling
|
# upsampling
|
||||||
@@ -594,10 +625,11 @@ class FlaxDecoder(nn.Module):
|
|||||||
|
|
||||||
is_final_block = i == len(block_out_channels) - 1
|
is_final_block = i == len(block_out_channels) - 1
|
||||||
|
|
||||||
up_block = FlaxUpEncoderBlock2D(
|
up_block = FlaxUpDecoderBlock2D(
|
||||||
in_channels=prev_output_channel,
|
in_channels=prev_output_channel,
|
||||||
out_channels=output_channel,
|
out_channels=output_channel,
|
||||||
num_layers=self.layers_per_block + 1,
|
num_layers=self.layers_per_block + 1,
|
||||||
|
resnet_groups=self.norm_num_groups,
|
||||||
add_upsample=not is_final_block,
|
add_upsample=not is_final_block,
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
)
|
)
|
||||||
@@ -607,7 +639,7 @@ class FlaxDecoder(nn.Module):
|
|||||||
self.up_blocks = up_blocks
|
self.up_blocks = up_blocks
|
||||||
|
|
||||||
# end
|
# end
|
||||||
self.conv_norm_out = nn.GroupNorm(num_groups=32, epsilon=1e-6)
|
self.conv_norm_out = nn.GroupNorm(num_groups=self.norm_num_groups, epsilon=1e-6)
|
||||||
self.conv_out = nn.Conv(
|
self.conv_out = nn.Conv(
|
||||||
self.out_channels,
|
self.out_channels,
|
||||||
kernel_size=(3, 3),
|
kernel_size=(3, 3),
|
||||||
|
|||||||
@@ -14,6 +14,8 @@ import PIL.ImageOps
|
|||||||
import requests
|
import requests
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
|
from .import_utils import is_flax_available
|
||||||
|
|
||||||
|
|
||||||
global_rng = random.Random()
|
global_rng = random.Random()
|
||||||
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
@@ -89,6 +91,13 @@ def slow(test_case):
|
|||||||
return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case)
|
return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case)
|
||||||
|
|
||||||
|
|
||||||
|
def require_flax(test_case):
|
||||||
|
"""
|
||||||
|
Decorator marking a test that requires JAX & Flax. These tests are skipped when one / both are not installed
|
||||||
|
"""
|
||||||
|
return unittest.skipUnless(is_flax_available(), "test requires JAX & Flax")(test_case)
|
||||||
|
|
||||||
|
|
||||||
def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image:
|
def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@@ -0,0 +1,44 @@
|
|||||||
|
from diffusers.utils import is_flax_available
|
||||||
|
from diffusers.utils.testing_utils import require_flax
|
||||||
|
|
||||||
|
|
||||||
|
if is_flax_available():
|
||||||
|
import jax
|
||||||
|
|
||||||
|
|
||||||
|
@require_flax
|
||||||
|
class FlaxModelTesterMixin:
|
||||||
|
def test_output(self):
|
||||||
|
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||||
|
|
||||||
|
model = self.model_class(**init_dict)
|
||||||
|
variables = model.init(inputs_dict["prng_key"], inputs_dict["sample"])
|
||||||
|
jax.lax.stop_gradient(variables)
|
||||||
|
|
||||||
|
output = model.apply(variables, inputs_dict["sample"])
|
||||||
|
|
||||||
|
if isinstance(output, dict):
|
||||||
|
output = output.sample
|
||||||
|
|
||||||
|
self.assertIsNotNone(output)
|
||||||
|
expected_shape = inputs_dict["sample"].shape
|
||||||
|
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||||
|
|
||||||
|
def test_forward_with_norm_groups(self):
|
||||||
|
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||||
|
|
||||||
|
init_dict["norm_num_groups"] = 16
|
||||||
|
init_dict["block_out_channels"] = (16, 32)
|
||||||
|
|
||||||
|
model = self.model_class(**init_dict)
|
||||||
|
variables = model.init(inputs_dict["prng_key"], inputs_dict["sample"])
|
||||||
|
jax.lax.stop_gradient(variables)
|
||||||
|
|
||||||
|
output = model.apply(variables, inputs_dict["sample"])
|
||||||
|
|
||||||
|
if isinstance(output, dict):
|
||||||
|
output = output.sample
|
||||||
|
|
||||||
|
self.assertIsNotNone(output)
|
||||||
|
expected_shape = inputs_dict["sample"].shape
|
||||||
|
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||||
@@ -0,0 +1,39 @@
|
|||||||
|
import unittest
|
||||||
|
|
||||||
|
from diffusers import FlaxAutoencoderKL
|
||||||
|
from diffusers.utils import is_flax_available
|
||||||
|
from diffusers.utils.testing_utils import require_flax
|
||||||
|
|
||||||
|
from .test_modeling_common_flax import FlaxModelTesterMixin
|
||||||
|
|
||||||
|
|
||||||
|
if is_flax_available():
|
||||||
|
import jax
|
||||||
|
|
||||||
|
|
||||||
|
@require_flax
|
||||||
|
class FlaxAutoencoderKLTests(FlaxModelTesterMixin, unittest.TestCase):
|
||||||
|
model_class = FlaxAutoencoderKL
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dummy_input(self):
|
||||||
|
batch_size = 4
|
||||||
|
num_channels = 3
|
||||||
|
sizes = (32, 32)
|
||||||
|
|
||||||
|
prng_key = jax.random.PRNGKey(0)
|
||||||
|
image = jax.random.uniform(prng_key, ((batch_size, num_channels) + sizes))
|
||||||
|
|
||||||
|
return {"sample": image, "prng_key": prng_key}
|
||||||
|
|
||||||
|
def prepare_init_args_and_inputs_for_common(self):
|
||||||
|
init_dict = {
|
||||||
|
"block_out_channels": [32, 64],
|
||||||
|
"in_channels": 3,
|
||||||
|
"out_channels": 3,
|
||||||
|
"down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||||
|
"up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||||
|
"latent_channels": 4,
|
||||||
|
}
|
||||||
|
inputs_dict = self.dummy_input
|
||||||
|
return init_dict, inputs_dict
|
||||||
Reference in New Issue
Block a user