Compare commits
19 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 6aaa4aac24 | |||
| a6375d4101 | |||
| 8e1b7a084a | |||
| 6946facf69 | |||
| 130dd936bb | |||
| a899e42fc7 | |||
| f96e4a16ad | |||
| 9c6e9684a2 | |||
| 2e4841ef1e | |||
| 8bea943714 | |||
| 614d0c64e9 | |||
| b1a2c0d577 | |||
| 06ee907b73 | |||
| 896fb6d8d7 | |||
| 7f51f286a5 | |||
| 829f6defa4 | |||
| 24bdf4b215 | |||
| 95e0c3757d | |||
| 6cf0be5d3d |
@@ -54,7 +54,7 @@ jobs:
|
||||
else
|
||||
# e.g. refs/tags/v0.28.1 -> v0.28.1
|
||||
echo "CHECKOUT_REF=${{ github.ref }}" >> $GITHUB_ENV
|
||||
echo "PATH_IN_REPO=${${{ github.ref }}#refs/tags/}" >> $GITHUB_ENV
|
||||
echo "PATH_IN_REPO=$(echo ${{ github.ref }} | sed 's/^refs\/tags\///')" >> $GITHUB_ENV
|
||||
fi
|
||||
- name: Print env vars
|
||||
run: |
|
||||
|
||||
@@ -42,7 +42,7 @@ RUN python3.10 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
|
||||
huggingface-hub \
|
||||
Jinja2 \
|
||||
librosa \
|
||||
numpy \
|
||||
numpy==1.26.4 \
|
||||
scipy \
|
||||
tensorboard \
|
||||
transformers \
|
||||
|
||||
@@ -40,7 +40,7 @@ RUN python3 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
|
||||
huggingface-hub \
|
||||
Jinja2 \
|
||||
librosa \
|
||||
numpy \
|
||||
numpy==1.26.4 \
|
||||
scipy \
|
||||
tensorboard \
|
||||
transformers
|
||||
|
||||
@@ -41,8 +41,8 @@ RUN python3 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
|
||||
hf-doc-builder \
|
||||
huggingface-hub \
|
||||
Jinja2 \
|
||||
librosa \
|
||||
numpy \
|
||||
librosa \
|
||||
numpy==1.26.4 \
|
||||
scipy \
|
||||
tensorboard \
|
||||
transformers
|
||||
|
||||
@@ -40,7 +40,7 @@ RUN python3 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
|
||||
huggingface-hub \
|
||||
Jinja2 \
|
||||
librosa \
|
||||
numpy \
|
||||
numpy==1.26.4 \
|
||||
scipy \
|
||||
tensorboard \
|
||||
transformers
|
||||
|
||||
@@ -40,7 +40,7 @@ RUN python3.10 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
|
||||
huggingface-hub \
|
||||
Jinja2 \
|
||||
librosa \
|
||||
numpy \
|
||||
numpy==1.26.4 \
|
||||
scipy \
|
||||
tensorboard \
|
||||
transformers
|
||||
|
||||
@@ -39,7 +39,7 @@ RUN python3.10 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
|
||||
huggingface-hub \
|
||||
Jinja2 \
|
||||
librosa \
|
||||
numpy \
|
||||
numpy==1.26.4 \
|
||||
scipy \
|
||||
tensorboard \
|
||||
transformers
|
||||
|
||||
@@ -40,7 +40,7 @@ RUN python3.10 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
|
||||
huggingface-hub \
|
||||
Jinja2 \
|
||||
librosa \
|
||||
numpy \
|
||||
numpy==1.26.4 \
|
||||
scipy \
|
||||
tensorboard \
|
||||
transformers matplotlib
|
||||
|
||||
@@ -39,7 +39,7 @@ RUN python3.10 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
|
||||
huggingface-hub \
|
||||
Jinja2 \
|
||||
librosa \
|
||||
numpy \
|
||||
numpy==1.26.4 \
|
||||
scipy \
|
||||
tensorboard \
|
||||
transformers \
|
||||
|
||||
@@ -39,7 +39,7 @@ RUN python3.10 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
|
||||
huggingface-hub \
|
||||
Jinja2 \
|
||||
librosa \
|
||||
numpy \
|
||||
numpy==1.26.4 \
|
||||
scipy \
|
||||
tensorboard \
|
||||
transformers \
|
||||
|
||||
@@ -35,6 +35,7 @@ The [`~loaders.FromSingleFileMixin.from_single_file`] method allows you to load:
|
||||
- [`StableDiffusionXLInstructPix2PixPipeline`]
|
||||
- [`StableDiffusionXLControlNetPipeline`]
|
||||
- [`StableDiffusionXLKDiffusionPipeline`]
|
||||
- [`StableDiffusion3Pipeline`]
|
||||
- [`LatentConsistencyModelPipeline`]
|
||||
- [`LatentConsistencyModelImg2ImgPipeline`]
|
||||
- [`StableDiffusionControlNetXSPipeline`]
|
||||
@@ -49,6 +50,7 @@ The [`~loaders.FromSingleFileMixin.from_single_file`] method allows you to load:
|
||||
- [`StableCascadeUNet`]
|
||||
- [`AutoencoderKL`]
|
||||
- [`ControlNetModel`]
|
||||
- [`SD3Transformer2DModel`]
|
||||
|
||||
## FromSingleFileMixin
|
||||
|
||||
|
||||
@@ -21,9 +21,9 @@ The abstract from the paper is:
|
||||
|
||||
## Usage Example
|
||||
|
||||
_As the model is gated, before using it with diffusers you first need to go to the [Stable Diffusion 3 Medium Hugging Face page](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers), fill in the form and accept the gate. Once you are in, you need to login so that your system knows you’ve accepted the gate._
|
||||
_As the model is gated, before using it with diffusers you first need to go to the [Stable Diffusion 3 Medium Hugging Face page](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers), fill in the form and accept the gate. Once you are in, you need to login so that your system knows you’ve accepted the gate._
|
||||
|
||||
Use the command below to log in:
|
||||
Use the command below to log in:
|
||||
|
||||
```bash
|
||||
huggingface-cli login
|
||||
@@ -186,7 +186,7 @@ pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgrap
|
||||
pipe.vae.decode = torch.compile(pipe.vae.decode, mode="max-autotune", fullgraph=True)
|
||||
|
||||
# Warm Up
|
||||
prompt = "a photo of a cat holding a sign that says hello world",
|
||||
prompt = "a photo of a cat holding a sign that says hello world"
|
||||
for _ in range(3):
|
||||
_ = pipe(prompt=prompt, generator=torch.manual_seed(1))
|
||||
|
||||
@@ -211,17 +211,38 @@ model = SD3Transformer2DModel.from_single_file("https://huggingface.co/stability
|
||||
|
||||
## Loading the single checkpoint for the `StableDiffusion3Pipeline`
|
||||
|
||||
```python
|
||||
from diffusers import StableDiffusion3Pipeline
|
||||
from transformers import T5EncoderModel
|
||||
### Loading the single file checkpoint without T5
|
||||
|
||||
text_encoder_3 = T5EncoderModel.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", subfolder="text_encoder_3", torch_dtype=torch.float16)
|
||||
pipe = StableDiffusion3Pipeline.from_single_file("https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/sd3_medium_incl_clips.safetensors", torch_dtype=torch.float16, text_encoder_3=text_encoder_3)
|
||||
```python
|
||||
import torch
|
||||
from diffusers import StableDiffusion3Pipeline
|
||||
|
||||
pipe = StableDiffusion3Pipeline.from_single_file(
|
||||
"https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/sd3_medium_incl_clips.safetensors",
|
||||
torch_dtype=torch.float16,
|
||||
text_encoder_3=None
|
||||
)
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
image = pipe("a picture of a cat holding a sign that says hello world").images[0]
|
||||
image.save('sd3-single-file.png')
|
||||
```
|
||||
|
||||
<Tip>
|
||||
`from_single_file` support for the `fp8` version of the checkpoints is coming soon. Watch this space.
|
||||
</Tip>
|
||||
### Loading the single file checkpoint without T5
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import StableDiffusion3Pipeline
|
||||
|
||||
pipe = StableDiffusion3Pipeline.from_single_file(
|
||||
"https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/sd3_medium_incl_clips_t5xxlfp8.safetensors",
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
image = pipe("a picture of a cat holding a sign that says hello world").images[0]
|
||||
image.save('sd3-single-file-t5-fp8.png')
|
||||
```
|
||||
|
||||
## StableDiffusion3Pipeline
|
||||
|
||||
|
||||
@@ -71,7 +71,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.29.0.dev0")
|
||||
check_min_version("0.30.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -78,7 +78,7 @@ from diffusers.utils.torch_utils import is_compiled_module
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.29.0.dev0")
|
||||
check_min_version("0.30.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -43,7 +43,7 @@ from diffusers.utils import BaseOutput, check_min_version
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.29.0.dev0")
|
||||
check_min_version("0.30.0.dev0")
|
||||
|
||||
|
||||
class MarigoldDepthOutput(BaseOutput):
|
||||
|
||||
@@ -73,7 +73,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.29.0.dev0")
|
||||
check_min_version("0.30.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -66,7 +66,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.29.0.dev0")
|
||||
check_min_version("0.30.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -79,7 +79,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.29.0.dev0")
|
||||
check_min_version("0.30.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -72,7 +72,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.29.0.dev0")
|
||||
check_min_version("0.30.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -78,7 +78,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.29.0.dev0")
|
||||
check_min_version("0.30.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.29.0.dev0")
|
||||
check_min_version("0.30.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.29.0.dev0")
|
||||
check_min_version("0.30.0.dev0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -61,7 +61,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.29.0.dev0")
|
||||
check_min_version("0.30.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
if is_torch_npu_available():
|
||||
|
||||
@@ -63,7 +63,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.29.0.dev0")
|
||||
check_min_version("0.30.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -4,4 +4,5 @@ transformers>=4.41.2
|
||||
ftfy
|
||||
tensorboard
|
||||
Jinja2
|
||||
peft== 0.11.1
|
||||
peft==0.11.1
|
||||
sentencepiece
|
||||
@@ -63,7 +63,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.29.0.dev0")
|
||||
check_min_version("0.30.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -35,7 +35,7 @@ from diffusers.utils import check_min_version
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.29.0.dev0")
|
||||
check_min_version("0.30.0.dev0")
|
||||
|
||||
# Cache compiled models across invocations of this script.
|
||||
cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache"))
|
||||
|
||||
@@ -70,7 +70,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.29.0.dev0")
|
||||
check_min_version("0.30.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -67,7 +67,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.28.0.dev0")
|
||||
check_min_version("0.30.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -1316,6 +1316,9 @@ def main(args):
|
||||
# Clear the memory here
|
||||
if not train_dataset.custom_instance_prompts:
|
||||
del tokenizers, text_encoders
|
||||
# Explicitly delete the objects as well, otherwise only the lists are deleted and the original references remain, preventing garbage collection
|
||||
del tokenizer_one, tokenizer_two, tokenizer_three
|
||||
del text_encoder_one, text_encoder_two, text_encoder_three
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
@@ -1462,7 +1465,18 @@ def main(args):
|
||||
bsz = model_input.shape[0]
|
||||
|
||||
# Sample a random timestep for each image
|
||||
indices = torch.randint(0, noise_scheduler_copy.config.num_train_timesteps, (bsz,))
|
||||
# for weighting schemes where we sample timesteps non-uniformly
|
||||
if args.weighting_scheme == "logit_normal":
|
||||
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
|
||||
u = torch.normal(mean=args.logit_mean, std=args.logit_std, size=(bsz,), device="cpu")
|
||||
u = torch.nn.functional.sigmoid(u)
|
||||
elif args.weighting_scheme == "mode":
|
||||
u = torch.rand(size=(bsz,), device="cpu")
|
||||
u = 1 - u - args.mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
|
||||
else:
|
||||
u = torch.rand(size=(bsz,), device="cpu")
|
||||
|
||||
indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()
|
||||
timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)
|
||||
|
||||
# Add noise according to flow matching.
|
||||
@@ -1483,16 +1497,15 @@ def main(args):
|
||||
model_pred = model_pred * (-sigmas) + noisy_model_input
|
||||
|
||||
# TODO (kashif, sayakpaul): weighting sceme needs to be experimented with :)
|
||||
# these weighting schemes use a uniform timestep sampling
|
||||
# and instead post-weight the loss
|
||||
if args.weighting_scheme == "sigma_sqrt":
|
||||
weighting = (sigmas**-2.0).float()
|
||||
elif args.weighting_scheme == "logit_normal":
|
||||
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
|
||||
u = torch.normal(mean=args.logit_mean, std=args.logit_std, size=(bsz,), device=accelerator.device)
|
||||
weighting = torch.nn.functional.sigmoid(u)
|
||||
elif args.weighting_scheme == "mode":
|
||||
# See sec 3.1 in the SD3 paper (20).
|
||||
u = torch.rand(size=(bsz,), device=accelerator.device)
|
||||
weighting = 1 - u - args.mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
|
||||
elif args.weighting_scheme == "cosmap":
|
||||
bot = 1 - 2 * sigmas + 2 * sigmas**2
|
||||
weighting = 2 / (math.pi * bot)
|
||||
else:
|
||||
weighting = torch.ones_like(sigmas)
|
||||
|
||||
# simplified flow matching aka 0-rectified flow matching loss
|
||||
# target = model_input - noise
|
||||
|
||||
@@ -78,7 +78,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.29.0.dev0")
|
||||
check_min_version("0.30.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -63,7 +63,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.28.0.dev0")
|
||||
check_min_version("0.30.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -1347,6 +1347,9 @@ def main(args):
|
||||
# Clear the memory here
|
||||
if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
|
||||
del tokenizers, text_encoders
|
||||
# Explicitly delete the objects as well, otherwise only the lists are deleted and the original references remain, preventing garbage collection
|
||||
del tokenizer_one, tokenizer_two, tokenizer_three
|
||||
del text_encoder_one, text_encoder_two, text_encoder_three
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
@@ -1526,7 +1529,18 @@ def main(args):
|
||||
bsz = model_input.shape[0]
|
||||
|
||||
# Sample a random timestep for each image
|
||||
indices = torch.randint(0, noise_scheduler_copy.config.num_train_timesteps, (bsz,))
|
||||
# for weighting schemes where we sample timesteps non-uniformly
|
||||
if args.weighting_scheme == "logit_normal":
|
||||
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
|
||||
u = torch.normal(mean=args.logit_mean, std=args.logit_std, size=(bsz,), device="cpu")
|
||||
u = torch.nn.functional.sigmoid(u)
|
||||
elif args.weighting_scheme == "mode":
|
||||
u = torch.rand(size=(bsz,), device="cpu")
|
||||
u = 1 - u - args.mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
|
||||
else:
|
||||
u = torch.rand(size=(bsz,), device="cpu")
|
||||
|
||||
indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()
|
||||
timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)
|
||||
|
||||
# Add noise according to flow matching.
|
||||
@@ -1560,18 +1574,15 @@ def main(args):
|
||||
# Follow: Section 5 of https://arxiv.org/abs/2206.00364.
|
||||
# Preconditioning of the model outputs.
|
||||
model_pred = model_pred * (-sigmas) + noisy_model_input
|
||||
|
||||
# TODO (kashif, sayakpaul): weighting sceme needs to be experimented with :)
|
||||
# these weighting schemes use a uniform timestep sampling
|
||||
# and instead post-weight the loss
|
||||
if args.weighting_scheme == "sigma_sqrt":
|
||||
weighting = (sigmas**-2.0).float()
|
||||
elif args.weighting_scheme == "logit_normal":
|
||||
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
|
||||
u = torch.normal(mean=args.logit_mean, std=args.logit_std, size=(bsz,), device=accelerator.device)
|
||||
weighting = torch.nn.functional.sigmoid(u)
|
||||
elif args.weighting_scheme == "mode":
|
||||
# See sec 3.1 in the SD3 paper (20).
|
||||
u = torch.rand(size=(bsz,), device=accelerator.device)
|
||||
weighting = 1 - u - args.mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
|
||||
elif args.weighting_scheme == "cosmap":
|
||||
bot = 1 - 2 * sigmas + 2 * sigmas**2
|
||||
weighting = 2 / (math.pi * bot)
|
||||
else:
|
||||
weighting = torch.ones_like(sigmas)
|
||||
|
||||
# simplified flow matching aka 0-rectified flow matching loss
|
||||
# target = model_input - noise
|
||||
|
||||
@@ -57,7 +57,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.29.0.dev0")
|
||||
check_min_version("0.30.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.29.0.dev0")
|
||||
check_min_version("0.30.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -52,7 +52,7 @@ if is_wandb_available():
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.29.0.dev0")
|
||||
check_min_version("0.30.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -46,7 +46,7 @@ from diffusers.utils import check_min_version, is_wandb_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.29.0.dev0")
|
||||
check_min_version("0.30.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -46,7 +46,7 @@ from diffusers.utils import check_min_version, is_wandb_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.29.0.dev0")
|
||||
check_min_version("0.30.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -51,7 +51,7 @@ if is_wandb_available():
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.29.0.dev0")
|
||||
check_min_version("0.30.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.29.0.dev0")
|
||||
check_min_version("0.30.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -57,7 +57,7 @@ if is_wandb_available():
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.29.0.dev0")
|
||||
check_min_version("0.30.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -49,7 +49,7 @@ from diffusers.utils import check_min_version
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.29.0.dev0")
|
||||
check_min_version("0.30.0.dev0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -56,7 +56,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.29.0.dev0")
|
||||
check_min_version("0.30.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -68,7 +68,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.29.0.dev0")
|
||||
check_min_version("0.30.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
if is_torch_npu_available():
|
||||
|
||||
@@ -55,7 +55,7 @@ from diffusers.utils.torch_utils import is_compiled_module
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.29.0.dev0")
|
||||
check_min_version("0.30.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
if is_torch_npu_available():
|
||||
|
||||
@@ -81,7 +81,7 @@ else:
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.29.0.dev0")
|
||||
check_min_version("0.30.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -56,7 +56,7 @@ else:
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.29.0.dev0")
|
||||
check_min_version("0.30.0.dev0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -76,7 +76,7 @@ else:
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.29.0.dev0")
|
||||
check_min_version("0.30.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.29.0.dev0")
|
||||
check_min_version("0.30.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -50,7 +50,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.29.0.dev0")
|
||||
check_min_version("0.30.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -50,7 +50,7 @@ if is_wandb_available():
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.29.0.dev0")
|
||||
check_min_version("0.30.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -51,7 +51,7 @@ if is_wandb_available():
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.29.0.dev0")
|
||||
check_min_version("0.30.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -95,7 +95,7 @@ from setuptools import Command, find_packages, setup
|
||||
# 2. once modified, run: `make deps_table_update` to update src/diffusers/dependency_versions_table.py
|
||||
_deps = [
|
||||
"Pillow", # keep the PIL.Image.Resampling deprecation away
|
||||
"accelerate>=0.29.3",
|
||||
"accelerate>=0.31.0",
|
||||
"compel==0.1.8",
|
||||
"datasets",
|
||||
"filelock",
|
||||
@@ -132,7 +132,7 @@ _deps = [
|
||||
"tensorboard",
|
||||
"torch>=1.4",
|
||||
"torchvision",
|
||||
"transformers>=4.25.1",
|
||||
"transformers>=4.41.2",
|
||||
"urllib3<=2.0.0",
|
||||
"black",
|
||||
]
|
||||
@@ -254,7 +254,7 @@ version_range_max = max(sys.version_info[1], 10) + 1
|
||||
|
||||
setup(
|
||||
name="diffusers",
|
||||
version="0.29.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
version="0.30.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
description="State-of-the-art diffusion in PyTorch and JAX.",
|
||||
long_description=open("README.md", "r", encoding="utf-8").read(),
|
||||
long_description_content_type="text/markdown",
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
__version__ = "0.29.0.dev0"
|
||||
__version__ = "0.30.0.dev0"
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
# 2. run `make deps_table_update`
|
||||
deps = {
|
||||
"Pillow": "Pillow",
|
||||
"accelerate": "accelerate>=0.29.3",
|
||||
"accelerate": "accelerate>=0.31.0",
|
||||
"compel": "compel==0.1.8",
|
||||
"datasets": "datasets",
|
||||
"filelock": "filelock",
|
||||
@@ -40,7 +40,7 @@ deps = {
|
||||
"tensorboard": "tensorboard",
|
||||
"torch": "torch>=1.4",
|
||||
"torchvision": "torchvision",
|
||||
"transformers": "transformers>=4.25.1",
|
||||
"transformers": "transformers>=4.41.2",
|
||||
"urllib3": "urllib3<=2.0.0",
|
||||
"black": "black",
|
||||
}
|
||||
|
||||
@@ -569,7 +569,7 @@ class VaeImageProcessor(ConfigMixin):
|
||||
|
||||
channel = image.shape[1]
|
||||
# don't need any preprocess if the image is latents
|
||||
if channel == 4:
|
||||
if channel == self.vae_latent_channels:
|
||||
return image
|
||||
|
||||
height, width = self.get_default_height_width(image, height, width)
|
||||
@@ -585,7 +585,6 @@ class VaeImageProcessor(ConfigMixin):
|
||||
FutureWarning,
|
||||
)
|
||||
do_normalize = False
|
||||
|
||||
if do_normalize:
|
||||
image = self.normalize(image)
|
||||
|
||||
|
||||
@@ -28,9 +28,11 @@ from .single_file_utils import (
|
||||
_legacy_load_safety_checker,
|
||||
_legacy_load_scheduler,
|
||||
create_diffusers_clip_model_from_ldm,
|
||||
create_diffusers_t5_model_from_checkpoint,
|
||||
fetch_diffusers_config,
|
||||
fetch_original_config,
|
||||
is_clip_model_in_single_file,
|
||||
is_t5_in_single_file,
|
||||
load_single_file_checkpoint,
|
||||
)
|
||||
|
||||
@@ -118,6 +120,16 @@ def load_single_file_sub_model(
|
||||
is_legacy_loading=is_legacy_loading,
|
||||
)
|
||||
|
||||
elif is_transformers_model and is_t5_in_single_file(checkpoint):
|
||||
loaded_sub_model = create_diffusers_t5_model_from_checkpoint(
|
||||
class_obj,
|
||||
checkpoint=checkpoint,
|
||||
config=cached_model_config_path,
|
||||
subfolder=name,
|
||||
torch_dtype=torch_dtype,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
|
||||
elif is_tokenizer and is_legacy_loading:
|
||||
loaded_sub_model = _legacy_load_clip_tokenizer(
|
||||
class_obj, checkpoint=checkpoint, config=cached_model_config_path, local_files_only=local_files_only
|
||||
|
||||
@@ -252,7 +252,6 @@ LDM_CONTROLNET_KEY = "control_model."
|
||||
LDM_CLIP_PREFIX_TO_REMOVE = [
|
||||
"cond_stage_model.transformer.",
|
||||
"conditioner.embedders.0.transformer.",
|
||||
"text_encoders.clip_l.transformer.",
|
||||
]
|
||||
OPEN_CLIP_PREFIX = "conditioner.embedders.0.model."
|
||||
LDM_OPEN_CLIP_TEXT_PROJECTION_DIM = 1024
|
||||
@@ -399,11 +398,14 @@ def is_open_clip_sdxl_model(checkpoint):
|
||||
|
||||
|
||||
def is_open_clip_sd3_model(checkpoint):
|
||||
is_open_clip_sdxl_refiner_model(checkpoint)
|
||||
if CHECKPOINT_KEY_NAMES["open_clip_sd3"] in checkpoint:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def is_open_clip_sdxl_refiner_model(checkpoint):
|
||||
if CHECKPOINT_KEY_NAMES["open_clip_sd3"] in checkpoint:
|
||||
if CHECKPOINT_KEY_NAMES["open_clip_sdxl_refiner"] in checkpoint:
|
||||
return True
|
||||
|
||||
return False
|
||||
@@ -1233,11 +1235,14 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
|
||||
return new_checkpoint
|
||||
|
||||
|
||||
def convert_ldm_clip_checkpoint(checkpoint):
|
||||
def convert_ldm_clip_checkpoint(checkpoint, remove_prefix=None):
|
||||
keys = list(checkpoint.keys())
|
||||
text_model_dict = {}
|
||||
|
||||
remove_prefixes = LDM_CLIP_PREFIX_TO_REMOVE
|
||||
remove_prefixes = []
|
||||
remove_prefixes.extend(LDM_CLIP_PREFIX_TO_REMOVE)
|
||||
if remove_prefix:
|
||||
remove_prefixes.append(remove_prefix)
|
||||
|
||||
for key in keys:
|
||||
for prefix in remove_prefixes:
|
||||
@@ -1376,6 +1381,13 @@ def create_diffusers_clip_model_from_ldm(
|
||||
):
|
||||
diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint)
|
||||
|
||||
elif (
|
||||
is_clip_sd3_model(checkpoint)
|
||||
and checkpoint[CHECKPOINT_KEY_NAMES["clip_sd3"]].shape[-1] == position_embedding_dim
|
||||
):
|
||||
diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint, "text_encoders.clip_l.transformer.")
|
||||
diffusers_format_checkpoint["text_projection.weight"] = torch.eye(position_embedding_dim)
|
||||
|
||||
elif is_open_clip_model(checkpoint):
|
||||
prefix = "cond_stage_model.model."
|
||||
diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix)
|
||||
@@ -1391,9 +1403,11 @@ def create_diffusers_clip_model_from_ldm(
|
||||
prefix = "conditioner.embedders.0.model."
|
||||
diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix)
|
||||
|
||||
elif is_open_clip_sd3_model(checkpoint):
|
||||
prefix = "text_encoders.clip_g.transformer."
|
||||
diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix)
|
||||
elif (
|
||||
is_open_clip_sd3_model(checkpoint)
|
||||
and checkpoint[CHECKPOINT_KEY_NAMES["open_clip_sd3"]].shape[-1] == position_embedding_dim
|
||||
):
|
||||
diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint, "text_encoders.clip_g.transformer.")
|
||||
|
||||
else:
|
||||
raise ValueError("The provided checkpoint does not seem to contain a valid CLIP model.")
|
||||
@@ -1755,7 +1769,7 @@ def convert_sd3_t5_checkpoint_to_diffusers(checkpoint):
|
||||
keys = list(checkpoint.keys())
|
||||
text_model_dict = {}
|
||||
|
||||
remove_prefixes = ["text_encoders.t5xxl.transformer.encoder."]
|
||||
remove_prefixes = ["text_encoders.t5xxl.transformer."]
|
||||
|
||||
for key in keys:
|
||||
for prefix in remove_prefixes:
|
||||
|
||||
@@ -1132,9 +1132,7 @@ class JointAttnProcessor2_0:
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
hidden_states = hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
@@ -1406,7 +1404,6 @@ class XFormersAttnProcessor:
|
||||
|
||||
|
||||
class AttnProcessorNPU:
|
||||
|
||||
r"""
|
||||
Processor for implementing flash attention using torch_npu. Torch_npu supports only fp16 and bf16 data types. If
|
||||
fp32 is used, F.scaled_dot_product_attention will be used for computation, but the acceleration effect on NPU is
|
||||
|
||||
@@ -282,9 +282,10 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
logger.warning(
|
||||
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
height, width = hidden_states.shape[-2:]
|
||||
|
||||
|
||||
@@ -27,6 +27,7 @@ from .controlnet import (
|
||||
StableDiffusionXLControlNetPipeline,
|
||||
)
|
||||
from .deepfloyd_if import IFImg2ImgPipeline, IFInpaintingPipeline, IFPipeline
|
||||
from .hunyuandit import HunyuanDiTPipeline
|
||||
from .kandinsky import (
|
||||
KandinskyCombinedPipeline,
|
||||
KandinskyImg2ImgCombinedPipeline,
|
||||
@@ -52,6 +53,10 @@ from .stable_diffusion import (
|
||||
StableDiffusionInpaintPipeline,
|
||||
StableDiffusionPipeline,
|
||||
)
|
||||
from .stable_diffusion_3 import (
|
||||
StableDiffusion3Img2ImgPipeline,
|
||||
StableDiffusion3Pipeline,
|
||||
)
|
||||
from .stable_diffusion_xl import (
|
||||
StableDiffusionXLImg2ImgPipeline,
|
||||
StableDiffusionXLInpaintPipeline,
|
||||
@@ -64,7 +69,9 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
|
||||
[
|
||||
("stable-diffusion", StableDiffusionPipeline),
|
||||
("stable-diffusion-xl", StableDiffusionXLPipeline),
|
||||
("stable-diffusion-3", StableDiffusion3Pipeline),
|
||||
("if", IFPipeline),
|
||||
("hunyuan", HunyuanDiTPipeline),
|
||||
("kandinsky", KandinskyCombinedPipeline),
|
||||
("kandinsky22", KandinskyV22CombinedPipeline),
|
||||
("kandinsky3", Kandinsky3Pipeline),
|
||||
@@ -82,6 +89,7 @@ AUTO_IMAGE2IMAGE_PIPELINES_MAPPING = OrderedDict(
|
||||
[
|
||||
("stable-diffusion", StableDiffusionImg2ImgPipeline),
|
||||
("stable-diffusion-xl", StableDiffusionXLImg2ImgPipeline),
|
||||
("stable-diffusion-3", StableDiffusion3Img2ImgPipeline),
|
||||
("if", IFImg2ImgPipeline),
|
||||
("kandinsky", KandinskyImg2ImgCombinedPipeline),
|
||||
("kandinsky22", KandinskyV22Img2ImgCombinedPipeline),
|
||||
|
||||
@@ -118,129 +118,6 @@ def retrieve_latents(
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.prepare_mask_and_masked_image
|
||||
def prepare_mask_and_masked_image(image, mask, height, width, return_image=False):
|
||||
"""
|
||||
Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be
|
||||
converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the
|
||||
``image`` and ``1`` for the ``mask``.
|
||||
|
||||
The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be
|
||||
binarized (``mask > 0.5``) and cast to ``torch.float32`` too.
|
||||
|
||||
Args:
|
||||
image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint.
|
||||
It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width``
|
||||
``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``.
|
||||
mask (_type_): The mask to apply to the image, i.e. regions to inpaint.
|
||||
It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width``
|
||||
``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``.
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask
|
||||
should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions.
|
||||
TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not
|
||||
(ot the other way around).
|
||||
|
||||
Returns:
|
||||
tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4
|
||||
dimensions: ``batch x channels x height x width``.
|
||||
"""
|
||||
deprecation_message = "The prepare_mask_and_masked_image method is deprecated and will be removed in a future version. Please use VaeImageProcessor.preprocess instead"
|
||||
deprecate(
|
||||
"prepare_mask_and_masked_image",
|
||||
"0.30.0",
|
||||
deprecation_message,
|
||||
)
|
||||
if image is None:
|
||||
raise ValueError("`image` input cannot be undefined.")
|
||||
|
||||
if mask is None:
|
||||
raise ValueError("`mask_image` input cannot be undefined.")
|
||||
|
||||
if isinstance(image, torch.Tensor):
|
||||
if not isinstance(mask, torch.Tensor):
|
||||
raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not")
|
||||
|
||||
# Batch single image
|
||||
if image.ndim == 3:
|
||||
assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)"
|
||||
image = image.unsqueeze(0)
|
||||
|
||||
# Batch and add channel dim for single mask
|
||||
if mask.ndim == 2:
|
||||
mask = mask.unsqueeze(0).unsqueeze(0)
|
||||
|
||||
# Batch single mask or add channel dim
|
||||
if mask.ndim == 3:
|
||||
# Single batched mask, no channel dim or single mask not batched but channel dim
|
||||
if mask.shape[0] == 1:
|
||||
mask = mask.unsqueeze(0)
|
||||
|
||||
# Batched masks no channel dim
|
||||
else:
|
||||
mask = mask.unsqueeze(1)
|
||||
|
||||
assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions"
|
||||
assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions"
|
||||
assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size"
|
||||
|
||||
# Check image is in [-1, 1]
|
||||
if image.min() < -1 or image.max() > 1:
|
||||
raise ValueError("Image should be in [-1, 1] range")
|
||||
|
||||
# Check mask is in [0, 1]
|
||||
if mask.min() < 0 or mask.max() > 1:
|
||||
raise ValueError("Mask should be in [0, 1] range")
|
||||
|
||||
# Binarize mask
|
||||
mask[mask < 0.5] = 0
|
||||
mask[mask >= 0.5] = 1
|
||||
|
||||
# Image as float32
|
||||
image = image.to(dtype=torch.float32)
|
||||
elif isinstance(mask, torch.Tensor):
|
||||
raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not")
|
||||
else:
|
||||
# preprocess image
|
||||
if isinstance(image, (PIL.Image.Image, np.ndarray)):
|
||||
image = [image]
|
||||
if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
|
||||
# resize all images w.r.t passed height an width
|
||||
image = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image]
|
||||
image = [np.array(i.convert("RGB"))[None, :] for i in image]
|
||||
image = np.concatenate(image, axis=0)
|
||||
elif isinstance(image, list) and isinstance(image[0], np.ndarray):
|
||||
image = np.concatenate([i[None, :] for i in image], axis=0)
|
||||
|
||||
image = image.transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
|
||||
|
||||
# preprocess mask
|
||||
if isinstance(mask, (PIL.Image.Image, np.ndarray)):
|
||||
mask = [mask]
|
||||
|
||||
if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image):
|
||||
mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask]
|
||||
mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0)
|
||||
mask = mask.astype(np.float32) / 255.0
|
||||
elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
|
||||
mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
|
||||
|
||||
mask[mask < 0.5] = 0
|
||||
mask[mask >= 0.5] = 1
|
||||
mask = torch.from_numpy(mask)
|
||||
|
||||
masked_image = image * (mask < 0.5)
|
||||
|
||||
# n.b. ensure backwards compatibility as old function does not return image
|
||||
if return_image:
|
||||
return mask, masked_image, image
|
||||
|
||||
return mask, masked_image
|
||||
|
||||
|
||||
class StableDiffusionControlNetInpaintPipeline(
|
||||
DiffusionPipeline,
|
||||
StableDiffusionMixin,
|
||||
|
||||
@@ -15,7 +15,6 @@
|
||||
import inspect
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
from packaging import version
|
||||
@@ -38,128 +37,6 @@ from .safety_checker import StableDiffusionSafetyChecker
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool = False):
|
||||
"""
|
||||
Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be
|
||||
converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the
|
||||
``image`` and ``1`` for the ``mask``.
|
||||
|
||||
The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be
|
||||
binarized (``mask > 0.5``) and cast to ``torch.float32`` too.
|
||||
|
||||
Args:
|
||||
image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint.
|
||||
It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width``
|
||||
``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``.
|
||||
mask (_type_): The mask to apply to the image, i.e. regions to inpaint.
|
||||
It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width``
|
||||
``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``.
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask
|
||||
should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions.
|
||||
TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not
|
||||
(ot the other way around).
|
||||
|
||||
Returns:
|
||||
tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4
|
||||
dimensions: ``batch x channels x height x width``.
|
||||
"""
|
||||
deprecation_message = "The prepare_mask_and_masked_image method is deprecated and will be removed in a future version. Please use VaeImageProcessor.preprocess instead"
|
||||
deprecate(
|
||||
"prepare_mask_and_masked_image",
|
||||
"0.30.0",
|
||||
deprecation_message,
|
||||
)
|
||||
if image is None:
|
||||
raise ValueError("`image` input cannot be undefined.")
|
||||
|
||||
if mask is None:
|
||||
raise ValueError("`mask_image` input cannot be undefined.")
|
||||
|
||||
if isinstance(image, torch.Tensor):
|
||||
if not isinstance(mask, torch.Tensor):
|
||||
raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not")
|
||||
|
||||
# Batch single image
|
||||
if image.ndim == 3:
|
||||
assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)"
|
||||
image = image.unsqueeze(0)
|
||||
|
||||
# Batch and add channel dim for single mask
|
||||
if mask.ndim == 2:
|
||||
mask = mask.unsqueeze(0).unsqueeze(0)
|
||||
|
||||
# Batch single mask or add channel dim
|
||||
if mask.ndim == 3:
|
||||
# Single batched mask, no channel dim or single mask not batched but channel dim
|
||||
if mask.shape[0] == 1:
|
||||
mask = mask.unsqueeze(0)
|
||||
|
||||
# Batched masks no channel dim
|
||||
else:
|
||||
mask = mask.unsqueeze(1)
|
||||
|
||||
assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions"
|
||||
assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions"
|
||||
assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size"
|
||||
|
||||
# Check image is in [-1, 1]
|
||||
if image.min() < -1 or image.max() > 1:
|
||||
raise ValueError("Image should be in [-1, 1] range")
|
||||
|
||||
# Check mask is in [0, 1]
|
||||
if mask.min() < 0 or mask.max() > 1:
|
||||
raise ValueError("Mask should be in [0, 1] range")
|
||||
|
||||
# Binarize mask
|
||||
mask[mask < 0.5] = 0
|
||||
mask[mask >= 0.5] = 1
|
||||
|
||||
# Image as float32
|
||||
image = image.to(dtype=torch.float32)
|
||||
elif isinstance(mask, torch.Tensor):
|
||||
raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not")
|
||||
else:
|
||||
# preprocess image
|
||||
if isinstance(image, (PIL.Image.Image, np.ndarray)):
|
||||
image = [image]
|
||||
if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
|
||||
# resize all images w.r.t passed height an width
|
||||
image = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image]
|
||||
image = [np.array(i.convert("RGB"))[None, :] for i in image]
|
||||
image = np.concatenate(image, axis=0)
|
||||
elif isinstance(image, list) and isinstance(image[0], np.ndarray):
|
||||
image = np.concatenate([i[None, :] for i in image], axis=0)
|
||||
|
||||
image = image.transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
|
||||
|
||||
# preprocess mask
|
||||
if isinstance(mask, (PIL.Image.Image, np.ndarray)):
|
||||
mask = [mask]
|
||||
|
||||
if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image):
|
||||
mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask]
|
||||
mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0)
|
||||
mask = mask.astype(np.float32) / 255.0
|
||||
elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
|
||||
mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
|
||||
|
||||
mask[mask < 0.5] = 0
|
||||
mask[mask >= 0.5] = 1
|
||||
mask = torch.from_numpy(mask)
|
||||
|
||||
masked_image = image * (mask < 0.5)
|
||||
|
||||
# n.b. ensure backwards compatibility as old function does not return image
|
||||
if return_image:
|
||||
return mask, masked_image, image
|
||||
|
||||
return mask, masked_image
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
|
||||
@@ -591,8 +591,6 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline):
|
||||
)
|
||||
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
if image.shape[1] == self.vae.config.latent_channels:
|
||||
init_latents = image
|
||||
|
||||
batch_size = batch_size * num_images_per_prompt
|
||||
if image.shape[1] == self.vae.config.latent_channels:
|
||||
|
||||
@@ -132,124 +132,6 @@ def mask_pil_to_torch(mask, height, width):
|
||||
return mask
|
||||
|
||||
|
||||
def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool = False):
|
||||
"""
|
||||
Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be
|
||||
converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the
|
||||
``image`` and ``1`` for the ``mask``.
|
||||
|
||||
The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be
|
||||
binarized (``mask > 0.5``) and cast to ``torch.float32`` too.
|
||||
|
||||
Args:
|
||||
image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint.
|
||||
It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width``
|
||||
``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``.
|
||||
mask (_type_): The mask to apply to the image, i.e. regions to inpaint.
|
||||
It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width``
|
||||
``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``.
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask
|
||||
should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions.
|
||||
TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not
|
||||
(ot the other way around).
|
||||
|
||||
Returns:
|
||||
tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4
|
||||
dimensions: ``batch x channels x height x width``.
|
||||
"""
|
||||
|
||||
# checkpoint. TOD(Yiyi) - need to clean this up later
|
||||
deprecation_message = "The prepare_mask_and_masked_image method is deprecated and will be removed in a future version. Please use VaeImageProcessor.preprocess instead"
|
||||
deprecate(
|
||||
"prepare_mask_and_masked_image",
|
||||
"0.30.0",
|
||||
deprecation_message,
|
||||
)
|
||||
if image is None:
|
||||
raise ValueError("`image` input cannot be undefined.")
|
||||
|
||||
if mask is None:
|
||||
raise ValueError("`mask_image` input cannot be undefined.")
|
||||
|
||||
if isinstance(image, torch.Tensor):
|
||||
if not isinstance(mask, torch.Tensor):
|
||||
mask = mask_pil_to_torch(mask, height, width)
|
||||
|
||||
if image.ndim == 3:
|
||||
image = image.unsqueeze(0)
|
||||
|
||||
# Batch and add channel dim for single mask
|
||||
if mask.ndim == 2:
|
||||
mask = mask.unsqueeze(0).unsqueeze(0)
|
||||
|
||||
# Batch single mask or add channel dim
|
||||
if mask.ndim == 3:
|
||||
# Single batched mask, no channel dim or single mask not batched but channel dim
|
||||
if mask.shape[0] == 1:
|
||||
mask = mask.unsqueeze(0)
|
||||
|
||||
# Batched masks no channel dim
|
||||
else:
|
||||
mask = mask.unsqueeze(1)
|
||||
|
||||
assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions"
|
||||
# assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions"
|
||||
assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size"
|
||||
|
||||
# Check image is in [-1, 1]
|
||||
# if image.min() < -1 or image.max() > 1:
|
||||
# raise ValueError("Image should be in [-1, 1] range")
|
||||
|
||||
# Check mask is in [0, 1]
|
||||
if mask.min() < 0 or mask.max() > 1:
|
||||
raise ValueError("Mask should be in [0, 1] range")
|
||||
|
||||
# Binarize mask
|
||||
mask[mask < 0.5] = 0
|
||||
mask[mask >= 0.5] = 1
|
||||
|
||||
# Image as float32
|
||||
image = image.to(dtype=torch.float32)
|
||||
elif isinstance(mask, torch.Tensor):
|
||||
raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not")
|
||||
else:
|
||||
# preprocess image
|
||||
if isinstance(image, (PIL.Image.Image, np.ndarray)):
|
||||
image = [image]
|
||||
if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
|
||||
# resize all images w.r.t passed height an width
|
||||
image = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image]
|
||||
image = [np.array(i.convert("RGB"))[None, :] for i in image]
|
||||
image = np.concatenate(image, axis=0)
|
||||
elif isinstance(image, list) and isinstance(image[0], np.ndarray):
|
||||
image = np.concatenate([i[None, :] for i in image], axis=0)
|
||||
|
||||
image = image.transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
|
||||
|
||||
mask = mask_pil_to_torch(mask, height, width)
|
||||
mask[mask < 0.5] = 0
|
||||
mask[mask >= 0.5] = 1
|
||||
|
||||
if image.shape[1] == 4:
|
||||
# images are in latent space and thus can't
|
||||
# be masked set masked_image to None
|
||||
# we assume that the checkpoint is not an inpainting
|
||||
# checkpoint. TOD(Yiyi) - need to clean this up later
|
||||
masked_image = None
|
||||
else:
|
||||
masked_image = image * (mask < 0.5)
|
||||
|
||||
# n.b. ensure backwards compatibility as old function does not return image
|
||||
if return_image:
|
||||
return mask, masked_image, image
|
||||
|
||||
return mask, masked_image
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
|
||||
@@ -114,7 +114,7 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
noise: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
Foward process in flow-matching
|
||||
Forward process in flow-matching
|
||||
|
||||
Args:
|
||||
sample (`torch.FloatTensor`):
|
||||
|
||||
@@ -36,7 +36,6 @@ from diffusers import (
|
||||
StableDiffusionInpaintPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint import prepare_mask_and_masked_image
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
floats_tensor,
|
||||
@@ -1105,530 +1104,3 @@ class StableDiffusionInpaintPipelineNightlyTests(unittest.TestCase):
|
||||
)
|
||||
max_diff = np.abs(expected_image - image).max()
|
||||
assert max_diff < 1e-3
|
||||
|
||||
|
||||
class StableDiffusionInpaintingPrepareMaskAndMaskedImageTests(unittest.TestCase):
|
||||
def test_pil_inputs(self):
|
||||
height, width = 32, 32
|
||||
im = np.random.randint(0, 255, (height, width, 3), dtype=np.uint8)
|
||||
im = Image.fromarray(im)
|
||||
mask = np.random.randint(0, 255, (height, width), dtype=np.uint8) > 127.5
|
||||
mask = Image.fromarray((mask * 255).astype(np.uint8))
|
||||
|
||||
t_mask, t_masked, t_image = prepare_mask_and_masked_image(im, mask, height, width, return_image=True)
|
||||
|
||||
self.assertTrue(isinstance(t_mask, torch.Tensor))
|
||||
self.assertTrue(isinstance(t_masked, torch.Tensor))
|
||||
self.assertTrue(isinstance(t_image, torch.Tensor))
|
||||
|
||||
self.assertEqual(t_mask.ndim, 4)
|
||||
self.assertEqual(t_masked.ndim, 4)
|
||||
self.assertEqual(t_image.ndim, 4)
|
||||
|
||||
self.assertEqual(t_mask.shape, (1, 1, height, width))
|
||||
self.assertEqual(t_masked.shape, (1, 3, height, width))
|
||||
self.assertEqual(t_image.shape, (1, 3, height, width))
|
||||
|
||||
self.assertTrue(t_mask.dtype == torch.float32)
|
||||
self.assertTrue(t_masked.dtype == torch.float32)
|
||||
self.assertTrue(t_image.dtype == torch.float32)
|
||||
|
||||
self.assertTrue(t_mask.min() >= 0.0)
|
||||
self.assertTrue(t_mask.max() <= 1.0)
|
||||
self.assertTrue(t_masked.min() >= -1.0)
|
||||
self.assertTrue(t_masked.min() <= 1.0)
|
||||
self.assertTrue(t_image.min() >= -1.0)
|
||||
self.assertTrue(t_image.min() >= -1.0)
|
||||
|
||||
self.assertTrue(t_mask.sum() > 0.0)
|
||||
|
||||
def test_np_inputs(self):
|
||||
height, width = 32, 32
|
||||
|
||||
im_np = np.random.randint(0, 255, (height, width, 3), dtype=np.uint8)
|
||||
im_pil = Image.fromarray(im_np)
|
||||
mask_np = (
|
||||
np.random.randint(
|
||||
0,
|
||||
255,
|
||||
(
|
||||
height,
|
||||
width,
|
||||
),
|
||||
dtype=np.uint8,
|
||||
)
|
||||
> 127.5
|
||||
)
|
||||
mask_pil = Image.fromarray((mask_np * 255).astype(np.uint8))
|
||||
|
||||
t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image(
|
||||
im_np, mask_np, height, width, return_image=True
|
||||
)
|
||||
t_mask_pil, t_masked_pil, t_image_pil = prepare_mask_and_masked_image(
|
||||
im_pil, mask_pil, height, width, return_image=True
|
||||
)
|
||||
|
||||
self.assertTrue((t_mask_np == t_mask_pil).all())
|
||||
self.assertTrue((t_masked_np == t_masked_pil).all())
|
||||
self.assertTrue((t_image_np == t_image_pil).all())
|
||||
|
||||
def test_torch_3D_2D_inputs(self):
|
||||
height, width = 32, 32
|
||||
|
||||
im_tensor = torch.randint(
|
||||
0,
|
||||
255,
|
||||
(
|
||||
3,
|
||||
height,
|
||||
width,
|
||||
),
|
||||
dtype=torch.uint8,
|
||||
)
|
||||
mask_tensor = (
|
||||
torch.randint(
|
||||
0,
|
||||
255,
|
||||
(
|
||||
height,
|
||||
width,
|
||||
),
|
||||
dtype=torch.uint8,
|
||||
)
|
||||
> 127.5
|
||||
)
|
||||
im_np = im_tensor.numpy().transpose(1, 2, 0)
|
||||
mask_np = mask_tensor.numpy()
|
||||
|
||||
t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image(
|
||||
im_tensor / 127.5 - 1, mask_tensor, height, width, return_image=True
|
||||
)
|
||||
t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image(
|
||||
im_np, mask_np, height, width, return_image=True
|
||||
)
|
||||
|
||||
self.assertTrue((t_mask_tensor == t_mask_np).all())
|
||||
self.assertTrue((t_masked_tensor == t_masked_np).all())
|
||||
self.assertTrue((t_image_tensor == t_image_np).all())
|
||||
|
||||
def test_torch_3D_3D_inputs(self):
|
||||
height, width = 32, 32
|
||||
|
||||
im_tensor = torch.randint(
|
||||
0,
|
||||
255,
|
||||
(
|
||||
3,
|
||||
height,
|
||||
width,
|
||||
),
|
||||
dtype=torch.uint8,
|
||||
)
|
||||
mask_tensor = (
|
||||
torch.randint(
|
||||
0,
|
||||
255,
|
||||
(
|
||||
1,
|
||||
height,
|
||||
width,
|
||||
),
|
||||
dtype=torch.uint8,
|
||||
)
|
||||
> 127.5
|
||||
)
|
||||
im_np = im_tensor.numpy().transpose(1, 2, 0)
|
||||
mask_np = mask_tensor.numpy()[0]
|
||||
|
||||
t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image(
|
||||
im_tensor / 127.5 - 1, mask_tensor, height, width, return_image=True
|
||||
)
|
||||
t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image(
|
||||
im_np, mask_np, height, width, return_image=True
|
||||
)
|
||||
|
||||
self.assertTrue((t_mask_tensor == t_mask_np).all())
|
||||
self.assertTrue((t_masked_tensor == t_masked_np).all())
|
||||
self.assertTrue((t_image_tensor == t_image_np).all())
|
||||
|
||||
def test_torch_4D_2D_inputs(self):
|
||||
height, width = 32, 32
|
||||
|
||||
im_tensor = torch.randint(
|
||||
0,
|
||||
255,
|
||||
(
|
||||
1,
|
||||
3,
|
||||
height,
|
||||
width,
|
||||
),
|
||||
dtype=torch.uint8,
|
||||
)
|
||||
mask_tensor = (
|
||||
torch.randint(
|
||||
0,
|
||||
255,
|
||||
(
|
||||
height,
|
||||
width,
|
||||
),
|
||||
dtype=torch.uint8,
|
||||
)
|
||||
> 127.5
|
||||
)
|
||||
im_np = im_tensor.numpy()[0].transpose(1, 2, 0)
|
||||
mask_np = mask_tensor.numpy()
|
||||
|
||||
t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image(
|
||||
im_tensor / 127.5 - 1, mask_tensor, height, width, return_image=True
|
||||
)
|
||||
t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image(
|
||||
im_np, mask_np, height, width, return_image=True
|
||||
)
|
||||
|
||||
self.assertTrue((t_mask_tensor == t_mask_np).all())
|
||||
self.assertTrue((t_masked_tensor == t_masked_np).all())
|
||||
self.assertTrue((t_image_tensor == t_image_np).all())
|
||||
|
||||
def test_torch_4D_3D_inputs(self):
|
||||
height, width = 32, 32
|
||||
|
||||
im_tensor = torch.randint(
|
||||
0,
|
||||
255,
|
||||
(
|
||||
1,
|
||||
3,
|
||||
height,
|
||||
width,
|
||||
),
|
||||
dtype=torch.uint8,
|
||||
)
|
||||
mask_tensor = (
|
||||
torch.randint(
|
||||
0,
|
||||
255,
|
||||
(
|
||||
1,
|
||||
height,
|
||||
width,
|
||||
),
|
||||
dtype=torch.uint8,
|
||||
)
|
||||
> 127.5
|
||||
)
|
||||
im_np = im_tensor.numpy()[0].transpose(1, 2, 0)
|
||||
mask_np = mask_tensor.numpy()[0]
|
||||
|
||||
t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image(
|
||||
im_tensor / 127.5 - 1, mask_tensor, height, width, return_image=True
|
||||
)
|
||||
t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image(
|
||||
im_np, mask_np, height, width, return_image=True
|
||||
)
|
||||
|
||||
self.assertTrue((t_mask_tensor == t_mask_np).all())
|
||||
self.assertTrue((t_masked_tensor == t_masked_np).all())
|
||||
self.assertTrue((t_image_tensor == t_image_np).all())
|
||||
|
||||
def test_torch_4D_4D_inputs(self):
|
||||
height, width = 32, 32
|
||||
|
||||
im_tensor = torch.randint(
|
||||
0,
|
||||
255,
|
||||
(
|
||||
1,
|
||||
3,
|
||||
height,
|
||||
width,
|
||||
),
|
||||
dtype=torch.uint8,
|
||||
)
|
||||
mask_tensor = (
|
||||
torch.randint(
|
||||
0,
|
||||
255,
|
||||
(
|
||||
1,
|
||||
1,
|
||||
height,
|
||||
width,
|
||||
),
|
||||
dtype=torch.uint8,
|
||||
)
|
||||
> 127.5
|
||||
)
|
||||
im_np = im_tensor.numpy()[0].transpose(1, 2, 0)
|
||||
mask_np = mask_tensor.numpy()[0][0]
|
||||
|
||||
t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image(
|
||||
im_tensor / 127.5 - 1, mask_tensor, height, width, return_image=True
|
||||
)
|
||||
t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image(
|
||||
im_np, mask_np, height, width, return_image=True
|
||||
)
|
||||
|
||||
self.assertTrue((t_mask_tensor == t_mask_np).all())
|
||||
self.assertTrue((t_masked_tensor == t_masked_np).all())
|
||||
self.assertTrue((t_image_tensor == t_image_np).all())
|
||||
|
||||
def test_torch_batch_4D_3D(self):
|
||||
height, width = 32, 32
|
||||
|
||||
im_tensor = torch.randint(
|
||||
0,
|
||||
255,
|
||||
(
|
||||
2,
|
||||
3,
|
||||
height,
|
||||
width,
|
||||
),
|
||||
dtype=torch.uint8,
|
||||
)
|
||||
mask_tensor = (
|
||||
torch.randint(
|
||||
0,
|
||||
255,
|
||||
(
|
||||
2,
|
||||
height,
|
||||
width,
|
||||
),
|
||||
dtype=torch.uint8,
|
||||
)
|
||||
> 127.5
|
||||
)
|
||||
|
||||
im_nps = [im.numpy().transpose(1, 2, 0) for im in im_tensor]
|
||||
mask_nps = [mask.numpy() for mask in mask_tensor]
|
||||
|
||||
t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image(
|
||||
im_tensor / 127.5 - 1, mask_tensor, height, width, return_image=True
|
||||
)
|
||||
nps = [prepare_mask_and_masked_image(i, m, height, width, return_image=True) for i, m in zip(im_nps, mask_nps)]
|
||||
t_mask_np = torch.cat([n[0] for n in nps])
|
||||
t_masked_np = torch.cat([n[1] for n in nps])
|
||||
t_image_np = torch.cat([n[2] for n in nps])
|
||||
|
||||
self.assertTrue((t_mask_tensor == t_mask_np).all())
|
||||
self.assertTrue((t_masked_tensor == t_masked_np).all())
|
||||
self.assertTrue((t_image_tensor == t_image_np).all())
|
||||
|
||||
def test_torch_batch_4D_4D(self):
|
||||
height, width = 32, 32
|
||||
|
||||
im_tensor = torch.randint(
|
||||
0,
|
||||
255,
|
||||
(
|
||||
2,
|
||||
3,
|
||||
height,
|
||||
width,
|
||||
),
|
||||
dtype=torch.uint8,
|
||||
)
|
||||
mask_tensor = (
|
||||
torch.randint(
|
||||
0,
|
||||
255,
|
||||
(
|
||||
2,
|
||||
1,
|
||||
height,
|
||||
width,
|
||||
),
|
||||
dtype=torch.uint8,
|
||||
)
|
||||
> 127.5
|
||||
)
|
||||
|
||||
im_nps = [im.numpy().transpose(1, 2, 0) for im in im_tensor]
|
||||
mask_nps = [mask.numpy()[0] for mask in mask_tensor]
|
||||
|
||||
t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image(
|
||||
im_tensor / 127.5 - 1, mask_tensor, height, width, return_image=True
|
||||
)
|
||||
nps = [prepare_mask_and_masked_image(i, m, height, width, return_image=True) for i, m in zip(im_nps, mask_nps)]
|
||||
t_mask_np = torch.cat([n[0] for n in nps])
|
||||
t_masked_np = torch.cat([n[1] for n in nps])
|
||||
t_image_np = torch.cat([n[2] for n in nps])
|
||||
|
||||
self.assertTrue((t_mask_tensor == t_mask_np).all())
|
||||
self.assertTrue((t_masked_tensor == t_masked_np).all())
|
||||
self.assertTrue((t_image_tensor == t_image_np).all())
|
||||
|
||||
def test_shape_mismatch(self):
|
||||
height, width = 32, 32
|
||||
|
||||
# test height and width
|
||||
with self.assertRaises(AssertionError):
|
||||
prepare_mask_and_masked_image(
|
||||
torch.randn(
|
||||
3,
|
||||
height,
|
||||
width,
|
||||
),
|
||||
torch.randn(64, 64),
|
||||
height,
|
||||
width,
|
||||
return_image=True,
|
||||
)
|
||||
# test batch dim
|
||||
with self.assertRaises(AssertionError):
|
||||
prepare_mask_and_masked_image(
|
||||
torch.randn(
|
||||
2,
|
||||
3,
|
||||
height,
|
||||
width,
|
||||
),
|
||||
torch.randn(4, 64, 64),
|
||||
height,
|
||||
width,
|
||||
return_image=True,
|
||||
)
|
||||
# test batch dim
|
||||
with self.assertRaises(AssertionError):
|
||||
prepare_mask_and_masked_image(
|
||||
torch.randn(
|
||||
2,
|
||||
3,
|
||||
height,
|
||||
width,
|
||||
),
|
||||
torch.randn(4, 1, 64, 64),
|
||||
height,
|
||||
width,
|
||||
return_image=True,
|
||||
)
|
||||
|
||||
def test_type_mismatch(self):
|
||||
height, width = 32, 32
|
||||
|
||||
# test tensors-only
|
||||
with self.assertRaises(TypeError):
|
||||
prepare_mask_and_masked_image(
|
||||
torch.rand(
|
||||
3,
|
||||
height,
|
||||
width,
|
||||
),
|
||||
torch.rand(
|
||||
3,
|
||||
height,
|
||||
width,
|
||||
).numpy(),
|
||||
height,
|
||||
width,
|
||||
return_image=True,
|
||||
)
|
||||
# test tensors-only
|
||||
with self.assertRaises(TypeError):
|
||||
prepare_mask_and_masked_image(
|
||||
torch.rand(
|
||||
3,
|
||||
height,
|
||||
width,
|
||||
).numpy(),
|
||||
torch.rand(
|
||||
3,
|
||||
height,
|
||||
width,
|
||||
),
|
||||
height,
|
||||
width,
|
||||
return_image=True,
|
||||
)
|
||||
|
||||
def test_channels_first(self):
|
||||
height, width = 32, 32
|
||||
|
||||
# test channels first for 3D tensors
|
||||
with self.assertRaises(AssertionError):
|
||||
prepare_mask_and_masked_image(
|
||||
torch.rand(height, width, 3),
|
||||
torch.rand(
|
||||
3,
|
||||
height,
|
||||
width,
|
||||
),
|
||||
height,
|
||||
width,
|
||||
return_image=True,
|
||||
)
|
||||
|
||||
def test_tensor_range(self):
|
||||
height, width = 32, 32
|
||||
|
||||
# test im <= 1
|
||||
with self.assertRaises(ValueError):
|
||||
prepare_mask_and_masked_image(
|
||||
torch.ones(
|
||||
3,
|
||||
height,
|
||||
width,
|
||||
)
|
||||
* 2,
|
||||
torch.rand(
|
||||
height,
|
||||
width,
|
||||
),
|
||||
height,
|
||||
width,
|
||||
return_image=True,
|
||||
)
|
||||
# test im >= -1
|
||||
with self.assertRaises(ValueError):
|
||||
prepare_mask_and_masked_image(
|
||||
torch.ones(
|
||||
3,
|
||||
height,
|
||||
width,
|
||||
)
|
||||
* (-2),
|
||||
torch.rand(
|
||||
height,
|
||||
width,
|
||||
),
|
||||
height,
|
||||
width,
|
||||
return_image=True,
|
||||
)
|
||||
# test mask <= 1
|
||||
with self.assertRaises(ValueError):
|
||||
prepare_mask_and_masked_image(
|
||||
torch.rand(
|
||||
3,
|
||||
height,
|
||||
width,
|
||||
),
|
||||
torch.ones(
|
||||
height,
|
||||
width,
|
||||
)
|
||||
* 2,
|
||||
height,
|
||||
width,
|
||||
return_image=True,
|
||||
)
|
||||
# test mask >= 0
|
||||
with self.assertRaises(ValueError):
|
||||
prepare_mask_and_masked_image(
|
||||
torch.rand(
|
||||
3,
|
||||
height,
|
||||
width,
|
||||
),
|
||||
torch.ones(
|
||||
height,
|
||||
width,
|
||||
)
|
||||
* -1,
|
||||
height,
|
||||
width,
|
||||
return_image=True,
|
||||
)
|
||||
|
||||
@@ -24,6 +24,7 @@ python utils/update_metadata.py
|
||||
Script modified from:
|
||||
https://github.com/huggingface/transformers/blob/main/utils/update_metadata.py
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
Reference in New Issue
Block a user