Replace flake8 with ruff and update black (#2279)
* before running make style * remove left overs from flake8 * finish * make fix-copies * final fix * more fixes
This commit is contained in:
parent
f5ccffecf7
commit
a7ca03aa85
5
.github/workflows/pr_quality.yml
vendored
5
.github/workflows/pr_quality.yml
vendored
@ -27,9 +27,8 @@ jobs:
|
|||||||
pip install .[quality]
|
pip install .[quality]
|
||||||
- name: Check quality
|
- name: Check quality
|
||||||
run: |
|
run: |
|
||||||
black --check --preview examples tests src utils scripts
|
black --check examples tests src utils scripts
|
||||||
isort --check-only examples tests src utils scripts
|
ruff examples tests src utils scripts
|
||||||
flake8 examples tests src utils scripts
|
|
||||||
doc-builder style src/diffusers docs/source --max_len 119 --check_only --path_to_docs docs/source
|
doc-builder style src/diffusers docs/source --max_len 119 --check_only --path_to_docs docs/source
|
||||||
|
|
||||||
check_repository_consistency:
|
check_repository_consistency:
|
||||||
|
|||||||
3
.gitignore
vendored
3
.gitignore
vendored
@ -169,3 +169,6 @@ tags
|
|||||||
|
|
||||||
# dependencies
|
# dependencies
|
||||||
/transformers
|
/transformers
|
||||||
|
|
||||||
|
# ruff
|
||||||
|
.ruff_cache
|
||||||
|
|||||||
@ -177,7 +177,7 @@ Follow these steps to start contributing ([supported Python versions](https://gi
|
|||||||
$ make style
|
$ make style
|
||||||
```
|
```
|
||||||
|
|
||||||
🧨 Diffusers also uses `flake8` and a few custom scripts to check for coding mistakes. Quality
|
🧨 Diffusers also uses `ruff` and a few custom scripts to check for coding mistakes. Quality
|
||||||
control runs in CI, however you can also run the same checks with:
|
control runs in CI, however you can also run the same checks with:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
|||||||
14
Makefile
14
Makefile
@ -9,9 +9,8 @@ modified_only_fixup:
|
|||||||
$(eval modified_py_files := $(shell python utils/get_modified_files.py $(check_dirs)))
|
$(eval modified_py_files := $(shell python utils/get_modified_files.py $(check_dirs)))
|
||||||
@if test -n "$(modified_py_files)"; then \
|
@if test -n "$(modified_py_files)"; then \
|
||||||
echo "Checking/fixing $(modified_py_files)"; \
|
echo "Checking/fixing $(modified_py_files)"; \
|
||||||
black --preview $(modified_py_files); \
|
black $(modified_py_files); \
|
||||||
isort $(modified_py_files); \
|
ruff $(modified_py_files); \
|
||||||
flake8 $(modified_py_files); \
|
|
||||||
else \
|
else \
|
||||||
echo "No library .py files were modified"; \
|
echo "No library .py files were modified"; \
|
||||||
fi
|
fi
|
||||||
@ -41,9 +40,8 @@ repo-consistency:
|
|||||||
# this target runs checks on all files
|
# this target runs checks on all files
|
||||||
|
|
||||||
quality:
|
quality:
|
||||||
black --check --preview $(check_dirs)
|
black --check $(check_dirs)
|
||||||
isort --check-only $(check_dirs)
|
ruff $(check_dirs)
|
||||||
flake8 $(check_dirs)
|
|
||||||
doc-builder style src/diffusers docs/source --max_len 119 --check_only --path_to_docs docs/source
|
doc-builder style src/diffusers docs/source --max_len 119 --check_only --path_to_docs docs/source
|
||||||
python utils/check_doc_toc.py
|
python utils/check_doc_toc.py
|
||||||
|
|
||||||
@ -57,8 +55,8 @@ extra_style_checks:
|
|||||||
# this target runs checks on all files and potentially modifies some of them
|
# this target runs checks on all files and potentially modifies some of them
|
||||||
|
|
||||||
style:
|
style:
|
||||||
black --preview $(check_dirs)
|
black $(check_dirs)
|
||||||
isort $(check_dirs)
|
ruff $(check_dirs) --fix
|
||||||
${MAKE} autogenerate_code
|
${MAKE} autogenerate_code
|
||||||
${MAKE} extra_style_checks
|
${MAKE} extra_style_checks
|
||||||
|
|
||||||
|
|||||||
@ -177,7 +177,7 @@ Follow these steps to start contributing ([supported Python versions](https://gi
|
|||||||
$ make style
|
$ make style
|
||||||
```
|
```
|
||||||
|
|
||||||
🧨 Diffusers also uses `flake8` and a few custom scripts to check for coding mistakes. Quality
|
🧨 Diffusers also uses `ruff` and a few custom scripts to check for coding mistakes. Quality
|
||||||
control runs in CI, however you can also run the same checks with:
|
control runs in CI, however you can also run the same checks with:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
|||||||
@ -210,6 +210,7 @@ torch.set_grad_enabled(False)
|
|||||||
n_experiments = 2
|
n_experiments = 2
|
||||||
unet_runs_per_experiment = 50
|
unet_runs_per_experiment = 50
|
||||||
|
|
||||||
|
|
||||||
# load inputs
|
# load inputs
|
||||||
def generate_inputs():
|
def generate_inputs():
|
||||||
sample = torch.randn(2, 4, 64, 64).half().cuda()
|
sample = torch.randn(2, 4, 64, 64).half().cuda()
|
||||||
@ -288,6 +289,8 @@ pipe = StableDiffusionPipeline.from_pretrained(
|
|||||||
|
|
||||||
# use jitted unet
|
# use jitted unet
|
||||||
unet_traced = torch.jit.load("unet_traced.pt")
|
unet_traced = torch.jit.load("unet_traced.pt")
|
||||||
|
|
||||||
|
|
||||||
# del pipe.unet
|
# del pipe.unet
|
||||||
class TracedUNet(torch.nn.Module):
|
class TracedUNet(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|||||||
@ -1,11 +1,11 @@
|
|||||||
from typing import Optional, Tuple, Union
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from einops import rearrange, reduce
|
||||||
|
|
||||||
from diffusers import DDIMScheduler, DDPMScheduler, DiffusionPipeline, ImagePipelineOutput, UNet2DConditionModel
|
from diffusers import DDIMScheduler, DDPMScheduler, DiffusionPipeline, ImagePipelineOutput, UNet2DConditionModel
|
||||||
from diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput
|
from diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput
|
||||||
from diffusers.schedulers.scheduling_ddpm import DDPMSchedulerOutput
|
from diffusers.schedulers.scheduling_ddpm import DDPMSchedulerOutput
|
||||||
from einops import rearrange, reduce
|
|
||||||
|
|
||||||
|
|
||||||
BITS = 8
|
BITS = 8
|
||||||
|
|||||||
@ -10,10 +10,11 @@ from diffusers.utils import is_safetensors_available
|
|||||||
if is_safetensors_available():
|
if is_safetensors_available():
|
||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
|
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
from diffusers import DiffusionPipeline, __version__
|
from diffusers import DiffusionPipeline, __version__
|
||||||
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
|
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
|
||||||
from diffusers.utils import CONFIG_NAME, DIFFUSERS_CACHE, ONNX_WEIGHTS_NAME, WEIGHTS_NAME
|
from diffusers.utils import CONFIG_NAME, DIFFUSERS_CACHE, ONNX_WEIGHTS_NAME, WEIGHTS_NAME
|
||||||
from huggingface_hub import snapshot_download
|
|
||||||
|
|
||||||
|
|
||||||
class CheckpointMergerPipeline(DiffusionPipeline):
|
class CheckpointMergerPipeline(DiffusionPipeline):
|
||||||
|
|||||||
@ -4,6 +4,8 @@ from typing import List, Optional, Union
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
from torchvision import transforms
|
||||||
|
from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextModel, CLIPTokenizer
|
||||||
|
|
||||||
from diffusers import (
|
from diffusers import (
|
||||||
AutoencoderKL,
|
AutoencoderKL,
|
||||||
@ -14,8 +16,6 @@ from diffusers import (
|
|||||||
UNet2DConditionModel,
|
UNet2DConditionModel,
|
||||||
)
|
)
|
||||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput
|
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput
|
||||||
from torchvision import transforms
|
|
||||||
from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextModel, CLIPTokenizer
|
|
||||||
|
|
||||||
|
|
||||||
class MakeCutouts(nn.Module):
|
class MakeCutouts(nn.Module):
|
||||||
|
|||||||
@ -16,6 +16,8 @@ import inspect
|
|||||||
from typing import Callable, List, Optional, Union
|
from typing import Callable, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from packaging import version
|
||||||
|
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||||
|
|
||||||
from diffusers import DiffusionPipeline
|
from diffusers import DiffusionPipeline
|
||||||
from diffusers.configuration_utils import FrozenDict
|
from diffusers.configuration_utils import FrozenDict
|
||||||
@ -29,8 +31,6 @@ from diffusers.schedulers import (
|
|||||||
PNDMScheduler,
|
PNDMScheduler,
|
||||||
)
|
)
|
||||||
from diffusers.utils import is_accelerate_available
|
from diffusers.utils import is_accelerate_available
|
||||||
from packaging import version
|
|
||||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
|
||||||
|
|
||||||
from ...utils import deprecate, logging
|
from ...utils import deprecate, logging
|
||||||
from . import StableDiffusionPipelineOutput
|
from . import StableDiffusionPipelineOutput
|
||||||
|
|||||||
@ -7,11 +7,16 @@ import warnings
|
|||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import PIL
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
import PIL
|
|
||||||
from accelerate import Accelerator
|
from accelerate import Accelerator
|
||||||
|
|
||||||
|
# TODO: remove and import from diffusers.utils when the new version of diffusers is released
|
||||||
|
from packaging import version
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||||
|
|
||||||
from diffusers import DiffusionPipeline
|
from diffusers import DiffusionPipeline
|
||||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||||
@ -19,11 +24,6 @@ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionS
|
|||||||
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||||
from diffusers.utils import deprecate, logging
|
from diffusers.utils import deprecate, logging
|
||||||
|
|
||||||
# TODO: remove and import from diffusers.utils when the new version of diffusers is released
|
|
||||||
from packaging import version
|
|
||||||
from tqdm.auto import tqdm
|
|
||||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
|
||||||
|
|
||||||
|
|
||||||
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
|
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
|
||||||
PIL_INTERPOLATION = {
|
PIL_INTERPOLATION = {
|
||||||
|
|||||||
@ -2,9 +2,10 @@ import inspect
|
|||||||
from typing import Callable, List, Optional, Tuple, Union
|
from typing import Callable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
|
||||||
|
|
||||||
import PIL
|
import PIL
|
||||||
|
import torch
|
||||||
|
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||||
|
|
||||||
from diffusers import DiffusionPipeline
|
from diffusers import DiffusionPipeline
|
||||||
from diffusers.configuration_utils import FrozenDict
|
from diffusers.configuration_utils import FrozenDict
|
||||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||||
@ -12,7 +13,6 @@ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
|||||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||||
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||||
from diffusers.utils import deprecate, logging
|
from diffusers.utils import deprecate, logging
|
||||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|||||||
@ -5,6 +5,7 @@ from typing import Callable, List, Optional, Union
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||||
|
|
||||||
from diffusers import DiffusionPipeline
|
from diffusers import DiffusionPipeline
|
||||||
from diffusers.configuration_utils import FrozenDict
|
from diffusers.configuration_utils import FrozenDict
|
||||||
@ -13,7 +14,6 @@ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
|||||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||||
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||||
from diffusers.utils import deprecate, logging
|
from diffusers.utils import deprecate, logging
|
||||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|||||||
@ -3,16 +3,16 @@ import re
|
|||||||
from typing import Callable, List, Optional, Union
|
from typing import Callable, List, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import PIL
|
||||||
import torch
|
import torch
|
||||||
|
from packaging import version
|
||||||
|
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||||
|
|
||||||
import diffusers
|
import diffusers
|
||||||
import PIL
|
|
||||||
from diffusers import SchedulerMixin, StableDiffusionPipeline
|
from diffusers import SchedulerMixin, StableDiffusionPipeline
|
||||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
|
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
|
||||||
from diffusers.utils import deprecate, logging
|
from diffusers.utils import deprecate, logging
|
||||||
from packaging import version
|
|
||||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
@ -3,15 +3,15 @@ import re
|
|||||||
from typing import Callable, List, Optional, Union
|
from typing import Callable, List, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import PIL
|
||||||
import torch
|
import torch
|
||||||
|
from packaging import version
|
||||||
|
from transformers import CLIPFeatureExtractor, CLIPTokenizer
|
||||||
|
|
||||||
import diffusers
|
import diffusers
|
||||||
import PIL
|
|
||||||
from diffusers import OnnxRuntimeModel, OnnxStableDiffusionPipeline, SchedulerMixin
|
from diffusers import OnnxRuntimeModel, OnnxStableDiffusionPipeline, SchedulerMixin
|
||||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||||
from diffusers.utils import deprecate, logging
|
from diffusers.utils import deprecate, logging
|
||||||
from packaging import version
|
|
||||||
from transformers import CLIPFeatureExtractor, CLIPTokenizer
|
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
@ -1,6 +1,10 @@
|
|||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from torchvision import transforms as tfms
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
from transformers import CLIPTextModel, CLIPTokenizer
|
||||||
|
|
||||||
from diffusers import (
|
from diffusers import (
|
||||||
AutoencoderKL,
|
AutoencoderKL,
|
||||||
@ -10,10 +14,6 @@ from diffusers import (
|
|||||||
PNDMScheduler,
|
PNDMScheduler,
|
||||||
UNet2DConditionModel,
|
UNet2DConditionModel,
|
||||||
)
|
)
|
||||||
from PIL import Image
|
|
||||||
from torchvision import transforms as tfms
|
|
||||||
from tqdm.auto import tqdm
|
|
||||||
from transformers import CLIPTextModel, CLIPTokenizer
|
|
||||||
|
|
||||||
|
|
||||||
class MagicMixPipeline(DiffusionPipeline):
|
class MagicMixPipeline(DiffusionPipeline):
|
||||||
|
|||||||
@ -2,14 +2,6 @@ import inspect
|
|||||||
from typing import Callable, List, Optional, Union
|
from typing import Callable, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from diffusers import DiffusionPipeline
|
|
||||||
from diffusers.configuration_utils import FrozenDict
|
|
||||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
|
||||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
|
||||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
|
||||||
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
|
||||||
from diffusers.utils import deprecate, logging
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
CLIPFeatureExtractor,
|
CLIPFeatureExtractor,
|
||||||
CLIPTextModel,
|
CLIPTextModel,
|
||||||
@ -19,6 +11,14 @@ from transformers import (
|
|||||||
pipeline,
|
pipeline,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from diffusers import DiffusionPipeline
|
||||||
|
from diffusers.configuration_utils import FrozenDict
|
||||||
|
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||||
|
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||||
|
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||||
|
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||||
|
from diffusers.utils import deprecate, logging
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|||||||
@ -17,11 +17,11 @@ import warnings
|
|||||||
from typing import Callable, List, Optional, Union
|
from typing import Callable, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser
|
||||||
|
|
||||||
from diffusers import DiffusionPipeline, LMSDiscreteScheduler
|
from diffusers import DiffusionPipeline, LMSDiscreteScheduler
|
||||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||||
from diffusers.utils import is_accelerate_available, logging
|
from diffusers.utils import is_accelerate_available, logging
|
||||||
from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|||||||
@ -5,6 +5,7 @@ import inspect
|
|||||||
from typing import Callable, List, Optional, Union
|
from typing import Callable, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||||
|
|
||||||
from diffusers import DiffusionPipeline
|
from diffusers import DiffusionPipeline
|
||||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||||
@ -12,7 +13,6 @@ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
|||||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||||
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||||
from diffusers.utils import logging
|
from diffusers.utils import logging
|
||||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|||||||
@ -2,6 +2,13 @@ import inspect
|
|||||||
from typing import Callable, List, Optional, Union
|
from typing import Callable, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from transformers import (
|
||||||
|
CLIPFeatureExtractor,
|
||||||
|
CLIPTextModel,
|
||||||
|
CLIPTokenizer,
|
||||||
|
WhisperForConditionalGeneration,
|
||||||
|
WhisperProcessor,
|
||||||
|
)
|
||||||
|
|
||||||
from diffusers import (
|
from diffusers import (
|
||||||
AutoencoderKL,
|
AutoencoderKL,
|
||||||
@ -14,13 +21,6 @@ from diffusers import (
|
|||||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput
|
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput
|
||||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||||
from diffusers.utils import logging
|
from diffusers.utils import logging
|
||||||
from transformers import (
|
|
||||||
CLIPFeatureExtractor,
|
|
||||||
CLIPTextModel,
|
|
||||||
CLIPTokenizer,
|
|
||||||
WhisperForConditionalGeneration,
|
|
||||||
WhisperProcessor,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
from typing import Any, Callable, Dict, List, Optional, Union
|
from typing import Any, Callable, Dict, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||||
|
|
||||||
from diffusers import (
|
from diffusers import (
|
||||||
AutoencoderKL,
|
AutoencoderKL,
|
||||||
@ -13,7 +14,6 @@ from diffusers import (
|
|||||||
)
|
)
|
||||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
|
||||||
|
|
||||||
|
|
||||||
pipe1_model_id = "CompVis/stable-diffusion-v1-1"
|
pipe1_model_id = "CompVis/stable-diffusion-v1-1"
|
||||||
|
|||||||
@ -1,8 +1,9 @@
|
|||||||
from typing import Any, Callable, Dict, List, Optional, Union
|
from typing import Any, Callable, Dict, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
import PIL.Image
|
import PIL.Image
|
||||||
|
import torch
|
||||||
|
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||||
|
|
||||||
from diffusers import (
|
from diffusers import (
|
||||||
AutoencoderKL,
|
AutoencoderKL,
|
||||||
DDIMScheduler,
|
DDIMScheduler,
|
||||||
@ -17,7 +18,6 @@ from diffusers import (
|
|||||||
from diffusers.configuration_utils import FrozenDict
|
from diffusers.configuration_utils import FrozenDict
|
||||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||||
from diffusers.utils import deprecate, logging
|
from diffusers.utils import deprecate, logging
|
||||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|||||||
@ -2,13 +2,13 @@ import types
|
|||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from transformers import CLIPTextModelWithProjection, CLIPTokenizer
|
||||||
|
from transformers.models.clip.modeling_clip import CLIPTextModelOutput
|
||||||
|
|
||||||
from diffusers.models import PriorTransformer
|
from diffusers.models import PriorTransformer
|
||||||
from diffusers.pipelines import DiffusionPipeline, StableDiffusionImageVariationPipeline
|
from diffusers.pipelines import DiffusionPipeline, StableDiffusionImageVariationPipeline
|
||||||
from diffusers.schedulers import UnCLIPScheduler
|
from diffusers.schedulers import UnCLIPScheduler
|
||||||
from diffusers.utils import logging, randn_tensor
|
from diffusers.utils import logging, randn_tensor
|
||||||
from transformers import CLIPTextModelWithProjection, CLIPTokenizer
|
|
||||||
from transformers.models.clip.modeling_clip import CLIPTextModelOutput
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|||||||
@ -1,15 +1,7 @@
|
|||||||
from typing import Callable, List, Optional, Union
|
from typing import Callable, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
import PIL
|
import PIL
|
||||||
from diffusers import DiffusionPipeline
|
import torch
|
||||||
from diffusers.configuration_utils import FrozenDict
|
|
||||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
|
||||||
from diffusers.pipelines.stable_diffusion import StableDiffusionInpaintPipeline
|
|
||||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
|
||||||
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
|
||||||
from diffusers.utils import deprecate, is_accelerate_available, logging
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
CLIPFeatureExtractor,
|
CLIPFeatureExtractor,
|
||||||
CLIPSegForImageSegmentation,
|
CLIPSegForImageSegmentation,
|
||||||
@ -18,6 +10,14 @@ from transformers import (
|
|||||||
CLIPTokenizer,
|
CLIPTokenizer,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from diffusers import DiffusionPipeline
|
||||||
|
from diffusers.configuration_utils import FrozenDict
|
||||||
|
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||||
|
from diffusers.pipelines.stable_diffusion import StableDiffusionInpaintPipeline
|
||||||
|
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||||
|
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||||
|
from diffusers.utils import deprecate, is_accelerate_available, logging
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|||||||
@ -16,14 +16,14 @@ import math
|
|||||||
from typing import Callable, List, Optional, Union
|
from typing import Callable, List, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
|
||||||
|
|
||||||
import PIL
|
import PIL
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from transformers import CLIPTextModel, CLIPTokenizer
|
||||||
|
|
||||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale import StableDiffusionUpscalePipeline
|
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale import StableDiffusionUpscalePipeline
|
||||||
from diffusers.schedulers import DDIMScheduler, DDPMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
from diffusers.schedulers import DDIMScheduler, DDPMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||||
from PIL import Image
|
|
||||||
from transformers import CLIPTextModel, CLIPTokenizer
|
|
||||||
|
|
||||||
|
|
||||||
def make_transparency_mask(size, overlap_pixels, remove_borders=[]):
|
def make_transparency_mask(size, overlap_pixels, remove_borders=[]):
|
||||||
|
|||||||
@ -6,6 +6,7 @@ from dataclasses import dataclass
|
|||||||
from typing import Callable, Dict, List, Optional, Union
|
from typing import Callable, Dict, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||||
|
|
||||||
from diffusers import DiffusionPipeline
|
from diffusers import DiffusionPipeline
|
||||||
from diffusers.configuration_utils import FrozenDict
|
from diffusers.configuration_utils import FrozenDict
|
||||||
@ -14,7 +15,6 @@ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import Stabl
|
|||||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||||
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||||
from diffusers.utils import deprecate, logging
|
from diffusers.utils import deprecate, logging
|
||||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|||||||
@ -23,27 +23,27 @@ import warnings
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
import accelerate
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from torch.utils.data import Dataset
|
|
||||||
|
|
||||||
import accelerate
|
|
||||||
import diffusers
|
|
||||||
import transformers
|
import transformers
|
||||||
from accelerate import Accelerator
|
from accelerate import Accelerator
|
||||||
from accelerate.logging import get_logger
|
from accelerate.logging import get_logger
|
||||||
from accelerate.utils import set_seed
|
from accelerate.utils import set_seed
|
||||||
|
from huggingface_hub import HfFolder, Repository, create_repo, whoami
|
||||||
|
from packaging import version
|
||||||
|
from PIL import Image
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from torchvision import transforms
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
from transformers import AutoTokenizer, PretrainedConfig
|
||||||
|
|
||||||
|
import diffusers
|
||||||
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
|
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
|
||||||
from diffusers.optimization import get_scheduler
|
from diffusers.optimization import get_scheduler
|
||||||
from diffusers.utils import check_min_version
|
from diffusers.utils import check_min_version
|
||||||
from diffusers.utils.import_utils import is_xformers_available
|
from diffusers.utils.import_utils import is_xformers_available
|
||||||
from huggingface_hub import HfFolder, Repository, create_repo, whoami
|
|
||||||
from packaging import version
|
|
||||||
from PIL import Image
|
|
||||||
from torchvision import transforms
|
|
||||||
from tqdm.auto import tqdm
|
|
||||||
from transformers import AutoTokenizer, PretrainedConfig
|
|
||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
|
|||||||
@ -6,15 +6,24 @@ import os
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import torch.utils.checkpoint
|
|
||||||
from torch.utils.data import Dataset
|
|
||||||
|
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
|
import numpy as np
|
||||||
import optax
|
import optax
|
||||||
|
import torch
|
||||||
|
import torch.utils.checkpoint
|
||||||
import transformers
|
import transformers
|
||||||
|
from flax import jax_utils
|
||||||
|
from flax.training import train_state
|
||||||
|
from flax.training.common_utils import shard
|
||||||
|
from huggingface_hub import HfFolder, Repository, create_repo, whoami
|
||||||
|
from jax.experimental.compilation_cache import compilation_cache as cc
|
||||||
|
from PIL import Image
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from torchvision import transforms
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel, set_seed
|
||||||
|
|
||||||
from diffusers import (
|
from diffusers import (
|
||||||
FlaxAutoencoderKL,
|
FlaxAutoencoderKL,
|
||||||
FlaxDDPMScheduler,
|
FlaxDDPMScheduler,
|
||||||
@ -24,15 +33,6 @@ from diffusers import (
|
|||||||
)
|
)
|
||||||
from diffusers.pipelines.stable_diffusion import FlaxStableDiffusionSafetyChecker
|
from diffusers.pipelines.stable_diffusion import FlaxStableDiffusionSafetyChecker
|
||||||
from diffusers.utils import check_min_version
|
from diffusers.utils import check_min_version
|
||||||
from flax import jax_utils
|
|
||||||
from flax.training import train_state
|
|
||||||
from flax.training.common_utils import shard
|
|
||||||
from huggingface_hub import HfFolder, Repository, create_repo, whoami
|
|
||||||
from jax.experimental.compilation_cache import compilation_cache as cc
|
|
||||||
from PIL import Image
|
|
||||||
from torchvision import transforms
|
|
||||||
from tqdm.auto import tqdm
|
|
||||||
from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel, set_seed
|
|
||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
|
|||||||
@ -26,13 +26,18 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from torch.utils.data import Dataset
|
|
||||||
|
|
||||||
import diffusers
|
|
||||||
import transformers
|
import transformers
|
||||||
from accelerate import Accelerator
|
from accelerate import Accelerator
|
||||||
from accelerate.logging import get_logger
|
from accelerate.logging import get_logger
|
||||||
from accelerate.utils import set_seed
|
from accelerate.utils import set_seed
|
||||||
|
from huggingface_hub import HfFolder, Repository, create_repo, whoami
|
||||||
|
from PIL import Image
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from torchvision import transforms
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
from transformers import AutoTokenizer, PretrainedConfig
|
||||||
|
|
||||||
|
import diffusers
|
||||||
from diffusers import (
|
from diffusers import (
|
||||||
AutoencoderKL,
|
AutoencoderKL,
|
||||||
DDPMScheduler,
|
DDPMScheduler,
|
||||||
@ -45,11 +50,6 @@ from diffusers.models.cross_attention import LoRACrossAttnProcessor
|
|||||||
from diffusers.optimization import get_scheduler
|
from diffusers.optimization import get_scheduler
|
||||||
from diffusers.utils import check_min_version, is_wandb_available
|
from diffusers.utils import check_min_version, is_wandb_available
|
||||||
from diffusers.utils.import_utils import is_xformers_available
|
from diffusers.utils.import_utils import is_xformers_available
|
||||||
from huggingface_hub import HfFolder, Repository, create_repo, whoami
|
|
||||||
from PIL import Image
|
|
||||||
from torchvision import transforms
|
|
||||||
from tqdm.auto import tqdm
|
|
||||||
from transformers import AutoTokenizer, PretrainedConfig
|
|
||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
|
|||||||
@ -5,12 +5,10 @@ import os
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
import colossalai
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from torch.utils.data import Dataset
|
|
||||||
|
|
||||||
import colossalai
|
|
||||||
from colossalai.context.parallel_mode import ParallelMode
|
from colossalai.context.parallel_mode import ParallelMode
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||||
@ -18,14 +16,16 @@ from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer
|
|||||||
from colossalai.nn.parallel.utils import get_static_torch_model
|
from colossalai.nn.parallel.utils import get_static_torch_model
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||||
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
|
|
||||||
from diffusers.optimization import get_scheduler
|
|
||||||
from huggingface_hub import HfFolder, Repository, create_repo, whoami
|
from huggingface_hub import HfFolder, Repository, create_repo, whoami
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
from torch.utils.data import Dataset
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
from transformers import AutoTokenizer, PretrainedConfig
|
from transformers import AutoTokenizer, PretrainedConfig
|
||||||
|
|
||||||
|
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
|
||||||
|
from diffusers.optimization import get_scheduler
|
||||||
|
|
||||||
|
|
||||||
disable_existing_loggers()
|
disable_existing_loggers()
|
||||||
logger = get_dist_logger()
|
logger = get_dist_logger()
|
||||||
|
|||||||
@ -11,11 +11,16 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from torch.utils.data import Dataset
|
|
||||||
|
|
||||||
from accelerate import Accelerator
|
from accelerate import Accelerator
|
||||||
from accelerate.logging import get_logger
|
from accelerate.logging import get_logger
|
||||||
from accelerate.utils import set_seed
|
from accelerate.utils import set_seed
|
||||||
|
from huggingface_hub import HfFolder, Repository, create_repo, whoami
|
||||||
|
from PIL import Image, ImageDraw
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from torchvision import transforms
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
from transformers import CLIPTextModel, CLIPTokenizer
|
||||||
|
|
||||||
from diffusers import (
|
from diffusers import (
|
||||||
AutoencoderKL,
|
AutoencoderKL,
|
||||||
DDPMScheduler,
|
DDPMScheduler,
|
||||||
@ -25,11 +30,6 @@ from diffusers import (
|
|||||||
)
|
)
|
||||||
from diffusers.optimization import get_scheduler
|
from diffusers.optimization import get_scheduler
|
||||||
from diffusers.utils import check_min_version
|
from diffusers.utils import check_min_version
|
||||||
from huggingface_hub import HfFolder, Repository, create_repo, whoami
|
|
||||||
from PIL import Image, ImageDraw
|
|
||||||
from torchvision import transforms
|
|
||||||
from tqdm.auto import tqdm
|
|
||||||
from transformers import CLIPTextModel, CLIPTokenizer
|
|
||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
|
|||||||
@ -10,22 +10,22 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from torch.utils.data import Dataset
|
|
||||||
|
|
||||||
from accelerate import Accelerator
|
from accelerate import Accelerator
|
||||||
from accelerate.logging import get_logger
|
from accelerate.logging import get_logger
|
||||||
from accelerate.utils import set_seed
|
from accelerate.utils import set_seed
|
||||||
|
from huggingface_hub import HfFolder, Repository, create_repo, whoami
|
||||||
|
from PIL import Image, ImageDraw
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from torchvision import transforms
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
from transformers import CLIPTextModel, CLIPTokenizer
|
||||||
|
|
||||||
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionInpaintPipeline, UNet2DConditionModel
|
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionInpaintPipeline, UNet2DConditionModel
|
||||||
from diffusers.loaders import AttnProcsLayers
|
from diffusers.loaders import AttnProcsLayers
|
||||||
from diffusers.models.cross_attention import LoRACrossAttnProcessor
|
from diffusers.models.cross_attention import LoRACrossAttnProcessor
|
||||||
from diffusers.optimization import get_scheduler
|
from diffusers.optimization import get_scheduler
|
||||||
from diffusers.utils import check_min_version
|
from diffusers.utils import check_min_version
|
||||||
from diffusers.utils.import_utils import is_xformers_available
|
from diffusers.utils.import_utils import is_xformers_available
|
||||||
from huggingface_hub import HfFolder, Repository, create_repo, whoami
|
|
||||||
from PIL import Image, ImageDraw
|
|
||||||
from torchvision import transforms
|
|
||||||
from tqdm.auto import tqdm
|
|
||||||
from transformers import CLIPTextModel, CLIPTokenizer
|
|
||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
|
|||||||
@ -1,9 +1,9 @@
|
|||||||
import torch
|
|
||||||
|
|
||||||
import intel_extension_for_pytorch as ipex
|
import intel_extension_for_pytorch as ipex
|
||||||
from diffusers import StableDiffusionPipeline
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
from diffusers import StableDiffusionPipeline
|
||||||
|
|
||||||
|
|
||||||
def image_grid(imgs, rows, cols):
|
def image_grid(imgs, rows, cols):
|
||||||
assert len(imgs) == rows * cols
|
assert len(imgs) == rows * cols
|
||||||
|
|||||||
@ -6,30 +6,30 @@ import random
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
import intel_extension_for_pytorch as ipex
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import PIL
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from torch.utils.data import Dataset
|
|
||||||
|
|
||||||
import intel_extension_for_pytorch as ipex
|
|
||||||
import PIL
|
|
||||||
from accelerate import Accelerator
|
from accelerate import Accelerator
|
||||||
from accelerate.logging import get_logger
|
from accelerate.logging import get_logger
|
||||||
from accelerate.utils import set_seed
|
from accelerate.utils import set_seed
|
||||||
from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel
|
|
||||||
from diffusers.optimization import get_scheduler
|
|
||||||
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
|
|
||||||
from diffusers.utils import check_min_version
|
|
||||||
from huggingface_hub import HfFolder, Repository, create_repo, whoami
|
from huggingface_hub import HfFolder, Repository, create_repo, whoami
|
||||||
|
|
||||||
# TODO: remove and import from diffusers.utils when the new version of diffusers is released
|
# TODO: remove and import from diffusers.utils when the new version of diffusers is released
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
from torch.utils.data import Dataset
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||||
|
|
||||||
|
from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel
|
||||||
|
from diffusers.optimization import get_scheduler
|
||||||
|
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
|
||||||
|
from diffusers.utils import check_min_version
|
||||||
|
|
||||||
|
|
||||||
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
|
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
|
||||||
PIL_INTERPOLATION = {
|
PIL_INTERPOLATION = {
|
||||||
|
|||||||
@ -8,26 +8,26 @@ import warnings
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
import datasets
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from torch.utils.data import Dataset
|
|
||||||
|
|
||||||
import datasets
|
|
||||||
import diffusers
|
|
||||||
import transformers
|
import transformers
|
||||||
from accelerate import Accelerator
|
from accelerate import Accelerator
|
||||||
from accelerate.logging import get_logger
|
from accelerate.logging import get_logger
|
||||||
from accelerate.utils import set_seed
|
from accelerate.utils import set_seed
|
||||||
|
from huggingface_hub import HfFolder, Repository, create_repo, whoami
|
||||||
|
from PIL import Image
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from torchvision import transforms
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
from transformers import AutoTokenizer, PretrainedConfig
|
||||||
|
|
||||||
|
import diffusers
|
||||||
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
|
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
|
||||||
from diffusers.optimization import get_scheduler
|
from diffusers.optimization import get_scheduler
|
||||||
from diffusers.utils import check_min_version
|
from diffusers.utils import check_min_version
|
||||||
from diffusers.utils.import_utils import is_xformers_available
|
from diffusers.utils.import_utils import is_xformers_available
|
||||||
from huggingface_hub import HfFolder, Repository, create_repo, whoami
|
|
||||||
from PIL import Image
|
|
||||||
from torchvision import transforms
|
|
||||||
from tqdm.auto import tqdm
|
|
||||||
from transformers import AutoTokenizer, PretrainedConfig
|
|
||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
|
|||||||
@ -21,29 +21,29 @@ import random
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
import datasets
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
|
|
||||||
import datasets
|
|
||||||
import diffusers
|
|
||||||
import transformers
|
import transformers
|
||||||
from accelerate import Accelerator
|
from accelerate import Accelerator
|
||||||
from accelerate.logging import get_logger
|
from accelerate.logging import get_logger
|
||||||
from accelerate.utils import set_seed
|
from accelerate.utils import set_seed
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
|
|
||||||
from diffusers.optimization import get_scheduler
|
|
||||||
from diffusers.training_utils import EMAModel
|
|
||||||
from diffusers.utils import check_min_version
|
|
||||||
from diffusers.utils.import_utils import is_xformers_available
|
|
||||||
from huggingface_hub import HfFolder, Repository, create_repo, whoami
|
from huggingface_hub import HfFolder, Repository, create_repo, whoami
|
||||||
from onnxruntime.training.ortmodule import ORTModule
|
from onnxruntime.training.ortmodule import ORTModule
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
from transformers import CLIPTextModel, CLIPTokenizer
|
from transformers import CLIPTextModel, CLIPTokenizer
|
||||||
|
|
||||||
|
import diffusers
|
||||||
|
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
|
||||||
|
from diffusers.optimization import get_scheduler
|
||||||
|
from diffusers.training_utils import EMAModel
|
||||||
|
from diffusers.utils import check_min_version
|
||||||
|
from diffusers.utils.import_utils import is_xformers_available
|
||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.13.0.dev0")
|
check_min_version("0.13.0.dev0")
|
||||||
|
|||||||
@ -21,19 +21,28 @@ import random
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
import datasets
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import PIL
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from torch.utils.data import Dataset
|
|
||||||
|
|
||||||
import datasets
|
|
||||||
import diffusers
|
|
||||||
import PIL
|
|
||||||
import transformers
|
import transformers
|
||||||
from accelerate import Accelerator
|
from accelerate import Accelerator
|
||||||
from accelerate.logging import get_logger
|
from accelerate.logging import get_logger
|
||||||
from accelerate.utils import set_seed
|
from accelerate.utils import set_seed
|
||||||
|
from huggingface_hub import HfFolder, Repository, create_repo, whoami
|
||||||
|
from onnxruntime.training.ortmodule import ORTModule
|
||||||
|
|
||||||
|
# TODO: remove and import from diffusers.utils when the new version of diffusers is released
|
||||||
|
from packaging import version
|
||||||
|
from PIL import Image
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from torchvision import transforms
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
from transformers import CLIPTextModel, CLIPTokenizer
|
||||||
|
|
||||||
|
import diffusers
|
||||||
from diffusers import (
|
from diffusers import (
|
||||||
AutoencoderKL,
|
AutoencoderKL,
|
||||||
DDPMScheduler,
|
DDPMScheduler,
|
||||||
@ -45,15 +54,6 @@ from diffusers import (
|
|||||||
from diffusers.optimization import get_scheduler
|
from diffusers.optimization import get_scheduler
|
||||||
from diffusers.utils import check_min_version, is_wandb_available
|
from diffusers.utils import check_min_version, is_wandb_available
|
||||||
from diffusers.utils.import_utils import is_xformers_available
|
from diffusers.utils.import_utils import is_xformers_available
|
||||||
from huggingface_hub import HfFolder, Repository, create_repo, whoami
|
|
||||||
from onnxruntime.training.ortmodule import ORTModule
|
|
||||||
|
|
||||||
# TODO: remove and import from diffusers.utils when the new version of diffusers is released
|
|
||||||
from packaging import version
|
|
||||||
from PIL import Image
|
|
||||||
from torchvision import transforms
|
|
||||||
from tqdm.auto import tqdm
|
|
||||||
from transformers import CLIPTextModel, CLIPTokenizer
|
|
||||||
|
|
||||||
|
|
||||||
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
|
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
|
||||||
|
|||||||
@ -6,23 +6,23 @@ import os
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
import datasets
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
import datasets
|
|
||||||
import diffusers
|
|
||||||
from accelerate import Accelerator
|
from accelerate import Accelerator
|
||||||
from accelerate.logging import get_logger
|
from accelerate.logging import get_logger
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
|
|
||||||
from diffusers.optimization import get_scheduler
|
|
||||||
from diffusers.training_utils import EMAModel
|
|
||||||
from diffusers.utils import check_min_version
|
|
||||||
from huggingface_hub import HfFolder, Repository, create_repo, whoami
|
from huggingface_hub import HfFolder, Repository, create_repo, whoami
|
||||||
from onnxruntime.training.ortmodule import ORTModule
|
from onnxruntime.training.ortmodule import ORTModule
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
|
import diffusers
|
||||||
|
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
|
||||||
|
from diffusers.optimization import get_scheduler
|
||||||
|
from diffusers.training_utils import EMAModel
|
||||||
|
from diffusers.utils import check_min_version
|
||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.13.0.dev0")
|
check_min_version("0.13.0.dev0")
|
||||||
|
|||||||
@ -24,6 +24,7 @@ import unittest
|
|||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from accelerate.utils import write_basic_config
|
from accelerate.utils import write_basic_config
|
||||||
|
|
||||||
from diffusers.utils import slow
|
from diffusers.utils import slow
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -21,30 +21,30 @@ import random
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
import accelerate
|
||||||
|
import datasets
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
|
|
||||||
import accelerate
|
|
||||||
import datasets
|
|
||||||
import diffusers
|
|
||||||
import transformers
|
import transformers
|
||||||
from accelerate import Accelerator
|
from accelerate import Accelerator
|
||||||
from accelerate.logging import get_logger
|
from accelerate.logging import get_logger
|
||||||
from accelerate.utils import set_seed
|
from accelerate.utils import set_seed
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
|
|
||||||
from diffusers.optimization import get_scheduler
|
|
||||||
from diffusers.training_utils import EMAModel
|
|
||||||
from diffusers.utils import check_min_version, deprecate
|
|
||||||
from diffusers.utils.import_utils import is_xformers_available
|
|
||||||
from huggingface_hub import HfFolder, Repository, create_repo, whoami
|
from huggingface_hub import HfFolder, Repository, create_repo, whoami
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
from transformers import CLIPTextModel, CLIPTokenizer
|
from transformers import CLIPTextModel, CLIPTokenizer
|
||||||
|
|
||||||
|
import diffusers
|
||||||
|
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
|
||||||
|
from diffusers.optimization import get_scheduler
|
||||||
|
from diffusers.training_utils import EMAModel
|
||||||
|
from diffusers.utils import check_min_version, deprecate
|
||||||
|
from diffusers.utils.import_utils import is_xformers_available
|
||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.13.0.dev0")
|
check_min_version("0.13.0.dev0")
|
||||||
|
|||||||
@ -6,15 +6,22 @@ import random
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import torch.utils.checkpoint
|
|
||||||
|
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
|
import numpy as np
|
||||||
import optax
|
import optax
|
||||||
|
import torch
|
||||||
|
import torch.utils.checkpoint
|
||||||
import transformers
|
import transformers
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
|
from flax import jax_utils
|
||||||
|
from flax.training import train_state
|
||||||
|
from flax.training.common_utils import shard
|
||||||
|
from huggingface_hub import HfFolder, Repository, create_repo, whoami
|
||||||
|
from torchvision import transforms
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel, set_seed
|
||||||
|
|
||||||
from diffusers import (
|
from diffusers import (
|
||||||
FlaxAutoencoderKL,
|
FlaxAutoencoderKL,
|
||||||
FlaxDDPMScheduler,
|
FlaxDDPMScheduler,
|
||||||
@ -24,13 +31,6 @@ from diffusers import (
|
|||||||
)
|
)
|
||||||
from diffusers.pipelines.stable_diffusion import FlaxStableDiffusionSafetyChecker
|
from diffusers.pipelines.stable_diffusion import FlaxStableDiffusionSafetyChecker
|
||||||
from diffusers.utils import check_min_version
|
from diffusers.utils import check_min_version
|
||||||
from flax import jax_utils
|
|
||||||
from flax.training import train_state
|
|
||||||
from flax.training.common_utils import shard
|
|
||||||
from huggingface_hub import HfFolder, Repository, create_repo, whoami
|
|
||||||
from torchvision import transforms
|
|
||||||
from tqdm.auto import tqdm
|
|
||||||
from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel, set_seed
|
|
||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
|
|||||||
@ -22,28 +22,28 @@ import random
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
import datasets
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
|
|
||||||
import datasets
|
|
||||||
import diffusers
|
|
||||||
import transformers
|
import transformers
|
||||||
from accelerate import Accelerator
|
from accelerate import Accelerator
|
||||||
from accelerate.logging import get_logger
|
from accelerate.logging import get_logger
|
||||||
from accelerate.utils import set_seed
|
from accelerate.utils import set_seed
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
|
from huggingface_hub import HfFolder, Repository, create_repo, whoami
|
||||||
|
from torchvision import transforms
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
from transformers import CLIPTextModel, CLIPTokenizer
|
||||||
|
|
||||||
|
import diffusers
|
||||||
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
|
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
|
||||||
from diffusers.loaders import AttnProcsLayers
|
from diffusers.loaders import AttnProcsLayers
|
||||||
from diffusers.models.cross_attention import LoRACrossAttnProcessor
|
from diffusers.models.cross_attention import LoRACrossAttnProcessor
|
||||||
from diffusers.optimization import get_scheduler
|
from diffusers.optimization import get_scheduler
|
||||||
from diffusers.utils import check_min_version, is_wandb_available
|
from diffusers.utils import check_min_version, is_wandb_available
|
||||||
from diffusers.utils.import_utils import is_xformers_available
|
from diffusers.utils.import_utils import is_xformers_available
|
||||||
from huggingface_hub import HfFolder, Repository, create_repo, whoami
|
|
||||||
from torchvision import transforms
|
|
||||||
from tqdm.auto import tqdm
|
|
||||||
from transformers import CLIPTextModel, CLIPTokenizer
|
|
||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
|
|||||||
@ -22,17 +22,25 @@ from pathlib import Path
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import PIL
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from torch.utils.data import Dataset
|
|
||||||
|
|
||||||
import diffusers
|
|
||||||
import PIL
|
|
||||||
import transformers
|
import transformers
|
||||||
from accelerate import Accelerator
|
from accelerate import Accelerator
|
||||||
from accelerate.logging import get_logger
|
from accelerate.logging import get_logger
|
||||||
from accelerate.utils import set_seed
|
from accelerate.utils import set_seed
|
||||||
|
from huggingface_hub import HfFolder, Repository, create_repo, whoami
|
||||||
|
|
||||||
|
# TODO: remove and import from diffusers.utils when the new version of diffusers is released
|
||||||
|
from packaging import version
|
||||||
|
from PIL import Image
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from torchvision import transforms
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
from transformers import CLIPTextModel, CLIPTokenizer
|
||||||
|
|
||||||
|
import diffusers
|
||||||
from diffusers import (
|
from diffusers import (
|
||||||
AutoencoderKL,
|
AutoencoderKL,
|
||||||
DDPMScheduler,
|
DDPMScheduler,
|
||||||
@ -44,14 +52,6 @@ from diffusers import (
|
|||||||
from diffusers.optimization import get_scheduler
|
from diffusers.optimization import get_scheduler
|
||||||
from diffusers.utils import check_min_version, is_wandb_available
|
from diffusers.utils import check_min_version, is_wandb_available
|
||||||
from diffusers.utils.import_utils import is_xformers_available
|
from diffusers.utils.import_utils import is_xformers_available
|
||||||
from huggingface_hub import HfFolder, Repository, create_repo, whoami
|
|
||||||
|
|
||||||
# TODO: remove and import from diffusers.utils when the new version of diffusers is released
|
|
||||||
from packaging import version
|
|
||||||
from PIL import Image
|
|
||||||
from torchvision import transforms
|
|
||||||
from tqdm.auto import tqdm
|
|
||||||
from transformers import CLIPTextModel, CLIPTokenizer
|
|
||||||
|
|
||||||
|
|
||||||
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
|
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
|
||||||
|
|||||||
@ -6,16 +6,27 @@ import random
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import torch.utils.checkpoint
|
|
||||||
from torch.utils.data import Dataset
|
|
||||||
|
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
|
import numpy as np
|
||||||
import optax
|
import optax
|
||||||
import PIL
|
import PIL
|
||||||
|
import torch
|
||||||
|
import torch.utils.checkpoint
|
||||||
import transformers
|
import transformers
|
||||||
|
from flax import jax_utils
|
||||||
|
from flax.training import train_state
|
||||||
|
from flax.training.common_utils import shard
|
||||||
|
from huggingface_hub import HfFolder, Repository, create_repo, whoami
|
||||||
|
|
||||||
|
# TODO: remove and import from diffusers.utils when the new version of diffusers is released
|
||||||
|
from packaging import version
|
||||||
|
from PIL import Image
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from torchvision import transforms
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel, set_seed
|
||||||
|
|
||||||
from diffusers import (
|
from diffusers import (
|
||||||
FlaxAutoencoderKL,
|
FlaxAutoencoderKL,
|
||||||
FlaxDDPMScheduler,
|
FlaxDDPMScheduler,
|
||||||
@ -25,17 +36,6 @@ from diffusers import (
|
|||||||
)
|
)
|
||||||
from diffusers.pipelines.stable_diffusion import FlaxStableDiffusionSafetyChecker
|
from diffusers.pipelines.stable_diffusion import FlaxStableDiffusionSafetyChecker
|
||||||
from diffusers.utils import check_min_version
|
from diffusers.utils import check_min_version
|
||||||
from flax import jax_utils
|
|
||||||
from flax.training import train_state
|
|
||||||
from flax.training.common_utils import shard
|
|
||||||
from huggingface_hub import HfFolder, Repository, create_repo, whoami
|
|
||||||
|
|
||||||
# TODO: remove and import from diffusers.utils when the new version of diffusers is released
|
|
||||||
from packaging import version
|
|
||||||
from PIL import Image
|
|
||||||
from torchvision import transforms
|
|
||||||
from tqdm.auto import tqdm
|
|
||||||
from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel, set_seed
|
|
||||||
|
|
||||||
|
|
||||||
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
|
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
|
||||||
|
|||||||
@ -6,24 +6,24 @@ import os
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
import accelerate
|
import accelerate
|
||||||
import datasets
|
import datasets
|
||||||
import diffusers
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
from accelerate import Accelerator
|
from accelerate import Accelerator
|
||||||
from accelerate.logging import get_logger
|
from accelerate.logging import get_logger
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
|
|
||||||
from diffusers.optimization import get_scheduler
|
|
||||||
from diffusers.training_utils import EMAModel
|
|
||||||
from diffusers.utils import check_min_version
|
|
||||||
from huggingface_hub import HfFolder, Repository, create_repo, whoami
|
from huggingface_hub import HfFolder, Repository, create_repo, whoami
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
|
import diffusers
|
||||||
|
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
|
||||||
|
from diffusers.optimization import get_scheduler
|
||||||
|
from diffusers.training_utils import EMAModel
|
||||||
|
from diffusers.utils import check_min_version
|
||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.13.0.dev0")
|
check_min_version("0.13.0.dev0")
|
||||||
|
|||||||
@ -1,3 +1,18 @@
|
|||||||
[tool.black]
|
[tool.black]
|
||||||
line-length = 119
|
line-length = 119
|
||||||
target-version = ['py36']
|
target-version = ['py37']
|
||||||
|
|
||||||
|
[tool.ruff]
|
||||||
|
# Never enforce `E501` (line length violations).
|
||||||
|
ignore = ["E501", "E741", "W605"]
|
||||||
|
select = ["E", "F", "I", "W"]
|
||||||
|
line-length = 119
|
||||||
|
|
||||||
|
# Ignore import violations in all `__init__.py` files.
|
||||||
|
[tool.ruff.per-file-ignores]
|
||||||
|
"__init__.py" = ["E402", "F401", "F403", "F811"]
|
||||||
|
"src/diffusers/utils/dummy_*.py" = ["F401"]
|
||||||
|
|
||||||
|
[tool.ruff.isort]
|
||||||
|
lines-after-imports = 2
|
||||||
|
known-first-party = ["diffusers"]
|
||||||
|
|||||||
@ -19,9 +19,9 @@ import json
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from transformers.file_utils import has_file
|
||||||
|
|
||||||
from diffusers import UNet2DConditionModel, UNet2DModel
|
from diffusers import UNet2DConditionModel, UNet2DModel
|
||||||
from transformers.file_utils import has_file
|
|
||||||
|
|
||||||
|
|
||||||
do_only_config = False
|
do_only_config = False
|
||||||
|
|||||||
@ -1,8 +1,8 @@
|
|||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
|
import OmegaConf
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import OmegaConf
|
|
||||||
from diffusers import DDIMScheduler, LDMPipeline, UNetLDMModel, VQModel
|
from diffusers import DDIMScheduler, LDMPipeline, UNetLDMModel, VQModel
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -5,11 +5,11 @@ import os
|
|||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from audio_diffusion.models import DiffusionAttnUnet1D
|
||||||
|
from diffusion import sampling
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from audio_diffusion.models import DiffusionAttnUnet1D
|
|
||||||
from diffusers import DanceDiffusionPipeline, IPNDMScheduler, UNet1DModel
|
from diffusers import DanceDiffusionPipeline, IPNDMScheduler, UNet1DModel
|
||||||
from diffusion import sampling
|
|
||||||
|
|
||||||
|
|
||||||
MODELS_MAP = {
|
MODELS_MAP = {
|
||||||
|
|||||||
@ -7,7 +7,6 @@ import os.path as osp
|
|||||||
import re
|
import re
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from safetensors.torch import load_file, save_file
|
from safetensors.torch import load_file, save_file
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -2,9 +2,9 @@ import argparse
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from torchvision.datasets.utils import download_url
|
||||||
|
|
||||||
from diffusers import AutoencoderKL, DDIMScheduler, DiTPipeline, Transformer2DModel
|
from diffusers import AutoencoderKL, DDIMScheduler, DiTPipeline, Transformer2DModel
|
||||||
from torchvision.datasets.utils import download_url
|
|
||||||
|
|
||||||
|
|
||||||
pretrained_models = {512: "DiT-XL-2-512x512.pt", 256: "DiT-XL-2-256x256.pt"}
|
pretrained_models = {512: "DiT-XL-2-512x512.pt", 256: "DiT-XL-2-256x256.pt"}
|
||||||
|
|||||||
@ -1,9 +1,9 @@
|
|||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
import huggingface_hub
|
import huggingface_hub
|
||||||
import k_diffusion as K
|
import k_diffusion as K
|
||||||
|
import torch
|
||||||
|
|
||||||
from diffusers import UNet2DConditionModel
|
from diffusers import UNet2DConditionModel
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -2,13 +2,13 @@ import argparse
|
|||||||
import tempfile
|
import tempfile
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from accelerate import load_checkpoint_and_dispatch
|
from accelerate import load_checkpoint_and_dispatch
|
||||||
|
from transformers import CLIPTextModelWithProjection, CLIPTokenizer
|
||||||
|
|
||||||
from diffusers import UnCLIPPipeline, UNet2DConditionModel, UNet2DModel
|
from diffusers import UnCLIPPipeline, UNet2DConditionModel, UNet2DModel
|
||||||
from diffusers.models.prior_transformer import PriorTransformer
|
from diffusers.models.prior_transformer import PriorTransformer
|
||||||
from diffusers.pipelines.unclip.text_proj import UnCLIPTextProjModel
|
from diffusers.pipelines.unclip.text_proj import UnCLIPTextProjModel
|
||||||
from diffusers.schedulers.scheduling_unclip import UnCLIPScheduler
|
from diffusers.schedulers.scheduling_unclip import UnCLIPScheduler
|
||||||
from transformers import CLIPTextModelWithProjection, CLIPTokenizer
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@ -249,7 +249,6 @@ DECODER_CONFIG = {
|
|||||||
"class_embed_type": "identity",
|
"class_embed_type": "identity",
|
||||||
"attention_head_dim": 64,
|
"attention_head_dim": 64,
|
||||||
"resnet_time_scale_shift": "scale_shift",
|
"resnet_time_scale_shift": "scale_shift",
|
||||||
"class_embed_type": "identity",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -355,5 +355,5 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
pipe = LDMPipeline(unet=model, scheduler=scheduler, vae=vqvae)
|
pipe = LDMPipeline(unet=model, scheduler=scheduler, vae=vqvae)
|
||||||
pipe.save_pretrained(args.dump_path)
|
pipe.save_pretrained(args.dump_path)
|
||||||
except:
|
except: # noqa: E722
|
||||||
model.save_pretrained(args.dump_path)
|
model.save_pretrained(args.dump_path)
|
||||||
|
|||||||
@ -181,5 +181,5 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
pipe = ScoreSdeVePipeline(unet=model, scheduler=scheduler)
|
pipe = ScoreSdeVePipeline(unet=model, scheduler=scheduler)
|
||||||
pipe.save_pretrained(args.dump_path)
|
pipe.save_pretrained(args.dump_path)
|
||||||
except:
|
except: # noqa: E722
|
||||||
model.save_pretrained(args.dump_path)
|
model.save_pretrained(args.dump_path)
|
||||||
|
|||||||
@ -17,12 +17,12 @@ import os
|
|||||||
import shutil
|
import shutil
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import onnx
|
||||||
import torch
|
import torch
|
||||||
|
from packaging import version
|
||||||
from torch.onnx import export
|
from torch.onnx import export
|
||||||
|
|
||||||
import onnx
|
|
||||||
from diffusers import OnnxRuntimeModel, OnnxStableDiffusionPipeline, StableDiffusionPipeline
|
from diffusers import OnnxRuntimeModel, OnnxStableDiffusionPipeline, StableDiffusionPipeline
|
||||||
from packaging import version
|
|
||||||
|
|
||||||
|
|
||||||
is_torch_less_than_1_11 = version.parse(version.parse(torch.__version__).base_version) < version.parse("1.11")
|
is_torch_less_than_1_11 = version.parse(version.parse(torch.__version__).base_version) < version.parse("1.11")
|
||||||
|
|||||||
@ -1,8 +1,9 @@
|
|||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
from diffusers import UnCLIPImageVariationPipeline, UnCLIPPipeline
|
|
||||||
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
||||||
|
|
||||||
|
from diffusers import UnCLIPImageVariationPipeline, UnCLIPPipeline
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|||||||
@ -1,9 +1,10 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import io
|
import io
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
import torch
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
from diffusers import AutoencoderKL
|
from diffusers import AutoencoderKL
|
||||||
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
|
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
|
||||||
assign_to_checkpoint,
|
assign_to_checkpoint,
|
||||||
@ -12,7 +13,6 @@ from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
|
|||||||
renew_vae_attention_paths,
|
renew_vae_attention_paths,
|
||||||
renew_vae_resnet_paths,
|
renew_vae_resnet_paths,
|
||||||
)
|
)
|
||||||
from omegaconf import OmegaConf
|
|
||||||
|
|
||||||
|
|
||||||
def custom_convert_ldm_vae_checkpoint(checkpoint, config):
|
def custom_convert_ldm_vae_checkpoint(checkpoint, config):
|
||||||
|
|||||||
@ -18,6 +18,12 @@ import argparse
|
|||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from transformers import (
|
||||||
|
CLIPFeatureExtractor,
|
||||||
|
CLIPTextModelWithProjection,
|
||||||
|
CLIPTokenizer,
|
||||||
|
CLIPVisionModelWithProjection,
|
||||||
|
)
|
||||||
|
|
||||||
from diffusers import (
|
from diffusers import (
|
||||||
AutoencoderKL,
|
AutoencoderKL,
|
||||||
@ -31,12 +37,6 @@ from diffusers import (
|
|||||||
VersatileDiffusionPipeline,
|
VersatileDiffusionPipeline,
|
||||||
)
|
)
|
||||||
from diffusers.pipelines.versatile_diffusion.modeling_text_unet import UNetFlatConditionModel
|
from diffusers.pipelines.versatile_diffusion.modeling_text_unet import UNetFlatConditionModel
|
||||||
from transformers import (
|
|
||||||
CLIPFeatureExtractor,
|
|
||||||
CLIPTextModelWithProjection,
|
|
||||||
CLIPTokenizer,
|
|
||||||
CLIPVisionModelWithProjection,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
SCHEDULER_CONFIG = Namespace(
|
SCHEDULER_CONFIG = Namespace(
|
||||||
|
|||||||
@ -36,14 +36,14 @@ import argparse
|
|||||||
import tempfile
|
import tempfile
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
|
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
|
||||||
from diffusers import Transformer2DModel, VQDiffusionPipeline, VQDiffusionScheduler, VQModel
|
|
||||||
from diffusers.pipelines.vq_diffusion.pipeline_vq_diffusion import LearnedClassifierFreeSamplingEmbeddings
|
|
||||||
from transformers import CLIPTextModel, CLIPTokenizer
|
from transformers import CLIPTextModel, CLIPTokenizer
|
||||||
from yaml.loader import FullLoader
|
from yaml.loader import FullLoader
|
||||||
|
|
||||||
|
from diffusers import Transformer2DModel, VQDiffusionPipeline, VQDiffusionScheduler, VQModel
|
||||||
|
from diffusers.pipelines.vq_diffusion.pipeline_vq_diffusion import LearnedClassifierFreeSamplingEmbeddings
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
|
|||||||
@ -1,9 +1,9 @@
|
|||||||
import random
|
import random
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from huggingface_hub import HfApi
|
||||||
|
|
||||||
from diffusers import UNet2DModel
|
from diffusers import UNet2DModel
|
||||||
from huggingface_hub import HfApi
|
|
||||||
|
|
||||||
|
|
||||||
api = HfApi()
|
api = HfApi()
|
||||||
|
|||||||
6
setup.py
6
setup.py
@ -80,10 +80,9 @@ from setuptools import find_packages, setup
|
|||||||
_deps = [
|
_deps = [
|
||||||
"Pillow", # keep the PIL.Image.Resampling deprecation away
|
"Pillow", # keep the PIL.Image.Resampling deprecation away
|
||||||
"accelerate>=0.11.0",
|
"accelerate>=0.11.0",
|
||||||
"black==22.12",
|
"black~=23.1",
|
||||||
"datasets",
|
"datasets",
|
||||||
"filelock",
|
"filelock",
|
||||||
"flake8>=3.8.3",
|
|
||||||
"flax>=0.4.1",
|
"flax>=0.4.1",
|
||||||
"hf-doc-builder>=0.3.0",
|
"hf-doc-builder>=0.3.0",
|
||||||
"huggingface-hub>=0.10.0",
|
"huggingface-hub>=0.10.0",
|
||||||
@ -99,6 +98,7 @@ _deps = [
|
|||||||
"pytest",
|
"pytest",
|
||||||
"pytest-timeout",
|
"pytest-timeout",
|
||||||
"pytest-xdist",
|
"pytest-xdist",
|
||||||
|
"ruff>=0.0.241",
|
||||||
"safetensors",
|
"safetensors",
|
||||||
"sentencepiece>=0.1.91,!=0.1.92",
|
"sentencepiece>=0.1.91,!=0.1.92",
|
||||||
"scipy",
|
"scipy",
|
||||||
@ -178,7 +178,7 @@ extras = {}
|
|||||||
|
|
||||||
|
|
||||||
extras = {}
|
extras = {}
|
||||||
extras["quality"] = deps_list("black", "isort", "flake8", "hf-doc-builder")
|
extras["quality"] = deps_list("black", "isort", "ruff", "hf-doc-builder")
|
||||||
extras["docs"] = deps_list("hf-doc-builder")
|
extras["docs"] = deps_list("hf-doc-builder")
|
||||||
extras["training"] = deps_list("accelerate", "datasets", "tensorboard", "Jinja2")
|
extras["training"] = deps_list("accelerate", "datasets", "tensorboard", "Jinja2")
|
||||||
extras["test"] = deps_list(
|
extras["test"] = deps_list(
|
||||||
|
|||||||
@ -26,7 +26,6 @@ from pathlib import PosixPath
|
|||||||
from typing import Any, Dict, Tuple, Union
|
from typing import Any, Dict, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
|
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
|
||||||
from requests import HTTPError
|
from requests import HTTPError
|
||||||
|
|||||||
@ -4,10 +4,9 @@
|
|||||||
deps = {
|
deps = {
|
||||||
"Pillow": "Pillow",
|
"Pillow": "Pillow",
|
||||||
"accelerate": "accelerate>=0.11.0",
|
"accelerate": "accelerate>=0.11.0",
|
||||||
"black": "black==22.12",
|
"black": "black~=23.1",
|
||||||
"datasets": "datasets",
|
"datasets": "datasets",
|
||||||
"filelock": "filelock",
|
"filelock": "filelock",
|
||||||
"flake8": "flake8>=3.8.3",
|
|
||||||
"flax": "flax>=0.4.1",
|
"flax": "flax>=0.4.1",
|
||||||
"hf-doc-builder": "hf-doc-builder>=0.3.0",
|
"hf-doc-builder": "hf-doc-builder>=0.3.0",
|
||||||
"huggingface-hub": "huggingface-hub>=0.10.0",
|
"huggingface-hub": "huggingface-hub>=0.10.0",
|
||||||
@ -23,6 +22,7 @@ deps = {
|
|||||||
"pytest": "pytest",
|
"pytest": "pytest",
|
||||||
"pytest-timeout": "pytest-timeout",
|
"pytest-timeout": "pytest-timeout",
|
||||||
"pytest-xdist": "pytest-xdist",
|
"pytest-xdist": "pytest-xdist",
|
||||||
|
"ruff": "ruff>=0.0.241",
|
||||||
"safetensors": "safetensors",
|
"safetensors": "safetensors",
|
||||||
"sentencepiece": "sentencepiece>=0.1.91,!=0.1.92",
|
"sentencepiece": "sentencepiece>=0.1.91,!=0.1.92",
|
||||||
"scipy": "scipy",
|
"scipy": "scipy",
|
||||||
|
|||||||
@ -14,7 +14,6 @@
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import tqdm
|
import tqdm
|
||||||
|
|
||||||
from ...models.unet_1d import UNet1DModel
|
from ...models.unet_1d import UNet1DModel
|
||||||
@ -57,13 +56,13 @@ class ValueGuidedRLPipeline(DiffusionPipeline):
|
|||||||
for key in self.data.keys():
|
for key in self.data.keys():
|
||||||
try:
|
try:
|
||||||
self.means[key] = self.data[key].mean()
|
self.means[key] = self.data[key].mean()
|
||||||
except:
|
except: # noqa: E722
|
||||||
pass
|
pass
|
||||||
self.stds = dict()
|
self.stds = dict()
|
||||||
for key in self.data.keys():
|
for key in self.data.keys():
|
||||||
try:
|
try:
|
||||||
self.stds[key] = self.data[key].std()
|
self.stds[key] = self.data[key].std()
|
||||||
except:
|
except: # noqa: E722
|
||||||
pass
|
pass
|
||||||
self.state_dim = env.observation_space.shape[0]
|
self.state_dim = env.observation_space.shape[0]
|
||||||
self.action_dim = env.action_space.shape[0]
|
self.action_dim = env.action_space.shape[0]
|
||||||
|
|||||||
@ -16,10 +16,9 @@
|
|||||||
|
|
||||||
from pickle import UnpicklingError
|
from pickle import UnpicklingError
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
|
import numpy as np
|
||||||
from flax.serialization import from_bytes
|
from flax.serialization import from_bytes
|
||||||
from flax.traverse_util import flatten_dict
|
from flax.traverse_util import flatten_dict
|
||||||
|
|
||||||
|
|||||||
@ -20,11 +20,10 @@ from functools import partial
|
|||||||
from typing import Callable, List, Optional, Tuple, Union
|
from typing import Callable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor, device
|
|
||||||
|
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
|
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
|
||||||
from requests import HTTPError
|
from requests import HTTPError
|
||||||
|
from torch import Tensor, device
|
||||||
|
|
||||||
from .. import __version__
|
from .. import __version__
|
||||||
from ..utils import (
|
from ..utils import (
|
||||||
@ -500,7 +499,7 @@ class ModelMixin(torch.nn.Module):
|
|||||||
subfolder=subfolder,
|
subfolder=subfolder,
|
||||||
user_agent=user_agent,
|
user_agent=user_agent,
|
||||||
)
|
)
|
||||||
except:
|
except: # noqa: E722
|
||||||
pass
|
pass
|
||||||
if model_file is None:
|
if model_file is None:
|
||||||
model_file = _get_model_file(
|
model_file = _get_model_file(
|
||||||
|
|||||||
@ -2,7 +2,6 @@ from dataclasses import dataclass
|
|||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import PIL
|
import PIL
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
|||||||
@ -3,7 +3,6 @@ from typing import Optional, Tuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from transformers import RobertaPreTrainedModel, XLMRobertaConfig, XLMRobertaModel
|
from transformers import RobertaPreTrainedModel, XLMRobertaConfig, XLMRobertaModel
|
||||||
from transformers.utils import ModelOutput
|
from transformers.utils import ModelOutput
|
||||||
|
|
||||||
|
|||||||
@ -16,11 +16,11 @@ import inspect
|
|||||||
from typing import Any, Callable, Dict, List, Optional, Union
|
from typing import Any, Callable, Dict, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from diffusers.utils import is_accelerate_available
|
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from transformers import CLIPFeatureExtractor, XLMRobertaTokenizer
|
from transformers import CLIPFeatureExtractor, XLMRobertaTokenizer
|
||||||
|
|
||||||
|
from diffusers.utils import is_accelerate_available
|
||||||
|
|
||||||
from ...configuration_utils import FrozenDict
|
from ...configuration_utils import FrozenDict
|
||||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||||
from ...schedulers import KarrasDiffusionSchedulers
|
from ...schedulers import KarrasDiffusionSchedulers
|
||||||
|
|||||||
@ -16,13 +16,13 @@ import inspect
|
|||||||
from typing import Callable, List, Optional, Union
|
from typing import Callable, List, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
|
||||||
|
|
||||||
import PIL
|
import PIL
|
||||||
from diffusers.utils import is_accelerate_available
|
import torch
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from transformers import CLIPFeatureExtractor, XLMRobertaTokenizer
|
from transformers import CLIPFeatureExtractor, XLMRobertaTokenizer
|
||||||
|
|
||||||
|
from diffusers.utils import is_accelerate_available
|
||||||
|
|
||||||
from ...configuration_utils import FrozenDict
|
from ...configuration_utils import FrozenDict
|
||||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||||
from ...schedulers import KarrasDiffusionSchedulers
|
from ...schedulers import KarrasDiffusionSchedulers
|
||||||
|
|||||||
@ -1,3 +1,2 @@
|
|||||||
# flake8: noqa
|
|
||||||
from .mel import Mel
|
from .mel import Mel
|
||||||
from .pipeline_audio_diffusion import AudioDiffusionPipeline
|
from .pipeline_audio_diffusion import AudioDiffusionPipeline
|
||||||
|
|||||||
@ -18,7 +18,6 @@ from typing import List, Tuple, Union
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||||
|
|||||||
@ -1,2 +1 @@
|
|||||||
# flake8: noqa
|
|
||||||
from .pipeline_dance_diffusion import DanceDiffusionPipeline
|
from .pipeline_dance_diffusion import DanceDiffusionPipeline
|
||||||
|
|||||||
@ -1,2 +1 @@
|
|||||||
# flake8: noqa
|
|
||||||
from .pipeline_ddim import DDIMPipeline
|
from .pipeline_ddim import DDIMPipeline
|
||||||
|
|||||||
@ -1,2 +1 @@
|
|||||||
# flake8: noqa
|
|
||||||
from .pipeline_ddpm import DDPMPipeline
|
from .pipeline_ddpm import DDPMPipeline
|
||||||
|
|||||||
@ -1,4 +1,3 @@
|
|||||||
# flake8: noqa
|
|
||||||
from ...utils import is_transformers_available
|
from ...utils import is_transformers_available
|
||||||
from .pipeline_latent_diffusion_superresolution import LDMSuperResolutionPipeline
|
from .pipeline_latent_diffusion_superresolution import LDMSuperResolutionPipeline
|
||||||
|
|
||||||
|
|||||||
@ -18,7 +18,6 @@ from typing import List, Optional, Tuple, Union
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
|
|
||||||
from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer
|
from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from transformers.modeling_outputs import BaseModelOutput
|
from transformers.modeling_outputs import BaseModelOutput
|
||||||
|
|||||||
@ -2,11 +2,10 @@ import inspect
|
|||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import PIL
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
|
|
||||||
import PIL
|
|
||||||
|
|
||||||
from ...models import UNet2DModel, VQModel
|
from ...models import UNet2DModel, VQModel
|
||||||
from ...schedulers import (
|
from ...schedulers import (
|
||||||
DDIMScheduler,
|
DDIMScheduler,
|
||||||
|
|||||||
@ -1,2 +1 @@
|
|||||||
# flake8: noqa
|
|
||||||
from .pipeline_latent_diffusion_uncond import LDMPipeline
|
from .pipeline_latent_diffusion_uncond import LDMPipeline
|
||||||
|
|||||||
@ -21,7 +21,6 @@ from pathlib import Path
|
|||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
|
|
||||||
from ..utils import ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME, is_onnx_available, logging
|
from ..utils import ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME, is_onnx_available, logging
|
||||||
|
|||||||
@ -2,7 +2,6 @@ from dataclasses import dataclass
|
|||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import PIL
|
import PIL
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
|||||||
@ -13,7 +13,6 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from transformers import CLIPPreTrainedModel, CLIPVisionModel
|
from transformers import CLIPPreTrainedModel, CLIPVisionModel
|
||||||
|
|
||||||
from ...models.attention import BasicTransformerBlock
|
from ...models.attention import BasicTransformerBlock
|
||||||
|
|||||||
@ -16,12 +16,12 @@ import inspect
|
|||||||
from typing import Callable, List, Optional, Union
|
from typing import Callable, List, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
|
||||||
|
|
||||||
import PIL
|
import PIL
|
||||||
from diffusers.utils import is_accelerate_available
|
import torch
|
||||||
from transformers import CLIPFeatureExtractor
|
from transformers import CLIPFeatureExtractor
|
||||||
|
|
||||||
|
from diffusers.utils import is_accelerate_available
|
||||||
|
|
||||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||||
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||||
from ...utils import logging, randn_tensor
|
from ...utils import logging, randn_tensor
|
||||||
|
|||||||
@ -19,9 +19,8 @@ import inspect
|
|||||||
import os
|
import os
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
import flax
|
import flax
|
||||||
|
import numpy as np
|
||||||
import PIL
|
import PIL
|
||||||
from flax.core.frozen_dict import FrozenDict
|
from flax.core.frozen_dict import FrozenDict
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
|||||||
@ -22,15 +22,15 @@ from pathlib import Path
|
|||||||
from typing import Any, Callable, Dict, List, Optional, Union
|
from typing import Any, Callable, Dict, List, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
|
||||||
|
|
||||||
import diffusers
|
|
||||||
import PIL
|
import PIL
|
||||||
|
import torch
|
||||||
from huggingface_hub import model_info, snapshot_download
|
from huggingface_hub import model_info, snapshot_download
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
|
import diffusers
|
||||||
|
|
||||||
from ..configuration_utils import ConfigMixin
|
from ..configuration_utils import ConfigMixin
|
||||||
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
|
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
|
||||||
from ..schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
|
from ..schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
|
||||||
|
|||||||
@ -1,2 +1 @@
|
|||||||
# flake8: noqa
|
|
||||||
from .pipeline_pndm import PNDMPipeline
|
from .pipeline_pndm import PNDMPipeline
|
||||||
|
|||||||
@ -16,9 +16,8 @@
|
|||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
|
||||||
|
|
||||||
import PIL
|
import PIL
|
||||||
|
import torch
|
||||||
|
|
||||||
from ...models import UNet2DModel
|
from ...models import UNet2DModel
|
||||||
from ...schedulers import RePaintScheduler
|
from ...schedulers import RePaintScheduler
|
||||||
|
|||||||
@ -1,2 +1 @@
|
|||||||
# flake8: noqa
|
|
||||||
from .pipeline_score_sde_ve import ScoreSdeVePipeline
|
from .pipeline_score_sde_ve import ScoreSdeVePipeline
|
||||||
|
|||||||
@ -2,7 +2,6 @@ from dataclasses import dataclass
|
|||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import PIL
|
import PIL
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
|||||||
@ -18,9 +18,10 @@ import os
|
|||||||
import re
|
import re
|
||||||
import tempfile
|
import tempfile
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
import torch
|
||||||
|
from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextModel, CLIPTokenizer, CLIPVisionConfig
|
||||||
|
|
||||||
from diffusers import (
|
from diffusers import (
|
||||||
AutoencoderKL,
|
AutoencoderKL,
|
||||||
DDIMScheduler,
|
DDIMScheduler,
|
||||||
@ -37,7 +38,6 @@ from diffusers import (
|
|||||||
from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
|
from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
|
||||||
from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder, PaintByExamplePipeline
|
from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder, PaintByExamplePipeline
|
||||||
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
|
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
|
||||||
from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextModel, CLIPTokenizer, CLIPVisionConfig
|
|
||||||
|
|
||||||
from ...utils import is_omegaconf_available, is_safetensors_available
|
from ...utils import is_omegaconf_available, is_safetensors_available
|
||||||
from ...utils.import_utils import BACKENDS_MAPPING
|
from ...utils.import_utils import BACKENDS_MAPPING
|
||||||
|
|||||||
@ -16,13 +16,13 @@ import inspect
|
|||||||
from typing import Callable, List, Optional, Union
|
from typing import Callable, List, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
|
||||||
|
|
||||||
import PIL
|
import PIL
|
||||||
from diffusers.utils import is_accelerate_available
|
import torch
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||||
|
|
||||||
|
from diffusers.utils import is_accelerate_available
|
||||||
|
|
||||||
from ...configuration_utils import FrozenDict
|
from ...configuration_utils import FrozenDict
|
||||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||||
from ...schedulers import DDIMScheduler
|
from ...schedulers import DDIMScheduler
|
||||||
|
|||||||
@ -16,10 +16,9 @@ import warnings
|
|||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
|
import numpy as np
|
||||||
from flax.core.frozen_dict import FrozenDict
|
from flax.core.frozen_dict import FrozenDict
|
||||||
from flax.jax_utils import unreplicate
|
from flax.jax_utils import unreplicate
|
||||||
from flax.training.common_utils import shard
|
from flax.training.common_utils import shard
|
||||||
|
|||||||
@ -16,10 +16,9 @@ import warnings
|
|||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
|
import numpy as np
|
||||||
from flax.core.frozen_dict import FrozenDict
|
from flax.core.frozen_dict import FrozenDict
|
||||||
from flax.jax_utils import unreplicate
|
from flax.jax_utils import unreplicate
|
||||||
from flax.training.common_utils import shard
|
from flax.training.common_utils import shard
|
||||||
|
|||||||
@ -16,10 +16,9 @@ import warnings
|
|||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
|
import numpy as np
|
||||||
from flax.core.frozen_dict import FrozenDict
|
from flax.core.frozen_dict import FrozenDict
|
||||||
from flax.jax_utils import unreplicate
|
from flax.jax_utils import unreplicate
|
||||||
from flax.training.common_utils import shard
|
from flax.training.common_utils import shard
|
||||||
|
|||||||
@ -17,7 +17,6 @@ from typing import Callable, List, Optional, Union
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transformers import CLIPFeatureExtractor, CLIPTokenizer
|
from transformers import CLIPFeatureExtractor, CLIPTokenizer
|
||||||
|
|
||||||
from ...configuration_utils import FrozenDict
|
from ...configuration_utils import FrozenDict
|
||||||
|
|||||||
@ -16,9 +16,8 @@ import inspect
|
|||||||
from typing import Callable, List, Optional, Union
|
from typing import Callable, List, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
|
||||||
|
|
||||||
import PIL
|
import PIL
|
||||||
|
import torch
|
||||||
from transformers import CLIPFeatureExtractor, CLIPTokenizer
|
from transformers import CLIPFeatureExtractor, CLIPTokenizer
|
||||||
|
|
||||||
from ...configuration_utils import FrozenDict
|
from ...configuration_utils import FrozenDict
|
||||||
|
|||||||
@ -16,9 +16,8 @@ import inspect
|
|||||||
from typing import Callable, List, Optional, Union
|
from typing import Callable, List, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
|
||||||
|
|
||||||
import PIL
|
import PIL
|
||||||
|
import torch
|
||||||
from transformers import CLIPFeatureExtractor, CLIPTokenizer
|
from transformers import CLIPFeatureExtractor, CLIPTokenizer
|
||||||
|
|
||||||
from ...configuration_utils import FrozenDict
|
from ...configuration_utils import FrozenDict
|
||||||
|
|||||||
@ -2,9 +2,8 @@ import inspect
|
|||||||
from typing import Callable, List, Optional, Union
|
from typing import Callable, List, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
|
||||||
|
|
||||||
import PIL
|
import PIL
|
||||||
|
import torch
|
||||||
from transformers import CLIPFeatureExtractor, CLIPTokenizer
|
from transformers import CLIPFeatureExtractor, CLIPTokenizer
|
||||||
|
|
||||||
from ...configuration_utils import FrozenDict
|
from ...configuration_utils import FrozenDict
|
||||||
|
|||||||
@ -16,7 +16,6 @@ import inspect
|
|||||||
from typing import Any, Callable, Dict, List, Optional, Union
|
from typing import Any, Callable, Dict, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||||
|
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user