|
|
|
@@ -18,7 +18,7 @@ from diffusers import (
|
|
|
|
|
StableDiffusionPipeline,
|
|
|
|
|
UNet2DConditionModel,
|
|
|
|
|
)
|
|
|
|
|
from diffusers.pipelines.pipeline_loading_utils import is_safetensors_compatible
|
|
|
|
|
from diffusers.pipelines.pipeline_loading_utils import is_safetensors_compatible, variant_compatible_siblings
|
|
|
|
|
from diffusers.utils.testing_utils import torch_device
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -210,6 +210,135 @@ class IsSafetensorsCompatibleTests(unittest.TestCase):
|
|
|
|
|
self.assertFalse(is_safetensors_compatible(filenames))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class VariantCompatibleSiblingsTest(unittest.TestCase):
|
|
|
|
|
def test_only_non_variants_downloaded(self):
|
|
|
|
|
variant = "fp16"
|
|
|
|
|
filenames = [
|
|
|
|
|
f"vae/diffusion_pytorch_model.{variant}.safetensors",
|
|
|
|
|
"vae/diffusion_pytorch_model.safetensors",
|
|
|
|
|
f"text_encoder/model.{variant}.safetensors",
|
|
|
|
|
"text_encoder/model.safetensors",
|
|
|
|
|
f"unet/diffusion_pytorch_model.{variant}.safetensors",
|
|
|
|
|
"unet/diffusion_pytorch_model.safetensors",
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=None)
|
|
|
|
|
assert all(variant not in f for f in model_filenames)
|
|
|
|
|
|
|
|
|
|
def test_only_variants_downloaded(self):
|
|
|
|
|
variant = "fp16"
|
|
|
|
|
filenames = [
|
|
|
|
|
f"vae/diffusion_pytorch_model.{variant}.safetensors",
|
|
|
|
|
"vae/diffusion_pytorch_model.safetensors",
|
|
|
|
|
f"text_encoder/model.{variant}.safetensors",
|
|
|
|
|
"text_encoder/model.safetensors",
|
|
|
|
|
f"unet/diffusion_pytorch_model.{variant}.safetensors",
|
|
|
|
|
"unet/diffusion_pytorch_model.safetensors",
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
|
|
|
|
|
assert all(variant in f for f in model_filenames)
|
|
|
|
|
|
|
|
|
|
def test_mixed_variants_downloaded(self):
|
|
|
|
|
variant = "fp16"
|
|
|
|
|
non_variant_file = "text_encoder/model.safetensors"
|
|
|
|
|
filenames = [
|
|
|
|
|
f"vae/diffusion_pytorch_model.{variant}.safetensors",
|
|
|
|
|
"vae/diffusion_pytorch_model.safetensors",
|
|
|
|
|
"text_encoder/model.safetensors",
|
|
|
|
|
f"unet/diffusion_pytorch_model.{variant}.safetensors",
|
|
|
|
|
"unet/diffusion_pytorch_model.safetensors",
|
|
|
|
|
]
|
|
|
|
|
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
|
|
|
|
|
assert all(variant in f if f != non_variant_file else variant not in f for f in model_filenames)
|
|
|
|
|
|
|
|
|
|
def test_non_variants_in_main_dir_downloaded(self):
|
|
|
|
|
variant = "fp16"
|
|
|
|
|
filenames = [
|
|
|
|
|
f"diffusion_pytorch_model.{variant}.safetensors",
|
|
|
|
|
"diffusion_pytorch_model.safetensors",
|
|
|
|
|
"model.safetensors",
|
|
|
|
|
f"model.{variant}.safetensors",
|
|
|
|
|
f"diffusion_pytorch_model.{variant}.safetensors",
|
|
|
|
|
"diffusion_pytorch_model.safetensors",
|
|
|
|
|
]
|
|
|
|
|
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=None)
|
|
|
|
|
assert all(variant not in f for f in model_filenames)
|
|
|
|
|
|
|
|
|
|
def test_variants_in_main_dir_downloaded(self):
|
|
|
|
|
variant = "fp16"
|
|
|
|
|
filenames = [
|
|
|
|
|
f"diffusion_pytorch_model.{variant}.safetensors",
|
|
|
|
|
"diffusion_pytorch_model.safetensors",
|
|
|
|
|
"model.safetensors",
|
|
|
|
|
f"model.{variant}.safetensors",
|
|
|
|
|
f"diffusion_pytorch_model.{variant}.safetensors",
|
|
|
|
|
"diffusion_pytorch_model.safetensors",
|
|
|
|
|
]
|
|
|
|
|
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
|
|
|
|
|
assert all(variant in f for f in model_filenames)
|
|
|
|
|
|
|
|
|
|
def test_mixed_variants_in_main_dir_downloaded(self):
|
|
|
|
|
variant = "fp16"
|
|
|
|
|
non_variant_file = "model.safetensors"
|
|
|
|
|
filenames = [
|
|
|
|
|
f"diffusion_pytorch_model.{variant}.safetensors",
|
|
|
|
|
"diffusion_pytorch_model.safetensors",
|
|
|
|
|
"model.safetensors",
|
|
|
|
|
f"diffusion_pytorch_model.{variant}.safetensors",
|
|
|
|
|
"diffusion_pytorch_model.safetensors",
|
|
|
|
|
]
|
|
|
|
|
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
|
|
|
|
|
assert all(variant in f if f != non_variant_file else variant not in f for f in model_filenames)
|
|
|
|
|
|
|
|
|
|
def test_sharded_non_variants_downloaded(self):
|
|
|
|
|
variant = "fp16"
|
|
|
|
|
filenames = [
|
|
|
|
|
f"unet/diffusion_pytorch_model.safetensors.index.{variant}.json",
|
|
|
|
|
"unet/diffusion_pytorch_model.safetensors.index.json",
|
|
|
|
|
"unet/diffusion_pytorch_model-00001-of-00003.safetensors",
|
|
|
|
|
"unet/diffusion_pytorch_model-00002-of-00003.safetensors",
|
|
|
|
|
"unet/diffusion_pytorch_model-00003-of-00003.safetensors",
|
|
|
|
|
f"unet/diffusion_pytorch_model.{variant}-00001-of-00002.safetensors",
|
|
|
|
|
f"unet/diffusion_pytorch_model.{variant}-00002-of-00002.safetensors",
|
|
|
|
|
]
|
|
|
|
|
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=None)
|
|
|
|
|
assert all(variant not in f for f in model_filenames)
|
|
|
|
|
|
|
|
|
|
def test_sharded_variants_downloaded(self):
|
|
|
|
|
variant = "fp16"
|
|
|
|
|
filenames = [
|
|
|
|
|
f"unet/diffusion_pytorch_model.safetensors.index.{variant}.json",
|
|
|
|
|
"unet/diffusion_pytorch_model.safetensors.index.json",
|
|
|
|
|
"unet/diffusion_pytorch_model-00001-of-00003.safetensors",
|
|
|
|
|
"unet/diffusion_pytorch_model-00002-of-00003.safetensors",
|
|
|
|
|
"unet/diffusion_pytorch_model-00003-of-00003.safetensors",
|
|
|
|
|
f"unet/diffusion_pytorch_model.{variant}-00001-of-00002.safetensors",
|
|
|
|
|
f"unet/diffusion_pytorch_model.{variant}-00002-of-00002.safetensors",
|
|
|
|
|
]
|
|
|
|
|
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
|
|
|
|
|
assert all(variant in f for f in model_filenames)
|
|
|
|
|
|
|
|
|
|
def test_sharded_mixed_variants_downloaded(self):
|
|
|
|
|
variant = "fp16"
|
|
|
|
|
allowed_non_variant = "unet"
|
|
|
|
|
filenames = [
|
|
|
|
|
f"vae/diffusion_pytorch_model.safetensors.index.{variant}.json",
|
|
|
|
|
"vae/diffusion_pytorch_model.safetensors.index.json",
|
|
|
|
|
"unet/diffusion_pytorch_model.safetensors.index.json",
|
|
|
|
|
"unet/diffusion_pytorch_model-00001-of-00003.safetensors",
|
|
|
|
|
"unet/diffusion_pytorch_model-00002-of-00003.safetensors",
|
|
|
|
|
"unet/diffusion_pytorch_model-00003-of-00003.safetensors",
|
|
|
|
|
f"vae/diffusion_pytorch_model.{variant}-00001-of-00002.safetensors",
|
|
|
|
|
f"vae/diffusion_pytorch_model.{variant}-00002-of-00002.safetensors",
|
|
|
|
|
"vae/diffusion_pytorch_model-00001-of-00003.safetensors",
|
|
|
|
|
"vae/diffusion_pytorch_model-00002-of-00003.safetensors",
|
|
|
|
|
"vae/diffusion_pytorch_model-00003-of-00003.safetensors",
|
|
|
|
|
]
|
|
|
|
|
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
|
|
|
|
|
assert all(variant in f if allowed_non_variant not in f else variant not in f for f in model_filenames)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ProgressBarTests(unittest.TestCase):
|
|
|
|
|
def get_dummy_components_image_generation(self):
|
|
|
|
|
cross_attention_dim = 8
|
|
|
|
|