[v0.4.0] Temporarily remove Flax modules from the public API (#755)
Temporarily remove Flax modules from the public API
This commit is contained in:
@@ -45,21 +45,3 @@ The models are built on the base class ['ModelMixin'] that is a `torch.nn.module
|
|||||||
|
|
||||||
## AutoencoderKL
|
## AutoencoderKL
|
||||||
[[autodoc]] AutoencoderKL
|
[[autodoc]] AutoencoderKL
|
||||||
|
|
||||||
## FlaxModelMixin
|
|
||||||
[[autodoc]] FlaxModelMixin
|
|
||||||
|
|
||||||
## FlaxUNet2DConditionOutput
|
|
||||||
[[autodoc]] models.unet_2d_condition_flax.FlaxUNet2DConditionOutput
|
|
||||||
|
|
||||||
## FlaxUNet2DConditionModel
|
|
||||||
[[autodoc]] FlaxUNet2DConditionModel
|
|
||||||
|
|
||||||
## FlaxDecoderOutput
|
|
||||||
[[autodoc]] models.vae_flax.FlaxDecoderOutput
|
|
||||||
|
|
||||||
## FlaxAutoencoderKLOutput
|
|
||||||
[[autodoc]] models.vae_flax.FlaxAutoencoderKLOutput
|
|
||||||
|
|
||||||
## FlaxAutoencoderKL
|
|
||||||
[[autodoc]] FlaxAutoencoderKL
|
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ This allows for rapid experimentation and cleaner abstractions in the code, wher
|
|||||||
To this end, the design of schedulers is such that:
|
To this end, the design of schedulers is such that:
|
||||||
|
|
||||||
- Schedulers can be used interchangeably between diffusion models in inference to find the preferred trade-off between speed and generation quality.
|
- Schedulers can be used interchangeably between diffusion models in inference to find the preferred trade-off between speed and generation quality.
|
||||||
- Schedulers are currently by default in PyTorch, but are designed to be framework independent (partial Jax support currently exists).
|
- Schedulers are currently by default in PyTorch.
|
||||||
|
|
||||||
|
|
||||||
## API
|
## API
|
||||||
|
|||||||
@@ -84,13 +84,10 @@ _deps = [
|
|||||||
"datasets",
|
"datasets",
|
||||||
"filelock",
|
"filelock",
|
||||||
"flake8>=3.8.3",
|
"flake8>=3.8.3",
|
||||||
"flax>=0.4.1",
|
|
||||||
"hf-doc-builder>=0.3.0",
|
"hf-doc-builder>=0.3.0",
|
||||||
"huggingface-hub>=0.10.0",
|
"huggingface-hub>=0.10.0",
|
||||||
"importlib_metadata",
|
"importlib_metadata",
|
||||||
"isort>=5.5.4",
|
"isort>=5.5.4",
|
||||||
"jax>=0.2.8,!=0.3.2,<=0.3.6",
|
|
||||||
"jaxlib>=0.1.65,<=0.3.6",
|
|
||||||
"modelcards>=0.1.4",
|
"modelcards>=0.1.4",
|
||||||
"numpy",
|
"numpy",
|
||||||
"onnxruntime",
|
"onnxruntime",
|
||||||
@@ -188,15 +185,9 @@ extras["test"] = deps_list(
|
|||||||
"torchvision",
|
"torchvision",
|
||||||
"transformers"
|
"transformers"
|
||||||
)
|
)
|
||||||
extras["torch"] = deps_list("torch")
|
|
||||||
|
|
||||||
if os.name == "nt": # windows
|
|
||||||
extras["flax"] = [] # jax is not supported on windows
|
|
||||||
else:
|
|
||||||
extras["flax"] = deps_list("jax", "jaxlib", "flax")
|
|
||||||
|
|
||||||
extras["dev"] = (
|
extras["dev"] = (
|
||||||
extras["quality"] + extras["test"] + extras["training"] + extras["docs"] + extras["torch"] + extras["flax"]
|
extras["quality"] + extras["test"] + extras["training"] + extras["docs"]
|
||||||
)
|
)
|
||||||
|
|
||||||
install_requires = [
|
install_requires = [
|
||||||
@@ -207,6 +198,7 @@ install_requires = [
|
|||||||
deps["regex"],
|
deps["regex"],
|
||||||
deps["requests"],
|
deps["requests"],
|
||||||
deps["Pillow"],
|
deps["Pillow"],
|
||||||
|
deps["torch"]
|
||||||
]
|
]
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
from .utils import (
|
from .utils import (
|
||||||
is_flax_available,
|
|
||||||
is_inflect_available,
|
is_inflect_available,
|
||||||
is_onnx_available,
|
is_onnx_available,
|
||||||
is_scipy_available,
|
is_scipy_available,
|
||||||
@@ -61,25 +60,3 @@ if is_torch_available() and is_transformers_available() and is_onnx_available():
|
|||||||
from .pipelines import StableDiffusionOnnxPipeline
|
from .pipelines import StableDiffusionOnnxPipeline
|
||||||
else:
|
else:
|
||||||
from .utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403
|
from .utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403
|
||||||
|
|
||||||
if is_flax_available():
|
|
||||||
from .modeling_flax_utils import FlaxModelMixin
|
|
||||||
from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel
|
|
||||||
from .models.vae_flax import FlaxAutoencoderKL
|
|
||||||
from .pipeline_flax_utils import FlaxDiffusionPipeline
|
|
||||||
from .schedulers import (
|
|
||||||
FlaxDDIMScheduler,
|
|
||||||
FlaxDDPMScheduler,
|
|
||||||
FlaxKarrasVeScheduler,
|
|
||||||
FlaxLMSDiscreteScheduler,
|
|
||||||
FlaxPNDMScheduler,
|
|
||||||
FlaxSchedulerMixin,
|
|
||||||
FlaxScoreSdeVeScheduler,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
from .utils.dummy_flax_objects import * # noqa F403
|
|
||||||
|
|
||||||
if is_flax_available() and is_transformers_available():
|
|
||||||
from .pipelines import FlaxStableDiffusionPipeline
|
|
||||||
else:
|
|
||||||
from .utils.dummy_flax_and_transformers_objects import * # noqa F403
|
|
||||||
|
|||||||
@@ -8,13 +8,10 @@ deps = {
|
|||||||
"datasets": "datasets",
|
"datasets": "datasets",
|
||||||
"filelock": "filelock",
|
"filelock": "filelock",
|
||||||
"flake8": "flake8>=3.8.3",
|
"flake8": "flake8>=3.8.3",
|
||||||
"flax": "flax>=0.4.1",
|
|
||||||
"hf-doc-builder": "hf-doc-builder>=0.3.0",
|
"hf-doc-builder": "hf-doc-builder>=0.3.0",
|
||||||
"huggingface-hub": "huggingface-hub>=0.10.0",
|
"huggingface-hub": "huggingface-hub>=0.10.0",
|
||||||
"importlib_metadata": "importlib_metadata",
|
"importlib_metadata": "importlib_metadata",
|
||||||
"isort": "isort>=5.5.4",
|
"isort": "isort>=5.5.4",
|
||||||
"jax": "jax>=0.2.8,!=0.3.2,<=0.3.6",
|
|
||||||
"jaxlib": "jaxlib>=0.1.65,<=0.3.6",
|
|
||||||
"modelcards": "modelcards>=0.1.4",
|
"modelcards": "modelcards>=0.1.4",
|
||||||
"numpy": "numpy",
|
"numpy": "numpy",
|
||||||
"onnxruntime": "onnxruntime",
|
"onnxruntime": "onnxruntime",
|
||||||
|
|||||||
@@ -12,14 +12,10 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from ..utils import is_flax_available, is_torch_available
|
from ..utils import is_torch_available
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
from .unet_2d import UNet2DModel
|
from .unet_2d import UNet2DModel
|
||||||
from .unet_2d_condition import UNet2DConditionModel
|
from .unet_2d_condition import UNet2DConditionModel
|
||||||
from .vae import AutoencoderKL, VQModel
|
from .vae import AutoencoderKL, VQModel
|
||||||
|
|
||||||
if is_flax_available():
|
|
||||||
from .unet_2d_condition_flax import FlaxUNet2DConditionModel
|
|
||||||
from .vae_flax import FlaxAutoencoderKL
|
|
||||||
|
|||||||
@@ -21,6 +21,3 @@ if is_torch_available() and is_transformers_available():
|
|||||||
|
|
||||||
if is_transformers_available() and is_onnx_available():
|
if is_transformers_available() and is_onnx_available():
|
||||||
from .stable_diffusion import StableDiffusionOnnxPipeline
|
from .stable_diffusion import StableDiffusionOnnxPipeline
|
||||||
|
|
||||||
if is_transformers_available() and is_flax_available():
|
|
||||||
from .stable_diffusion import FlaxStableDiffusionPipeline
|
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import numpy as np
|
|||||||
import PIL
|
import PIL
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from ...utils import BaseOutput, is_flax_available, is_onnx_available, is_torch_available, is_transformers_available
|
from ...utils import BaseOutput, is_onnx_available, is_torch_available, is_transformers_available
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -35,27 +35,3 @@ if is_transformers_available() and is_torch_available():
|
|||||||
|
|
||||||
if is_transformers_available() and is_onnx_available():
|
if is_transformers_available() and is_onnx_available():
|
||||||
from .pipeline_stable_diffusion_onnx import StableDiffusionOnnxPipeline
|
from .pipeline_stable_diffusion_onnx import StableDiffusionOnnxPipeline
|
||||||
|
|
||||||
if is_transformers_available() and is_flax_available():
|
|
||||||
import flax
|
|
||||||
|
|
||||||
@flax.struct.dataclass
|
|
||||||
class FlaxStableDiffusionPipelineOutput(BaseOutput):
|
|
||||||
"""
|
|
||||||
Output class for Stable Diffusion pipelines.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
images (`List[PIL.Image.Image]` or `np.ndarray`)
|
|
||||||
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
|
|
||||||
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
|
|
||||||
nsfw_content_detected (`List[bool]`)
|
|
||||||
List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
|
||||||
(nsfw) content.
|
|
||||||
"""
|
|
||||||
|
|
||||||
images: Union[List[PIL.Image.Image], np.ndarray]
|
|
||||||
nsfw_content_detected: List[bool]
|
|
||||||
|
|
||||||
from ...schedulers.scheduling_pndm_flax import PNDMSchedulerState
|
|
||||||
from .pipeline_flax_stable_diffusion import FlaxStableDiffusionPipeline
|
|
||||||
from .safety_checker_flax import FlaxStableDiffusionSafetyChecker
|
|
||||||
|
|||||||
@@ -13,7 +13,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
from ..utils import is_flax_available, is_scipy_available, is_torch_available
|
from ..utils import is_scipy_available, is_torch_available
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
@@ -27,17 +27,6 @@ if is_torch_available():
|
|||||||
else:
|
else:
|
||||||
from ..utils.dummy_pt_objects import * # noqa F403
|
from ..utils.dummy_pt_objects import * # noqa F403
|
||||||
|
|
||||||
if is_flax_available():
|
|
||||||
from .scheduling_ddim_flax import FlaxDDIMScheduler
|
|
||||||
from .scheduling_ddpm_flax import FlaxDDPMScheduler
|
|
||||||
from .scheduling_karras_ve_flax import FlaxKarrasVeScheduler
|
|
||||||
from .scheduling_lms_discrete_flax import FlaxLMSDiscreteScheduler
|
|
||||||
from .scheduling_pndm_flax import FlaxPNDMScheduler
|
|
||||||
from .scheduling_sde_ve_flax import FlaxScoreSdeVeScheduler
|
|
||||||
from .scheduling_utils_flax import FlaxSchedulerMixin
|
|
||||||
else:
|
|
||||||
from ..utils.dummy_flax_objects import * # noqa F403
|
|
||||||
|
|
||||||
|
|
||||||
if is_scipy_available() and is_torch_available():
|
if is_scipy_available() and is_torch_available():
|
||||||
from .scheduling_lms_discrete import LMSDiscreteScheduler
|
from .scheduling_lms_discrete import LMSDiscreteScheduler
|
||||||
|
|||||||
Reference in New Issue
Block a user