[Deterministic torch randn] Allow tensors to be generated on CPU (#1902)
* [Deterministic torch randn] Allow tensors to be generated on CPU * fix more * up * fix more * up * Update src/diffusers/utils/torch_utils.py Co-authored-by: Anton Lozhkov <anton@huggingface.co> * Apply suggestions from code review * up * up * Apply suggestions from code review Co-authored-by: Pedro Cuenca <pedro@huggingface.co> Co-authored-by: Anton Lozhkov <anton@huggingface.co> Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
This commit is contained in:
parent
0df83c79e4
commit
8ed08e4270
@ -336,7 +336,7 @@ class TextualInversionDataset(Dataset):
|
||||
|
||||
if self.center_crop:
|
||||
crop = min(img.shape[0], img.shape[1])
|
||||
h, w, = (
|
||||
(h, w,) = (
|
||||
img.shape[0],
|
||||
img.shape[1],
|
||||
)
|
||||
|
||||
@ -381,7 +381,7 @@ class TextualInversionDataset(Dataset):
|
||||
|
||||
if self.center_crop:
|
||||
crop = min(img.shape[0], img.shape[1])
|
||||
h, w, = (
|
||||
(h, w,) = (
|
||||
img.shape[0],
|
||||
img.shape[1],
|
||||
)
|
||||
|
||||
@ -306,7 +306,7 @@ class TextualInversionDataset(Dataset):
|
||||
|
||||
if self.center_crop:
|
||||
crop = min(img.shape[0], img.shape[1])
|
||||
h, w, = (
|
||||
(h, w,) = (
|
||||
img.shape[0],
|
||||
img.shape[1],
|
||||
)
|
||||
|
||||
@ -564,6 +564,7 @@ def super_res_unet_last_step_original_checkpoint_to_diffusers_checkpoint(model,
|
||||
|
||||
# unet utils
|
||||
|
||||
|
||||
# <original>.time_embed -> <diffusers>.time_embedding
|
||||
def unet_time_embeddings(checkpoint, original_unet_prefix):
|
||||
diffusers_checkpoint = {}
|
||||
|
||||
@ -37,6 +37,7 @@ def rename_key(key):
|
||||
# PyTorch => Flax #
|
||||
#####################
|
||||
|
||||
|
||||
# Adapted from https://github.com/huggingface/transformers/blob/c603c80f46881ae18b2ca50770ef65fa4033eacd/src/transformers/modeling_flax_pytorch_utils.py#L69
|
||||
# and https://github.com/patil-suraj/stable-diffusion-jax/blob/main/stable_diffusion_jax/convert_diffusers_to_jax.py
|
||||
def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict):
|
||||
|
||||
@ -24,7 +24,7 @@ from transformers.models.clip.modeling_clip import CLIPTextModelOutput
|
||||
from ...models import PriorTransformer, UNet2DConditionModel, UNet2DModel
|
||||
from ...pipelines import DiffusionPipeline, ImagePipelineOutput
|
||||
from ...schedulers import UnCLIPScheduler
|
||||
from ...utils import is_accelerate_available, logging
|
||||
from ...utils import is_accelerate_available, logging, torch_randn
|
||||
from .text_proj import UnCLIPTextProjModel
|
||||
|
||||
|
||||
@ -105,11 +105,7 @@ class UnCLIPPipeline(DiffusionPipeline):
|
||||
|
||||
def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
|
||||
if latents is None:
|
||||
if device.type == "mps":
|
||||
# randn does not work reproducibly on mps
|
||||
latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
|
||||
else:
|
||||
latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
|
||||
latents = torch_randn(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
if latents.shape != shape:
|
||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
||||
|
||||
@ -29,7 +29,7 @@ from transformers import (
|
||||
from ...models import UNet2DConditionModel, UNet2DModel
|
||||
from ...pipelines import DiffusionPipeline, ImagePipelineOutput
|
||||
from ...schedulers import UnCLIPScheduler
|
||||
from ...utils import is_accelerate_available, logging
|
||||
from ...utils import is_accelerate_available, logging, torch_randn
|
||||
from .text_proj import UnCLIPTextProjModel
|
||||
|
||||
|
||||
@ -113,11 +113,7 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline):
|
||||
# Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents
|
||||
def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
|
||||
if latents is None:
|
||||
if device.type == "mps":
|
||||
# randn does not work reproducibly on mps
|
||||
latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
|
||||
else:
|
||||
latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
|
||||
latents = torch_randn(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
if latents.shape != shape:
|
||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
||||
|
||||
@ -20,7 +20,7 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import BaseOutput
|
||||
from ..utils import BaseOutput, torch_randn
|
||||
from .scheduling_utils import SchedulerMixin
|
||||
|
||||
|
||||
@ -273,15 +273,9 @@ class UnCLIPScheduler(SchedulerMixin, ConfigMixin):
|
||||
# 6. Add noise
|
||||
variance = 0
|
||||
if t > 0:
|
||||
device = model_output.device
|
||||
if device.type == "mps":
|
||||
# randn does not work reproducibly on mps
|
||||
variance_noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator)
|
||||
variance_noise = variance_noise.to(device)
|
||||
else:
|
||||
variance_noise = torch.randn(
|
||||
model_output.shape, generator=generator, device=device, dtype=model_output.dtype
|
||||
)
|
||||
variance_noise = torch_randn(
|
||||
model_output.shape, dtype=model_output.dtype, generator=generator, device=model_output.device
|
||||
)
|
||||
|
||||
variance = self._get_variance(
|
||||
t,
|
||||
|
||||
@ -64,6 +64,7 @@ from .import_utils import (
|
||||
from .logging import get_logger
|
||||
from .outputs import BaseOutput
|
||||
from .pil_utils import PIL_INTERPOLATION
|
||||
from .torch_utils import torch_randn
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
|
||||
64
src/diffusers/utils/torch_utils.py
Normal file
64
src/diffusers/utils/torch_utils.py
Normal file
@ -0,0 +1,64 @@
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
PyTorch utilities: Utilities related to PyTorch
|
||||
"""
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
from . import logging
|
||||
from .import_utils import is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def torch_randn(
|
||||
shape: Union[Tuple, List],
|
||||
generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None,
|
||||
device: Optional["torch.device"] = None,
|
||||
dtype: Optional["torch.dtype"] = None,
|
||||
):
|
||||
"""This is a helper function that allows to create random tensors on the desired `device` with the desired `dtype`. When
|
||||
passing a list of generators one can seed each batched size individually. If CPU generators are passed the tensor
|
||||
will always be created on CPU.
|
||||
"""
|
||||
# device on which tensor is created defaults to device
|
||||
rand_device = device
|
||||
batch_size = shape[0]
|
||||
|
||||
if generator is not None:
|
||||
if generator.device != device and generator.device.type == "cpu":
|
||||
rand_device = "cpu"
|
||||
if device != "mps":
|
||||
logger.info(
|
||||
f"The passed generator was created on 'cpu' even though a tensor on {device} was expected."
|
||||
f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably"
|
||||
f" slighly speed up this function by passing a generator that was created on the {device} device."
|
||||
)
|
||||
elif generator.device.type != device.type and generator.device.type == "cuda":
|
||||
raise ValueError(f"Cannot generate a {device} tensor from a generator of type {generator.device.type}.")
|
||||
|
||||
if isinstance(generator, list):
|
||||
shape = (1,) + shape[1:]
|
||||
latents = [
|
||||
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) for i in range(batch_size)
|
||||
]
|
||||
latents = torch.cat(latents, dim=0).to(device)
|
||||
else:
|
||||
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
|
||||
|
||||
return latents
|
||||
@ -382,7 +382,7 @@ class UnCLIPPipelineIntegrationTests(unittest.TestCase):
|
||||
pipeline = pipeline.to(torch_device)
|
||||
pipeline.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
output = pipeline(
|
||||
"horse",
|
||||
num_images_per_prompt=1,
|
||||
|
||||
@ -480,7 +480,7 @@ class UnCLIPImageVariationPipelineIntegrationTests(unittest.TestCase):
|
||||
pipeline.set_progress_bar_config(disable=None)
|
||||
pipeline.enable_sequential_cpu_offload()
|
||||
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
output = pipeline(
|
||||
input_image,
|
||||
num_images_per_prompt=1,
|
||||
|
||||
@ -96,6 +96,7 @@ def ignore_underscore(key):
|
||||
|
||||
def sort_objects(objects, key=None):
|
||||
"Sort a list of `objects` following the rules of isort. `key` optionally maps an object to a str."
|
||||
|
||||
# If no key is provided, we use a noop.
|
||||
def noop(x):
|
||||
return x
|
||||
@ -117,6 +118,7 @@ def sort_objects_in_import(import_statement):
|
||||
"""
|
||||
Return the same `import_statement` but with objects properly sorted.
|
||||
"""
|
||||
|
||||
# This inner function sort imports between [ ].
|
||||
def _replace(match):
|
||||
imports = match.groups()[0]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user