diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index 01bcc6a338..fb39b6b3dd 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -377,7 +377,8 @@ class DiffusionPipeline(ConfigMixin): also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch, setting this argument to `True` will raise an error. - + return_cached_folder (`bool`, *optional*, defaults to `False`): + If set to `True`, path to downloaded cached folder will be returned in addition to loaded pipeline. kwargs (remaining dictionary of keyword arguments, *optional*): Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the specific pipeline class. The overwritten components are then directly passed to the pipelines @@ -430,33 +431,7 @@ class DiffusionPipeline(ConfigMixin): sess_options = kwargs.pop("sess_options", None) device_map = kwargs.pop("device_map", None) low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) - - if low_cpu_mem_usage and not is_accelerate_available(): - low_cpu_mem_usage = False - logger.warning( - "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the" - " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install" - " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip" - " install accelerate\n```\n." - ) - - if device_map is not None and not is_torch_version(">=", "1.9.0"): - raise NotImplementedError( - "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set" - " `device_map=None`." - ) - - if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"): - raise NotImplementedError( - "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set" - " `low_cpu_mem_usage=False`." - ) - - if low_cpu_mem_usage is False and device_map is not None: - raise ValueError( - f"You cannot set `low_cpu_mem_usage` to False while using device_map={device_map} for loading and" - " dispatching. Please make sure to set `low_cpu_mem_usage=True`." - ) + return_cached_folder = kwargs.pop("return_cached_folder", False) # 1. Download the checkpoints and configs # use snapshot download here to get it working from from_pretrained @@ -585,6 +560,33 @@ class DiffusionPipeline(ConfigMixin): f"Keyword arguments {unused_kwargs} are not expected by {pipeline_class.__name__} and will be ignored." ) + if low_cpu_mem_usage and not is_accelerate_available(): + low_cpu_mem_usage = False + logger.warning( + "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the" + " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install" + " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip" + " install accelerate\n```\n." + ) + + if device_map is not None and not is_torch_version(">=", "1.9.0"): + raise NotImplementedError( + "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set" + " `device_map=None`." + ) + + if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"): + raise NotImplementedError( + "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set" + " `low_cpu_mem_usage=False`." + ) + + if low_cpu_mem_usage is False and device_map is not None: + raise ValueError( + f"You cannot set `low_cpu_mem_usage` to False while using device_map={device_map} for loading and" + " dispatching. Please make sure to set `low_cpu_mem_usage=True`." + ) + # import it here to avoid circular import from diffusers import pipelines @@ -704,6 +706,9 @@ class DiffusionPipeline(ConfigMixin): # 5. Instantiate the pipeline model = pipeline_class(**init_kwargs) + + if return_cached_folder: + return model, cached_folder return model @staticmethod diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index ee0a087561..630cec65ff 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -95,6 +95,35 @@ class DownloadTests(unittest.TestCase): # We need to never convert this tiny model to safetensors for this test to pass assert not any(f.endswith(".safetensors") for f in files) + def test_returned_cached_folder(self): + prompt = "hello" + pipe = StableDiffusionPipeline.from_pretrained( + "hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None + ) + _, local_path = StableDiffusionPipeline.from_pretrained( + "hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None, return_cached_folder=True + ) + pipe_2 = StableDiffusionPipeline.from_pretrained(local_path) + + pipe = pipe.to(torch_device) + pipe_2 = pipe.to(torch_device) + if torch_device == "mps": + # device type MPS is not supported for torch.Generator() api. + generator = torch.manual_seed(0) + else: + generator = torch.Generator(device=torch_device).manual_seed(0) + + out = pipe(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images + + if torch_device == "mps": + # device type MPS is not supported for torch.Generator() api. + generator = torch.manual_seed(0) + else: + generator = torch.Generator(device=torch_device).manual_seed(0) + out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images + + assert np.max(np.abs(out - out_2)) < 1e-3 + def test_download_safetensors(self): with tempfile.TemporaryDirectory() as tmpdirname: # pipeline has Flax weights