[Low CPU memory] + device map (#772)
* add accelerate to load models with smaller memory footprint * remove low_cpu_mem_usage as it is reduntant * move accelerate init weights context to modelling utils * add test to ensure results are the same when loading with accelerate * add tests to ensure ram usage gets lower when using accelerate * move accelerate logic to single snippet under modelling utils and remove it from configuration utils * format code using to pass quality check * fix imports with isor * add accelerate to test extra deps * only import accelerate if device_map is set to auto * move accelerate availability check to diffusers import utils * format code * add device map to pipeline abstraction * lint it to pass PR quality check * fix class check to use accelerate when using diffusers ModelMixin subclasses * use low_cpu_mem_usage in transformers if device_map is not available * NoModuleLayer * comment out tests * up * uP * finish * Update src/diffusers/pipelines/stable_diffusion/safety_checker.py * finish * uP * make style Co-authored-by: Pi Esposito <piero.skywalker@gmail.com>
This commit is contained in:
committed by
GitHub
parent
feaa73243d
commit
fab17528da
@@ -32,7 +32,19 @@ from tqdm.auto import tqdm
|
|||||||
from .configuration_utils import ConfigMixin
|
from .configuration_utils import ConfigMixin
|
||||||
from .dynamic_modules_utils import get_class_from_dynamic_module
|
from .dynamic_modules_utils import get_class_from_dynamic_module
|
||||||
from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
|
from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
|
||||||
from .utils import CONFIG_NAME, DIFFUSERS_CACHE, ONNX_WEIGHTS_NAME, WEIGHTS_NAME, BaseOutput, logging
|
from .utils import (
|
||||||
|
CONFIG_NAME,
|
||||||
|
DIFFUSERS_CACHE,
|
||||||
|
ONNX_WEIGHTS_NAME,
|
||||||
|
WEIGHTS_NAME,
|
||||||
|
BaseOutput,
|
||||||
|
is_transformers_available,
|
||||||
|
logging,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if is_transformers_available():
|
||||||
|
from transformers import PreTrainedModel
|
||||||
|
|
||||||
|
|
||||||
INDEX_FILE = "diffusion_pytorch_model.bin"
|
INDEX_FILE = "diffusion_pytorch_model.bin"
|
||||||
@@ -338,6 +350,7 @@ class DiffusionPipeline(ConfigMixin):
|
|||||||
custom_pipeline = kwargs.pop("custom_pipeline", None)
|
custom_pipeline = kwargs.pop("custom_pipeline", None)
|
||||||
provider = kwargs.pop("provider", None)
|
provider = kwargs.pop("provider", None)
|
||||||
sess_options = kwargs.pop("sess_options", None)
|
sess_options = kwargs.pop("sess_options", None)
|
||||||
|
device_map = kwargs.pop("device_map", None)
|
||||||
|
|
||||||
# 1. Download the checkpoints and configs
|
# 1. Download the checkpoints and configs
|
||||||
# use snapshot download here to get it working from from_pretrained
|
# use snapshot download here to get it working from from_pretrained
|
||||||
@@ -463,6 +476,13 @@ class DiffusionPipeline(ConfigMixin):
|
|||||||
loading_kwargs["provider"] = provider
|
loading_kwargs["provider"] = provider
|
||||||
loading_kwargs["sess_options"] = sess_options
|
loading_kwargs["sess_options"] = sess_options
|
||||||
|
|
||||||
|
if (
|
||||||
|
issubclass(class_obj, diffusers.ModelMixin)
|
||||||
|
or is_transformers_available()
|
||||||
|
and issubclass(class_obj, PreTrainedModel)
|
||||||
|
):
|
||||||
|
loading_kwargs["device_map"] = device_map
|
||||||
|
|
||||||
# check if the module is in a subdirectory
|
# check if the module is in a subdirectory
|
||||||
if os.path.isdir(os.path.join(cached_folder, name)):
|
if os.path.isdir(os.path.join(cached_folder, name)):
|
||||||
loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs)
|
loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs)
|
||||||
|
|||||||
@@ -19,6 +19,8 @@ def cosine_distance(image_embeds, text_embeds):
|
|||||||
class StableDiffusionSafetyChecker(PreTrainedModel):
|
class StableDiffusionSafetyChecker(PreTrainedModel):
|
||||||
config_class = CLIPConfig
|
config_class = CLIPConfig
|
||||||
|
|
||||||
|
_no_split_modules = ["CLIPEncoderLayer"]
|
||||||
|
|
||||||
def __init__(self, config: CLIPConfig):
|
def __init__(self, config: CLIPConfig):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
@@ -28,8 +30,8 @@ class StableDiffusionSafetyChecker(PreTrainedModel):
|
|||||||
self.concept_embeds = nn.Parameter(torch.ones(17, config.projection_dim), requires_grad=False)
|
self.concept_embeds = nn.Parameter(torch.ones(17, config.projection_dim), requires_grad=False)
|
||||||
self.special_care_embeds = nn.Parameter(torch.ones(3, config.projection_dim), requires_grad=False)
|
self.special_care_embeds = nn.Parameter(torch.ones(3, config.projection_dim), requires_grad=False)
|
||||||
|
|
||||||
self.register_buffer("concept_embeds_weights", torch.ones(17))
|
self.concept_embeds_weights = nn.Parameter(torch.ones(17), requires_grad=False)
|
||||||
self.register_buffer("special_care_embeds_weights", torch.ones(3))
|
self.special_care_embeds_weights = nn.Parameter(torch.ones(3), requires_grad=False)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def forward(self, clip_input, images):
|
def forward(self, clip_input, images):
|
||||||
|
|||||||
@@ -17,12 +17,15 @@ import gc
|
|||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import tempfile
|
import tempfile
|
||||||
|
import tracemalloc
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
import accelerate
|
||||||
import PIL
|
import PIL
|
||||||
|
import transformers
|
||||||
from diffusers import (
|
from diffusers import (
|
||||||
AutoencoderKL,
|
AutoencoderKL,
|
||||||
DDIMPipeline,
|
DDIMPipeline,
|
||||||
@@ -50,6 +53,7 @@ from diffusers.pipeline_utils import DiffusionPipeline
|
|||||||
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
|
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
|
||||||
from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, load_image, slow, torch_device
|
from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, load_image, slow, torch_device
|
||||||
from diffusers.utils.testing_utils import get_tests_dir
|
from diffusers.utils.testing_utils import get_tests_dir
|
||||||
|
from packaging import version
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||||
|
|
||||||
@@ -2034,3 +2038,53 @@ class PipelineTesterMixin(unittest.TestCase):
|
|||||||
pipe(prompt=prompt, num_inference_steps=5, guidance_scale=7.5, callback=test_callback_fn, callback_steps=1)
|
pipe(prompt=prompt, num_inference_steps=5, guidance_scale=7.5, callback=test_callback_fn, callback_steps=1)
|
||||||
assert test_callback_fn.has_been_called
|
assert test_callback_fn.has_been_called
|
||||||
assert number_of_steps == 6
|
assert number_of_steps == 6
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
|
||||||
|
def test_stable_diffusion_accelerate_load_works(self):
|
||||||
|
if version.parse(version.parse(transformers.__version__).base_version) < version.parse("4.23"):
|
||||||
|
return
|
||||||
|
|
||||||
|
if version.parse(version.parse(accelerate.__version__).base_version) < version.parse("0.14"):
|
||||||
|
return
|
||||||
|
|
||||||
|
model_id = "CompVis/stable-diffusion-v1-4"
|
||||||
|
_ = StableDiffusionPipeline.from_pretrained(
|
||||||
|
model_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True, device_map="auto"
|
||||||
|
).to(torch_device)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@unittest.skipIf(torch_device == "cpu", "This test is supposed to run on GPU")
|
||||||
|
def test_stable_diffusion_accelerate_load_reduces_memory_footprint(self):
|
||||||
|
if version.parse(version.parse(transformers.__version__).base_version) < version.parse("4.23"):
|
||||||
|
return
|
||||||
|
|
||||||
|
if version.parse(version.parse(accelerate.__version__).base_version) < version.parse("0.14"):
|
||||||
|
return
|
||||||
|
|
||||||
|
pipeline_id = "CompVis/stable-diffusion-v1-4"
|
||||||
|
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
tracemalloc.start()
|
||||||
|
pipeline_normal_load = StableDiffusionPipeline.from_pretrained(
|
||||||
|
pipeline_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True
|
||||||
|
)
|
||||||
|
pipeline_normal_load.to(torch_device)
|
||||||
|
_, peak_normal = tracemalloc.get_traced_memory()
|
||||||
|
tracemalloc.stop()
|
||||||
|
|
||||||
|
del pipeline_normal_load
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
tracemalloc.start()
|
||||||
|
_ = StableDiffusionPipeline.from_pretrained(
|
||||||
|
pipeline_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True, device_map="auto"
|
||||||
|
)
|
||||||
|
_, peak_accelerate = tracemalloc.get_traced_memory()
|
||||||
|
|
||||||
|
tracemalloc.stop()
|
||||||
|
|
||||||
|
assert peak_accelerate < peak_normal
|
||||||
|
|||||||
Reference in New Issue
Block a user