Allow converting Flax to PyTorch by adding a "from_flax" keyword (#1900)
* from_flax * oops * oops * make style with pip install -e ".[dev]" * oops * now code quality happy 😋 * allow_patterns += FLAX_WEIGHTS_NAME * Update src/diffusers/pipelines/pipeline_utils.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/diffusers/pipelines/pipeline_utils.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/diffusers/pipelines/pipeline_utils.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/diffusers/pipelines/pipeline_utils.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/diffusers/models/modeling_utils.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/diffusers/pipelines/pipeline_utils.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * for test * bye bye is_flax_available() * oops * Update src/diffusers/models/modeling_pytorch_flax_utils.py Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * Update src/diffusers/models/modeling_pytorch_flax_utils.py Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * Update src/diffusers/models/modeling_pytorch_flax_utils.py Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * Update src/diffusers/models/modeling_utils.py Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * Update src/diffusers/models/modeling_utils.py Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * make style * add test * finihs Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
This commit is contained in:
@@ -0,0 +1,156 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2022 The HuggingFace Inc. team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
""" PyTorch - Flax general utilities."""
|
||||||
|
|
||||||
|
from pickle import UnpicklingError
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
from flax.serialization import from_bytes
|
||||||
|
from flax.traverse_util import flatten_dict
|
||||||
|
|
||||||
|
from ..utils import logging
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
#####################
|
||||||
|
# Flax => PyTorch #
|
||||||
|
#####################
|
||||||
|
|
||||||
|
|
||||||
|
# from https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_flax_pytorch_utils.py#L224-L352
|
||||||
|
def load_flax_checkpoint_in_pytorch_model(pt_model, model_file):
|
||||||
|
try:
|
||||||
|
with open(model_file, "rb") as flax_state_f:
|
||||||
|
flax_state = from_bytes(None, flax_state_f.read())
|
||||||
|
except UnpicklingError as e:
|
||||||
|
try:
|
||||||
|
with open(model_file) as f:
|
||||||
|
if f.read().startswith("version"):
|
||||||
|
raise OSError(
|
||||||
|
"You seem to have cloned a repository without having git-lfs installed. Please"
|
||||||
|
" install git-lfs and run `git lfs install` followed by `git lfs pull` in the"
|
||||||
|
" folder you cloned."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError from e
|
||||||
|
except (UnicodeDecodeError, ValueError):
|
||||||
|
raise EnvironmentError(f"Unable to convert {model_file} to Flax deserializable object. ")
|
||||||
|
|
||||||
|
return load_flax_weights_in_pytorch_model(pt_model, flax_state)
|
||||||
|
|
||||||
|
|
||||||
|
def load_flax_weights_in_pytorch_model(pt_model, flax_state):
|
||||||
|
"""Load flax checkpoints in a PyTorch model"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
import torch # noqa: F401
|
||||||
|
except ImportError:
|
||||||
|
logger.error(
|
||||||
|
"Loading Flax weights in PyTorch requires both PyTorch and Flax to be installed. Please see"
|
||||||
|
" https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation"
|
||||||
|
" instructions."
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
# check if we have bf16 weights
|
||||||
|
is_type_bf16 = flatten_dict(jax.tree_util.tree_map(lambda x: x.dtype == jnp.bfloat16, flax_state)).values()
|
||||||
|
if any(is_type_bf16):
|
||||||
|
# convert all weights to fp32 if they are bf16 since torch.from_numpy can-not handle bf16
|
||||||
|
|
||||||
|
# and bf16 is not fully supported in PT yet.
|
||||||
|
logger.warning(
|
||||||
|
"Found ``bfloat16`` weights in Flax model. Casting all ``bfloat16`` weights to ``float32`` "
|
||||||
|
"before loading those in PyTorch model."
|
||||||
|
)
|
||||||
|
flax_state = jax.tree_util.tree_map(
|
||||||
|
lambda params: params.astype(np.float32) if params.dtype == jnp.bfloat16 else params, flax_state
|
||||||
|
)
|
||||||
|
|
||||||
|
pt_model.base_model_prefix = ""
|
||||||
|
|
||||||
|
flax_state_dict = flatten_dict(flax_state, sep=".")
|
||||||
|
pt_model_dict = pt_model.state_dict()
|
||||||
|
|
||||||
|
# keep track of unexpected & missing keys
|
||||||
|
unexpected_keys = []
|
||||||
|
missing_keys = set(pt_model_dict.keys())
|
||||||
|
|
||||||
|
for flax_key_tuple, flax_tensor in flax_state_dict.items():
|
||||||
|
flax_key_tuple_array = flax_key_tuple.split(".")
|
||||||
|
|
||||||
|
if flax_key_tuple_array[-1] == "kernel" and flax_tensor.ndim == 4:
|
||||||
|
flax_key_tuple_array = flax_key_tuple_array[:-1] + ["weight"]
|
||||||
|
flax_tensor = jnp.transpose(flax_tensor, (3, 2, 0, 1))
|
||||||
|
elif flax_key_tuple_array[-1] == "kernel":
|
||||||
|
flax_key_tuple_array = flax_key_tuple_array[:-1] + ["weight"]
|
||||||
|
flax_tensor = flax_tensor.T
|
||||||
|
elif flax_key_tuple_array[-1] == "scale":
|
||||||
|
flax_key_tuple_array = flax_key_tuple_array[:-1] + ["weight"]
|
||||||
|
|
||||||
|
if "time_embedding" not in flax_key_tuple_array:
|
||||||
|
for i, flax_key_tuple_string in enumerate(flax_key_tuple_array):
|
||||||
|
flax_key_tuple_array[i] = (
|
||||||
|
flax_key_tuple_string.replace("_0", ".0")
|
||||||
|
.replace("_1", ".1")
|
||||||
|
.replace("_2", ".2")
|
||||||
|
.replace("_3", ".3")
|
||||||
|
)
|
||||||
|
|
||||||
|
flax_key = ".".join(flax_key_tuple_array)
|
||||||
|
|
||||||
|
if flax_key in pt_model_dict:
|
||||||
|
if flax_tensor.shape != pt_model_dict[flax_key].shape:
|
||||||
|
raise ValueError(
|
||||||
|
f"Flax checkpoint seems to be incorrect. Weight {flax_key_tuple} was expected "
|
||||||
|
f"to be of shape {pt_model_dict[flax_key].shape}, but is {flax_tensor.shape}."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# add weight to pytorch dict
|
||||||
|
flax_tensor = np.asarray(flax_tensor) if not isinstance(flax_tensor, np.ndarray) else flax_tensor
|
||||||
|
pt_model_dict[flax_key] = torch.from_numpy(flax_tensor)
|
||||||
|
# remove from missing keys
|
||||||
|
missing_keys.remove(flax_key)
|
||||||
|
else:
|
||||||
|
# weight is not expected by PyTorch model
|
||||||
|
unexpected_keys.append(flax_key)
|
||||||
|
|
||||||
|
pt_model.load_state_dict(pt_model_dict)
|
||||||
|
|
||||||
|
# re-transform missing_keys to list
|
||||||
|
missing_keys = list(missing_keys)
|
||||||
|
|
||||||
|
if len(unexpected_keys) > 0:
|
||||||
|
logger.warning(
|
||||||
|
"Some weights of the Flax model were not used when initializing the PyTorch model"
|
||||||
|
f" {pt_model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are initializing"
|
||||||
|
f" {pt_model.__class__.__name__} from a Flax model trained on another task or with another architecture"
|
||||||
|
" (e.g. initializing a BertForSequenceClassification model from a FlaxBertForPreTraining model).\n- This"
|
||||||
|
f" IS NOT expected if you are initializing {pt_model.__class__.__name__} from a Flax model that you expect"
|
||||||
|
" to be exactly identical (e.g. initializing a BertForSequenceClassification model from a"
|
||||||
|
" FlaxBertForSequenceClassification model)."
|
||||||
|
)
|
||||||
|
if len(missing_keys) > 0:
|
||||||
|
logger.warning(
|
||||||
|
f"Some weights of {pt_model.__class__.__name__} were not initialized from the Flax model and are newly"
|
||||||
|
f" initialized: {missing_keys}\nYou should probably TRAIN this model on a down-stream task to be able to"
|
||||||
|
" use it for predictions and inference."
|
||||||
|
)
|
||||||
|
|
||||||
|
return pt_model
|
||||||
@@ -30,6 +30,7 @@ from .. import __version__
|
|||||||
from ..utils import (
|
from ..utils import (
|
||||||
CONFIG_NAME,
|
CONFIG_NAME,
|
||||||
DIFFUSERS_CACHE,
|
DIFFUSERS_CACHE,
|
||||||
|
FLAX_WEIGHTS_NAME,
|
||||||
HF_HUB_OFFLINE,
|
HF_HUB_OFFLINE,
|
||||||
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
|
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
|
||||||
SAFETENSORS_WEIGHTS_NAME,
|
SAFETENSORS_WEIGHTS_NAME,
|
||||||
@@ -335,6 +336,8 @@ class ModelMixin(torch.nn.Module):
|
|||||||
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
||||||
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
||||||
identifier allowed by git.
|
identifier allowed by git.
|
||||||
|
from_flax (`bool`, *optional*, defaults to `False`):
|
||||||
|
Load the model weights from a Flax checkpoint save file.
|
||||||
subfolder (`str`, *optional*, defaults to `""`):
|
subfolder (`str`, *optional*, defaults to `""`):
|
||||||
In case the relevant files are located inside a subfolder of the model repo (either remote in
|
In case the relevant files are located inside a subfolder of the model repo (either remote in
|
||||||
huggingface.co or downloaded locally), you can specify the folder name here.
|
huggingface.co or downloaded locally), you can specify the folder name here.
|
||||||
@@ -375,6 +378,7 @@ class ModelMixin(torch.nn.Module):
|
|||||||
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
||||||
ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
|
ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
|
||||||
force_download = kwargs.pop("force_download", False)
|
force_download = kwargs.pop("force_download", False)
|
||||||
|
from_flax = kwargs.pop("from_flax", False)
|
||||||
resume_download = kwargs.pop("resume_download", False)
|
resume_download = kwargs.pop("resume_download", False)
|
||||||
proxies = kwargs.pop("proxies", None)
|
proxies = kwargs.pop("proxies", None)
|
||||||
output_loading_info = kwargs.pop("output_loading_info", False)
|
output_loading_info = kwargs.pop("output_loading_info", False)
|
||||||
@@ -433,6 +437,41 @@ class ModelMixin(torch.nn.Module):
|
|||||||
# Load model
|
# Load model
|
||||||
|
|
||||||
model_file = None
|
model_file = None
|
||||||
|
if from_flax:
|
||||||
|
model_file = cls._get_model_file(
|
||||||
|
pretrained_model_name_or_path,
|
||||||
|
weights_name=FLAX_WEIGHTS_NAME,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
force_download=force_download,
|
||||||
|
resume_download=resume_download,
|
||||||
|
proxies=proxies,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
use_auth_token=use_auth_token,
|
||||||
|
revision=revision,
|
||||||
|
subfolder=subfolder,
|
||||||
|
user_agent=user_agent,
|
||||||
|
)
|
||||||
|
config, unused_kwargs = cls.load_config(
|
||||||
|
config_path,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
return_unused_kwargs=True,
|
||||||
|
force_download=force_download,
|
||||||
|
resume_download=resume_download,
|
||||||
|
proxies=proxies,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
use_auth_token=use_auth_token,
|
||||||
|
revision=revision,
|
||||||
|
subfolder=subfolder,
|
||||||
|
device_map=device_map,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
model = cls.from_config(config, **unused_kwargs)
|
||||||
|
|
||||||
|
# Convert the weights
|
||||||
|
from .modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model
|
||||||
|
|
||||||
|
model = load_flax_checkpoint_in_pytorch_model(model, model_file)
|
||||||
|
else:
|
||||||
if is_safetensors_available():
|
if is_safetensors_available():
|
||||||
try:
|
try:
|
||||||
model_file = cls._get_model_file(
|
model_file = cls._get_model_file(
|
||||||
@@ -484,15 +523,19 @@ class ModelMixin(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
model = cls.from_config(config, **unused_kwargs)
|
model = cls.from_config(config, **unused_kwargs)
|
||||||
|
|
||||||
# if device_map is Non,e load the state dict on move the params from meta device to the cpu
|
# if device_map is None, load the state dict and move the params from meta device to the cpu
|
||||||
if device_map is None:
|
if device_map is None:
|
||||||
param_device = "cpu"
|
param_device = "cpu"
|
||||||
state_dict = load_state_dict(model_file)
|
state_dict = load_state_dict(model_file)
|
||||||
# move the parms from meta device to cpu
|
# move the params from meta device to cpu
|
||||||
for param_name, param in state_dict.items():
|
for param_name, param in state_dict.items():
|
||||||
accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys())
|
accepts_dtype = "dtype" in set(
|
||||||
|
inspect.signature(set_module_tensor_to_device).parameters.keys()
|
||||||
|
)
|
||||||
if accepts_dtype:
|
if accepts_dtype:
|
||||||
set_module_tensor_to_device(model, param_name, param_device, value=param, dtype=torch_dtype)
|
set_module_tensor_to_device(
|
||||||
|
model, param_name, param_device, value=param, dtype=torch_dtype
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
set_module_tensor_to_device(model, param_name, param_device, value=param)
|
set_module_tensor_to_device(model, param_name, param_device, value=param)
|
||||||
else: # else let accelerate handle loading and dispatching.
|
else: # else let accelerate handle loading and dispatching.
|
||||||
|
|||||||
@@ -37,6 +37,7 @@ from ..schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
|
|||||||
from ..utils import (
|
from ..utils import (
|
||||||
CONFIG_NAME,
|
CONFIG_NAME,
|
||||||
DIFFUSERS_CACHE,
|
DIFFUSERS_CACHE,
|
||||||
|
FLAX_WEIGHTS_NAME,
|
||||||
HF_HUB_OFFLINE,
|
HF_HUB_OFFLINE,
|
||||||
ONNX_WEIGHTS_NAME,
|
ONNX_WEIGHTS_NAME,
|
||||||
WEIGHTS_NAME,
|
WEIGHTS_NAME,
|
||||||
@@ -445,6 +446,7 @@ class DiffusionPipeline(ConfigMixin):
|
|||||||
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
|
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
|
||||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||||
revision = kwargs.pop("revision", None)
|
revision = kwargs.pop("revision", None)
|
||||||
|
from_flax = kwargs.pop("from_flax", False)
|
||||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||||
custom_pipeline = kwargs.pop("custom_pipeline", None)
|
custom_pipeline = kwargs.pop("custom_pipeline", None)
|
||||||
custom_revision = kwargs.pop("custom_revision", None)
|
custom_revision = kwargs.pop("custom_revision", None)
|
||||||
@@ -470,11 +472,26 @@ class DiffusionPipeline(ConfigMixin):
|
|||||||
# make sure we only download sub-folders and `diffusers` filenames
|
# make sure we only download sub-folders and `diffusers` filenames
|
||||||
folder_names = [k for k in config_dict.keys() if not k.startswith("_")]
|
folder_names = [k for k in config_dict.keys() if not k.startswith("_")]
|
||||||
allow_patterns = [os.path.join(k, "*") for k in folder_names]
|
allow_patterns = [os.path.join(k, "*") for k in folder_names]
|
||||||
allow_patterns += [WEIGHTS_NAME, SCHEDULER_CONFIG_NAME, CONFIG_NAME, ONNX_WEIGHTS_NAME, cls.config_name]
|
allow_patterns += [
|
||||||
|
WEIGHTS_NAME,
|
||||||
|
SCHEDULER_CONFIG_NAME,
|
||||||
|
CONFIG_NAME,
|
||||||
|
ONNX_WEIGHTS_NAME,
|
||||||
|
cls.config_name,
|
||||||
|
]
|
||||||
|
|
||||||
# make sure we don't download flax weights
|
# make sure we don't download flax weights
|
||||||
ignore_patterns = ["*.msgpack"]
|
ignore_patterns = ["*.msgpack"]
|
||||||
|
|
||||||
|
if from_flax:
|
||||||
|
ignore_patterns = ["*.bin", "*.safetensors"]
|
||||||
|
allow_patterns += [
|
||||||
|
SCHEDULER_CONFIG_NAME,
|
||||||
|
CONFIG_NAME,
|
||||||
|
FLAX_WEIGHTS_NAME,
|
||||||
|
cls.config_name,
|
||||||
|
]
|
||||||
|
|
||||||
if custom_pipeline is not None:
|
if custom_pipeline is not None:
|
||||||
allow_patterns += [CUSTOM_PIPELINE_FILE_NAME]
|
allow_patterns += [CUSTOM_PIPELINE_FILE_NAME]
|
||||||
|
|
||||||
@@ -704,7 +721,14 @@ class DiffusionPipeline(ConfigMixin):
|
|||||||
# This makes sure that the weights won't be initialized which significantly speeds up loading.
|
# This makes sure that the weights won't be initialized which significantly speeds up loading.
|
||||||
if is_diffusers_model or is_transformers_model:
|
if is_diffusers_model or is_transformers_model:
|
||||||
loading_kwargs["device_map"] = device_map
|
loading_kwargs["device_map"] = device_map
|
||||||
|
if from_flax:
|
||||||
|
loading_kwargs["from_flax"] = True
|
||||||
|
|
||||||
|
# if `from_flax` and model is transformer model, can currently not load with `low_cpu_mem_usage`
|
||||||
|
if not (from_flax and is_transformers_model):
|
||||||
loading_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
loading_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
||||||
|
else:
|
||||||
|
loading_kwargs["low_cpu_mem_usage"] = False
|
||||||
|
|
||||||
# check if the module is in a subdirectory
|
# check if the module is in a subdirectory
|
||||||
if os.path.isdir(os.path.join(cached_folder, name)):
|
if os.path.isdir(os.path.join(cached_folder, name)):
|
||||||
|
|||||||
+44
-1
@@ -47,7 +47,7 @@ from diffusers import (
|
|||||||
logging,
|
logging,
|
||||||
)
|
)
|
||||||
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
|
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
|
||||||
from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, nightly, slow, torch_device
|
from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, is_flax_available, nightly, slow, torch_device
|
||||||
from diffusers.utils.testing_utils import CaptureLogger, get_tests_dir, require_torch_gpu
|
from diffusers.utils.testing_utils import CaptureLogger, get_tests_dir, require_torch_gpu
|
||||||
from parameterized import parameterized
|
from parameterized import parameterized
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
@@ -816,6 +816,49 @@ class PipelineSlowTests(unittest.TestCase):
|
|||||||
assert isinstance(images, list)
|
assert isinstance(images, list)
|
||||||
assert isinstance(images[0], PIL.Image.Image)
|
assert isinstance(images[0], PIL.Image.Image)
|
||||||
|
|
||||||
|
def test_from_flax_from_pt(self):
|
||||||
|
pipe_pt = StableDiffusionPipeline.from_pretrained(
|
||||||
|
"hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None
|
||||||
|
)
|
||||||
|
pipe_pt.to(torch_device)
|
||||||
|
|
||||||
|
if not is_flax_available():
|
||||||
|
raise ImportError("Make sure flax is installed.")
|
||||||
|
|
||||||
|
from diffusers import FlaxStableDiffusionPipeline
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
pipe_pt.save_pretrained(tmpdirname)
|
||||||
|
|
||||||
|
pipe_flax, params = FlaxStableDiffusionPipeline.from_pretrained(
|
||||||
|
tmpdirname, safety_checker=None, from_pt=True
|
||||||
|
)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
pipe_flax.save_pretrained(tmpdirname, params=params)
|
||||||
|
pipe_pt_2 = StableDiffusionPipeline.from_pretrained(tmpdirname, safety_checker=None, from_flax=True)
|
||||||
|
pipe_pt_2.to(torch_device)
|
||||||
|
|
||||||
|
prompt = "Hello"
|
||||||
|
|
||||||
|
generator = torch.manual_seed(0)
|
||||||
|
image_0 = pipe_pt(
|
||||||
|
[prompt],
|
||||||
|
generator=generator,
|
||||||
|
num_inference_steps=2,
|
||||||
|
output_type="np",
|
||||||
|
).images[0]
|
||||||
|
|
||||||
|
generator = torch.manual_seed(0)
|
||||||
|
image_1 = pipe_pt_2(
|
||||||
|
[prompt],
|
||||||
|
generator=generator,
|
||||||
|
num_inference_steps=2,
|
||||||
|
output_type="np",
|
||||||
|
).images[0]
|
||||||
|
|
||||||
|
assert np.abs(image_0 - image_1).sum() < 1e-5, "Models don't give the same forward pass"
|
||||||
|
|
||||||
|
|
||||||
@nightly
|
@nightly
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
|
|||||||
Reference in New Issue
Block a user