Compare commits

...

29 Commits

Author SHA1 Message Date
Sayak Paul 7394b99047 Merge branch 'main' into fix-casting-training-params 2024-06-18 14:45:33 +01:00
Sayak Paul 4edde134f6 [SD3 training] refactor the density and weighting utilities. (#8591)
refactor the density and weighting utilities.
2024-06-18 14:44:38 +01:00
Sayak Paul 4716a413bf Merge branch 'main' into fix-casting-training-params 2024-06-18 14:18:05 +01:00
Bagheera 074a7cc3c5 SD3: update default training timestep / loss weighting distribution to logit_normal (#8592)
Co-authored-by: bghira <bghira@users.github.com>
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2024-06-18 14:15:19 +01:00
Álvaro Somoza 6bfd13f07a [SD3 Training] T5 token limit (#8564)
* initial commit

* default back to 77

* better text

* text correction

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2024-06-17 16:32:56 -04:00
Sayak Paul c843fb25ec Merge branch 'main' into fix-casting-training-params 2024-06-17 20:43:54 +01:00
AmosDinh eeb70033a6 Syntax error in readme example "pipe" -> "pipeline" (#8601)
Update controlnet.md

Syntax error pipe -> pipeline
2024-06-17 11:02:07 -07:00
Dhruv Nair c4a4750cb3 Temporarily pin Numpy in the CI (#8603)
temp pin numpy
2024-06-17 19:32:38 +05:30
YiYi Xu a6375d4101 Image processor latent (#8513)
* fix

* up

---------

Co-authored-by: yiyixuxu <yixu310@gmail,com>
2024-06-16 22:34:55 -10:00
spacepxl 8e1b7a084a Fix the deletion of SD3 text encoders for Dreambooth/LoRA training if the text encoders are not being trained (#8536)
* Update train_dreambooth_sd3.py to fix TE garbage collection

* Update train_dreambooth_lora_sd3.py to fix TE garbage collection

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2024-06-16 20:52:33 +01:00
Rafie Walker 6946facf69 Implement SD3 loss weighting (#8528)
* Add lognorm and cosmap weighting

* Implement mode sampling

* Update examples/dreambooth/train_dreambooth_lora_sd3.py

* Update examples/dreambooth/train_dreambooth_lora_sd3.py

* Update examples/dreambooth/train_dreambooth_sd3.py

* Update examples/dreambooth/train_dreambooth_sd3.py

* Update examples/dreambooth/train_dreambooth_sd3.py

* Update examples/dreambooth/train_dreambooth_lora_sd3.py

* Update examples/dreambooth/train_dreambooth_sd3.py

* Update examples/dreambooth/train_dreambooth_sd3.py

* Update examples/dreambooth/train_dreambooth_lora_sd3.py

* keep timestamp sampling fully on cpu

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2024-06-16 20:15:50 +01:00
Sayak Paul 130dd936bb pin accelerate to 0.31.0 (#8563)
* pin accelerate to 0.31.0

* update dep table

* empty
2024-06-16 08:37:00 -10:00
Jonathan Rahn a899e42fc7 add sentencepiece to requirements.txt for SD3 dreambooth (#8538)
* add `sentencepiece` requirement for SD3

add `sentencepiece` requirement

* Empty-Commit

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2024-06-14 22:48:36 +01:00
Sayak Paul f96e4a16ad pin transformers to the latest (#8522)
thanks!
2024-06-13 07:39:24 -10:00
Tolga Cangöz 9c6e9684a2 Refactor StableDiffusion3Img2ImgPipeline to remove redundant code (#8533) 2024-06-13 07:36:46 -10:00
Sayak Paul 2e4841ef1e post release 0.29.0 (#8492)
post release
2024-06-13 06:14:20 -10:00
Haofan Wang 8bea943714 Update requirements_sd3.txt (#8521) 2024-06-13 17:02:17 +01:00
YiYi Xu 614d0c64e9 remove the deprecated prepare_mask_and_masked_image function (#8512)
remove prepare mask fn

Co-authored-by: yiyixuxu <yixu310@gmail,com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2024-06-13 14:59:21 +01:00
Dhruv Nair b1a2c0d577 Expand Single File support in SD3 Pipeline (#8517)
* update

* update
2024-06-13 18:29:19 +05:30
Lucain 06ee907b73 Fix PATH_IN_REPO on new release in mirror_community_pipeline.yaml (#8519)
Fix PATH_IN_REPO in mirror workflow
2024-06-13 10:25:24 +02:00
ちくわぶ 896fb6d8d7 Fix duplicate variable assignments in SD3's JointAttnProcessor (#8516)
* Fix duplicate variable assignments.

* Fix duplicate variable assignments.
2024-06-12 21:52:35 -10:00
Beinsezii 7f51f286a5 Add Hunyuan AutoPipe mapping (#8505) 2024-06-12 16:11:55 -10:00
kkj15dk 829f6defa4 Fix spelling in scheduling_flow_match_euler_discrete.py (#8497)
Update scheduling_flow_match_euler_discrete.py

Spelling:
Foward -> Forward

Co-authored-by: YiYi Xu <yixu310@gmail.com>
2024-06-12 12:37:47 -10:00
Beinsezii 24bdf4b215 Add SD3 AutoPipeline mappings (#8489) 2024-06-12 12:31:36 -10:00
Radamés Ajna 95e0c3757d Fix small typo (#8498) 2024-06-12 15:30:58 -07:00
Sayak Paul 6cf0be5d3d fix warning log for Transformer SD3 (#8496)
fix warning log
2024-06-12 12:25:18 -10:00
Sayak Paul 38d768a876 Merge branch 'main' into fix-casting-training-params 2024-06-11 13:09:41 +01:00
Sayak Paul 2c35ea66d9 Merge branch 'main' into fix-casting-training-params 2024-06-10 13:57:27 +01:00
sayakpaul 34fbd5526a fix the position of param casting when loading them 2024-06-10 13:53:59 +01:00
70 changed files with 279 additions and 1027 deletions
@@ -54,7 +54,7 @@ jobs:
else
# e.g. refs/tags/v0.28.1 -> v0.28.1
echo "CHECKOUT_REF=${{ github.ref }}" >> $GITHUB_ENV
echo "PATH_IN_REPO=${${{ github.ref }}#refs/tags/}" >> $GITHUB_ENV
echo "PATH_IN_REPO=$(echo ${{ github.ref }} | sed 's/^refs\/tags\///')" >> $GITHUB_ENV
fi
- name: Print env vars
run: |
+1 -1
View File
@@ -42,7 +42,7 @@ RUN python3.10 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
huggingface-hub \
Jinja2 \
librosa \
numpy \
numpy==1.26.4 \
scipy \
tensorboard \
transformers \
+1 -1
View File
@@ -40,7 +40,7 @@ RUN python3 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
huggingface-hub \
Jinja2 \
librosa \
numpy \
numpy==1.26.4 \
scipy \
tensorboard \
transformers
+2 -2
View File
@@ -41,8 +41,8 @@ RUN python3 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
hf-doc-builder \
huggingface-hub \
Jinja2 \
librosa \
numpy \
librosa \
numpy==1.26.4 \
scipy \
tensorboard \
transformers
+1 -1
View File
@@ -40,7 +40,7 @@ RUN python3 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
huggingface-hub \
Jinja2 \
librosa \
numpy \
numpy==1.26.4 \
scipy \
tensorboard \
transformers
+1 -1
View File
@@ -40,7 +40,7 @@ RUN python3.10 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
huggingface-hub \
Jinja2 \
librosa \
numpy \
numpy==1.26.4 \
scipy \
tensorboard \
transformers
@@ -39,7 +39,7 @@ RUN python3.10 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
huggingface-hub \
Jinja2 \
librosa \
numpy \
numpy==1.26.4 \
scipy \
tensorboard \
transformers
+1 -1
View File
@@ -40,7 +40,7 @@ RUN python3.10 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
huggingface-hub \
Jinja2 \
librosa \
numpy \
numpy==1.26.4 \
scipy \
tensorboard \
transformers matplotlib
+1 -1
View File
@@ -39,7 +39,7 @@ RUN python3.10 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
huggingface-hub \
Jinja2 \
librosa \
numpy \
numpy==1.26.4 \
scipy \
tensorboard \
transformers \
@@ -39,7 +39,7 @@ RUN python3.10 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
huggingface-hub \
Jinja2 \
librosa \
numpy \
numpy==1.26.4 \
scipy \
tensorboard \
transformers \
@@ -35,6 +35,7 @@ The [`~loaders.FromSingleFileMixin.from_single_file`] method allows you to load:
- [`StableDiffusionXLInstructPix2PixPipeline`]
- [`StableDiffusionXLControlNetPipeline`]
- [`StableDiffusionXLKDiffusionPipeline`]
- [`StableDiffusion3Pipeline`]
- [`LatentConsistencyModelPipeline`]
- [`LatentConsistencyModelImg2ImgPipeline`]
- [`StableDiffusionControlNetXSPipeline`]
@@ -49,6 +50,7 @@ The [`~loaders.FromSingleFileMixin.from_single_file`] method allows you to load:
- [`StableCascadeUNet`]
- [`AutoencoderKL`]
- [`ControlNetModel`]
- [`SD3Transformer2DModel`]
## FromSingleFileMixin
@@ -21,9 +21,9 @@ The abstract from the paper is:
## Usage Example
_As the model is gated, before using it with diffusers you first need to go to the [Stable Diffusion 3 Medium Hugging Face page](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers), fill in the form and accept the gate. Once you are in, you need to login so that your system knows youve accepted the gate._
_As the model is gated, before using it with diffusers you first need to go to the [Stable Diffusion 3 Medium Hugging Face page](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers), fill in the form and accept the gate. Once you are in, you need to login so that your system knows youve accepted the gate._
Use the command below to log in:
Use the command below to log in:
```bash
huggingface-cli login
@@ -186,7 +186,7 @@ pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgrap
pipe.vae.decode = torch.compile(pipe.vae.decode, mode="max-autotune", fullgraph=True)
# Warm Up
prompt = "a photo of a cat holding a sign that says hello world",
prompt = "a photo of a cat holding a sign that says hello world"
for _ in range(3):
_ = pipe(prompt=prompt, generator=torch.manual_seed(1))
@@ -211,17 +211,38 @@ model = SD3Transformer2DModel.from_single_file("https://huggingface.co/stability
## Loading the single checkpoint for the `StableDiffusion3Pipeline`
```python
from diffusers import StableDiffusion3Pipeline
from transformers import T5EncoderModel
### Loading the single file checkpoint without T5
text_encoder_3 = T5EncoderModel.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", subfolder="text_encoder_3", torch_dtype=torch.float16)
pipe = StableDiffusion3Pipeline.from_single_file("https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/sd3_medium_incl_clips.safetensors", torch_dtype=torch.float16, text_encoder_3=text_encoder_3)
```python
import torch
from diffusers import StableDiffusion3Pipeline
pipe = StableDiffusion3Pipeline.from_single_file(
"https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/sd3_medium_incl_clips.safetensors",
torch_dtype=torch.float16,
text_encoder_3=None
)
pipe.enable_model_cpu_offload()
image = pipe("a picture of a cat holding a sign that says hello world").images[0]
image.save('sd3-single-file.png')
```
<Tip>
`from_single_file` support for the `fp8` version of the checkpoints is coming soon. Watch this space.
</Tip>
### Loading the single file checkpoint without T5
```python
import torch
from diffusers import StableDiffusion3Pipeline
pipe = StableDiffusion3Pipeline.from_single_file(
"https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/sd3_medium_incl_clips_t5xxlfp8.safetensors",
torch_dtype=torch.float16,
)
pipe.enable_model_cpu_offload()
image = pipe("a picture of a cat holding a sign that says hello world").images[0]
image.save('sd3-single-file-t5-fp8.png')
```
## StableDiffusion3Pipeline
+2 -2
View File
@@ -349,7 +349,7 @@ control_image = load_image("./conditioning_image_1.png")
prompt = "pale golden rod circle with old lace background"
generator = torch.manual_seed(0)
image = pipe(prompt, num_inference_steps=20, generator=generator, image=control_image).images[0]
image = pipeline(prompt, num_inference_steps=20, generator=generator, image=control_image).images[0]
image.save("./output.png")
```
@@ -363,4 +363,4 @@ The SDXL training script is discussed in more detail in the [SDXL training](sdxl
Congratulations on training your own ControlNet! To learn more about how to use your new model, the following guides may be helpful:
- Learn how to [use a ControlNet](../using-diffusers/controlnet) for inference on a variety of tasks.
- Learn how to [use a ControlNet](../using-diffusers/controlnet) for inference on a variety of tasks.
@@ -71,7 +71,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.29.0.dev0")
check_min_version("0.30.0.dev0")
logger = get_logger(__name__)
@@ -78,7 +78,7 @@ from diffusers.utils.torch_utils import is_compiled_module
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.29.0.dev0")
check_min_version("0.30.0.dev0")
logger = get_logger(__name__)
@@ -43,7 +43,7 @@ from diffusers.utils import BaseOutput, check_min_version
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.29.0.dev0")
check_min_version("0.30.0.dev0")
class MarigoldDepthOutput(BaseOutput):
@@ -73,7 +73,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.29.0.dev0")
check_min_version("0.30.0.dev0")
logger = get_logger(__name__)
@@ -66,7 +66,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.29.0.dev0")
check_min_version("0.30.0.dev0")
logger = get_logger(__name__)
@@ -79,7 +79,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.29.0.dev0")
check_min_version("0.30.0.dev0")
logger = get_logger(__name__)
@@ -72,7 +72,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.29.0.dev0")
check_min_version("0.30.0.dev0")
logger = get_logger(__name__)
@@ -78,7 +78,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.29.0.dev0")
check_min_version("0.30.0.dev0")
logger = get_logger(__name__)
+1 -1
View File
@@ -60,7 +60,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.29.0.dev0")
check_min_version("0.30.0.dev0")
logger = get_logger(__name__)
+1 -1
View File
@@ -60,7 +60,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.29.0.dev0")
check_min_version("0.30.0.dev0")
logger = logging.getLogger(__name__)
+1 -1
View File
@@ -61,7 +61,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.29.0.dev0")
check_min_version("0.30.0.dev0")
logger = get_logger(__name__)
if is_torch_npu_available():
@@ -63,7 +63,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.29.0.dev0")
check_min_version("0.30.0.dev0")
logger = get_logger(__name__)
+3
View File
@@ -106,6 +106,9 @@ To better track our training experiments, we're using the following flags in the
* `report_to="wandb` will ensure the training runs are tracked on Weights and Biases. To use it, be sure to install `wandb` with `pip install wandb`.
* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
> [!NOTE]
> If you want to train using long prompts with the T5 text encoder, you can use `--max_sequence_length` to set the token limit. The default is 77, but it can be increased to as high as 512. Note that this will use more resources and may slow down the training in some cases.
> [!TIP]
> You can pass `--use_8bit_adam` to reduce the memory requirements of training. Make sure to install `bitsandbytes` if you want to do so.
+2 -1
View File
@@ -4,4 +4,5 @@ transformers>=4.41.2
ftfy
tensorboard
Jinja2
peft== 0.11.1
peft==0.11.1
sentencepiece
+1 -1
View File
@@ -63,7 +63,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.29.0.dev0")
check_min_version("0.30.0.dev0")
logger = get_logger(__name__)
+1 -1
View File
@@ -35,7 +35,7 @@ from diffusers.utils import check_min_version
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.29.0.dev0")
check_min_version("0.30.0.dev0")
# Cache compiled models across invocations of this script.
cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache"))
+1 -1
View File
@@ -70,7 +70,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.29.0.dev0")
check_min_version("0.30.0.dev0")
logger = get_logger(__name__)
@@ -53,7 +53,11 @@ from diffusers import (
StableDiffusion3Pipeline,
)
from diffusers.optimization import get_scheduler
from diffusers.training_utils import cast_training_params
from diffusers.training_utils import (
cast_training_params,
compute_density_for_timestep_sampling,
compute_loss_weighting_for_sd3,
)
from diffusers.utils import (
check_min_version,
convert_unet_state_dict_to_peft,
@@ -67,7 +71,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.28.0.dev0")
check_min_version("0.30.0.dev0")
logger = get_logger(__name__)
@@ -298,6 +302,12 @@ def parse_args(input_args=None):
default=None,
help="The prompt to specify images in the same class as provided instance images.",
)
parser.add_argument(
"--max_sequence_length",
type=int,
default=77,
help="Maximum sequence length to use with with the T5 text encoder",
)
parser.add_argument(
"--validation_prompt",
type=str,
@@ -467,11 +477,20 @@ def parse_args(input_args=None):
),
)
parser.add_argument(
"--weighting_scheme", type=str, default="sigma_sqrt", choices=["sigma_sqrt", "logit_normal", "mode"]
"--weighting_scheme", type=str, default="sigma_sqrt", choices=["sigma_sqrt", "logit_normal", "mode", "cosmap"]
)
parser.add_argument(
"--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme."
)
parser.add_argument(
"--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme."
)
parser.add_argument(
"--mode_scale",
type=float,
default=1.29,
help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
)
parser.add_argument("--logit_mean", type=float, default=0.0)
parser.add_argument("--logit_std", type=float, default=1.0)
parser.add_argument("--mode_scale", type=float, default=1.29)
parser.add_argument(
"--optimizer",
type=str,
@@ -830,6 +849,7 @@ def tokenize_prompt(tokenizer, prompt):
def _encode_prompt_with_t5(
text_encoder,
tokenizer,
max_sequence_length,
prompt=None,
num_images_per_prompt=1,
device=None,
@@ -840,7 +860,7 @@ def _encode_prompt_with_t5(
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=77,
max_length=max_sequence_length,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
@@ -897,6 +917,7 @@ def encode_prompt(
text_encoders,
tokenizers,
prompt: str,
max_sequence_length,
device=None,
num_images_per_prompt: int = 1,
):
@@ -924,6 +945,7 @@ def encode_prompt(
t5_prompt_embed = _encode_prompt_with_t5(
text_encoders[-1],
tokenizers[-1],
max_sequence_length,
prompt=prompt,
num_images_per_prompt=num_images_per_prompt,
device=device if device is not None else text_encoders[-1].device,
@@ -1297,7 +1319,9 @@ def main(args):
def compute_text_embeddings(prompt, text_encoders, tokenizers):
with torch.no_grad():
prompt_embeds, pooled_prompt_embeds = encode_prompt(text_encoders, tokenizers, prompt)
prompt_embeds, pooled_prompt_embeds = encode_prompt(
text_encoders, tokenizers, prompt, args.max_sequence_length
)
prompt_embeds = prompt_embeds.to(accelerator.device)
pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)
return prompt_embeds, pooled_prompt_embeds
@@ -1316,6 +1340,9 @@ def main(args):
# Clear the memory here
if not train_dataset.custom_instance_prompts:
del tokenizers, text_encoders
# Explicitly delete the objects as well, otherwise only the lists are deleted and the original references remain, preventing garbage collection
del tokenizer_one, tokenizer_two, tokenizer_three
del text_encoder_one, text_encoder_two, text_encoder_three
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
@@ -1462,7 +1489,15 @@ def main(args):
bsz = model_input.shape[0]
# Sample a random timestep for each image
indices = torch.randint(0, noise_scheduler_copy.config.num_train_timesteps, (bsz,))
# for weighting schemes where we sample timesteps non-uniformly
u = compute_density_for_timestep_sampling(
weighting_scheme=args.weighting_scheme,
batch_size=bsz,
logit_mean=args.logit_mean,
logit_std=args.logit_std,
mode_scale=args.mode_scale,
)
indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()
timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)
# Add noise according to flow matching.
@@ -1482,20 +1517,11 @@ def main(args):
# Preconditioning of the model outputs.
model_pred = model_pred * (-sigmas) + noisy_model_input
# TODO (kashif, sayakpaul): weighting sceme needs to be experimented with :)
if args.weighting_scheme == "sigma_sqrt":
weighting = (sigmas**-2.0).float()
elif args.weighting_scheme == "logit_normal":
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
u = torch.normal(mean=args.logit_mean, std=args.logit_std, size=(bsz,), device=accelerator.device)
weighting = torch.nn.functional.sigmoid(u)
elif args.weighting_scheme == "mode":
# See sec 3.1 in the SD3 paper (20).
u = torch.rand(size=(bsz,), device=accelerator.device)
weighting = 1 - u - args.mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
# these weighting schemes use a uniform timestep sampling
# and instead post-weight the loss
weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
# simplified flow matching aka 0-rectified flow matching loss
# target = model_input - noise
# flow matching loss
target = model_input
if args.with_prior_preservation:
@@ -78,7 +78,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.29.0.dev0")
check_min_version("0.30.0.dev0")
logger = get_logger(__name__)
@@ -1289,8 +1289,8 @@ def main(args):
models = [unet_]
if args.train_text_encoder:
models.extend([text_encoder_one_, text_encoder_two_])
# only upcast trainable parameters (LoRA) into fp32
cast_training_params(models)
# only upcast trainable parameters (LoRA) into fp32
cast_training_params(models)
accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook)
+44 -22
View File
@@ -51,6 +51,7 @@ from diffusers import (
StableDiffusion3Pipeline,
)
from diffusers.optimization import get_scheduler
from diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3
from diffusers.utils import (
check_min_version,
is_wandb_available,
@@ -63,7 +64,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.28.0.dev0")
check_min_version("0.30.0.dev0")
logger = get_logger(__name__)
@@ -297,6 +298,12 @@ def parse_args(input_args=None):
default=None,
help="The prompt to specify images in the same class as provided instance images.",
)
parser.add_argument(
"--max_sequence_length",
type=int,
default=77,
help="Maximum sequence length to use with with the T5 text encoder",
)
parser.add_argument(
"--validation_prompt",
type=str,
@@ -465,11 +472,20 @@ def parse_args(input_args=None):
),
)
parser.add_argument(
"--weighting_scheme", type=str, default="sigma_sqrt", choices=["sigma_sqrt", "logit_normal", "mode"]
"--weighting_scheme", type=str, default="sigma_sqrt", choices=["sigma_sqrt", "logit_normal", "mode", "cosmap"]
)
parser.add_argument(
"--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme."
)
parser.add_argument(
"--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme."
)
parser.add_argument(
"--mode_scale",
type=float,
default=1.29,
help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
)
parser.add_argument("--logit_mean", type=float, default=0.0)
parser.add_argument("--logit_std", type=float, default=1.0)
parser.add_argument("--mode_scale", type=float, default=1.29)
parser.add_argument(
"--optimizer",
type=str,
@@ -828,6 +844,7 @@ def tokenize_prompt(tokenizer, prompt):
def _encode_prompt_with_t5(
text_encoder,
tokenizer,
max_sequence_length,
prompt=None,
num_images_per_prompt=1,
device=None,
@@ -838,7 +855,7 @@ def _encode_prompt_with_t5(
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=77,
max_length=max_sequence_length,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
@@ -895,6 +912,7 @@ def encode_prompt(
text_encoders,
tokenizers,
prompt: str,
max_sequence_length,
device=None,
num_images_per_prompt: int = 1,
):
@@ -922,6 +940,7 @@ def encode_prompt(
t5_prompt_embed = _encode_prompt_with_t5(
text_encoders[-1],
tokenizers[-1],
max_sequence_length,
prompt=prompt,
num_images_per_prompt=num_images_per_prompt,
device=device if device is not None else text_encoders[-1].device,
@@ -1324,7 +1343,9 @@ def main(args):
def compute_text_embeddings(prompt, text_encoders, tokenizers):
with torch.no_grad():
prompt_embeds, pooled_prompt_embeds = encode_prompt(text_encoders, tokenizers, prompt)
prompt_embeds, pooled_prompt_embeds = encode_prompt(
text_encoders, tokenizers, prompt, args.max_sequence_length
)
prompt_embeds = prompt_embeds.to(accelerator.device)
pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)
return prompt_embeds, pooled_prompt_embeds
@@ -1347,6 +1368,9 @@ def main(args):
# Clear the memory here
if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
del tokenizers, text_encoders
# Explicitly delete the objects as well, otherwise only the lists are deleted and the original references remain, preventing garbage collection
del tokenizer_one, tokenizer_two, tokenizer_three
del text_encoder_one, text_encoder_two, text_encoder_three
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
@@ -1526,7 +1550,15 @@ def main(args):
bsz = model_input.shape[0]
# Sample a random timestep for each image
indices = torch.randint(0, noise_scheduler_copy.config.num_train_timesteps, (bsz,))
# for weighting schemes where we sample timesteps non-uniformly
u = compute_density_for_timestep_sampling(
weighting_scheme=args.weighting_scheme,
batch_size=bsz,
logit_mean=args.logit_mean,
logit_std=args.logit_std,
mode_scale=args.mode_scale,
)
indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()
timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)
# Add noise according to flow matching.
@@ -1560,21 +1592,11 @@ def main(args):
# Follow: Section 5 of https://arxiv.org/abs/2206.00364.
# Preconditioning of the model outputs.
model_pred = model_pred * (-sigmas) + noisy_model_input
# these weighting schemes use a uniform timestep sampling
# and instead post-weight the loss
weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
# TODO (kashif, sayakpaul): weighting sceme needs to be experimented with :)
if args.weighting_scheme == "sigma_sqrt":
weighting = (sigmas**-2.0).float()
elif args.weighting_scheme == "logit_normal":
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
u = torch.normal(mean=args.logit_mean, std=args.logit_std, size=(bsz,), device=accelerator.device)
weighting = torch.nn.functional.sigmoid(u)
elif args.weighting_scheme == "mode":
# See sec 3.1 in the SD3 paper (20).
u = torch.rand(size=(bsz,), device=accelerator.device)
weighting = 1 - u - args.mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
# simplified flow matching aka 0-rectified flow matching loss
# target = model_input - noise
# flow matching loss
target = model_input
if args.with_prior_preservation:
@@ -57,7 +57,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.29.0.dev0")
check_min_version("0.30.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -60,7 +60,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.29.0.dev0")
check_min_version("0.30.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -52,7 +52,7 @@ if is_wandb_available():
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.29.0.dev0")
check_min_version("0.30.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -46,7 +46,7 @@ from diffusers.utils import check_min_version, is_wandb_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.29.0.dev0")
check_min_version("0.30.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -46,7 +46,7 @@ from diffusers.utils import check_min_version, is_wandb_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.29.0.dev0")
check_min_version("0.30.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -51,7 +51,7 @@ if is_wandb_available():
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.29.0.dev0")
check_min_version("0.30.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -1363,8 +1363,8 @@ def main(args):
models = [unet_]
if args.train_text_encoder:
models.extend([text_encoder_one_, text_encoder_two_])
# only upcast trainable parameters (LoRA) into fp32
cast_training_params(models)
# only upcast trainable parameters (LoRA) into fp32
cast_training_params(models)
accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook)
@@ -60,7 +60,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.29.0.dev0")
check_min_version("0.30.0.dev0")
logger = get_logger(__name__)
@@ -57,7 +57,7 @@ if is_wandb_available():
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.29.0.dev0")
check_min_version("0.30.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -49,7 +49,7 @@ from diffusers.utils import check_min_version
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.29.0.dev0")
check_min_version("0.30.0.dev0")
logger = logging.getLogger(__name__)
@@ -56,7 +56,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.29.0.dev0")
check_min_version("0.30.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -68,7 +68,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.29.0.dev0")
check_min_version("0.30.0.dev0")
logger = get_logger(__name__)
if is_torch_npu_available():
@@ -55,7 +55,7 @@ from diffusers.utils.torch_utils import is_compiled_module
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.29.0.dev0")
check_min_version("0.30.0.dev0")
logger = get_logger(__name__)
if is_torch_npu_available():
@@ -81,7 +81,7 @@ else:
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.29.0.dev0")
check_min_version("0.30.0.dev0")
logger = get_logger(__name__)
@@ -56,7 +56,7 @@ else:
# ------------------------------------------------------------------------------
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.29.0.dev0")
check_min_version("0.30.0.dev0")
logger = logging.getLogger(__name__)
@@ -76,7 +76,7 @@ else:
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.29.0.dev0")
check_min_version("0.30.0.dev0")
logger = get_logger(__name__)
@@ -29,7 +29,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.29.0.dev0")
check_min_version("0.30.0.dev0")
logger = get_logger(__name__, log_level="INFO")
+1 -1
View File
@@ -50,7 +50,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.29.0.dev0")
check_min_version("0.30.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -50,7 +50,7 @@ if is_wandb_available():
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.29.0.dev0")
check_min_version("0.30.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -51,7 +51,7 @@ if is_wandb_available():
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.29.0.dev0")
check_min_version("0.30.0.dev0")
logger = get_logger(__name__, log_level="INFO")
+3 -3
View File
@@ -95,7 +95,7 @@ from setuptools import Command, find_packages, setup
# 2. once modified, run: `make deps_table_update` to update src/diffusers/dependency_versions_table.py
_deps = [
"Pillow", # keep the PIL.Image.Resampling deprecation away
"accelerate>=0.29.3",
"accelerate>=0.31.0",
"compel==0.1.8",
"datasets",
"filelock",
@@ -132,7 +132,7 @@ _deps = [
"tensorboard",
"torch>=1.4",
"torchvision",
"transformers>=4.25.1",
"transformers>=4.41.2",
"urllib3<=2.0.0",
"black",
]
@@ -254,7 +254,7 @@ version_range_max = max(sys.version_info[1], 10) + 1
setup(
name="diffusers",
version="0.29.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
version="0.30.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
description="State-of-the-art diffusion in PyTorch and JAX.",
long_description=open("README.md", "r", encoding="utf-8").read(),
long_description_content_type="text/markdown",
+1 -1
View File
@@ -1,4 +1,4 @@
__version__ = "0.29.0.dev0"
__version__ = "0.30.0.dev0"
from typing import TYPE_CHECKING
+2 -2
View File
@@ -3,7 +3,7 @@
# 2. run `make deps_table_update`
deps = {
"Pillow": "Pillow",
"accelerate": "accelerate>=0.29.3",
"accelerate": "accelerate>=0.31.0",
"compel": "compel==0.1.8",
"datasets": "datasets",
"filelock": "filelock",
@@ -40,7 +40,7 @@ deps = {
"tensorboard": "tensorboard",
"torch": "torch>=1.4",
"torchvision": "torchvision",
"transformers": "transformers>=4.25.1",
"transformers": "transformers>=4.41.2",
"urllib3": "urllib3<=2.0.0",
"black": "black",
}
+1 -2
View File
@@ -569,7 +569,7 @@ class VaeImageProcessor(ConfigMixin):
channel = image.shape[1]
# don't need any preprocess if the image is latents
if channel == 4:
if channel == self.vae_latent_channels:
return image
height, width = self.get_default_height_width(image, height, width)
@@ -585,7 +585,6 @@ class VaeImageProcessor(ConfigMixin):
FutureWarning,
)
do_normalize = False
if do_normalize:
image = self.normalize(image)
+12
View File
@@ -28,9 +28,11 @@ from .single_file_utils import (
_legacy_load_safety_checker,
_legacy_load_scheduler,
create_diffusers_clip_model_from_ldm,
create_diffusers_t5_model_from_checkpoint,
fetch_diffusers_config,
fetch_original_config,
is_clip_model_in_single_file,
is_t5_in_single_file,
load_single_file_checkpoint,
)
@@ -118,6 +120,16 @@ def load_single_file_sub_model(
is_legacy_loading=is_legacy_loading,
)
elif is_transformers_model and is_t5_in_single_file(checkpoint):
loaded_sub_model = create_diffusers_t5_model_from_checkpoint(
class_obj,
checkpoint=checkpoint,
config=cached_model_config_path,
subfolder=name,
torch_dtype=torch_dtype,
local_files_only=local_files_only,
)
elif is_tokenizer and is_legacy_loading:
loaded_sub_model = _legacy_load_clip_tokenizer(
class_obj, checkpoint=checkpoint, config=cached_model_config_path, local_files_only=local_files_only
+23 -9
View File
@@ -252,7 +252,6 @@ LDM_CONTROLNET_KEY = "control_model."
LDM_CLIP_PREFIX_TO_REMOVE = [
"cond_stage_model.transformer.",
"conditioner.embedders.0.transformer.",
"text_encoders.clip_l.transformer.",
]
OPEN_CLIP_PREFIX = "conditioner.embedders.0.model."
LDM_OPEN_CLIP_TEXT_PROJECTION_DIM = 1024
@@ -399,11 +398,14 @@ def is_open_clip_sdxl_model(checkpoint):
def is_open_clip_sd3_model(checkpoint):
is_open_clip_sdxl_refiner_model(checkpoint)
if CHECKPOINT_KEY_NAMES["open_clip_sd3"] in checkpoint:
return True
return False
def is_open_clip_sdxl_refiner_model(checkpoint):
if CHECKPOINT_KEY_NAMES["open_clip_sd3"] in checkpoint:
if CHECKPOINT_KEY_NAMES["open_clip_sdxl_refiner"] in checkpoint:
return True
return False
@@ -1233,11 +1235,14 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
return new_checkpoint
def convert_ldm_clip_checkpoint(checkpoint):
def convert_ldm_clip_checkpoint(checkpoint, remove_prefix=None):
keys = list(checkpoint.keys())
text_model_dict = {}
remove_prefixes = LDM_CLIP_PREFIX_TO_REMOVE
remove_prefixes = []
remove_prefixes.extend(LDM_CLIP_PREFIX_TO_REMOVE)
if remove_prefix:
remove_prefixes.append(remove_prefix)
for key in keys:
for prefix in remove_prefixes:
@@ -1376,6 +1381,13 @@ def create_diffusers_clip_model_from_ldm(
):
diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint)
elif (
is_clip_sd3_model(checkpoint)
and checkpoint[CHECKPOINT_KEY_NAMES["clip_sd3"]].shape[-1] == position_embedding_dim
):
diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint, "text_encoders.clip_l.transformer.")
diffusers_format_checkpoint["text_projection.weight"] = torch.eye(position_embedding_dim)
elif is_open_clip_model(checkpoint):
prefix = "cond_stage_model.model."
diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix)
@@ -1391,9 +1403,11 @@ def create_diffusers_clip_model_from_ldm(
prefix = "conditioner.embedders.0.model."
diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix)
elif is_open_clip_sd3_model(checkpoint):
prefix = "text_encoders.clip_g.transformer."
diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix)
elif (
is_open_clip_sd3_model(checkpoint)
and checkpoint[CHECKPOINT_KEY_NAMES["open_clip_sd3"]].shape[-1] == position_embedding_dim
):
diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint, "text_encoders.clip_g.transformer.")
else:
raise ValueError("The provided checkpoint does not seem to contain a valid CLIP model.")
@@ -1755,7 +1769,7 @@ def convert_sd3_t5_checkpoint_to_diffusers(checkpoint):
keys = list(checkpoint.keys())
text_model_dict = {}
remove_prefixes = ["text_encoders.t5xxl.transformer.encoder."]
remove_prefixes = ["text_encoders.t5xxl.transformer."]
for key in keys:
for prefix in remove_prefixes:
+1 -4
View File
@@ -1132,9 +1132,7 @@ class JointAttnProcessor2_0:
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
hidden_states = hidden_states = F.scaled_dot_product_attention(
query, key, value, dropout_p=0.0, is_causal=False
)
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
@@ -1406,7 +1404,6 @@ class XFormersAttnProcessor:
class AttnProcessorNPU:
r"""
Processor for implementing flash attention using torch_npu. Torch_npu supports only fp16 and bf16 data types. If
fp32 is used, F.scaled_dot_product_attention will be used for computation, but the acceleration effect on NPU is
@@ -282,9 +282,10 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
logger.warning(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)
height, width = hidden_states.shape[-2:]
+8
View File
@@ -27,6 +27,7 @@ from .controlnet import (
StableDiffusionXLControlNetPipeline,
)
from .deepfloyd_if import IFImg2ImgPipeline, IFInpaintingPipeline, IFPipeline
from .hunyuandit import HunyuanDiTPipeline
from .kandinsky import (
KandinskyCombinedPipeline,
KandinskyImg2ImgCombinedPipeline,
@@ -52,6 +53,10 @@ from .stable_diffusion import (
StableDiffusionInpaintPipeline,
StableDiffusionPipeline,
)
from .stable_diffusion_3 import (
StableDiffusion3Img2ImgPipeline,
StableDiffusion3Pipeline,
)
from .stable_diffusion_xl import (
StableDiffusionXLImg2ImgPipeline,
StableDiffusionXLInpaintPipeline,
@@ -64,7 +69,9 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
[
("stable-diffusion", StableDiffusionPipeline),
("stable-diffusion-xl", StableDiffusionXLPipeline),
("stable-diffusion-3", StableDiffusion3Pipeline),
("if", IFPipeline),
("hunyuan", HunyuanDiTPipeline),
("kandinsky", KandinskyCombinedPipeline),
("kandinsky22", KandinskyV22CombinedPipeline),
("kandinsky3", Kandinsky3Pipeline),
@@ -82,6 +89,7 @@ AUTO_IMAGE2IMAGE_PIPELINES_MAPPING = OrderedDict(
[
("stable-diffusion", StableDiffusionImg2ImgPipeline),
("stable-diffusion-xl", StableDiffusionXLImg2ImgPipeline),
("stable-diffusion-3", StableDiffusion3Img2ImgPipeline),
("if", IFImg2ImgPipeline),
("kandinsky", KandinskyImg2ImgCombinedPipeline),
("kandinsky22", KandinskyV22Img2ImgCombinedPipeline),
@@ -118,129 +118,6 @@ def retrieve_latents(
raise AttributeError("Could not access latents of provided encoder_output")
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.prepare_mask_and_masked_image
def prepare_mask_and_masked_image(image, mask, height, width, return_image=False):
"""
Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be
converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the
``image`` and ``1`` for the ``mask``.
The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be
binarized (``mask > 0.5``) and cast to ``torch.float32`` too.
Args:
image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint.
It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width``
``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``.
mask (_type_): The mask to apply to the image, i.e. regions to inpaint.
It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width``
``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``.
Raises:
ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask
should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions.
TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not
(ot the other way around).
Returns:
tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4
dimensions: ``batch x channels x height x width``.
"""
deprecation_message = "The prepare_mask_and_masked_image method is deprecated and will be removed in a future version. Please use VaeImageProcessor.preprocess instead"
deprecate(
"prepare_mask_and_masked_image",
"0.30.0",
deprecation_message,
)
if image is None:
raise ValueError("`image` input cannot be undefined.")
if mask is None:
raise ValueError("`mask_image` input cannot be undefined.")
if isinstance(image, torch.Tensor):
if not isinstance(mask, torch.Tensor):
raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not")
# Batch single image
if image.ndim == 3:
assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)"
image = image.unsqueeze(0)
# Batch and add channel dim for single mask
if mask.ndim == 2:
mask = mask.unsqueeze(0).unsqueeze(0)
# Batch single mask or add channel dim
if mask.ndim == 3:
# Single batched mask, no channel dim or single mask not batched but channel dim
if mask.shape[0] == 1:
mask = mask.unsqueeze(0)
# Batched masks no channel dim
else:
mask = mask.unsqueeze(1)
assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions"
assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions"
assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size"
# Check image is in [-1, 1]
if image.min() < -1 or image.max() > 1:
raise ValueError("Image should be in [-1, 1] range")
# Check mask is in [0, 1]
if mask.min() < 0 or mask.max() > 1:
raise ValueError("Mask should be in [0, 1] range")
# Binarize mask
mask[mask < 0.5] = 0
mask[mask >= 0.5] = 1
# Image as float32
image = image.to(dtype=torch.float32)
elif isinstance(mask, torch.Tensor):
raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not")
else:
# preprocess image
if isinstance(image, (PIL.Image.Image, np.ndarray)):
image = [image]
if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
# resize all images w.r.t passed height an width
image = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image]
image = [np.array(i.convert("RGB"))[None, :] for i in image]
image = np.concatenate(image, axis=0)
elif isinstance(image, list) and isinstance(image[0], np.ndarray):
image = np.concatenate([i[None, :] for i in image], axis=0)
image = image.transpose(0, 3, 1, 2)
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
# preprocess mask
if isinstance(mask, (PIL.Image.Image, np.ndarray)):
mask = [mask]
if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image):
mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask]
mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0)
mask = mask.astype(np.float32) / 255.0
elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
mask[mask < 0.5] = 0
mask[mask >= 0.5] = 1
mask = torch.from_numpy(mask)
masked_image = image * (mask < 0.5)
# n.b. ensure backwards compatibility as old function does not return image
if return_image:
return mask, masked_image, image
return mask, masked_image
class StableDiffusionControlNetInpaintPipeline(
DiffusionPipeline,
StableDiffusionMixin,
@@ -15,7 +15,6 @@
import inspect
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import PIL.Image
import torch
from packaging import version
@@ -38,128 +37,6 @@ from .safety_checker import StableDiffusionSafetyChecker
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool = False):
"""
Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be
converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the
``image`` and ``1`` for the ``mask``.
The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be
binarized (``mask > 0.5``) and cast to ``torch.float32`` too.
Args:
image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint.
It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width``
``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``.
mask (_type_): The mask to apply to the image, i.e. regions to inpaint.
It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width``
``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``.
Raises:
ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask
should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions.
TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not
(ot the other way around).
Returns:
tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4
dimensions: ``batch x channels x height x width``.
"""
deprecation_message = "The prepare_mask_and_masked_image method is deprecated and will be removed in a future version. Please use VaeImageProcessor.preprocess instead"
deprecate(
"prepare_mask_and_masked_image",
"0.30.0",
deprecation_message,
)
if image is None:
raise ValueError("`image` input cannot be undefined.")
if mask is None:
raise ValueError("`mask_image` input cannot be undefined.")
if isinstance(image, torch.Tensor):
if not isinstance(mask, torch.Tensor):
raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not")
# Batch single image
if image.ndim == 3:
assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)"
image = image.unsqueeze(0)
# Batch and add channel dim for single mask
if mask.ndim == 2:
mask = mask.unsqueeze(0).unsqueeze(0)
# Batch single mask or add channel dim
if mask.ndim == 3:
# Single batched mask, no channel dim or single mask not batched but channel dim
if mask.shape[0] == 1:
mask = mask.unsqueeze(0)
# Batched masks no channel dim
else:
mask = mask.unsqueeze(1)
assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions"
assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions"
assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size"
# Check image is in [-1, 1]
if image.min() < -1 or image.max() > 1:
raise ValueError("Image should be in [-1, 1] range")
# Check mask is in [0, 1]
if mask.min() < 0 or mask.max() > 1:
raise ValueError("Mask should be in [0, 1] range")
# Binarize mask
mask[mask < 0.5] = 0
mask[mask >= 0.5] = 1
# Image as float32
image = image.to(dtype=torch.float32)
elif isinstance(mask, torch.Tensor):
raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not")
else:
# preprocess image
if isinstance(image, (PIL.Image.Image, np.ndarray)):
image = [image]
if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
# resize all images w.r.t passed height an width
image = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image]
image = [np.array(i.convert("RGB"))[None, :] for i in image]
image = np.concatenate(image, axis=0)
elif isinstance(image, list) and isinstance(image[0], np.ndarray):
image = np.concatenate([i[None, :] for i in image], axis=0)
image = image.transpose(0, 3, 1, 2)
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
# preprocess mask
if isinstance(mask, (PIL.Image.Image, np.ndarray)):
mask = [mask]
if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image):
mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask]
mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0)
mask = mask.astype(np.float32) / 255.0
elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
mask[mask < 0.5] = 0
mask[mask >= 0.5] = 1
mask = torch.from_numpy(mask)
masked_image = image * (mask < 0.5)
# n.b. ensure backwards compatibility as old function does not return image
if return_image:
return mask, masked_image, image
return mask, masked_image
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
@@ -591,8 +591,6 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline):
)
image = image.to(device=device, dtype=dtype)
if image.shape[1] == self.vae.config.latent_channels:
init_latents = image
batch_size = batch_size * num_images_per_prompt
if image.shape[1] == self.vae.config.latent_channels:
@@ -132,124 +132,6 @@ def mask_pil_to_torch(mask, height, width):
return mask
def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool = False):
"""
Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be
converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the
``image`` and ``1`` for the ``mask``.
The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be
binarized (``mask > 0.5``) and cast to ``torch.float32`` too.
Args:
image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint.
It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width``
``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``.
mask (_type_): The mask to apply to the image, i.e. regions to inpaint.
It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width``
``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``.
Raises:
ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask
should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions.
TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not
(ot the other way around).
Returns:
tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4
dimensions: ``batch x channels x height x width``.
"""
# checkpoint. TOD(Yiyi) - need to clean this up later
deprecation_message = "The prepare_mask_and_masked_image method is deprecated and will be removed in a future version. Please use VaeImageProcessor.preprocess instead"
deprecate(
"prepare_mask_and_masked_image",
"0.30.0",
deprecation_message,
)
if image is None:
raise ValueError("`image` input cannot be undefined.")
if mask is None:
raise ValueError("`mask_image` input cannot be undefined.")
if isinstance(image, torch.Tensor):
if not isinstance(mask, torch.Tensor):
mask = mask_pil_to_torch(mask, height, width)
if image.ndim == 3:
image = image.unsqueeze(0)
# Batch and add channel dim for single mask
if mask.ndim == 2:
mask = mask.unsqueeze(0).unsqueeze(0)
# Batch single mask or add channel dim
if mask.ndim == 3:
# Single batched mask, no channel dim or single mask not batched but channel dim
if mask.shape[0] == 1:
mask = mask.unsqueeze(0)
# Batched masks no channel dim
else:
mask = mask.unsqueeze(1)
assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions"
# assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions"
assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size"
# Check image is in [-1, 1]
# if image.min() < -1 or image.max() > 1:
# raise ValueError("Image should be in [-1, 1] range")
# Check mask is in [0, 1]
if mask.min() < 0 or mask.max() > 1:
raise ValueError("Mask should be in [0, 1] range")
# Binarize mask
mask[mask < 0.5] = 0
mask[mask >= 0.5] = 1
# Image as float32
image = image.to(dtype=torch.float32)
elif isinstance(mask, torch.Tensor):
raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not")
else:
# preprocess image
if isinstance(image, (PIL.Image.Image, np.ndarray)):
image = [image]
if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
# resize all images w.r.t passed height an width
image = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image]
image = [np.array(i.convert("RGB"))[None, :] for i in image]
image = np.concatenate(image, axis=0)
elif isinstance(image, list) and isinstance(image[0], np.ndarray):
image = np.concatenate([i[None, :] for i in image], axis=0)
image = image.transpose(0, 3, 1, 2)
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
mask = mask_pil_to_torch(mask, height, width)
mask[mask < 0.5] = 0
mask[mask >= 0.5] = 1
if image.shape[1] == 4:
# images are in latent space and thus can't
# be masked set masked_image to None
# we assume that the checkpoint is not an inpainting
# checkpoint. TOD(Yiyi) - need to clean this up later
masked_image = None
else:
masked_image = image * (mask < 0.5)
# n.b. ensure backwards compatibility as old function does not return image
if return_image:
return mask, masked_image, image
return mask, masked_image
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
@@ -114,7 +114,7 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
noise: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
"""
Foward process in flow-matching
Forward process in flow-matching
Args:
sample (`torch.FloatTensor`):
+39
View File
@@ -1,5 +1,6 @@
import contextlib
import copy
import math
import random
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
@@ -220,6 +221,44 @@ def _set_state_dict_into_text_encoder(
set_peft_model_state_dict(text_encoder, text_encoder_state_dict, adapter_name="default")
def compute_density_for_timestep_sampling(
weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None
):
"""Compute the density for sampling the timesteps when doing SD3 training.
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
"""
if weighting_scheme == "logit_normal":
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu")
u = torch.nn.functional.sigmoid(u)
elif weighting_scheme == "mode":
u = torch.rand(size=(batch_size,), device="cpu")
u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
else:
u = torch.rand(size=(batch_size,), device="cpu")
return u
def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
"""Computes loss weighting scheme for SD3 training.
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
"""
if weighting_scheme == "sigma_sqrt":
weighting = (sigmas**-2.0).float()
elif weighting_scheme == "cosmap":
bot = 1 - 2 * sigmas + 2 * sigmas**2
weighting = 2 / (math.pi * bot)
else:
weighting = torch.ones_like(sigmas)
return weighting
# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
class EMAModel:
"""
@@ -36,7 +36,6 @@ from diffusers import (
StableDiffusionInpaintPipeline,
UNet2DConditionModel,
)
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint import prepare_mask_and_masked_image
from diffusers.utils.testing_utils import (
enable_full_determinism,
floats_tensor,
@@ -1105,530 +1104,3 @@ class StableDiffusionInpaintPipelineNightlyTests(unittest.TestCase):
)
max_diff = np.abs(expected_image - image).max()
assert max_diff < 1e-3
class StableDiffusionInpaintingPrepareMaskAndMaskedImageTests(unittest.TestCase):
def test_pil_inputs(self):
height, width = 32, 32
im = np.random.randint(0, 255, (height, width, 3), dtype=np.uint8)
im = Image.fromarray(im)
mask = np.random.randint(0, 255, (height, width), dtype=np.uint8) > 127.5
mask = Image.fromarray((mask * 255).astype(np.uint8))
t_mask, t_masked, t_image = prepare_mask_and_masked_image(im, mask, height, width, return_image=True)
self.assertTrue(isinstance(t_mask, torch.Tensor))
self.assertTrue(isinstance(t_masked, torch.Tensor))
self.assertTrue(isinstance(t_image, torch.Tensor))
self.assertEqual(t_mask.ndim, 4)
self.assertEqual(t_masked.ndim, 4)
self.assertEqual(t_image.ndim, 4)
self.assertEqual(t_mask.shape, (1, 1, height, width))
self.assertEqual(t_masked.shape, (1, 3, height, width))
self.assertEqual(t_image.shape, (1, 3, height, width))
self.assertTrue(t_mask.dtype == torch.float32)
self.assertTrue(t_masked.dtype == torch.float32)
self.assertTrue(t_image.dtype == torch.float32)
self.assertTrue(t_mask.min() >= 0.0)
self.assertTrue(t_mask.max() <= 1.0)
self.assertTrue(t_masked.min() >= -1.0)
self.assertTrue(t_masked.min() <= 1.0)
self.assertTrue(t_image.min() >= -1.0)
self.assertTrue(t_image.min() >= -1.0)
self.assertTrue(t_mask.sum() > 0.0)
def test_np_inputs(self):
height, width = 32, 32
im_np = np.random.randint(0, 255, (height, width, 3), dtype=np.uint8)
im_pil = Image.fromarray(im_np)
mask_np = (
np.random.randint(
0,
255,
(
height,
width,
),
dtype=np.uint8,
)
> 127.5
)
mask_pil = Image.fromarray((mask_np * 255).astype(np.uint8))
t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image(
im_np, mask_np, height, width, return_image=True
)
t_mask_pil, t_masked_pil, t_image_pil = prepare_mask_and_masked_image(
im_pil, mask_pil, height, width, return_image=True
)
self.assertTrue((t_mask_np == t_mask_pil).all())
self.assertTrue((t_masked_np == t_masked_pil).all())
self.assertTrue((t_image_np == t_image_pil).all())
def test_torch_3D_2D_inputs(self):
height, width = 32, 32
im_tensor = torch.randint(
0,
255,
(
3,
height,
width,
),
dtype=torch.uint8,
)
mask_tensor = (
torch.randint(
0,
255,
(
height,
width,
),
dtype=torch.uint8,
)
> 127.5
)
im_np = im_tensor.numpy().transpose(1, 2, 0)
mask_np = mask_tensor.numpy()
t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image(
im_tensor / 127.5 - 1, mask_tensor, height, width, return_image=True
)
t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image(
im_np, mask_np, height, width, return_image=True
)
self.assertTrue((t_mask_tensor == t_mask_np).all())
self.assertTrue((t_masked_tensor == t_masked_np).all())
self.assertTrue((t_image_tensor == t_image_np).all())
def test_torch_3D_3D_inputs(self):
height, width = 32, 32
im_tensor = torch.randint(
0,
255,
(
3,
height,
width,
),
dtype=torch.uint8,
)
mask_tensor = (
torch.randint(
0,
255,
(
1,
height,
width,
),
dtype=torch.uint8,
)
> 127.5
)
im_np = im_tensor.numpy().transpose(1, 2, 0)
mask_np = mask_tensor.numpy()[0]
t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image(
im_tensor / 127.5 - 1, mask_tensor, height, width, return_image=True
)
t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image(
im_np, mask_np, height, width, return_image=True
)
self.assertTrue((t_mask_tensor == t_mask_np).all())
self.assertTrue((t_masked_tensor == t_masked_np).all())
self.assertTrue((t_image_tensor == t_image_np).all())
def test_torch_4D_2D_inputs(self):
height, width = 32, 32
im_tensor = torch.randint(
0,
255,
(
1,
3,
height,
width,
),
dtype=torch.uint8,
)
mask_tensor = (
torch.randint(
0,
255,
(
height,
width,
),
dtype=torch.uint8,
)
> 127.5
)
im_np = im_tensor.numpy()[0].transpose(1, 2, 0)
mask_np = mask_tensor.numpy()
t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image(
im_tensor / 127.5 - 1, mask_tensor, height, width, return_image=True
)
t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image(
im_np, mask_np, height, width, return_image=True
)
self.assertTrue((t_mask_tensor == t_mask_np).all())
self.assertTrue((t_masked_tensor == t_masked_np).all())
self.assertTrue((t_image_tensor == t_image_np).all())
def test_torch_4D_3D_inputs(self):
height, width = 32, 32
im_tensor = torch.randint(
0,
255,
(
1,
3,
height,
width,
),
dtype=torch.uint8,
)
mask_tensor = (
torch.randint(
0,
255,
(
1,
height,
width,
),
dtype=torch.uint8,
)
> 127.5
)
im_np = im_tensor.numpy()[0].transpose(1, 2, 0)
mask_np = mask_tensor.numpy()[0]
t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image(
im_tensor / 127.5 - 1, mask_tensor, height, width, return_image=True
)
t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image(
im_np, mask_np, height, width, return_image=True
)
self.assertTrue((t_mask_tensor == t_mask_np).all())
self.assertTrue((t_masked_tensor == t_masked_np).all())
self.assertTrue((t_image_tensor == t_image_np).all())
def test_torch_4D_4D_inputs(self):
height, width = 32, 32
im_tensor = torch.randint(
0,
255,
(
1,
3,
height,
width,
),
dtype=torch.uint8,
)
mask_tensor = (
torch.randint(
0,
255,
(
1,
1,
height,
width,
),
dtype=torch.uint8,
)
> 127.5
)
im_np = im_tensor.numpy()[0].transpose(1, 2, 0)
mask_np = mask_tensor.numpy()[0][0]
t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image(
im_tensor / 127.5 - 1, mask_tensor, height, width, return_image=True
)
t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image(
im_np, mask_np, height, width, return_image=True
)
self.assertTrue((t_mask_tensor == t_mask_np).all())
self.assertTrue((t_masked_tensor == t_masked_np).all())
self.assertTrue((t_image_tensor == t_image_np).all())
def test_torch_batch_4D_3D(self):
height, width = 32, 32
im_tensor = torch.randint(
0,
255,
(
2,
3,
height,
width,
),
dtype=torch.uint8,
)
mask_tensor = (
torch.randint(
0,
255,
(
2,
height,
width,
),
dtype=torch.uint8,
)
> 127.5
)
im_nps = [im.numpy().transpose(1, 2, 0) for im in im_tensor]
mask_nps = [mask.numpy() for mask in mask_tensor]
t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image(
im_tensor / 127.5 - 1, mask_tensor, height, width, return_image=True
)
nps = [prepare_mask_and_masked_image(i, m, height, width, return_image=True) for i, m in zip(im_nps, mask_nps)]
t_mask_np = torch.cat([n[0] for n in nps])
t_masked_np = torch.cat([n[1] for n in nps])
t_image_np = torch.cat([n[2] for n in nps])
self.assertTrue((t_mask_tensor == t_mask_np).all())
self.assertTrue((t_masked_tensor == t_masked_np).all())
self.assertTrue((t_image_tensor == t_image_np).all())
def test_torch_batch_4D_4D(self):
height, width = 32, 32
im_tensor = torch.randint(
0,
255,
(
2,
3,
height,
width,
),
dtype=torch.uint8,
)
mask_tensor = (
torch.randint(
0,
255,
(
2,
1,
height,
width,
),
dtype=torch.uint8,
)
> 127.5
)
im_nps = [im.numpy().transpose(1, 2, 0) for im in im_tensor]
mask_nps = [mask.numpy()[0] for mask in mask_tensor]
t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image(
im_tensor / 127.5 - 1, mask_tensor, height, width, return_image=True
)
nps = [prepare_mask_and_masked_image(i, m, height, width, return_image=True) for i, m in zip(im_nps, mask_nps)]
t_mask_np = torch.cat([n[0] for n in nps])
t_masked_np = torch.cat([n[1] for n in nps])
t_image_np = torch.cat([n[2] for n in nps])
self.assertTrue((t_mask_tensor == t_mask_np).all())
self.assertTrue((t_masked_tensor == t_masked_np).all())
self.assertTrue((t_image_tensor == t_image_np).all())
def test_shape_mismatch(self):
height, width = 32, 32
# test height and width
with self.assertRaises(AssertionError):
prepare_mask_and_masked_image(
torch.randn(
3,
height,
width,
),
torch.randn(64, 64),
height,
width,
return_image=True,
)
# test batch dim
with self.assertRaises(AssertionError):
prepare_mask_and_masked_image(
torch.randn(
2,
3,
height,
width,
),
torch.randn(4, 64, 64),
height,
width,
return_image=True,
)
# test batch dim
with self.assertRaises(AssertionError):
prepare_mask_and_masked_image(
torch.randn(
2,
3,
height,
width,
),
torch.randn(4, 1, 64, 64),
height,
width,
return_image=True,
)
def test_type_mismatch(self):
height, width = 32, 32
# test tensors-only
with self.assertRaises(TypeError):
prepare_mask_and_masked_image(
torch.rand(
3,
height,
width,
),
torch.rand(
3,
height,
width,
).numpy(),
height,
width,
return_image=True,
)
# test tensors-only
with self.assertRaises(TypeError):
prepare_mask_and_masked_image(
torch.rand(
3,
height,
width,
).numpy(),
torch.rand(
3,
height,
width,
),
height,
width,
return_image=True,
)
def test_channels_first(self):
height, width = 32, 32
# test channels first for 3D tensors
with self.assertRaises(AssertionError):
prepare_mask_and_masked_image(
torch.rand(height, width, 3),
torch.rand(
3,
height,
width,
),
height,
width,
return_image=True,
)
def test_tensor_range(self):
height, width = 32, 32
# test im <= 1
with self.assertRaises(ValueError):
prepare_mask_and_masked_image(
torch.ones(
3,
height,
width,
)
* 2,
torch.rand(
height,
width,
),
height,
width,
return_image=True,
)
# test im >= -1
with self.assertRaises(ValueError):
prepare_mask_and_masked_image(
torch.ones(
3,
height,
width,
)
* (-2),
torch.rand(
height,
width,
),
height,
width,
return_image=True,
)
# test mask <= 1
with self.assertRaises(ValueError):
prepare_mask_and_masked_image(
torch.rand(
3,
height,
width,
),
torch.ones(
height,
width,
)
* 2,
height,
width,
return_image=True,
)
# test mask >= 0
with self.assertRaises(ValueError):
prepare_mask_and_masked_image(
torch.rand(
3,
height,
width,
),
torch.ones(
height,
width,
)
* -1,
height,
width,
return_image=True,
)
+1
View File
@@ -24,6 +24,7 @@ python utils/update_metadata.py
Script modified from:
https://github.com/huggingface/transformers/blob/main/utils/update_metadata.py
"""
import argparse
import os
import tempfile