Fix import with Flax but without PyTorch (#688)
* Don't use `load_state_dict` if torch is not installed.
* Define `SchedulerOutput` to use torch or flax arrays.
* Don't import LMSDiscreteScheduler without torch.
* Create distinct FlaxSchedulerOutput.
* Additional changes required for FlaxSchedulerMixin
* Do not import torch pipelines in Flax.
* Revert "Define `SchedulerOutput` to use torch or flax arrays."
This reverts commit f653140134.
* Prefix Flax scheduler outputs for consistency.
* make style
* FlaxSchedulerOutput is now a dataclass.
* Don't use f-string without placeholders.
* Add blank line.
* Style (docstrings)
This commit is contained in:
@@ -73,6 +73,7 @@ if is_flax_available():
|
||||
FlaxKarrasVeScheduler,
|
||||
FlaxLMSDiscreteScheduler,
|
||||
FlaxPNDMScheduler,
|
||||
FlaxSchedulerMixin,
|
||||
FlaxScoreSdeVeScheduler,
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -27,8 +27,8 @@ from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
|
||||
from requests import HTTPError
|
||||
|
||||
from . import is_torch_available
|
||||
from .modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax
|
||||
from .modeling_utils import load_state_dict
|
||||
from .utils import (
|
||||
CONFIG_NAME,
|
||||
DIFFUSERS_CACHE,
|
||||
@@ -391,6 +391,14 @@ class FlaxModelMixin:
|
||||
)
|
||||
|
||||
if from_pt:
|
||||
if is_torch_available():
|
||||
from .modeling_utils import load_state_dict
|
||||
else:
|
||||
raise EnvironmentError(
|
||||
"Can't load the model in PyTorch format because PyTorch is not installed. "
|
||||
"Please, install PyTorch or use native Flax weights."
|
||||
)
|
||||
|
||||
# Step 1: Get the pytorch file
|
||||
pytorch_model_file = load_state_dict(model_file)
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ from tqdm.auto import tqdm
|
||||
|
||||
from .configuration_utils import ConfigMixin
|
||||
from .modeling_flax_utils import FLAX_WEIGHTS_NAME, FlaxModelMixin
|
||||
from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME, SchedulerMixin
|
||||
from .schedulers.scheduling_utils_flax import SCHEDULER_CONFIG_NAME, FlaxSchedulerMixin
|
||||
from .utils import CONFIG_NAME, DIFFUSERS_CACHE, BaseOutput, is_transformers_available, logging
|
||||
|
||||
|
||||
@@ -46,7 +46,7 @@ logger = logging.get_logger(__name__)
|
||||
LOADABLE_CLASSES = {
|
||||
"diffusers": {
|
||||
"FlaxModelMixin": ["save_pretrained", "from_pretrained"],
|
||||
"SchedulerMixin": ["save_config", "from_config"],
|
||||
"FlaxSchedulerMixin": ["save_config", "from_config"],
|
||||
"FlaxDiffusionPipeline": ["save_pretrained", "from_pretrained"],
|
||||
},
|
||||
"transformers": {
|
||||
@@ -436,7 +436,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
|
||||
else:
|
||||
loaded_sub_model, loaded_params = load_method(loadable_folder, _do_init=False)
|
||||
params[name] = loaded_params
|
||||
elif issubclass(class_obj, SchedulerMixin):
|
||||
elif issubclass(class_obj, FlaxSchedulerMixin):
|
||||
loaded_sub_model, scheduler_state = load_method(loadable_folder)
|
||||
params[name] = scheduler_state
|
||||
else:
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
from ..utils import is_flax_available, is_onnx_available, is_torch_available, is_transformers_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from .ddim import DDIMPipeline
|
||||
from .ddpm import DDPMPipeline
|
||||
from .latent_diffusion_uncond import LDMPipeline
|
||||
from .pndm import PNDMPipeline
|
||||
from .score_sde_ve import ScoreSdeVePipeline
|
||||
from .stochastic_karras_ve import KarrasVePipeline
|
||||
|
||||
else:
|
||||
from ..utils.dummy_pt_objects import * # noqa F403
|
||||
|
||||
if is_torch_available() and is_transformers_available():
|
||||
from .latent_diffusion import LDMTextToImagePipeline
|
||||
|
||||
@@ -6,7 +6,7 @@ import numpy as np
|
||||
import PIL
|
||||
from PIL import Image
|
||||
|
||||
from ...utils import BaseOutput, is_flax_available, is_onnx_available, is_transformers_available
|
||||
from ...utils import BaseOutput, is_flax_available, is_onnx_available, is_torch_available, is_transformers_available
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -27,7 +27,7 @@ class StableDiffusionPipelineOutput(BaseOutput):
|
||||
nsfw_content_detected: List[bool]
|
||||
|
||||
|
||||
if is_transformers_available():
|
||||
if is_transformers_available() and is_torch_available():
|
||||
from .pipeline_stable_diffusion import StableDiffusionPipeline
|
||||
from .pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline
|
||||
from .pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline
|
||||
|
||||
@@ -34,10 +34,12 @@ if is_flax_available():
|
||||
from .scheduling_lms_discrete_flax import FlaxLMSDiscreteScheduler
|
||||
from .scheduling_pndm_flax import FlaxPNDMScheduler
|
||||
from .scheduling_sde_ve_flax import FlaxScoreSdeVeScheduler
|
||||
from .scheduling_utils_flax import FlaxSchedulerMixin
|
||||
else:
|
||||
from ..utils.dummy_flax_objects import * # noqa F403
|
||||
|
||||
if is_scipy_available():
|
||||
|
||||
if is_scipy_available() and is_torch_available():
|
||||
from .scheduling_lms_discrete import LMSDiscreteScheduler
|
||||
else:
|
||||
from ..utils.dummy_torch_and_scipy_objects import * # noqa F403
|
||||
|
||||
@@ -23,7 +23,7 @@ import flax
|
||||
import jax.numpy as jnp
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from .scheduling_utils import SchedulerMixin, SchedulerOutput
|
||||
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput
|
||||
|
||||
|
||||
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray:
|
||||
@@ -68,11 +68,11 @@ class DDIMSchedulerState:
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlaxSchedulerOutput(SchedulerOutput):
|
||||
class FlaxDDIMSchedulerOutput(FlaxSchedulerOutput):
|
||||
state: DDIMSchedulerState
|
||||
|
||||
|
||||
class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising
|
||||
diffusion probabilistic models (DDPMs) with non-Markovian guidance.
|
||||
@@ -183,7 +183,7 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
timestep: int,
|
||||
sample: jnp.ndarray,
|
||||
return_dict: bool = True,
|
||||
) -> Union[FlaxSchedulerOutput, Tuple]:
|
||||
) -> Union[FlaxDDIMSchedulerOutput, Tuple]:
|
||||
"""
|
||||
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
|
||||
process from the learned model outputs (most often the predicted noise).
|
||||
@@ -197,11 +197,11 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
key (`random.KeyArray`): a PRNG key.
|
||||
eta (`float`): weight of noise for added noise in diffusion step.
|
||||
use_clipped_model_output (`bool`): TODO
|
||||
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
|
||||
return_dict (`bool`): option for returning tuple rather than FlaxDDIMSchedulerOutput class
|
||||
|
||||
Returns:
|
||||
[`FlaxSchedulerOutput`] or `tuple`: [`FlaxSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`.
|
||||
When returning a tuple, the first element is the sample tensor.
|
||||
[`FlaxDDIMSchedulerOutput`] or `tuple`: [`FlaxDDIMSchedulerOutput`] if `return_dict` is True, otherwise a
|
||||
`tuple`. When returning a tuple, the first element is the sample tensor.
|
||||
|
||||
"""
|
||||
if state.num_inference_steps is None:
|
||||
@@ -252,7 +252,7 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
if not return_dict:
|
||||
return (prev_sample, state)
|
||||
|
||||
return FlaxSchedulerOutput(prev_sample=prev_sample, state=state)
|
||||
return FlaxDDIMSchedulerOutput(prev_sample=prev_sample, state=state)
|
||||
|
||||
def add_noise(
|
||||
self,
|
||||
|
||||
@@ -23,7 +23,7 @@ import jax.numpy as jnp
|
||||
from jax import random
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from .scheduling_utils import SchedulerMixin, SchedulerOutput
|
||||
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput
|
||||
|
||||
|
||||
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray:
|
||||
@@ -67,11 +67,11 @@ class DDPMSchedulerState:
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlaxSchedulerOutput(SchedulerOutput):
|
||||
class FlaxDDPMSchedulerOutput(FlaxSchedulerOutput):
|
||||
state: DDPMSchedulerState
|
||||
|
||||
|
||||
class FlaxDDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and
|
||||
Langevin dynamics sampling.
|
||||
@@ -191,7 +191,7 @@ class FlaxDDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
key: random.KeyArray,
|
||||
predict_epsilon: bool = True,
|
||||
return_dict: bool = True,
|
||||
) -> Union[FlaxSchedulerOutput, Tuple]:
|
||||
) -> Union[FlaxDDPMSchedulerOutput, Tuple]:
|
||||
"""
|
||||
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
|
||||
process from the learned model outputs (most often the predicted noise).
|
||||
@@ -205,11 +205,11 @@ class FlaxDDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
key (`random.KeyArray`): a PRNG key.
|
||||
predict_epsilon (`bool`):
|
||||
optional flag to use when model predicts the samples directly instead of the noise, epsilon.
|
||||
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
|
||||
return_dict (`bool`): option for returning tuple rather than FlaxDDPMSchedulerOutput class
|
||||
|
||||
Returns:
|
||||
[`FlaxSchedulerOutput`] or `tuple`: [`FlaxSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`.
|
||||
When returning a tuple, the first element is the sample tensor.
|
||||
[`FlaxDDPMSchedulerOutput`] or `tuple`: [`FlaxDDPMSchedulerOutput`] if `return_dict` is True, otherwise a
|
||||
`tuple`. When returning a tuple, the first element is the sample tensor.
|
||||
|
||||
"""
|
||||
t = timestep
|
||||
@@ -257,7 +257,7 @@ class FlaxDDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
if not return_dict:
|
||||
return (pred_prev_sample, state)
|
||||
|
||||
return FlaxSchedulerOutput(prev_sample=pred_prev_sample, state=state)
|
||||
return FlaxDDPMSchedulerOutput(prev_sample=pred_prev_sample, state=state)
|
||||
|
||||
def add_noise(
|
||||
self,
|
||||
|
||||
@@ -22,7 +22,7 @@ from jax import random
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import BaseOutput
|
||||
from .scheduling_utils import SchedulerMixin
|
||||
from .scheduling_utils_flax import FlaxSchedulerMixin
|
||||
|
||||
|
||||
@flax.struct.dataclass
|
||||
@@ -56,7 +56,7 @@ class FlaxKarrasVeOutput(BaseOutput):
|
||||
state: KarrasVeSchedulerState
|
||||
|
||||
|
||||
class FlaxKarrasVeScheduler(SchedulerMixin, ConfigMixin):
|
||||
class FlaxKarrasVeScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and
|
||||
the VE column of Table 1 from [1] for reference.
|
||||
@@ -172,7 +172,7 @@ class FlaxKarrasVeScheduler(SchedulerMixin, ConfigMixin):
|
||||
sigma_hat (`float`): TODO
|
||||
sigma_prev (`float`): TODO
|
||||
sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO
|
||||
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
|
||||
return_dict (`bool`): option for returning tuple rather than FlaxKarrasVeOutput class
|
||||
|
||||
Returns:
|
||||
[`~schedulers.scheduling_karras_ve_flax.FlaxKarrasVeOutput`] or `tuple`: Updated sample in the diffusion
|
||||
@@ -211,7 +211,7 @@ class FlaxKarrasVeScheduler(SchedulerMixin, ConfigMixin):
|
||||
sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO
|
||||
sample_prev (`torch.FloatTensor` or `np.ndarray`): TODO
|
||||
derivative (`torch.FloatTensor` or `np.ndarray`): TODO
|
||||
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
|
||||
return_dict (`bool`): option for returning tuple rather than FlaxKarrasVeOutput class
|
||||
|
||||
Returns:
|
||||
prev_sample (TODO): updated sample in the diffusion chain. derivative (TODO): TODO
|
||||
|
||||
@@ -20,7 +20,7 @@ import jax.numpy as jnp
|
||||
from scipy import integrate
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from .scheduling_utils import SchedulerMixin, SchedulerOutput
|
||||
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput
|
||||
|
||||
|
||||
@flax.struct.dataclass
|
||||
@@ -37,11 +37,11 @@ class LMSDiscreteSchedulerState:
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlaxSchedulerOutput(SchedulerOutput):
|
||||
class FlaxLMSSchedulerOutput(FlaxSchedulerOutput):
|
||||
state: LMSDiscreteSchedulerState
|
||||
|
||||
|
||||
class FlaxLMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
Linear Multistep Scheduler for discrete beta schedules. Based on the original k-diffusion implementation by
|
||||
Katherine Crowson:
|
||||
@@ -147,7 +147,7 @@ class FlaxLMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
sample: jnp.ndarray,
|
||||
order: int = 4,
|
||||
return_dict: bool = True,
|
||||
) -> Union[SchedulerOutput, Tuple]:
|
||||
) -> Union[FlaxLMSSchedulerOutput, Tuple]:
|
||||
"""
|
||||
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
|
||||
process from the learned model outputs (most often the predicted noise).
|
||||
@@ -159,11 +159,11 @@ class FlaxLMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
sample (`jnp.ndarray`):
|
||||
current instance of sample being created by diffusion process.
|
||||
order: coefficient for multi-step inference.
|
||||
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
|
||||
return_dict (`bool`): option for returning tuple rather than FlaxLMSSchedulerOutput class
|
||||
|
||||
Returns:
|
||||
[`FlaxSchedulerOutput`] or `tuple`: [`FlaxSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`.
|
||||
When returning a tuple, the first element is the sample tensor.
|
||||
[`FlaxLMSSchedulerOutput`] or `tuple`: [`FlaxLMSSchedulerOutput`] if `return_dict` is True, otherwise a
|
||||
`tuple`. When returning a tuple, the first element is the sample tensor.
|
||||
|
||||
"""
|
||||
sigma = state.sigmas[timestep]
|
||||
@@ -189,7 +189,7 @@ class FlaxLMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
if not return_dict:
|
||||
return (prev_sample, state)
|
||||
|
||||
return FlaxSchedulerOutput(prev_sample=prev_sample, state=state)
|
||||
return FlaxLMSSchedulerOutput(prev_sample=prev_sample, state=state)
|
||||
|
||||
def add_noise(
|
||||
self,
|
||||
|
||||
@@ -23,7 +23,7 @@ import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from .scheduling_utils import SchedulerMixin, SchedulerOutput
|
||||
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput
|
||||
|
||||
|
||||
def betas_for_alpha_bar(num_diffusion_timesteps: int, max_beta=0.999) -> jnp.ndarray:
|
||||
@@ -76,11 +76,11 @@ class PNDMSchedulerState:
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlaxSchedulerOutput(SchedulerOutput):
|
||||
class FlaxPNDMSchedulerOutput(FlaxSchedulerOutput):
|
||||
state: PNDMSchedulerState
|
||||
|
||||
|
||||
class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
Pseudo numerical methods for diffusion models (PNDM) proposes using more advanced ODE integration techniques,
|
||||
namely Runge-Kutta method and a linear multi-step method.
|
||||
@@ -211,7 +211,7 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
timestep: int,
|
||||
sample: jnp.ndarray,
|
||||
return_dict: bool = True,
|
||||
) -> Union[FlaxSchedulerOutput, Tuple]:
|
||||
) -> Union[FlaxPNDMSchedulerOutput, Tuple]:
|
||||
"""
|
||||
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
|
||||
process from the learned model outputs (most often the predicted noise).
|
||||
@@ -224,11 +224,11 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
timestep (`int`): current discrete timestep in the diffusion chain.
|
||||
sample (`jnp.ndarray`):
|
||||
current instance of sample being created by diffusion process.
|
||||
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
|
||||
return_dict (`bool`): option for returning tuple rather than FlaxPNDMSchedulerOutput class
|
||||
|
||||
Returns:
|
||||
[`FlaxSchedulerOutput`] or `tuple`: [`FlaxSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`.
|
||||
When returning a tuple, the first element is the sample tensor.
|
||||
[`FlaxPNDMSchedulerOutput`] or `tuple`: [`FlaxPNDMSchedulerOutput`] if `return_dict` is True, otherwise a
|
||||
`tuple`. When returning a tuple, the first element is the sample tensor.
|
||||
|
||||
"""
|
||||
if self.config.skip_prk_steps:
|
||||
@@ -249,7 +249,7 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
if not return_dict:
|
||||
return (prev_sample, state)
|
||||
|
||||
return FlaxSchedulerOutput(prev_sample=prev_sample, state=state)
|
||||
return FlaxPNDMSchedulerOutput(prev_sample=prev_sample, state=state)
|
||||
|
||||
def step_prk(
|
||||
self,
|
||||
@@ -257,7 +257,7 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
model_output: jnp.ndarray,
|
||||
timestep: int,
|
||||
sample: jnp.ndarray,
|
||||
) -> Union[FlaxSchedulerOutput, Tuple]:
|
||||
) -> Union[FlaxPNDMSchedulerOutput, Tuple]:
|
||||
"""
|
||||
Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the
|
||||
solution to the differential equation.
|
||||
@@ -268,11 +268,11 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
timestep (`int`): current discrete timestep in the diffusion chain.
|
||||
sample (`jnp.ndarray`):
|
||||
current instance of sample being created by diffusion process.
|
||||
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
|
||||
return_dict (`bool`): option for returning tuple rather than FlaxPNDMSchedulerOutput class
|
||||
|
||||
Returns:
|
||||
[`FlaxSchedulerOutput`] or `tuple`: [`FlaxSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`.
|
||||
When returning a tuple, the first element is the sample tensor.
|
||||
[`FlaxPNDMSchedulerOutput`] or `tuple`: [`FlaxPNDMSchedulerOutput`] if `return_dict` is True, otherwise a
|
||||
`tuple`. When returning a tuple, the first element is the sample tensor.
|
||||
|
||||
"""
|
||||
if state.num_inference_steps is None:
|
||||
@@ -327,7 +327,7 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
model_output: jnp.ndarray,
|
||||
timestep: int,
|
||||
sample: jnp.ndarray,
|
||||
) -> Union[FlaxSchedulerOutput, Tuple]:
|
||||
) -> Union[FlaxPNDMSchedulerOutput, Tuple]:
|
||||
"""
|
||||
Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple
|
||||
times to approximate the solution.
|
||||
@@ -338,11 +338,11 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
timestep (`int`): current discrete timestep in the diffusion chain.
|
||||
sample (`jnp.ndarray`):
|
||||
current instance of sample being created by diffusion process.
|
||||
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
|
||||
return_dict (`bool`): option for returning tuple rather than FlaxPNDMSchedulerOutput class
|
||||
|
||||
Returns:
|
||||
[`FlaxSchedulerOutput`] or `tuple`: [`FlaxSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`.
|
||||
When returning a tuple, the first element is the sample tensor.
|
||||
[`FlaxPNDMSchedulerOutput`] or `tuple`: [`FlaxPNDMSchedulerOutput`] if `return_dict` is True, otherwise a
|
||||
`tuple`. When returning a tuple, the first element is the sample tensor.
|
||||
|
||||
"""
|
||||
if state.num_inference_steps is None:
|
||||
|
||||
@@ -22,7 +22,7 @@ import jax.numpy as jnp
|
||||
from jax import random
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from .scheduling_utils import SchedulerMixin, SchedulerOutput
|
||||
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput
|
||||
|
||||
|
||||
@flax.struct.dataclass
|
||||
@@ -38,7 +38,7 @@ class ScoreSdeVeSchedulerState:
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlaxSdeVeOutput(SchedulerOutput):
|
||||
class FlaxSdeVeOutput(FlaxSchedulerOutput):
|
||||
"""
|
||||
Output class for the ScoreSdeVeScheduler's step function output.
|
||||
|
||||
@@ -56,7 +56,7 @@ class FlaxSdeVeOutput(SchedulerOutput):
|
||||
prev_sample_mean: Optional[jnp.ndarray] = None
|
||||
|
||||
|
||||
class FlaxScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
|
||||
class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
The variance exploding stochastic differential equation (SDE) scheduler.
|
||||
|
||||
@@ -168,7 +168,7 @@ class FlaxScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
|
||||
sample (`jnp.ndarray`):
|
||||
current instance of sample being created by diffusion process.
|
||||
generator: random number generator.
|
||||
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
|
||||
return_dict (`bool`): option for returning tuple rather than FlaxSdeVeOutput class
|
||||
|
||||
Returns:
|
||||
[`FlaxSdeVeOutput`] or `tuple`: [`FlaxSdeVeOutput`] if `return_dict` is True, otherwise a `tuple`. When
|
||||
@@ -216,7 +216,7 @@ class FlaxScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
|
||||
sample: jnp.ndarray,
|
||||
key: random.KeyArray,
|
||||
return_dict: bool = True,
|
||||
) -> Union[SchedulerOutput, Tuple]:
|
||||
) -> Union[FlaxSdeVeOutput, Tuple]:
|
||||
"""
|
||||
Correct the predicted sample based on the output model_output of the network. This is often run repeatedly
|
||||
after making the prediction for the previous timestep.
|
||||
@@ -227,7 +227,7 @@ class FlaxScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
|
||||
sample (`jnp.ndarray`):
|
||||
current instance of sample being created by diffusion process.
|
||||
generator: random number generator.
|
||||
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
|
||||
return_dict (`bool`): option for returning tuple rather than FlaxSdeVeOutput class
|
||||
|
||||
Returns:
|
||||
[`FlaxSdeVeOutput`] or `tuple`: [`FlaxSdeVeOutput`] if `return_dict` is True, otherwise a `tuple`. When
|
||||
|
||||
@@ -0,0 +1,53 @@
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
|
||||
import jax.numpy as jnp
|
||||
|
||||
from ..utils import BaseOutput
|
||||
|
||||
|
||||
SCHEDULER_CONFIG_NAME = "scheduler_config.json"
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlaxSchedulerOutput(BaseOutput):
|
||||
"""
|
||||
Base class for the scheduler's step function output.
|
||||
|
||||
Args:
|
||||
prev_sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)` for images):
|
||||
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
|
||||
denoising loop.
|
||||
"""
|
||||
|
||||
prev_sample: jnp.ndarray
|
||||
|
||||
|
||||
class FlaxSchedulerMixin:
|
||||
"""
|
||||
Mixin containing common functions for the schedulers.
|
||||
"""
|
||||
|
||||
config_name = SCHEDULER_CONFIG_NAME
|
||||
|
||||
def set_format(self, tensor_format="pt"):
|
||||
warnings.warn(
|
||||
"The method `set_format` is deprecated and will be removed in version `0.5.0`."
|
||||
"If you're running your code in PyTorch, you can safely remove this function as the schedulers"
|
||||
"are always in Pytorch",
|
||||
DeprecationWarning,
|
||||
)
|
||||
return self
|
||||
Reference in New Issue
Block a user