is_safetensors_compatible refactor (#2499)
* is_safetensors_compatible refactor * files list comma
This commit is contained in:
@@ -129,21 +129,49 @@ class AudioPipelineOutput(BaseOutput):
|
|||||||
|
|
||||||
|
|
||||||
def is_safetensors_compatible(filenames, variant=None) -> bool:
|
def is_safetensors_compatible(filenames, variant=None) -> bool:
|
||||||
pt_filenames = set(filename for filename in filenames if filename.endswith(".bin"))
|
"""
|
||||||
is_safetensors_compatible = any(file.endswith(".safetensors") for file in filenames)
|
Checking for safetensors compatibility:
|
||||||
|
- By default, all models are saved with the default pytorch serialization, so we use the list of default pytorch
|
||||||
|
files to know which safetensors files are needed.
|
||||||
|
- The model is safetensors compatible only if there is a matching safetensors file for every default pytorch file.
|
||||||
|
|
||||||
for pt_filename in pt_filenames:
|
Converting default pytorch serialized filenames to safetensors serialized filenames:
|
||||||
_variant = f".{variant}" if (variant is not None and variant in pt_filename) else ""
|
- For models from the diffusers library, just replace the ".bin" extension with ".safetensors"
|
||||||
prefix, raw = os.path.split(pt_filename)
|
- For models from the transformers library, the filename changes from "pytorch_model" to "model", and the ".bin"
|
||||||
if raw == f"pytorch_model{_variant}.bin":
|
extension is replaced with ".safetensors"
|
||||||
# transformers specific
|
"""
|
||||||
sf_filename = os.path.join(prefix, f"model{_variant}.safetensors")
|
pt_filenames = []
|
||||||
|
|
||||||
|
sf_filenames = set()
|
||||||
|
|
||||||
|
for filename in filenames:
|
||||||
|
_, extension = os.path.splitext(filename)
|
||||||
|
|
||||||
|
if extension == ".bin":
|
||||||
|
pt_filenames.append(filename)
|
||||||
|
elif extension == ".safetensors":
|
||||||
|
sf_filenames.add(filename)
|
||||||
|
|
||||||
|
for filename in pt_filenames:
|
||||||
|
# filename = 'foo/bar/baz.bam' -> path = 'foo/bar', filename = 'baz', extention = '.bam'
|
||||||
|
path, filename = os.path.split(filename)
|
||||||
|
filename, extension = os.path.splitext(filename)
|
||||||
|
|
||||||
|
if filename == "pytorch_model":
|
||||||
|
filename = "model"
|
||||||
|
elif filename == f"pytorch_model.{variant}":
|
||||||
|
filename = f"model.{variant}"
|
||||||
else:
|
else:
|
||||||
sf_filename = pt_filename[: -len(".bin")] + ".safetensors"
|
filename = filename
|
||||||
if is_safetensors_compatible and sf_filename not in filenames:
|
|
||||||
logger.warning(f"{sf_filename} not found")
|
expected_sf_filename = os.path.join(path, filename)
|
||||||
is_safetensors_compatible = False
|
expected_sf_filename = f"{expected_sf_filename}.safetensors"
|
||||||
return is_safetensors_compatible
|
|
||||||
|
if expected_sf_filename not in sf_filenames:
|
||||||
|
logger.warning(f"{expected_sf_filename} not found")
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
def variant_compatible_siblings(info, variant=None) -> Union[List[os.PathLike], str]:
|
def variant_compatible_siblings(info, variant=None) -> Union[List[os.PathLike], str]:
|
||||||
|
|||||||
@@ -0,0 +1,134 @@
|
|||||||
|
import unittest
|
||||||
|
|
||||||
|
from diffusers.pipelines.pipeline_utils import is_safetensors_compatible
|
||||||
|
|
||||||
|
|
||||||
|
class IsSafetensorsCompatibleTests(unittest.TestCase):
|
||||||
|
def test_all_is_compatible(self):
|
||||||
|
filenames = [
|
||||||
|
"safety_checker/pytorch_model.bin",
|
||||||
|
"safety_checker/model.safetensors",
|
||||||
|
"vae/diffusion_pytorch_model.bin",
|
||||||
|
"vae/diffusion_pytorch_model.safetensors",
|
||||||
|
"text_encoder/pytorch_model.bin",
|
||||||
|
"text_encoder/model.safetensors",
|
||||||
|
"unet/diffusion_pytorch_model.bin",
|
||||||
|
"unet/diffusion_pytorch_model.safetensors",
|
||||||
|
]
|
||||||
|
self.assertTrue(is_safetensors_compatible(filenames))
|
||||||
|
|
||||||
|
def test_diffusers_model_is_compatible(self):
|
||||||
|
filenames = [
|
||||||
|
"unet/diffusion_pytorch_model.bin",
|
||||||
|
"unet/diffusion_pytorch_model.safetensors",
|
||||||
|
]
|
||||||
|
self.assertTrue(is_safetensors_compatible(filenames))
|
||||||
|
|
||||||
|
def test_diffusers_model_is_not_compatible(self):
|
||||||
|
filenames = [
|
||||||
|
"safety_checker/pytorch_model.bin",
|
||||||
|
"safety_checker/model.safetensors",
|
||||||
|
"vae/diffusion_pytorch_model.bin",
|
||||||
|
"vae/diffusion_pytorch_model.safetensors",
|
||||||
|
"text_encoder/pytorch_model.bin",
|
||||||
|
"text_encoder/model.safetensors",
|
||||||
|
"unet/diffusion_pytorch_model.bin",
|
||||||
|
# Removed: 'unet/diffusion_pytorch_model.safetensors',
|
||||||
|
]
|
||||||
|
self.assertFalse(is_safetensors_compatible(filenames))
|
||||||
|
|
||||||
|
def test_transformer_model_is_compatible(self):
|
||||||
|
filenames = [
|
||||||
|
"text_encoder/pytorch_model.bin",
|
||||||
|
"text_encoder/model.safetensors",
|
||||||
|
]
|
||||||
|
self.assertTrue(is_safetensors_compatible(filenames))
|
||||||
|
|
||||||
|
def test_transformer_model_is_not_compatible(self):
|
||||||
|
filenames = [
|
||||||
|
"safety_checker/pytorch_model.bin",
|
||||||
|
"safety_checker/model.safetensors",
|
||||||
|
"vae/diffusion_pytorch_model.bin",
|
||||||
|
"vae/diffusion_pytorch_model.safetensors",
|
||||||
|
"text_encoder/pytorch_model.bin",
|
||||||
|
# Removed: 'text_encoder/model.safetensors',
|
||||||
|
"unet/diffusion_pytorch_model.bin",
|
||||||
|
"unet/diffusion_pytorch_model.safetensors",
|
||||||
|
]
|
||||||
|
self.assertFalse(is_safetensors_compatible(filenames))
|
||||||
|
|
||||||
|
def test_all_is_compatible_variant(self):
|
||||||
|
filenames = [
|
||||||
|
"safety_checker/pytorch_model.fp16.bin",
|
||||||
|
"safety_checker/model.fp16.safetensors",
|
||||||
|
"vae/diffusion_pytorch_model.fp16.bin",
|
||||||
|
"vae/diffusion_pytorch_model.fp16.safetensors",
|
||||||
|
"text_encoder/pytorch_model.fp16.bin",
|
||||||
|
"text_encoder/model.fp16.safetensors",
|
||||||
|
"unet/diffusion_pytorch_model.fp16.bin",
|
||||||
|
"unet/diffusion_pytorch_model.fp16.safetensors",
|
||||||
|
]
|
||||||
|
variant = "fp16"
|
||||||
|
self.assertTrue(is_safetensors_compatible(filenames, variant=variant))
|
||||||
|
|
||||||
|
def test_diffusers_model_is_compatible_variant(self):
|
||||||
|
filenames = [
|
||||||
|
"unet/diffusion_pytorch_model.fp16.bin",
|
||||||
|
"unet/diffusion_pytorch_model.fp16.safetensors",
|
||||||
|
]
|
||||||
|
variant = "fp16"
|
||||||
|
self.assertTrue(is_safetensors_compatible(filenames, variant=variant))
|
||||||
|
|
||||||
|
def test_diffusers_model_is_compatible_variant_partial(self):
|
||||||
|
# pass variant but use the non-variant filenames
|
||||||
|
filenames = [
|
||||||
|
"unet/diffusion_pytorch_model.bin",
|
||||||
|
"unet/diffusion_pytorch_model.safetensors",
|
||||||
|
]
|
||||||
|
variant = "fp16"
|
||||||
|
self.assertTrue(is_safetensors_compatible(filenames, variant=variant))
|
||||||
|
|
||||||
|
def test_diffusers_model_is_not_compatible_variant(self):
|
||||||
|
filenames = [
|
||||||
|
"safety_checker/pytorch_model.fp16.bin",
|
||||||
|
"safety_checker/model.fp16.safetensors",
|
||||||
|
"vae/diffusion_pytorch_model.fp16.bin",
|
||||||
|
"vae/diffusion_pytorch_model.fp16.safetensors",
|
||||||
|
"text_encoder/pytorch_model.fp16.bin",
|
||||||
|
"text_encoder/model.fp16.safetensors",
|
||||||
|
"unet/diffusion_pytorch_model.fp16.bin",
|
||||||
|
# Removed: 'unet/diffusion_pytorch_model.fp16.safetensors',
|
||||||
|
]
|
||||||
|
variant = "fp16"
|
||||||
|
self.assertFalse(is_safetensors_compatible(filenames, variant=variant))
|
||||||
|
|
||||||
|
def test_transformer_model_is_compatible_variant(self):
|
||||||
|
filenames = [
|
||||||
|
"text_encoder/pytorch_model.fp16.bin",
|
||||||
|
"text_encoder/model.fp16.safetensors",
|
||||||
|
]
|
||||||
|
variant = "fp16"
|
||||||
|
self.assertTrue(is_safetensors_compatible(filenames, variant=variant))
|
||||||
|
|
||||||
|
def test_transformer_model_is_compatible_variant_partial(self):
|
||||||
|
# pass variant but use the non-variant filenames
|
||||||
|
filenames = [
|
||||||
|
"text_encoder/pytorch_model.bin",
|
||||||
|
"text_encoder/model.safetensors",
|
||||||
|
]
|
||||||
|
variant = "fp16"
|
||||||
|
self.assertTrue(is_safetensors_compatible(filenames, variant=variant))
|
||||||
|
|
||||||
|
def test_transformer_model_is_not_compatible_variant(self):
|
||||||
|
filenames = [
|
||||||
|
"safety_checker/pytorch_model.fp16.bin",
|
||||||
|
"safety_checker/model.fp16.safetensors",
|
||||||
|
"vae/diffusion_pytorch_model.fp16.bin",
|
||||||
|
"vae/diffusion_pytorch_model.fp16.safetensors",
|
||||||
|
"text_encoder/pytorch_model.fp16.bin",
|
||||||
|
# 'text_encoder/model.fp16.safetensors',
|
||||||
|
"unet/diffusion_pytorch_model.fp16.bin",
|
||||||
|
"unet/diffusion_pytorch_model.fp16.safetensors",
|
||||||
|
]
|
||||||
|
variant = "fp16"
|
||||||
|
self.assertFalse(is_safetensors_compatible(filenames, variant=variant))
|
||||||
Reference in New Issue
Block a user