update flax scheduler API (#822)
* update flax scheduler API * remoev set format * fix call to scale_model_input * update flax pndm * use int32 * update docstr
This commit is contained in:
@@ -170,6 +170,8 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
|
||||
t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
|
||||
timestep = jnp.broadcast_to(t, latents_input.shape[0])
|
||||
|
||||
latents_input = self.scheduler.scale_model_input(scheduler_state, latents_input, t)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet.apply(
|
||||
{"params": params["unet"]},
|
||||
@@ -189,6 +191,9 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
|
||||
params["scheduler"], num_inference_steps=num_inference_steps, shape=latents.shape
|
||||
)
|
||||
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
|
||||
if debug:
|
||||
# run with python for loop
|
||||
for i in range(num_inference_steps):
|
||||
|
||||
@@ -141,6 +141,23 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
# whether we use the final alpha of the "non-previous" one.
|
||||
self.final_alpha_cumprod = jnp.array(1.0) if set_alpha_to_one else float(self._alphas_cumprod[0])
|
||||
|
||||
# standard deviation of the initial noise distribution
|
||||
self.init_noise_sigma = 1.0
|
||||
|
||||
def scale_model_input(
|
||||
self, state: DDIMSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None
|
||||
) -> jnp.ndarray:
|
||||
"""
|
||||
Args:
|
||||
state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance.
|
||||
sample (`jnp.ndarray`): input sample
|
||||
timestep (`int`, optional): current timestep
|
||||
|
||||
Returns:
|
||||
`jnp.ndarray`: scaled input sample
|
||||
"""
|
||||
return sample
|
||||
|
||||
def create_state(self):
|
||||
return DDIMSchedulerState.create(
|
||||
num_train_timesteps=self.config.num_train_timesteps, alphas_cumprod=self._alphas_cumprod
|
||||
|
||||
@@ -153,6 +153,9 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
# mainly at formula (9), (12), (13) and the Algorithm 2.
|
||||
self.pndm_order = 4
|
||||
|
||||
# standard deviation of the initial noise distribution
|
||||
self.init_noise_sigma = 1.0
|
||||
|
||||
def create_state(self):
|
||||
return PNDMSchedulerState.create(num_train_timesteps=self.config.num_train_timesteps)
|
||||
|
||||
@@ -196,7 +199,7 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
)
|
||||
|
||||
return state.replace(
|
||||
timesteps=jnp.concatenate([state.prk_timesteps, state.plms_timesteps]).astype(jnp.int64),
|
||||
timesteps=jnp.concatenate([state.prk_timesteps, state.plms_timesteps]).astype(jnp.int32),
|
||||
counter=0,
|
||||
# Reserve space for the state variables
|
||||
cur_model_output=jnp.zeros(shape),
|
||||
@@ -204,6 +207,23 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
ets=jnp.zeros((4,) + shape),
|
||||
)
|
||||
|
||||
def scale_model_input(
|
||||
self, state: PNDMSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None
|
||||
) -> jnp.ndarray:
|
||||
"""
|
||||
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
||||
current timestep.
|
||||
|
||||
Args:
|
||||
state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance.
|
||||
sample (`jnp.ndarray`): input sample
|
||||
timestep (`int`, optional): current timestep
|
||||
|
||||
Returns:
|
||||
`jnp.ndarray`: scaled input sample
|
||||
"""
|
||||
return sample
|
||||
|
||||
def step(
|
||||
self,
|
||||
state: PNDMSchedulerState,
|
||||
|
||||
Reference in New Issue
Block a user