Fix import with Flax but without PyTorch (#688)
* Don't use `load_state_dict` if torch is not installed.
* Define `SchedulerOutput` to use torch or flax arrays.
* Don't import LMSDiscreteScheduler without torch.
* Create distinct FlaxSchedulerOutput.
* Additional changes required for FlaxSchedulerMixin
* Do not import torch pipelines in Flax.
* Revert "Define `SchedulerOutput` to use torch or flax arrays."
This reverts commit f653140134.
* Prefix Flax scheduler outputs for consistency.
* make style
* FlaxSchedulerOutput is now a dataclass.
* Don't use f-string without placeholders.
* Add blank line.
* Style (docstrings)
This commit is contained in:
@@ -73,6 +73,7 @@ if is_flax_available():
|
|||||||
FlaxKarrasVeScheduler,
|
FlaxKarrasVeScheduler,
|
||||||
FlaxLMSDiscreteScheduler,
|
FlaxLMSDiscreteScheduler,
|
||||||
FlaxPNDMScheduler,
|
FlaxPNDMScheduler,
|
||||||
|
FlaxSchedulerMixin,
|
||||||
FlaxScoreSdeVeScheduler,
|
FlaxScoreSdeVeScheduler,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -27,8 +27,8 @@ from huggingface_hub import hf_hub_download
|
|||||||
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
|
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
|
||||||
from requests import HTTPError
|
from requests import HTTPError
|
||||||
|
|
||||||
|
from . import is_torch_available
|
||||||
from .modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax
|
from .modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax
|
||||||
from .modeling_utils import load_state_dict
|
|
||||||
from .utils import (
|
from .utils import (
|
||||||
CONFIG_NAME,
|
CONFIG_NAME,
|
||||||
DIFFUSERS_CACHE,
|
DIFFUSERS_CACHE,
|
||||||
@@ -391,6 +391,14 @@ class FlaxModelMixin:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if from_pt:
|
if from_pt:
|
||||||
|
if is_torch_available():
|
||||||
|
from .modeling_utils import load_state_dict
|
||||||
|
else:
|
||||||
|
raise EnvironmentError(
|
||||||
|
"Can't load the model in PyTorch format because PyTorch is not installed. "
|
||||||
|
"Please, install PyTorch or use native Flax weights."
|
||||||
|
)
|
||||||
|
|
||||||
# Step 1: Get the pytorch file
|
# Step 1: Get the pytorch file
|
||||||
pytorch_model_file = load_state_dict(model_file)
|
pytorch_model_file = load_state_dict(model_file)
|
||||||
|
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ from tqdm.auto import tqdm
|
|||||||
|
|
||||||
from .configuration_utils import ConfigMixin
|
from .configuration_utils import ConfigMixin
|
||||||
from .modeling_flax_utils import FLAX_WEIGHTS_NAME, FlaxModelMixin
|
from .modeling_flax_utils import FLAX_WEIGHTS_NAME, FlaxModelMixin
|
||||||
from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME, SchedulerMixin
|
from .schedulers.scheduling_utils_flax import SCHEDULER_CONFIG_NAME, FlaxSchedulerMixin
|
||||||
from .utils import CONFIG_NAME, DIFFUSERS_CACHE, BaseOutput, is_transformers_available, logging
|
from .utils import CONFIG_NAME, DIFFUSERS_CACHE, BaseOutput, is_transformers_available, logging
|
||||||
|
|
||||||
|
|
||||||
@@ -46,7 +46,7 @@ logger = logging.get_logger(__name__)
|
|||||||
LOADABLE_CLASSES = {
|
LOADABLE_CLASSES = {
|
||||||
"diffusers": {
|
"diffusers": {
|
||||||
"FlaxModelMixin": ["save_pretrained", "from_pretrained"],
|
"FlaxModelMixin": ["save_pretrained", "from_pretrained"],
|
||||||
"SchedulerMixin": ["save_config", "from_config"],
|
"FlaxSchedulerMixin": ["save_config", "from_config"],
|
||||||
"FlaxDiffusionPipeline": ["save_pretrained", "from_pretrained"],
|
"FlaxDiffusionPipeline": ["save_pretrained", "from_pretrained"],
|
||||||
},
|
},
|
||||||
"transformers": {
|
"transformers": {
|
||||||
@@ -436,7 +436,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
|
|||||||
else:
|
else:
|
||||||
loaded_sub_model, loaded_params = load_method(loadable_folder, _do_init=False)
|
loaded_sub_model, loaded_params = load_method(loadable_folder, _do_init=False)
|
||||||
params[name] = loaded_params
|
params[name] = loaded_params
|
||||||
elif issubclass(class_obj, SchedulerMixin):
|
elif issubclass(class_obj, FlaxSchedulerMixin):
|
||||||
loaded_sub_model, scheduler_state = load_method(loadable_folder)
|
loaded_sub_model, scheduler_state = load_method(loadable_folder)
|
||||||
params[name] = scheduler_state
|
params[name] = scheduler_state
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -1,12 +1,16 @@
|
|||||||
from ..utils import is_flax_available, is_onnx_available, is_torch_available, is_transformers_available
|
from ..utils import is_flax_available, is_onnx_available, is_torch_available, is_transformers_available
|
||||||
from .ddim import DDIMPipeline
|
|
||||||
from .ddpm import DDPMPipeline
|
|
||||||
from .latent_diffusion_uncond import LDMPipeline
|
|
||||||
from .pndm import PNDMPipeline
|
|
||||||
from .score_sde_ve import ScoreSdeVePipeline
|
|
||||||
from .stochastic_karras_ve import KarrasVePipeline
|
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
from .ddim import DDIMPipeline
|
||||||
|
from .ddpm import DDPMPipeline
|
||||||
|
from .latent_diffusion_uncond import LDMPipeline
|
||||||
|
from .pndm import PNDMPipeline
|
||||||
|
from .score_sde_ve import ScoreSdeVePipeline
|
||||||
|
from .stochastic_karras_ve import KarrasVePipeline
|
||||||
|
else:
|
||||||
|
from ..utils.dummy_pt_objects import * # noqa F403
|
||||||
|
|
||||||
if is_torch_available() and is_transformers_available():
|
if is_torch_available() and is_transformers_available():
|
||||||
from .latent_diffusion import LDMTextToImagePipeline
|
from .latent_diffusion import LDMTextToImagePipeline
|
||||||
from .stable_diffusion import (
|
from .stable_diffusion import (
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import numpy as np
|
|||||||
import PIL
|
import PIL
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from ...utils import BaseOutput, is_flax_available, is_onnx_available, is_transformers_available
|
from ...utils import BaseOutput, is_flax_available, is_onnx_available, is_torch_available, is_transformers_available
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -27,7 +27,7 @@ class StableDiffusionPipelineOutput(BaseOutput):
|
|||||||
nsfw_content_detected: List[bool]
|
nsfw_content_detected: List[bool]
|
||||||
|
|
||||||
|
|
||||||
if is_transformers_available():
|
if is_transformers_available() and is_torch_available():
|
||||||
from .pipeline_stable_diffusion import StableDiffusionPipeline
|
from .pipeline_stable_diffusion import StableDiffusionPipeline
|
||||||
from .pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline
|
from .pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline
|
||||||
from .pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline
|
from .pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline
|
||||||
|
|||||||
@@ -34,10 +34,12 @@ if is_flax_available():
|
|||||||
from .scheduling_lms_discrete_flax import FlaxLMSDiscreteScheduler
|
from .scheduling_lms_discrete_flax import FlaxLMSDiscreteScheduler
|
||||||
from .scheduling_pndm_flax import FlaxPNDMScheduler
|
from .scheduling_pndm_flax import FlaxPNDMScheduler
|
||||||
from .scheduling_sde_ve_flax import FlaxScoreSdeVeScheduler
|
from .scheduling_sde_ve_flax import FlaxScoreSdeVeScheduler
|
||||||
|
from .scheduling_utils_flax import FlaxSchedulerMixin
|
||||||
else:
|
else:
|
||||||
from ..utils.dummy_flax_objects import * # noqa F403
|
from ..utils.dummy_flax_objects import * # noqa F403
|
||||||
|
|
||||||
if is_scipy_available():
|
|
||||||
|
if is_scipy_available() and is_torch_available():
|
||||||
from .scheduling_lms_discrete import LMSDiscreteScheduler
|
from .scheduling_lms_discrete import LMSDiscreteScheduler
|
||||||
else:
|
else:
|
||||||
from ..utils.dummy_torch_and_scipy_objects import * # noqa F403
|
from ..utils.dummy_torch_and_scipy_objects import * # noqa F403
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ import flax
|
|||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
|
|
||||||
from ..configuration_utils import ConfigMixin, register_to_config
|
from ..configuration_utils import ConfigMixin, register_to_config
|
||||||
from .scheduling_utils import SchedulerMixin, SchedulerOutput
|
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput
|
||||||
|
|
||||||
|
|
||||||
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray:
|
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray:
|
||||||
@@ -68,11 +68,11 @@ class DDIMSchedulerState:
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FlaxSchedulerOutput(SchedulerOutput):
|
class FlaxDDIMSchedulerOutput(FlaxSchedulerOutput):
|
||||||
state: DDIMSchedulerState
|
state: DDIMSchedulerState
|
||||||
|
|
||||||
|
|
||||||
class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin):
|
class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||||
"""
|
"""
|
||||||
Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising
|
Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising
|
||||||
diffusion probabilistic models (DDPMs) with non-Markovian guidance.
|
diffusion probabilistic models (DDPMs) with non-Markovian guidance.
|
||||||
@@ -183,7 +183,7 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
timestep: int,
|
timestep: int,
|
||||||
sample: jnp.ndarray,
|
sample: jnp.ndarray,
|
||||||
return_dict: bool = True,
|
return_dict: bool = True,
|
||||||
) -> Union[FlaxSchedulerOutput, Tuple]:
|
) -> Union[FlaxDDIMSchedulerOutput, Tuple]:
|
||||||
"""
|
"""
|
||||||
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
|
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
|
||||||
process from the learned model outputs (most often the predicted noise).
|
process from the learned model outputs (most often the predicted noise).
|
||||||
@@ -197,11 +197,11 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
key (`random.KeyArray`): a PRNG key.
|
key (`random.KeyArray`): a PRNG key.
|
||||||
eta (`float`): weight of noise for added noise in diffusion step.
|
eta (`float`): weight of noise for added noise in diffusion step.
|
||||||
use_clipped_model_output (`bool`): TODO
|
use_clipped_model_output (`bool`): TODO
|
||||||
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
|
return_dict (`bool`): option for returning tuple rather than FlaxDDIMSchedulerOutput class
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
[`FlaxSchedulerOutput`] or `tuple`: [`FlaxSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`.
|
[`FlaxDDIMSchedulerOutput`] or `tuple`: [`FlaxDDIMSchedulerOutput`] if `return_dict` is True, otherwise a
|
||||||
When returning a tuple, the first element is the sample tensor.
|
`tuple`. When returning a tuple, the first element is the sample tensor.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if state.num_inference_steps is None:
|
if state.num_inference_steps is None:
|
||||||
@@ -252,7 +252,7 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
if not return_dict:
|
if not return_dict:
|
||||||
return (prev_sample, state)
|
return (prev_sample, state)
|
||||||
|
|
||||||
return FlaxSchedulerOutput(prev_sample=prev_sample, state=state)
|
return FlaxDDIMSchedulerOutput(prev_sample=prev_sample, state=state)
|
||||||
|
|
||||||
def add_noise(
|
def add_noise(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ import jax.numpy as jnp
|
|||||||
from jax import random
|
from jax import random
|
||||||
|
|
||||||
from ..configuration_utils import ConfigMixin, register_to_config
|
from ..configuration_utils import ConfigMixin, register_to_config
|
||||||
from .scheduling_utils import SchedulerMixin, SchedulerOutput
|
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput
|
||||||
|
|
||||||
|
|
||||||
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray:
|
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray:
|
||||||
@@ -67,11 +67,11 @@ class DDPMSchedulerState:
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FlaxSchedulerOutput(SchedulerOutput):
|
class FlaxDDPMSchedulerOutput(FlaxSchedulerOutput):
|
||||||
state: DDPMSchedulerState
|
state: DDPMSchedulerState
|
||||||
|
|
||||||
|
|
||||||
class FlaxDDPMScheduler(SchedulerMixin, ConfigMixin):
|
class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||||
"""
|
"""
|
||||||
Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and
|
Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and
|
||||||
Langevin dynamics sampling.
|
Langevin dynamics sampling.
|
||||||
@@ -191,7 +191,7 @@ class FlaxDDPMScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
key: random.KeyArray,
|
key: random.KeyArray,
|
||||||
predict_epsilon: bool = True,
|
predict_epsilon: bool = True,
|
||||||
return_dict: bool = True,
|
return_dict: bool = True,
|
||||||
) -> Union[FlaxSchedulerOutput, Tuple]:
|
) -> Union[FlaxDDPMSchedulerOutput, Tuple]:
|
||||||
"""
|
"""
|
||||||
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
|
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
|
||||||
process from the learned model outputs (most often the predicted noise).
|
process from the learned model outputs (most often the predicted noise).
|
||||||
@@ -205,11 +205,11 @@ class FlaxDDPMScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
key (`random.KeyArray`): a PRNG key.
|
key (`random.KeyArray`): a PRNG key.
|
||||||
predict_epsilon (`bool`):
|
predict_epsilon (`bool`):
|
||||||
optional flag to use when model predicts the samples directly instead of the noise, epsilon.
|
optional flag to use when model predicts the samples directly instead of the noise, epsilon.
|
||||||
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
|
return_dict (`bool`): option for returning tuple rather than FlaxDDPMSchedulerOutput class
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
[`FlaxSchedulerOutput`] or `tuple`: [`FlaxSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`.
|
[`FlaxDDPMSchedulerOutput`] or `tuple`: [`FlaxDDPMSchedulerOutput`] if `return_dict` is True, otherwise a
|
||||||
When returning a tuple, the first element is the sample tensor.
|
`tuple`. When returning a tuple, the first element is the sample tensor.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
t = timestep
|
t = timestep
|
||||||
@@ -257,7 +257,7 @@ class FlaxDDPMScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
if not return_dict:
|
if not return_dict:
|
||||||
return (pred_prev_sample, state)
|
return (pred_prev_sample, state)
|
||||||
|
|
||||||
return FlaxSchedulerOutput(prev_sample=pred_prev_sample, state=state)
|
return FlaxDDPMSchedulerOutput(prev_sample=pred_prev_sample, state=state)
|
||||||
|
|
||||||
def add_noise(
|
def add_noise(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ from jax import random
|
|||||||
|
|
||||||
from ..configuration_utils import ConfigMixin, register_to_config
|
from ..configuration_utils import ConfigMixin, register_to_config
|
||||||
from ..utils import BaseOutput
|
from ..utils import BaseOutput
|
||||||
from .scheduling_utils import SchedulerMixin
|
from .scheduling_utils_flax import FlaxSchedulerMixin
|
||||||
|
|
||||||
|
|
||||||
@flax.struct.dataclass
|
@flax.struct.dataclass
|
||||||
@@ -56,7 +56,7 @@ class FlaxKarrasVeOutput(BaseOutput):
|
|||||||
state: KarrasVeSchedulerState
|
state: KarrasVeSchedulerState
|
||||||
|
|
||||||
|
|
||||||
class FlaxKarrasVeScheduler(SchedulerMixin, ConfigMixin):
|
class FlaxKarrasVeScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||||
"""
|
"""
|
||||||
Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and
|
Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and
|
||||||
the VE column of Table 1 from [1] for reference.
|
the VE column of Table 1 from [1] for reference.
|
||||||
@@ -172,7 +172,7 @@ class FlaxKarrasVeScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
sigma_hat (`float`): TODO
|
sigma_hat (`float`): TODO
|
||||||
sigma_prev (`float`): TODO
|
sigma_prev (`float`): TODO
|
||||||
sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO
|
sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO
|
||||||
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
|
return_dict (`bool`): option for returning tuple rather than FlaxKarrasVeOutput class
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
[`~schedulers.scheduling_karras_ve_flax.FlaxKarrasVeOutput`] or `tuple`: Updated sample in the diffusion
|
[`~schedulers.scheduling_karras_ve_flax.FlaxKarrasVeOutput`] or `tuple`: Updated sample in the diffusion
|
||||||
@@ -211,7 +211,7 @@ class FlaxKarrasVeScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO
|
sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO
|
||||||
sample_prev (`torch.FloatTensor` or `np.ndarray`): TODO
|
sample_prev (`torch.FloatTensor` or `np.ndarray`): TODO
|
||||||
derivative (`torch.FloatTensor` or `np.ndarray`): TODO
|
derivative (`torch.FloatTensor` or `np.ndarray`): TODO
|
||||||
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
|
return_dict (`bool`): option for returning tuple rather than FlaxKarrasVeOutput class
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
prev_sample (TODO): updated sample in the diffusion chain. derivative (TODO): TODO
|
prev_sample (TODO): updated sample in the diffusion chain. derivative (TODO): TODO
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ import jax.numpy as jnp
|
|||||||
from scipy import integrate
|
from scipy import integrate
|
||||||
|
|
||||||
from ..configuration_utils import ConfigMixin, register_to_config
|
from ..configuration_utils import ConfigMixin, register_to_config
|
||||||
from .scheduling_utils import SchedulerMixin, SchedulerOutput
|
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput
|
||||||
|
|
||||||
|
|
||||||
@flax.struct.dataclass
|
@flax.struct.dataclass
|
||||||
@@ -37,11 +37,11 @@ class LMSDiscreteSchedulerState:
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FlaxSchedulerOutput(SchedulerOutput):
|
class FlaxLMSSchedulerOutput(FlaxSchedulerOutput):
|
||||||
state: LMSDiscreteSchedulerState
|
state: LMSDiscreteSchedulerState
|
||||||
|
|
||||||
|
|
||||||
class FlaxLMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||||
"""
|
"""
|
||||||
Linear Multistep Scheduler for discrete beta schedules. Based on the original k-diffusion implementation by
|
Linear Multistep Scheduler for discrete beta schedules. Based on the original k-diffusion implementation by
|
||||||
Katherine Crowson:
|
Katherine Crowson:
|
||||||
@@ -147,7 +147,7 @@ class FlaxLMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
sample: jnp.ndarray,
|
sample: jnp.ndarray,
|
||||||
order: int = 4,
|
order: int = 4,
|
||||||
return_dict: bool = True,
|
return_dict: bool = True,
|
||||||
) -> Union[SchedulerOutput, Tuple]:
|
) -> Union[FlaxLMSSchedulerOutput, Tuple]:
|
||||||
"""
|
"""
|
||||||
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
|
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
|
||||||
process from the learned model outputs (most often the predicted noise).
|
process from the learned model outputs (most often the predicted noise).
|
||||||
@@ -159,11 +159,11 @@ class FlaxLMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
sample (`jnp.ndarray`):
|
sample (`jnp.ndarray`):
|
||||||
current instance of sample being created by diffusion process.
|
current instance of sample being created by diffusion process.
|
||||||
order: coefficient for multi-step inference.
|
order: coefficient for multi-step inference.
|
||||||
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
|
return_dict (`bool`): option for returning tuple rather than FlaxLMSSchedulerOutput class
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
[`FlaxSchedulerOutput`] or `tuple`: [`FlaxSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`.
|
[`FlaxLMSSchedulerOutput`] or `tuple`: [`FlaxLMSSchedulerOutput`] if `return_dict` is True, otherwise a
|
||||||
When returning a tuple, the first element is the sample tensor.
|
`tuple`. When returning a tuple, the first element is the sample tensor.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
sigma = state.sigmas[timestep]
|
sigma = state.sigmas[timestep]
|
||||||
@@ -189,7 +189,7 @@ class FlaxLMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
if not return_dict:
|
if not return_dict:
|
||||||
return (prev_sample, state)
|
return (prev_sample, state)
|
||||||
|
|
||||||
return FlaxSchedulerOutput(prev_sample=prev_sample, state=state)
|
return FlaxLMSSchedulerOutput(prev_sample=prev_sample, state=state)
|
||||||
|
|
||||||
def add_noise(
|
def add_noise(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ import jax
|
|||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
|
|
||||||
from ..configuration_utils import ConfigMixin, register_to_config
|
from ..configuration_utils import ConfigMixin, register_to_config
|
||||||
from .scheduling_utils import SchedulerMixin, SchedulerOutput
|
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput
|
||||||
|
|
||||||
|
|
||||||
def betas_for_alpha_bar(num_diffusion_timesteps: int, max_beta=0.999) -> jnp.ndarray:
|
def betas_for_alpha_bar(num_diffusion_timesteps: int, max_beta=0.999) -> jnp.ndarray:
|
||||||
@@ -76,11 +76,11 @@ class PNDMSchedulerState:
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FlaxSchedulerOutput(SchedulerOutput):
|
class FlaxPNDMSchedulerOutput(FlaxSchedulerOutput):
|
||||||
state: PNDMSchedulerState
|
state: PNDMSchedulerState
|
||||||
|
|
||||||
|
|
||||||
class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
|
class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||||
"""
|
"""
|
||||||
Pseudo numerical methods for diffusion models (PNDM) proposes using more advanced ODE integration techniques,
|
Pseudo numerical methods for diffusion models (PNDM) proposes using more advanced ODE integration techniques,
|
||||||
namely Runge-Kutta method and a linear multi-step method.
|
namely Runge-Kutta method and a linear multi-step method.
|
||||||
@@ -211,7 +211,7 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
timestep: int,
|
timestep: int,
|
||||||
sample: jnp.ndarray,
|
sample: jnp.ndarray,
|
||||||
return_dict: bool = True,
|
return_dict: bool = True,
|
||||||
) -> Union[FlaxSchedulerOutput, Tuple]:
|
) -> Union[FlaxPNDMSchedulerOutput, Tuple]:
|
||||||
"""
|
"""
|
||||||
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
|
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
|
||||||
process from the learned model outputs (most often the predicted noise).
|
process from the learned model outputs (most often the predicted noise).
|
||||||
@@ -224,11 +224,11 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
timestep (`int`): current discrete timestep in the diffusion chain.
|
timestep (`int`): current discrete timestep in the diffusion chain.
|
||||||
sample (`jnp.ndarray`):
|
sample (`jnp.ndarray`):
|
||||||
current instance of sample being created by diffusion process.
|
current instance of sample being created by diffusion process.
|
||||||
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
|
return_dict (`bool`): option for returning tuple rather than FlaxPNDMSchedulerOutput class
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
[`FlaxSchedulerOutput`] or `tuple`: [`FlaxSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`.
|
[`FlaxPNDMSchedulerOutput`] or `tuple`: [`FlaxPNDMSchedulerOutput`] if `return_dict` is True, otherwise a
|
||||||
When returning a tuple, the first element is the sample tensor.
|
`tuple`. When returning a tuple, the first element is the sample tensor.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if self.config.skip_prk_steps:
|
if self.config.skip_prk_steps:
|
||||||
@@ -249,7 +249,7 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
if not return_dict:
|
if not return_dict:
|
||||||
return (prev_sample, state)
|
return (prev_sample, state)
|
||||||
|
|
||||||
return FlaxSchedulerOutput(prev_sample=prev_sample, state=state)
|
return FlaxPNDMSchedulerOutput(prev_sample=prev_sample, state=state)
|
||||||
|
|
||||||
def step_prk(
|
def step_prk(
|
||||||
self,
|
self,
|
||||||
@@ -257,7 +257,7 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
model_output: jnp.ndarray,
|
model_output: jnp.ndarray,
|
||||||
timestep: int,
|
timestep: int,
|
||||||
sample: jnp.ndarray,
|
sample: jnp.ndarray,
|
||||||
) -> Union[FlaxSchedulerOutput, Tuple]:
|
) -> Union[FlaxPNDMSchedulerOutput, Tuple]:
|
||||||
"""
|
"""
|
||||||
Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the
|
Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the
|
||||||
solution to the differential equation.
|
solution to the differential equation.
|
||||||
@@ -268,11 +268,11 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
timestep (`int`): current discrete timestep in the diffusion chain.
|
timestep (`int`): current discrete timestep in the diffusion chain.
|
||||||
sample (`jnp.ndarray`):
|
sample (`jnp.ndarray`):
|
||||||
current instance of sample being created by diffusion process.
|
current instance of sample being created by diffusion process.
|
||||||
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
|
return_dict (`bool`): option for returning tuple rather than FlaxPNDMSchedulerOutput class
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
[`FlaxSchedulerOutput`] or `tuple`: [`FlaxSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`.
|
[`FlaxPNDMSchedulerOutput`] or `tuple`: [`FlaxPNDMSchedulerOutput`] if `return_dict` is True, otherwise a
|
||||||
When returning a tuple, the first element is the sample tensor.
|
`tuple`. When returning a tuple, the first element is the sample tensor.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if state.num_inference_steps is None:
|
if state.num_inference_steps is None:
|
||||||
@@ -327,7 +327,7 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
model_output: jnp.ndarray,
|
model_output: jnp.ndarray,
|
||||||
timestep: int,
|
timestep: int,
|
||||||
sample: jnp.ndarray,
|
sample: jnp.ndarray,
|
||||||
) -> Union[FlaxSchedulerOutput, Tuple]:
|
) -> Union[FlaxPNDMSchedulerOutput, Tuple]:
|
||||||
"""
|
"""
|
||||||
Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple
|
Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple
|
||||||
times to approximate the solution.
|
times to approximate the solution.
|
||||||
@@ -338,11 +338,11 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
timestep (`int`): current discrete timestep in the diffusion chain.
|
timestep (`int`): current discrete timestep in the diffusion chain.
|
||||||
sample (`jnp.ndarray`):
|
sample (`jnp.ndarray`):
|
||||||
current instance of sample being created by diffusion process.
|
current instance of sample being created by diffusion process.
|
||||||
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
|
return_dict (`bool`): option for returning tuple rather than FlaxPNDMSchedulerOutput class
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
[`FlaxSchedulerOutput`] or `tuple`: [`FlaxSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`.
|
[`FlaxPNDMSchedulerOutput`] or `tuple`: [`FlaxPNDMSchedulerOutput`] if `return_dict` is True, otherwise a
|
||||||
When returning a tuple, the first element is the sample tensor.
|
`tuple`. When returning a tuple, the first element is the sample tensor.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if state.num_inference_steps is None:
|
if state.num_inference_steps is None:
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ import jax.numpy as jnp
|
|||||||
from jax import random
|
from jax import random
|
||||||
|
|
||||||
from ..configuration_utils import ConfigMixin, register_to_config
|
from ..configuration_utils import ConfigMixin, register_to_config
|
||||||
from .scheduling_utils import SchedulerMixin, SchedulerOutput
|
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput
|
||||||
|
|
||||||
|
|
||||||
@flax.struct.dataclass
|
@flax.struct.dataclass
|
||||||
@@ -38,7 +38,7 @@ class ScoreSdeVeSchedulerState:
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FlaxSdeVeOutput(SchedulerOutput):
|
class FlaxSdeVeOutput(FlaxSchedulerOutput):
|
||||||
"""
|
"""
|
||||||
Output class for the ScoreSdeVeScheduler's step function output.
|
Output class for the ScoreSdeVeScheduler's step function output.
|
||||||
|
|
||||||
@@ -56,7 +56,7 @@ class FlaxSdeVeOutput(SchedulerOutput):
|
|||||||
prev_sample_mean: Optional[jnp.ndarray] = None
|
prev_sample_mean: Optional[jnp.ndarray] = None
|
||||||
|
|
||||||
|
|
||||||
class FlaxScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
|
class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||||
"""
|
"""
|
||||||
The variance exploding stochastic differential equation (SDE) scheduler.
|
The variance exploding stochastic differential equation (SDE) scheduler.
|
||||||
|
|
||||||
@@ -168,7 +168,7 @@ class FlaxScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
sample (`jnp.ndarray`):
|
sample (`jnp.ndarray`):
|
||||||
current instance of sample being created by diffusion process.
|
current instance of sample being created by diffusion process.
|
||||||
generator: random number generator.
|
generator: random number generator.
|
||||||
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
|
return_dict (`bool`): option for returning tuple rather than FlaxSdeVeOutput class
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
[`FlaxSdeVeOutput`] or `tuple`: [`FlaxSdeVeOutput`] if `return_dict` is True, otherwise a `tuple`. When
|
[`FlaxSdeVeOutput`] or `tuple`: [`FlaxSdeVeOutput`] if `return_dict` is True, otherwise a `tuple`. When
|
||||||
@@ -216,7 +216,7 @@ class FlaxScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
sample: jnp.ndarray,
|
sample: jnp.ndarray,
|
||||||
key: random.KeyArray,
|
key: random.KeyArray,
|
||||||
return_dict: bool = True,
|
return_dict: bool = True,
|
||||||
) -> Union[SchedulerOutput, Tuple]:
|
) -> Union[FlaxSdeVeOutput, Tuple]:
|
||||||
"""
|
"""
|
||||||
Correct the predicted sample based on the output model_output of the network. This is often run repeatedly
|
Correct the predicted sample based on the output model_output of the network. This is often run repeatedly
|
||||||
after making the prediction for the previous timestep.
|
after making the prediction for the previous timestep.
|
||||||
@@ -227,7 +227,7 @@ class FlaxScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
sample (`jnp.ndarray`):
|
sample (`jnp.ndarray`):
|
||||||
current instance of sample being created by diffusion process.
|
current instance of sample being created by diffusion process.
|
||||||
generator: random number generator.
|
generator: random number generator.
|
||||||
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
|
return_dict (`bool`): option for returning tuple rather than FlaxSdeVeOutput class
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
[`FlaxSdeVeOutput`] or `tuple`: [`FlaxSdeVeOutput`] if `return_dict` is True, otherwise a `tuple`. When
|
[`FlaxSdeVeOutput`] or `tuple`: [`FlaxSdeVeOutput`] if `return_dict` is True, otherwise a `tuple`. When
|
||||||
|
|||||||
@@ -0,0 +1,53 @@
|
|||||||
|
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import warnings
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import jax.numpy as jnp
|
||||||
|
|
||||||
|
from ..utils import BaseOutput
|
||||||
|
|
||||||
|
|
||||||
|
SCHEDULER_CONFIG_NAME = "scheduler_config.json"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FlaxSchedulerOutput(BaseOutput):
|
||||||
|
"""
|
||||||
|
Base class for the scheduler's step function output.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prev_sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)` for images):
|
||||||
|
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
|
||||||
|
denoising loop.
|
||||||
|
"""
|
||||||
|
|
||||||
|
prev_sample: jnp.ndarray
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxSchedulerMixin:
|
||||||
|
"""
|
||||||
|
Mixin containing common functions for the schedulers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
config_name = SCHEDULER_CONFIG_NAME
|
||||||
|
|
||||||
|
def set_format(self, tensor_format="pt"):
|
||||||
|
warnings.warn(
|
||||||
|
"The method `set_format` is deprecated and will be removed in version `0.5.0`."
|
||||||
|
"If you're running your code in PyTorch, you can safely remove this function as the schedulers"
|
||||||
|
"are always in Pytorch",
|
||||||
|
DeprecationWarning,
|
||||||
|
)
|
||||||
|
return self
|
||||||
Reference in New Issue
Block a user