Support training SD V2 with Flax (#1783)
* Support training SD V2 with Flax Mostly involves supporting a v_prediction scheduler. The implementation in #1777 doesn't take into account a recent refactor of `scheduling_utils_flax`, so this should be used instead. * Add to other top-level files.
This commit is contained in:
@@ -525,28 +525,35 @@ def main():
|
|||||||
)[0]
|
)[0]
|
||||||
|
|
||||||
# Predict the noise residual
|
# Predict the noise residual
|
||||||
unet_outputs = unet.apply(
|
model_pred = unet.apply(
|
||||||
{"params": params["unet"]}, noisy_latents, timesteps, encoder_hidden_states, train=True
|
{"params": params["unet"]}, noisy_latents, timesteps, encoder_hidden_states, train=True
|
||||||
)
|
).sample
|
||||||
noise_pred = unet_outputs.sample
|
|
||||||
|
# Get the target for loss depending on the prediction type
|
||||||
|
if noise_scheduler.config.prediction_type == "epsilon":
|
||||||
|
target = noise
|
||||||
|
elif noise_scheduler.config.prediction_type == "v_prediction":
|
||||||
|
target = noise_scheduler.get_velocity(noise_scheduler_state, latents, noise, timesteps)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
||||||
|
|
||||||
if args.with_prior_preservation:
|
if args.with_prior_preservation:
|
||||||
# Chunk the noise and noise_pred into two parts and compute the loss on each part separately.
|
# Chunk the noise and noise_pred into two parts and compute the loss on each part separately.
|
||||||
noise_pred, noise_pred_prior = jnp.split(noise_pred, 2, axis=0)
|
model_pred, model_pred_prior = jnp.split(model_pred, 2, axis=0)
|
||||||
noise, noise_prior = jnp.split(noise, 2, axis=0)
|
target, target_prior = jnp.split(target, 2, axis=0)
|
||||||
|
|
||||||
# Compute instance loss
|
# Compute instance loss
|
||||||
loss = (noise - noise_pred) ** 2
|
loss = (target - model_pred) ** 2
|
||||||
loss = loss.mean()
|
loss = loss.mean()
|
||||||
|
|
||||||
# Compute prior loss
|
# Compute prior loss
|
||||||
prior_loss = (noise_prior - noise_pred_prior) ** 2
|
prior_loss = (target_prior - model_pred_prior) ** 2
|
||||||
prior_loss = prior_loss.mean()
|
prior_loss = prior_loss.mean()
|
||||||
|
|
||||||
# Add the prior loss to the instance loss.
|
# Add the prior loss to the instance loss.
|
||||||
loss = loss + args.prior_loss_weight * prior_loss
|
loss = loss + args.prior_loss_weight * prior_loss
|
||||||
else:
|
else:
|
||||||
loss = (noise - noise_pred) ** 2
|
loss = (target - model_pred) ** 2
|
||||||
loss = loss.mean()
|
loss = loss.mean()
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|||||||
@@ -459,9 +459,19 @@ def main():
|
|||||||
)[0]
|
)[0]
|
||||||
|
|
||||||
# Predict the noise residual and compute loss
|
# Predict the noise residual and compute loss
|
||||||
unet_outputs = unet.apply({"params": params}, noisy_latents, timesteps, encoder_hidden_states, train=True)
|
model_pred = unet.apply(
|
||||||
noise_pred = unet_outputs.sample
|
{"params": params}, noisy_latents, timesteps, encoder_hidden_states, train=True
|
||||||
loss = (noise - noise_pred) ** 2
|
).sample
|
||||||
|
|
||||||
|
# Get the target for loss depending on the prediction type
|
||||||
|
if noise_scheduler.config.prediction_type == "epsilon":
|
||||||
|
target = noise
|
||||||
|
elif noise_scheduler.config.prediction_type == "v_prediction":
|
||||||
|
target = noise_scheduler.get_velocity(noise_scheduler_state, latents, noise, timesteps)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
||||||
|
|
||||||
|
loss = (target - model_pred) ** 2
|
||||||
loss = loss.mean()
|
loss = loss.mean()
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|||||||
@@ -536,11 +536,20 @@ def main():
|
|||||||
encoder_hidden_states = state.apply_fn(
|
encoder_hidden_states = state.apply_fn(
|
||||||
batch["input_ids"], params=params, dropout_rng=dropout_rng, train=True
|
batch["input_ids"], params=params, dropout_rng=dropout_rng, train=True
|
||||||
)[0]
|
)[0]
|
||||||
unet_outputs = unet.apply(
|
# Predict the noise residual and compute loss
|
||||||
|
model_pred = unet.apply(
|
||||||
{"params": unet_params}, noisy_latents, timesteps, encoder_hidden_states, train=False
|
{"params": unet_params}, noisy_latents, timesteps, encoder_hidden_states, train=False
|
||||||
)
|
).sample
|
||||||
noise_pred = unet_outputs.sample
|
|
||||||
loss = (noise - noise_pred) ** 2
|
# Get the target for loss depending on the prediction type
|
||||||
|
if noise_scheduler.config.prediction_type == "epsilon":
|
||||||
|
target = noise
|
||||||
|
elif noise_scheduler.config.prediction_type == "v_prediction":
|
||||||
|
target = noise_scheduler.get_velocity(noise_scheduler_state, latents, noise, timesteps)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
||||||
|
|
||||||
|
loss = (target - model_pred) ** 2
|
||||||
loss = loss.mean()
|
loss = loss.mean()
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ from .scheduling_utils_flax import (
|
|||||||
FlaxSchedulerMixin,
|
FlaxSchedulerMixin,
|
||||||
FlaxSchedulerOutput,
|
FlaxSchedulerOutput,
|
||||||
add_noise_common,
|
add_noise_common,
|
||||||
|
get_velocity_common,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -301,5 +302,14 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
|||||||
) -> jnp.ndarray:
|
) -> jnp.ndarray:
|
||||||
return add_noise_common(state.common, original_samples, noise, timesteps)
|
return add_noise_common(state.common, original_samples, noise, timesteps)
|
||||||
|
|
||||||
|
def get_velocity(
|
||||||
|
self,
|
||||||
|
state: DDIMSchedulerState,
|
||||||
|
sample: jnp.ndarray,
|
||||||
|
noise: jnp.ndarray,
|
||||||
|
timesteps: jnp.ndarray,
|
||||||
|
) -> jnp.ndarray:
|
||||||
|
return get_velocity_common(state.common, sample, noise, timesteps)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.config.num_train_timesteps
|
return self.config.num_train_timesteps
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ from .scheduling_utils_flax import (
|
|||||||
FlaxSchedulerMixin,
|
FlaxSchedulerMixin,
|
||||||
FlaxSchedulerOutput,
|
FlaxSchedulerOutput,
|
||||||
add_noise_common,
|
add_noise_common,
|
||||||
|
get_velocity_common,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -293,5 +294,14 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
|||||||
) -> jnp.ndarray:
|
) -> jnp.ndarray:
|
||||||
return add_noise_common(state.common, original_samples, noise, timesteps)
|
return add_noise_common(state.common, original_samples, noise, timesteps)
|
||||||
|
|
||||||
|
def get_velocity(
|
||||||
|
self,
|
||||||
|
state: DDPMSchedulerState,
|
||||||
|
sample: jnp.ndarray,
|
||||||
|
noise: jnp.ndarray,
|
||||||
|
timesteps: jnp.ndarray,
|
||||||
|
) -> jnp.ndarray:
|
||||||
|
return get_velocity_common(state.common, sample, noise, timesteps)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.config.num_train_timesteps
|
return self.config.num_train_timesteps
|
||||||
|
|||||||
@@ -242,7 +242,7 @@ class CommonSchedulerState:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def add_noise_common(
|
def get_sqrt_alpha_prod(
|
||||||
state: CommonSchedulerState, original_samples: jnp.ndarray, noise: jnp.ndarray, timesteps: jnp.ndarray
|
state: CommonSchedulerState, original_samples: jnp.ndarray, noise: jnp.ndarray, timesteps: jnp.ndarray
|
||||||
):
|
):
|
||||||
alphas_cumprod = state.alphas_cumprod
|
alphas_cumprod = state.alphas_cumprod
|
||||||
@@ -255,5 +255,18 @@ def add_noise_common(
|
|||||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
||||||
sqrt_one_minus_alpha_prod = broadcast_to_shape_from_left(sqrt_one_minus_alpha_prod, original_samples.shape)
|
sqrt_one_minus_alpha_prod = broadcast_to_shape_from_left(sqrt_one_minus_alpha_prod, original_samples.shape)
|
||||||
|
|
||||||
|
return sqrt_alpha_prod, sqrt_one_minus_alpha_prod
|
||||||
|
|
||||||
|
|
||||||
|
def add_noise_common(
|
||||||
|
state: CommonSchedulerState, original_samples: jnp.ndarray, noise: jnp.ndarray, timesteps: jnp.ndarray
|
||||||
|
):
|
||||||
|
sqrt_alpha_prod, sqrt_one_minus_alpha_prod = get_sqrt_alpha_prod(state, original_samples, noise, timesteps)
|
||||||
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
||||||
return noisy_samples
|
return noisy_samples
|
||||||
|
|
||||||
|
|
||||||
|
def get_velocity_common(state: CommonSchedulerState, sample: jnp.ndarray, noise: jnp.ndarray, timesteps: jnp.ndarray):
|
||||||
|
sqrt_alpha_prod, sqrt_one_minus_alpha_prod = get_sqrt_alpha_prod(state, sample, noise, timesteps)
|
||||||
|
velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
|
||||||
|
return velocity
|
||||||
|
|||||||
Reference in New Issue
Block a user