[From pretrained] Allow returning local path (#1450)
Allow returning local path
This commit is contained in:
committed by
GitHub
parent
25f850a23b
commit
22b9cb086b
@@ -377,7 +377,8 @@ class DiffusionPipeline(ConfigMixin):
|
|||||||
also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
|
also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
|
||||||
model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
|
model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
|
||||||
setting this argument to `True` will raise an error.
|
setting this argument to `True` will raise an error.
|
||||||
|
return_cached_folder (`bool`, *optional*, defaults to `False`):
|
||||||
|
If set to `True`, path to downloaded cached folder will be returned in addition to loaded pipeline.
|
||||||
kwargs (remaining dictionary of keyword arguments, *optional*):
|
kwargs (remaining dictionary of keyword arguments, *optional*):
|
||||||
Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the
|
Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the
|
||||||
specific pipeline class. The overwritten components are then directly passed to the pipelines
|
specific pipeline class. The overwritten components are then directly passed to the pipelines
|
||||||
@@ -430,33 +431,7 @@ class DiffusionPipeline(ConfigMixin):
|
|||||||
sess_options = kwargs.pop("sess_options", None)
|
sess_options = kwargs.pop("sess_options", None)
|
||||||
device_map = kwargs.pop("device_map", None)
|
device_map = kwargs.pop("device_map", None)
|
||||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
||||||
|
return_cached_folder = kwargs.pop("return_cached_folder", False)
|
||||||
if low_cpu_mem_usage and not is_accelerate_available():
|
|
||||||
low_cpu_mem_usage = False
|
|
||||||
logger.warning(
|
|
||||||
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
|
||||||
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
|
||||||
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
|
||||||
" install accelerate\n```\n."
|
|
||||||
)
|
|
||||||
|
|
||||||
if device_map is not None and not is_torch_version(">=", "1.9.0"):
|
|
||||||
raise NotImplementedError(
|
|
||||||
"Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
|
||||||
" `device_map=None`."
|
|
||||||
)
|
|
||||||
|
|
||||||
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
|
||||||
raise NotImplementedError(
|
|
||||||
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
|
||||||
" `low_cpu_mem_usage=False`."
|
|
||||||
)
|
|
||||||
|
|
||||||
if low_cpu_mem_usage is False and device_map is not None:
|
|
||||||
raise ValueError(
|
|
||||||
f"You cannot set `low_cpu_mem_usage` to False while using device_map={device_map} for loading and"
|
|
||||||
" dispatching. Please make sure to set `low_cpu_mem_usage=True`."
|
|
||||||
)
|
|
||||||
|
|
||||||
# 1. Download the checkpoints and configs
|
# 1. Download the checkpoints and configs
|
||||||
# use snapshot download here to get it working from from_pretrained
|
# use snapshot download here to get it working from from_pretrained
|
||||||
@@ -585,6 +560,33 @@ class DiffusionPipeline(ConfigMixin):
|
|||||||
f"Keyword arguments {unused_kwargs} are not expected by {pipeline_class.__name__} and will be ignored."
|
f"Keyword arguments {unused_kwargs} are not expected by {pipeline_class.__name__} and will be ignored."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if low_cpu_mem_usage and not is_accelerate_available():
|
||||||
|
low_cpu_mem_usage = False
|
||||||
|
logger.warning(
|
||||||
|
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
||||||
|
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
||||||
|
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
||||||
|
" install accelerate\n```\n."
|
||||||
|
)
|
||||||
|
|
||||||
|
if device_map is not None and not is_torch_version(">=", "1.9.0"):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||||
|
" `device_map=None`."
|
||||||
|
)
|
||||||
|
|
||||||
|
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||||
|
" `low_cpu_mem_usage=False`."
|
||||||
|
)
|
||||||
|
|
||||||
|
if low_cpu_mem_usage is False and device_map is not None:
|
||||||
|
raise ValueError(
|
||||||
|
f"You cannot set `low_cpu_mem_usage` to False while using device_map={device_map} for loading and"
|
||||||
|
" dispatching. Please make sure to set `low_cpu_mem_usage=True`."
|
||||||
|
)
|
||||||
|
|
||||||
# import it here to avoid circular import
|
# import it here to avoid circular import
|
||||||
from diffusers import pipelines
|
from diffusers import pipelines
|
||||||
|
|
||||||
@@ -704,6 +706,9 @@ class DiffusionPipeline(ConfigMixin):
|
|||||||
|
|
||||||
# 5. Instantiate the pipeline
|
# 5. Instantiate the pipeline
|
||||||
model = pipeline_class(**init_kwargs)
|
model = pipeline_class(**init_kwargs)
|
||||||
|
|
||||||
|
if return_cached_folder:
|
||||||
|
return model, cached_folder
|
||||||
return model
|
return model
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@@ -95,6 +95,35 @@ class DownloadTests(unittest.TestCase):
|
|||||||
# We need to never convert this tiny model to safetensors for this test to pass
|
# We need to never convert this tiny model to safetensors for this test to pass
|
||||||
assert not any(f.endswith(".safetensors") for f in files)
|
assert not any(f.endswith(".safetensors") for f in files)
|
||||||
|
|
||||||
|
def test_returned_cached_folder(self):
|
||||||
|
prompt = "hello"
|
||||||
|
pipe = StableDiffusionPipeline.from_pretrained(
|
||||||
|
"hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None
|
||||||
|
)
|
||||||
|
_, local_path = StableDiffusionPipeline.from_pretrained(
|
||||||
|
"hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None, return_cached_folder=True
|
||||||
|
)
|
||||||
|
pipe_2 = StableDiffusionPipeline.from_pretrained(local_path)
|
||||||
|
|
||||||
|
pipe = pipe.to(torch_device)
|
||||||
|
pipe_2 = pipe.to(torch_device)
|
||||||
|
if torch_device == "mps":
|
||||||
|
# device type MPS is not supported for torch.Generator() api.
|
||||||
|
generator = torch.manual_seed(0)
|
||||||
|
else:
|
||||||
|
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||||
|
|
||||||
|
out = pipe(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images
|
||||||
|
|
||||||
|
if torch_device == "mps":
|
||||||
|
# device type MPS is not supported for torch.Generator() api.
|
||||||
|
generator = torch.manual_seed(0)
|
||||||
|
else:
|
||||||
|
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||||
|
out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images
|
||||||
|
|
||||||
|
assert np.max(np.abs(out - out_2)) < 1e-3
|
||||||
|
|
||||||
def test_download_safetensors(self):
|
def test_download_safetensors(self):
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
# pipeline has Flax weights
|
# pipeline has Flax weights
|
||||||
|
|||||||
Reference in New Issue
Block a user