Align PT and Flax API - allow loading checkpoint from PyTorch configs (#827)
* up * finish * add more tests * up * up * finish
This commit is contained in:
committed by
GitHub
parent
78db11dbf3
commit
7c2262640b
@@ -111,6 +111,9 @@ class FlaxDiffusionPipeline(ConfigMixin):
|
|||||||
from diffusers import pipelines
|
from diffusers import pipelines
|
||||||
|
|
||||||
for name, module in kwargs.items():
|
for name, module in kwargs.items():
|
||||||
|
if module is None:
|
||||||
|
register_dict = {name: (None, None)}
|
||||||
|
else:
|
||||||
# retrieve library
|
# retrieve library
|
||||||
library = module.__module__.split(".")[0]
|
library = module.__module__.split(".")[0]
|
||||||
|
|
||||||
@@ -320,6 +323,11 @@ class FlaxDiffusionPipeline(ConfigMixin):
|
|||||||
pipeline_class = cls
|
pipeline_class = cls
|
||||||
else:
|
else:
|
||||||
diffusers_module = importlib.import_module(cls.__module__.split(".")[0])
|
diffusers_module = importlib.import_module(cls.__module__.split(".")[0])
|
||||||
|
class_name = (
|
||||||
|
config_dict["_class_name"]
|
||||||
|
if config_dict["_class_name"].startswith("Flax")
|
||||||
|
else "Flax" + config_dict["_class_name"]
|
||||||
|
)
|
||||||
pipeline_class = getattr(diffusers_module, config_dict["_class_name"])
|
pipeline_class = getattr(diffusers_module, config_dict["_class_name"])
|
||||||
|
|
||||||
# some modules can be passed directly to the init
|
# some modules can be passed directly to the init
|
||||||
@@ -342,6 +350,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
|
|||||||
for name, (library_name, class_name) in init_dict.items():
|
for name, (library_name, class_name) in init_dict.items():
|
||||||
is_pipeline_module = hasattr(pipelines, library_name)
|
is_pipeline_module = hasattr(pipelines, library_name)
|
||||||
loaded_sub_model = None
|
loaded_sub_model = None
|
||||||
|
sub_model_should_be_defined = True
|
||||||
|
|
||||||
# if the model is in a pipeline module, then we load it from the pipeline
|
# if the model is in a pipeline module, then we load it from the pipeline
|
||||||
if name in passed_class_obj:
|
if name in passed_class_obj:
|
||||||
@@ -362,6 +371,12 @@ class FlaxDiffusionPipeline(ConfigMixin):
|
|||||||
f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be"
|
f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be"
|
||||||
f" {expected_class_obj}"
|
f" {expected_class_obj}"
|
||||||
)
|
)
|
||||||
|
elif passed_class_obj[name] is None:
|
||||||
|
logger.warn(
|
||||||
|
f"You have passed `None` for {name} to disable its functionality in {pipeline_class}. Note"
|
||||||
|
f" that this might lead to problems when using {pipeline_class} and is not recommended."
|
||||||
|
)
|
||||||
|
sub_model_should_be_defined = False
|
||||||
else:
|
else:
|
||||||
logger.warn(
|
logger.warn(
|
||||||
f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it"
|
f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it"
|
||||||
@@ -372,25 +387,19 @@ class FlaxDiffusionPipeline(ConfigMixin):
|
|||||||
loaded_sub_model = passed_class_obj[name]
|
loaded_sub_model = passed_class_obj[name]
|
||||||
elif is_pipeline_module:
|
elif is_pipeline_module:
|
||||||
pipeline_module = getattr(pipelines, library_name)
|
pipeline_module = getattr(pipelines, library_name)
|
||||||
if from_pt:
|
|
||||||
class_obj = import_flax_or_no_model(pipeline_module, class_name)
|
class_obj = import_flax_or_no_model(pipeline_module, class_name)
|
||||||
else:
|
|
||||||
class_obj = getattr(pipeline_module, class_name)
|
|
||||||
|
|
||||||
importable_classes = ALL_IMPORTABLE_CLASSES
|
importable_classes = ALL_IMPORTABLE_CLASSES
|
||||||
class_candidates = {c: class_obj for c in importable_classes.keys()}
|
class_candidates = {c: class_obj for c in importable_classes.keys()}
|
||||||
else:
|
else:
|
||||||
# else we just import it from the library.
|
# else we just import it from the library.
|
||||||
library = importlib.import_module(library_name)
|
library = importlib.import_module(library_name)
|
||||||
if from_pt:
|
|
||||||
class_obj = import_flax_or_no_model(library, class_name)
|
class_obj = import_flax_or_no_model(library, class_name)
|
||||||
else:
|
|
||||||
class_obj = getattr(library, class_name)
|
|
||||||
|
|
||||||
importable_classes = LOADABLE_CLASSES[library_name]
|
importable_classes = LOADABLE_CLASSES[library_name]
|
||||||
class_candidates = {c: getattr(library, c) for c in importable_classes.keys()}
|
class_candidates = {c: getattr(library, c) for c in importable_classes.keys()}
|
||||||
|
|
||||||
if loaded_sub_model is None:
|
if loaded_sub_model is None and sub_model_should_be_defined:
|
||||||
load_method_name = None
|
load_method_name = None
|
||||||
for class_name, class_candidate in class_candidates.items():
|
for class_name, class_candidate in class_candidates.items():
|
||||||
if issubclass(class_obj, class_candidate):
|
if issubclass(class_obj, class_candidate):
|
||||||
|
|||||||
@@ -14,10 +14,14 @@ from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel
|
|||||||
from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel
|
from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel
|
||||||
from ...pipeline_flax_utils import FlaxDiffusionPipeline
|
from ...pipeline_flax_utils import FlaxDiffusionPipeline
|
||||||
from ...schedulers import FlaxDDIMScheduler, FlaxLMSDiscreteScheduler, FlaxPNDMScheduler
|
from ...schedulers import FlaxDDIMScheduler, FlaxLMSDiscreteScheduler, FlaxPNDMScheduler
|
||||||
|
from ...utils import logging
|
||||||
from . import FlaxStableDiffusionPipelineOutput
|
from . import FlaxStableDiffusionPipelineOutput
|
||||||
from .safety_checker_flax import FlaxStableDiffusionSafetyChecker
|
from .safety_checker_flax import FlaxStableDiffusionSafetyChecker
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
|
class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
|
||||||
r"""
|
r"""
|
||||||
Pipeline for text-to-image generation using Stable Diffusion.
|
Pipeline for text-to-image generation using Stable Diffusion.
|
||||||
@@ -60,6 +64,16 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
|
|
||||||
|
if safety_checker is None:
|
||||||
|
logger.warn(
|
||||||
|
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||||
|
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
||||||
|
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
||||||
|
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
||||||
|
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
|
||||||
|
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
||||||
|
)
|
||||||
|
|
||||||
self.register_modules(
|
self.register_modules(
|
||||||
vae=vae,
|
vae=vae,
|
||||||
text_encoder=text_encoder,
|
text_encoder=text_encoder,
|
||||||
@@ -265,10 +279,23 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
|
|||||||
prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug
|
prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.safety_checker is not None:
|
||||||
safety_params = params["safety_checker"]
|
safety_params = params["safety_checker"]
|
||||||
images = (images * 255).round().astype("uint8")
|
images_uint8_casted = (images * 255).round().astype("uint8")
|
||||||
images = np.asarray(images).reshape(-1, height, width, 3)
|
num_devices, batch_size = images.shape[:2]
|
||||||
images, has_nsfw_concept = self._run_safety_checker(images, safety_params, jit)
|
|
||||||
|
images_uint8_casted = np.asarray(images_uint8_casted).reshape(num_devices * batch_size, height, width, 3)
|
||||||
|
images_uint8_casted, has_nsfw_concept = self._run_safety_checker(images_uint8_casted, safety_params, jit)
|
||||||
|
images = np.asarray(images)
|
||||||
|
|
||||||
|
# block images
|
||||||
|
if any(has_nsfw_concept):
|
||||||
|
for i, is_nsfw in enumerate(has_nsfw_concept):
|
||||||
|
images[i] = np.asarray(images_uint8_casted[i])
|
||||||
|
|
||||||
|
images = images.reshape(num_devices, batch_size, height, width, 3)
|
||||||
|
else:
|
||||||
|
has_nsfw_concept = False
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return (images, has_nsfw_concept)
|
return (images, has_nsfw_concept)
|
||||||
|
|||||||
@@ -73,7 +73,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
|||||||
|
|
||||||
if safety_checker is None:
|
if safety_checker is None:
|
||||||
logger.warn(
|
logger.warn(
|
||||||
f"You have disabed the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||||
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
||||||
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
||||||
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
||||||
|
|||||||
@@ -85,7 +85,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
|||||||
|
|
||||||
if safety_checker is None:
|
if safety_checker is None:
|
||||||
logger.warn(
|
logger.warn(
|
||||||
f"You have disabed the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||||
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
||||||
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
||||||
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
||||||
|
|||||||
@@ -100,7 +100,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
|||||||
|
|
||||||
if safety_checker is None:
|
if safety_checker is None:
|
||||||
logger.warn(
|
logger.warn(
|
||||||
f"You have disabed the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||||
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
||||||
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
||||||
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ from diffusers.utils.testing_utils import require_flax, slow
|
|||||||
|
|
||||||
if is_flax_available():
|
if is_flax_available():
|
||||||
import jax
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
from diffusers import FlaxStableDiffusionPipeline
|
from diffusers import FlaxStableDiffusionPipeline
|
||||||
from flax.jax_utils import replicate
|
from flax.jax_utils import replicate
|
||||||
from flax.training.common_utils import shard
|
from flax.training.common_utils import shard
|
||||||
@@ -34,7 +35,7 @@ if is_flax_available():
|
|||||||
class FlaxPipelineTests(unittest.TestCase):
|
class FlaxPipelineTests(unittest.TestCase):
|
||||||
def test_dummy_all_tpus(self):
|
def test_dummy_all_tpus(self):
|
||||||
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
|
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
|
||||||
"hf-internal-testing/tiny-stable-diffusion-pipe"
|
"hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt = (
|
prompt = (
|
||||||
@@ -57,6 +58,103 @@ class FlaxPipelineTests(unittest.TestCase):
|
|||||||
prompt_ids = shard(prompt_ids)
|
prompt_ids = shard(prompt_ids)
|
||||||
|
|
||||||
images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images
|
images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images
|
||||||
|
|
||||||
|
assert images.shape == (8, 1, 64, 64, 3)
|
||||||
|
assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 4.151474)) < 1e-3
|
||||||
|
assert np.abs((np.abs(images, dtype=np.float32).sum() - 49947.875)) < 1e-2
|
||||||
|
|
||||||
images_pil = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
|
images_pil = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
|
||||||
|
|
||||||
assert len(images_pil) == 8
|
assert len(images_pil) == 8
|
||||||
|
|
||||||
|
def test_stable_diffusion_v1_4(self):
|
||||||
|
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
|
||||||
|
"CompVis/stable-diffusion-v1-4", revision="flax", safety_checker=None
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt = (
|
||||||
|
"A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of"
|
||||||
|
" field, close up, split lighting, cinematic"
|
||||||
|
)
|
||||||
|
|
||||||
|
prng_seed = jax.random.PRNGKey(0)
|
||||||
|
num_inference_steps = 50
|
||||||
|
|
||||||
|
num_samples = jax.device_count()
|
||||||
|
prompt = num_samples * [prompt]
|
||||||
|
prompt_ids = pipeline.prepare_inputs(prompt)
|
||||||
|
|
||||||
|
p_sample = pmap(pipeline.__call__, static_broadcasted_argnums=(3,))
|
||||||
|
|
||||||
|
# shard inputs and rng
|
||||||
|
params = replicate(params)
|
||||||
|
prng_seed = jax.random.split(prng_seed, 8)
|
||||||
|
prompt_ids = shard(prompt_ids)
|
||||||
|
|
||||||
|
images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images
|
||||||
|
|
||||||
|
images_pil = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
|
||||||
|
for i, image in enumerate(images_pil):
|
||||||
|
image.save(f"/home/patrick/images/flax-test-{i}_fp32.png")
|
||||||
|
|
||||||
|
assert images.shape == (8, 1, 512, 512, 3)
|
||||||
|
assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.05652401)) < 1e-3
|
||||||
|
assert np.abs((np.abs(images, dtype=np.float32).sum() - 2383808.2)) < 1e-2
|
||||||
|
|
||||||
|
def test_stable_diffusion_v1_4_bfloat_16(self):
|
||||||
|
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
|
||||||
|
"CompVis/stable-diffusion-v1-4", revision="bf16", dtype=jnp.bfloat16, safety_checker=None
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt = (
|
||||||
|
"A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of"
|
||||||
|
" field, close up, split lighting, cinematic"
|
||||||
|
)
|
||||||
|
|
||||||
|
prng_seed = jax.random.PRNGKey(0)
|
||||||
|
num_inference_steps = 50
|
||||||
|
|
||||||
|
num_samples = jax.device_count()
|
||||||
|
prompt = num_samples * [prompt]
|
||||||
|
prompt_ids = pipeline.prepare_inputs(prompt)
|
||||||
|
|
||||||
|
p_sample = pmap(pipeline.__call__, static_broadcasted_argnums=(3,))
|
||||||
|
|
||||||
|
# shard inputs and rng
|
||||||
|
params = replicate(params)
|
||||||
|
prng_seed = jax.random.split(prng_seed, 8)
|
||||||
|
prompt_ids = shard(prompt_ids)
|
||||||
|
|
||||||
|
images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images
|
||||||
|
|
||||||
|
assert images.shape == (8, 1, 512, 512, 3)
|
||||||
|
assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.06652832)) < 1e-3
|
||||||
|
assert np.abs((np.abs(images, dtype=np.float32).sum() - 2384849.8)) < 1e-2
|
||||||
|
|
||||||
|
def test_stable_diffusion_v1_4_bfloat_16_with_safety(self):
|
||||||
|
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
|
||||||
|
"CompVis/stable-diffusion-v1-4", revision="bf16", dtype=jnp.bfloat16
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt = (
|
||||||
|
"A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of"
|
||||||
|
" field, close up, split lighting, cinematic"
|
||||||
|
)
|
||||||
|
|
||||||
|
prng_seed = jax.random.PRNGKey(0)
|
||||||
|
num_inference_steps = 50
|
||||||
|
|
||||||
|
num_samples = jax.device_count()
|
||||||
|
prompt = num_samples * [prompt]
|
||||||
|
prompt_ids = pipeline.prepare_inputs(prompt)
|
||||||
|
|
||||||
|
# shard inputs and rng
|
||||||
|
params = replicate(params)
|
||||||
|
prng_seed = jax.random.split(prng_seed, 8)
|
||||||
|
prompt_ids = shard(prompt_ids)
|
||||||
|
|
||||||
|
images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
|
||||||
|
|
||||||
|
assert images.shape == (8, 1, 512, 512, 3)
|
||||||
|
assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.06652832)) < 1e-3
|
||||||
|
assert np.abs((np.abs(images, dtype=np.float32).sum() - 2384849.8)) < 1e-2
|
||||||
|
|||||||
Reference in New Issue
Block a user