Improve reproduceability 2/3 (#1906)

* [Repro] Correct reproducability

* up

* up

* uP

* up

* need better image

* allow conversion from no state dict checkpoints

* up

* up

* up

* up

* check tensors

* check tensors

* check tensors

* check tensors

* next try

* up

* up

* better name

* up

* up

* Apply suggestions from code review

* correct more

* up

* replace all torch randn

* fix

* correct

* correct

* finish

* fix more

* up
This commit is contained in:
Patrick von Platen
2023-01-05 02:51:17 +04:00
committed by GitHub
parent 67e2f95cc4
commit 9b63854886
49 changed files with 171 additions and 391 deletions
@@ -19,6 +19,7 @@ import tqdm
from ...models.unet_1d import UNet1DModel from ...models.unet_1d import UNet1DModel
from ...pipelines import DiffusionPipeline from ...pipelines import DiffusionPipeline
from ...utils import randn_tensor
from ...utils.dummy_pt_objects import DDPMScheduler from ...utils.dummy_pt_objects import DDPMScheduler
@@ -127,7 +128,7 @@ class ValueGuidedRLPipeline(DiffusionPipeline):
shape = (batch_size, planning_horizon, self.state_dim + self.action_dim) shape = (batch_size, planning_horizon, self.state_dim + self.action_dim)
# generate initial noise and apply our conditions (to make the trajectories start at current state) # generate initial noise and apply our conditions (to make the trajectories start at current state)
x1 = torch.randn(shape, device=self.unet.device) x1 = randn_tensor(shape, device=self.unet.device)
x = self.reset_x0(x1, conditions, self.action_dim) x = self.reset_x0(x1, conditions, self.action_dim)
x = self.to_torch(x) x = self.to_torch(x)
+1 -1
View File
@@ -95,7 +95,7 @@ class PriorTransformer(ModelMixin, ConfigMixin):
self.proj_to_clip_embeddings = nn.Linear(inner_dim, embedding_dim) self.proj_to_clip_embeddings = nn.Linear(inner_dim, embedding_dim)
causal_attention_mask = torch.full( causal_attention_mask = torch.full(
[num_embeddings + additional_embeddings, num_embeddings + additional_embeddings], float("-inf") [num_embeddings + additional_embeddings, num_embeddings + additional_embeddings], -10000.0
) )
causal_attention_mask.triu_(1) causal_attention_mask.triu_(1)
causal_attention_mask = causal_attention_mask[None, ...] causal_attention_mask = causal_attention_mask[None, ...]
+4 -5
View File
@@ -18,7 +18,7 @@ import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from ..utils import BaseOutput from ..utils import BaseOutput, randn_tensor
from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
@@ -323,11 +323,10 @@ class DiagonalGaussianDistribution(object):
) )
def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor: def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
device = self.parameters.device
sample_device = "cpu" if device.type == "mps" else device
sample = torch.randn(self.mean.shape, generator=generator, device=sample_device)
# make sure sample is on the same device as the parameters and has same dtype # make sure sample is on the same device as the parameters and has same dtype
sample = sample.to(device=device, dtype=self.parameters.dtype) sample = randn_tensor(
self.mean.shape, generator=generator, device=self.parameters.device, dtype=self.parameters.dtype
)
x = self.mean + self.std * sample x = self.mean + self.std * sample
return x return x
@@ -31,7 +31,7 @@ from ...schedulers import (
LMSDiscreteScheduler, LMSDiscreteScheduler,
PNDMScheduler, PNDMScheduler,
) )
from ...utils import deprecate, logging, replace_example_docstring from ...utils import deprecate, logging, randn_tensor, replace_example_docstring
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from . import AltDiffusionPipelineOutput, RobertaSeriesModelWithTransformation from . import AltDiffusionPipelineOutput, RobertaSeriesModelWithTransformation
@@ -401,20 +401,8 @@ class AltDiffusionPipeline(DiffusionPipeline):
) )
if latents is None: if latents is None:
rand_device = "cpu" if device.type == "mps" else device latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
if isinstance(generator, list):
shape = (1,) + shape[1:]
latents = [
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
for i in range(batch_size)
]
latents = torch.cat(latents, dim=0).to(device)
else: else:
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
else:
if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
latents = latents.to(device) latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler # scale the initial noise by the standard deviation required by the scheduler
@@ -33,7 +33,7 @@ from ...schedulers import (
LMSDiscreteScheduler, LMSDiscreteScheduler,
PNDMScheduler, PNDMScheduler,
) )
from ...utils import PIL_INTERPOLATION, deprecate, logging, replace_example_docstring from ...utils import PIL_INTERPOLATION, deprecate, logging, randn_tensor, replace_example_docstring
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from . import AltDiffusionPipelineOutput, RobertaSeriesModelWithTransformation from . import AltDiffusionPipelineOutput, RobertaSeriesModelWithTransformation
@@ -461,16 +461,8 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
else: else:
init_latents = torch.cat([init_latents], dim=0) init_latents = torch.cat([init_latents], dim=0)
rand_device = "cpu" if device.type == "mps" else device
shape = init_latents.shape shape = init_latents.shape
if isinstance(generator, list): noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
shape = (1,) + shape[1:]
noise = [
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) for i in range(batch_size)
]
noise = torch.cat(noise, dim=0).to(device)
else:
noise = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
# get latents # get latents
init_latents = self.scheduler.add_noise(init_latents, noise, timestep) init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
@@ -23,6 +23,7 @@ from PIL import Image
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import DDIMScheduler, DDPMScheduler from ...schedulers import DDIMScheduler, DDPMScheduler
from ...utils import randn_tensor
from ..pipeline_utils import AudioPipelineOutput, BaseOutput, DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import AudioPipelineOutput, BaseOutput, DiffusionPipeline, ImagePipelineOutput
from .mel import Mel from .mel import Mel
@@ -126,7 +127,7 @@ class AudioDiffusionPipeline(DiffusionPipeline):
input_dims = self.get_input_dims() input_dims = self.get_input_dims()
self.mel.set_resolution(x_res=input_dims[1], y_res=input_dims[0]) self.mel.set_resolution(x_res=input_dims[1], y_res=input_dims[0])
if noise is None: if noise is None:
noise = torch.randn( noise = randn_tensor(
( (
batch_size, batch_size,
self.unet.in_channels, self.unet.in_channels,
@@ -17,7 +17,7 @@ from typing import List, Optional, Tuple, Union
import torch import torch
from ...utils import logging from ...utils import logging, randn_tensor
from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline
@@ -100,16 +100,7 @@ class DanceDiffusionPipeline(DiffusionPipeline):
f" size of {batch_size}. Make sure the batch size matches the length of the generators." f" size of {batch_size}. Make sure the batch size matches the length of the generators."
) )
rand_device = "cpu" if self.device.type == "mps" else self.device audio = randn_tensor(shape, generator=generator, device=self.device, dtype=dtype)
if isinstance(generator, list):
shape = (1,) + shape[1:]
audio = [
torch.randn(shape, generator=generator[i], device=rand_device, dtype=self.unet.dtype)
for i in range(batch_size)
]
audio = torch.cat(audio, dim=0).to(self.device)
else:
audio = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(self.device)
# set step values # set step values
self.scheduler.set_timesteps(num_inference_steps, device=audio.device) self.scheduler.set_timesteps(num_inference_steps, device=audio.device)
+2 -12
View File
@@ -16,7 +16,7 @@ from typing import List, Optional, Tuple, Union
import torch import torch
from ...utils import deprecate from ...utils import deprecate, randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
@@ -103,17 +103,7 @@ class DDIMPipeline(DiffusionPipeline):
f" size of {batch_size}. Make sure the batch size matches the length of the generators." f" size of {batch_size}. Make sure the batch size matches the length of the generators."
) )
rand_device = "cpu" if self.device.type == "mps" else self.device image = randn_tensor(image_shape, generator=generator, device=self.device, dtype=self.unet.dtype)
if isinstance(generator, list):
shape = (1,) + image_shape[1:]
image = [
torch.randn(shape, generator=generator[i], device=rand_device, dtype=self.unet.dtype)
for i in range(batch_size)
]
image = torch.cat(image, dim=0).to(self.device)
else:
image = torch.randn(image_shape, generator=generator, device=rand_device, dtype=self.unet.dtype)
image = image.to(self.device)
# set step values # set step values
self.scheduler.set_timesteps(num_inference_steps) self.scheduler.set_timesteps(num_inference_steps)
@@ -18,7 +18,7 @@ from typing import List, Optional, Tuple, Union
import torch import torch
from ...configuration_utils import FrozenDict from ...configuration_utils import FrozenDict
from ...utils import deprecate from ...utils import deprecate, randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
@@ -100,10 +100,10 @@ class DDPMPipeline(DiffusionPipeline):
if self.device.type == "mps": if self.device.type == "mps":
# randn does not work reproducibly on mps # randn does not work reproducibly on mps
image = torch.randn(image_shape, generator=generator) image = randn_tensor(image_shape, generator=generator)
image = image.to(self.device) image = image.to(self.device)
else: else:
image = torch.randn(image_shape, generator=generator, device=self.device) image = randn_tensor(image_shape, generator=generator, device=self.device)
# set step values # set step values
self.scheduler.set_timesteps(num_inference_steps) self.scheduler.set_timesteps(num_inference_steps)
@@ -26,6 +26,7 @@ from transformers.utils import logging
from ...models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel from ...models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
@@ -143,20 +144,7 @@ class LDMTextToImagePipeline(DiffusionPipeline):
) )
if latents is None: if latents is None:
rand_device = "cpu" if self.device.type == "mps" else self.device latents = randn_tensor(latents_shape, generator=generator, device=self.device, dtype=text_embeddings.dtype)
if isinstance(generator, list):
latents_shape = (1,) + latents_shape[1:]
latents = [
torch.randn(latents_shape, generator=generator[i], device=rand_device, dtype=text_embeddings.dtype)
for i in range(batch_size)
]
latents = torch.cat(latents, dim=0)
else:
latents = torch.randn(
latents_shape, generator=generator, device=rand_device, dtype=text_embeddings.dtype
)
latents = latents.to(self.device)
else: else:
if latents.shape != latents_shape: if latents.shape != latents_shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
@@ -16,7 +16,7 @@ from ...schedulers import (
LMSDiscreteScheduler, LMSDiscreteScheduler,
PNDMScheduler, PNDMScheduler,
) )
from ...utils import PIL_INTERPOLATION, deprecate from ...utils import PIL_INTERPOLATION, deprecate, randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
@@ -121,12 +121,7 @@ class LDMSuperResolutionPipeline(DiffusionPipeline):
latents_shape = (batch_size, self.unet.in_channels // 2, height, width) latents_shape = (batch_size, self.unet.in_channels // 2, height, width)
latents_dtype = next(self.unet.parameters()).dtype latents_dtype = next(self.unet.parameters()).dtype
if self.device.type == "mps": latents = randn_tensor(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
# randn does not work reproducibly on mps
latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype)
latents = latents.to(self.device)
else:
latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
image = image.to(device=self.device, dtype=latents_dtype) image = image.to(device=self.device, dtype=latents_dtype)
@@ -19,6 +19,7 @@ import torch
from ...models import UNet2DModel, VQModel from ...models import UNet2DModel, VQModel
from ...schedulers import DDIMScheduler from ...schedulers import DDIMScheduler
from ...utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
@@ -71,7 +72,7 @@ class LDMPipeline(DiffusionPipeline):
True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images.
""" """
latents = torch.randn( latents = randn_tensor(
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size), (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
generator=generator, generator=generator,
) )
@@ -34,7 +34,7 @@ class PaintByExampleImageEncoder(CLIPPreTrainedModel):
self.proj_out = nn.Linear(config.hidden_size, self.proj_size) self.proj_out = nn.Linear(config.hidden_size, self.proj_size)
# uncondition for scaling # uncondition for scaling
self.uncond_vector = nn.Parameter(torch.rand((1, 1, self.proj_size))) self.uncond_vector = nn.Parameter(torch.randn((1, 1, self.proj_size)))
def forward(self, pixel_values): def forward(self, pixel_values):
clip_output = self.model(pixel_values=pixel_values) clip_output = self.model(pixel_values=pixel_values)
@@ -24,7 +24,7 @@ from transformers import CLIPFeatureExtractor
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...utils import logging from ...utils import logging, randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from ..stable_diffusion import StableDiffusionPipelineOutput from ..stable_diffusion import StableDiffusionPipelineOutput
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
@@ -300,20 +300,8 @@ class PaintByExamplePipeline(DiffusionPipeline):
) )
if latents is None: if latents is None:
rand_device = "cpu" if device.type == "mps" else device latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
if isinstance(generator, list):
shape = (1,) + shape[1:]
latents = [
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
for i in range(batch_size)
]
latents = torch.cat(latents, dim=0).to(device)
else: else:
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
else:
if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
latents = latents.to(device) latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler # scale the initial noise by the standard deviation required by the scheduler
@@ -19,6 +19,7 @@ import torch
from ...models import UNet2DModel from ...models import UNet2DModel
from ...schedulers import PNDMScheduler from ...schedulers import PNDMScheduler
from ...utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
@@ -72,11 +73,11 @@ class PNDMPipeline(DiffusionPipeline):
# the official paper: https://arxiv.org/pdf/2202.09778.pdf # the official paper: https://arxiv.org/pdf/2202.09778.pdf
# Sample gaussian noise to begin loop # Sample gaussian noise to begin loop
image = torch.randn( image = randn_tensor(
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size), (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
generator=generator, generator=generator,
device=self.device,
) )
image = image.to(self.device)
self.scheduler.set_timesteps(num_inference_steps) self.scheduler.set_timesteps(num_inference_steps)
for t in self.progress_bar(self.scheduler.timesteps): for t in self.progress_bar(self.scheduler.timesteps):
@@ -22,7 +22,7 @@ import PIL
from ...models import UNet2DModel from ...models import UNet2DModel
from ...schedulers import RePaintScheduler from ...schedulers import RePaintScheduler
from ...utils import PIL_INTERPOLATION, deprecate, logging from ...utils import PIL_INTERPOLATION, deprecate, logging, randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
@@ -143,18 +143,8 @@ class RePaintPipeline(DiffusionPipeline):
f" size of {batch_size}. Make sure the batch size matches the length of the generators." f" size of {batch_size}. Make sure the batch size matches the length of the generators."
) )
rand_device = "cpu" if self.device.type == "mps" else self.device
image_shape = original_image.shape image_shape = original_image.shape
if isinstance(generator, list): image = randn_tensor(image_shape, generator=generator, device=self.device, dtype=self.unet.dtype)
shape = (1,) + image_shape[1:]
image = [
torch.randn(shape, generator=generator[i], device=rand_device, dtype=self.unet.dtype)
for i in range(batch_size)
]
image = torch.cat(image, dim=0).to(self.device)
else:
image = torch.randn(image_shape, generator=generator, device=rand_device, dtype=self.unet.dtype)
image = image.to(self.device)
# set step values # set step values
self.scheduler.set_timesteps(num_inference_steps, jump_length, jump_n_sample, self.device) self.scheduler.set_timesteps(num_inference_steps, jump_length, jump_n_sample, self.device)
@@ -18,6 +18,7 @@ import torch
from ...models import UNet2DModel from ...models import UNet2DModel
from ...schedulers import ScoreSdeVeScheduler from ...schedulers import ScoreSdeVeScheduler
from ...utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
@@ -69,7 +70,7 @@ class ScoreSdeVePipeline(DiffusionPipeline):
model = self.unet model = self.unet
sample = torch.randn(*shape, generator=generator) * self.scheduler.init_noise_sigma sample = randn_tensor(shape, generator=generator) * self.scheduler.init_noise_sigma
sample = sample.to(self.device) sample = sample.to(self.device)
self.scheduler.set_timesteps(num_inference_steps) self.scheduler.set_timesteps(num_inference_steps)
@@ -26,7 +26,7 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict from ...configuration_utils import FrozenDict
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import DDIMScheduler from ...schedulers import DDIMScheduler
from ...utils import PIL_INTERPOLATION, deprecate, logging from ...utils import PIL_INTERPOLATION, deprecate, logging, randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from . import StableDiffusionPipelineOutput from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker from .safety_checker import StableDiffusionSafetyChecker
@@ -76,7 +76,7 @@ def posterior_sample(scheduler, latents, timestep, clean_latents, generator, eta
# direction pointing to x_t # direction pointing to x_t
e_t = (latents - alpha_prod_t ** (0.5) * clean_latents) / (1 - alpha_prod_t) ** (0.5) e_t = (latents - alpha_prod_t ** (0.5) * clean_latents) / (1 - alpha_prod_t) ** (0.5)
dir_xt = (1.0 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * e_t dir_xt = (1.0 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * e_t
noise = std_dev_t * torch.randn( noise = std_dev_t * randn_tensor(
clean_latents.shape, dtype=clean_latents.dtype, device=clean_latents.device, generator=generator clean_latents.shape, dtype=clean_latents.dtype, device=clean_latents.device, generator=generator
) )
prev_latents = alpha_prod_t_prev ** (0.5) * clean_latents + dir_xt + noise prev_latents = alpha_prod_t_prev ** (0.5) * clean_latents + dir_xt + noise
@@ -472,16 +472,8 @@ class CycleDiffusionPipeline(DiffusionPipeline):
init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0) init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0)
# add noise to latents using the timestep # add noise to latents using the timestep
rand_device = "cpu" if device.type == "mps" else device
shape = init_latents.shape shape = init_latents.shape
if isinstance(generator, list): noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
shape = (1,) + shape[1:]
noise = [
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) for i in range(batch_size)
]
noise = torch.cat(noise, dim=0).to(device)
else:
noise = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
# get latents # get latents
clean_latents = init_latents clean_latents = init_latents
@@ -30,7 +30,7 @@ from ...schedulers import (
LMSDiscreteScheduler, LMSDiscreteScheduler,
PNDMScheduler, PNDMScheduler,
) )
from ...utils import deprecate, is_accelerate_available, logging, replace_example_docstring from ...utils import deprecate, is_accelerate_available, logging, randn_tensor, replace_example_docstring
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from . import StableDiffusionPipelineOutput from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker from .safety_checker import StableDiffusionSafetyChecker
@@ -398,20 +398,8 @@ class StableDiffusionPipeline(DiffusionPipeline):
) )
if latents is None: if latents is None:
rand_device = "cpu" if device.type == "mps" else device latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
if isinstance(generator, list):
shape = (1,) + shape[1:]
latents = [
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
for i in range(batch_size)
]
latents = torch.cat(latents, dim=0).to(device)
else: else:
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
else:
if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
latents = latents.to(device) latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler # scale the initial noise by the standard deviation required by the scheduler
@@ -20,7 +20,6 @@ import numpy as np
import torch import torch
import PIL import PIL
from diffusers.utils import is_accelerate_available
from packaging import version from packaging import version
from transformers import CLIPTextModel, CLIPTokenizer, DPTFeatureExtractor, DPTForDepthEstimation from transformers import CLIPTextModel, CLIPTokenizer, DPTFeatureExtractor, DPTForDepthEstimation
@@ -34,7 +33,7 @@ from ...schedulers import (
LMSDiscreteScheduler, LMSDiscreteScheduler,
PNDMScheduler, PNDMScheduler,
) )
from ...utils import PIL_INTERPOLATION, deprecate, logging from ...utils import PIL_INTERPOLATION, deprecate, is_accelerate_available, logging, randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
@@ -381,16 +380,8 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline):
else: else:
init_latents = torch.cat([init_latents], dim=0) init_latents = torch.cat([init_latents], dim=0)
rand_device = "cpu" if device.type == "mps" else device
shape = init_latents.shape shape = init_latents.shape
if isinstance(generator, list): noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
shape = (1,) + shape[1:]
noise = [
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) for i in range(batch_size)
]
noise = torch.cat(noise, dim=0).to(device)
else:
noise = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
# get latents # get latents
init_latents = self.scheduler.add_noise(init_latents, noise, timestep) init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
@@ -18,7 +18,6 @@ from typing import Callable, List, Optional, Union
import torch import torch
import PIL import PIL
from diffusers.utils import is_accelerate_available
from packaging import version from packaging import version
from transformers import CLIPFeatureExtractor, CLIPVisionModelWithProjection from transformers import CLIPFeatureExtractor, CLIPVisionModelWithProjection
@@ -32,7 +31,7 @@ from ...schedulers import (
LMSDiscreteScheduler, LMSDiscreteScheduler,
PNDMScheduler, PNDMScheduler,
) )
from ...utils import deprecate, logging from ...utils import deprecate, is_accelerate_available, logging, randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from . import StableDiffusionPipelineOutput from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker from .safety_checker import StableDiffusionSafetyChecker
@@ -267,20 +266,8 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline):
) )
if latents is None: if latents is None:
rand_device = "cpu" if device.type == "mps" else device latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
if isinstance(generator, list):
shape = (1,) + shape[1:]
latents = [
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
for i in range(batch_size)
]
latents = torch.cat(latents, dim=0).to(device)
else: else:
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
else:
if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
latents = latents.to(device) latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler # scale the initial noise by the standard deviation required by the scheduler
@@ -32,7 +32,14 @@ from ...schedulers import (
LMSDiscreteScheduler, LMSDiscreteScheduler,
PNDMScheduler, PNDMScheduler,
) )
from ...utils import PIL_INTERPOLATION, deprecate, is_accelerate_available, logging, replace_example_docstring from ...utils import (
PIL_INTERPOLATION,
deprecate,
is_accelerate_available,
logging,
randn_tensor,
replace_example_docstring,
)
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from . import StableDiffusionPipelineOutput from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker from .safety_checker import StableDiffusionSafetyChecker
@@ -464,16 +471,8 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
else: else:
init_latents = torch.cat([init_latents], dim=0) init_latents = torch.cat([init_latents], dim=0)
rand_device = "cpu" if device.type == "mps" else device
shape = init_latents.shape shape = init_latents.shape
if isinstance(generator, list): noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
shape = (1,) + shape[1:]
noise = [
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) for i in range(batch_size)
]
noise = torch.cat(noise, dim=0).to(device)
else:
noise = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
# get latents # get latents
init_latents = self.scheduler.add_noise(init_latents, noise, timestep) init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
@@ -19,14 +19,13 @@ import numpy as np
import torch import torch
import PIL import PIL
from diffusers.utils import is_accelerate_available
from packaging import version from packaging import version
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict from ...configuration_utils import FrozenDict
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...utils import deprecate, logging from ...utils import deprecate, is_accelerate_available, logging, randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from . import StableDiffusionPipelineOutput from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker from .safety_checker import StableDiffusionSafetyChecker
@@ -470,20 +469,8 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
) )
if latents is None: if latents is None:
rand_device = "cpu" if device.type == "mps" else device latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
if isinstance(generator, list):
shape = (1,) + shape[1:]
latents = [
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
for i in range(batch_size)
]
latents = torch.cat(latents, dim=0).to(device)
else: else:
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
else:
if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
latents = latents.to(device) latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler # scale the initial noise by the standard deviation required by the scheduler
@@ -19,7 +19,6 @@ import numpy as np
import torch import torch
import PIL import PIL
from diffusers.utils import is_accelerate_available
from packaging import version from packaging import version
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
@@ -33,7 +32,7 @@ from ...schedulers import (
LMSDiscreteScheduler, LMSDiscreteScheduler,
PNDMScheduler, PNDMScheduler,
) )
from ...utils import PIL_INTERPOLATION, deprecate, logging from ...utils import PIL_INTERPOLATION, deprecate, is_accelerate_available, logging, randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from . import StableDiffusionPipelineOutput from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker from .safety_checker import StableDiffusionSafetyChecker
@@ -414,7 +413,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
init_latents_orig = init_latents init_latents_orig = init_latents
# add noise to latents using the timesteps # add noise to latents using the timesteps
noise = torch.randn(init_latents.shape, generator=generator, device=self.device, dtype=dtype) noise = randn_tensor(init_latents.shape, generator=generator, device=self.device, dtype=dtype)
init_latents = self.scheduler.add_noise(init_latents, noise, timestep) init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
latents = init_latents latents = init_latents
return latents, init_latents_orig, noise return latents, init_latents_orig, noise
@@ -21,7 +21,7 @@ from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser
from ...pipelines import DiffusionPipeline from ...pipelines import DiffusionPipeline
from ...schedulers import LMSDiscreteScheduler from ...schedulers import LMSDiscreteScheduler
from ...utils import is_accelerate_available, logging from ...utils import is_accelerate_available, logging, randn_tensor
from . import StableDiffusionPipelineOutput from . import StableDiffusionPipelineOutput
@@ -308,11 +308,7 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline):
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
if latents is None: if latents is None:
if device.type == "mps": latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
# randn does not work reproducibly on mps
latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
else:
latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
else: else:
if latents.shape != shape: if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
@@ -19,12 +19,11 @@ import numpy as np
import torch import torch
import PIL import PIL
from diffusers.utils import is_accelerate_available
from transformers import CLIPTextModel, CLIPTokenizer from transformers import CLIPTextModel, CLIPTokenizer
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import DDIMScheduler, DDPMScheduler, LMSDiscreteScheduler, PNDMScheduler from ...schedulers import DDIMScheduler, DDPMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...utils import logging from ...utils import is_accelerate_available, logging, randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
@@ -313,11 +312,7 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline):
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height, width) shape = (batch_size, num_channels_latents, height, width)
if latents is None: if latents is None:
if device.type == "mps": latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
# randn does not work reproducibly on mps
latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
else:
latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
else: else:
if latents.shape != shape: if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
@@ -450,11 +445,7 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline):
# 5. Add noise to image # 5. Add noise to image
noise_level = torch.tensor([noise_level], dtype=torch.long, device=device) noise_level = torch.tensor([noise_level], dtype=torch.long, device=device)
if device.type == "mps": noise = randn_tensor(image.shape, generator=generator, device=device, dtype=text_embeddings.dtype)
# randn does not work reproducibly on mps
noise = torch.randn(image.shape, generator=generator, device="cpu", dtype=text_embeddings.dtype).to(device)
else:
noise = torch.randn(image.shape, generator=generator, device=device, dtype=text_embeddings.dtype)
image = self.low_res_scheduler.add_noise(image, noise, noise_level) image = self.low_res_scheduler.add_noise(image, noise, noise_level)
batch_multiplier = 2 if do_classifier_free_guidance else 1 batch_multiplier = 2 if do_classifier_free_guidance else 1
@@ -18,7 +18,7 @@ from ...schedulers import (
LMSDiscreteScheduler, LMSDiscreteScheduler,
PNDMScheduler, PNDMScheduler,
) )
from ...utils import deprecate, is_accelerate_available, logging from ...utils import deprecate, is_accelerate_available, logging, randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from . import StableDiffusionSafePipelineOutput from . import StableDiffusionSafePipelineOutput
from .safety_checker import SafeStableDiffusionSafetyChecker from .safety_checker import SafeStableDiffusionSafetyChecker
@@ -429,20 +429,8 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
) )
if latents is None: if latents is None:
rand_device = "cpu" if device.type == "mps" else device latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
if isinstance(generator, list):
shape = (1,) + shape[1:]
latents = [
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
for i in range(batch_size)
]
latents = torch.cat(latents, dim=0).to(device)
else: else:
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
else:
if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
latents = latents.to(device) latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler # scale the initial noise by the standard deviation required by the scheduler
@@ -18,6 +18,7 @@ import torch
from ...models import UNet2DModel from ...models import UNet2DModel
from ...schedulers import KarrasVeScheduler from ...schedulers import KarrasVeScheduler
from ...utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
@@ -81,8 +82,7 @@ class KarrasVePipeline(DiffusionPipeline):
model = self.unet model = self.unet
# sample x_0 ~ N(0, sigma_0^2 * I) # sample x_0 ~ N(0, sigma_0^2 * I)
sample = torch.randn(*shape) * self.scheduler.init_noise_sigma sample = randn_tensor(shape, device=self.device) * self.scheduler.init_noise_sigma
sample = sample.to(self.device)
self.scheduler.set_timesteps(num_inference_steps) self.scheduler.set_timesteps(num_inference_steps)
@@ -24,7 +24,7 @@ from transformers.models.clip.modeling_clip import CLIPTextModelOutput
from ...models import PriorTransformer, UNet2DConditionModel, UNet2DModel from ...models import PriorTransformer, UNet2DConditionModel, UNet2DModel
from ...pipelines import DiffusionPipeline, ImagePipelineOutput from ...pipelines import DiffusionPipeline, ImagePipelineOutput
from ...schedulers import UnCLIPScheduler from ...schedulers import UnCLIPScheduler
from ...utils import is_accelerate_available, logging, torch_randn from ...utils import is_accelerate_available, logging, randn_tensor
from .text_proj import UnCLIPTextProjModel from .text_proj import UnCLIPTextProjModel
@@ -105,7 +105,7 @@ class UnCLIPPipeline(DiffusionPipeline):
def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
if latents is None: if latents is None:
latents = torch_randn(shape, generator=generator, device=device, dtype=dtype) latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else: else:
if latents.shape != shape: if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
@@ -499,7 +499,6 @@ class UnCLIPPipeline(DiffusionPipeline):
).prev_sample ).prev_sample
image = super_res_latents image = super_res_latents
# done super res # done super res
# post processing # post processing
@@ -29,7 +29,7 @@ from transformers import (
from ...models import UNet2DConditionModel, UNet2DModel from ...models import UNet2DConditionModel, UNet2DModel
from ...pipelines import DiffusionPipeline, ImagePipelineOutput from ...pipelines import DiffusionPipeline, ImagePipelineOutput
from ...schedulers import UnCLIPScheduler from ...schedulers import UnCLIPScheduler
from ...utils import is_accelerate_available, logging, torch_randn from ...utils import is_accelerate_available, logging, randn_tensor
from .text_proj import UnCLIPTextProjModel from .text_proj import UnCLIPTextProjModel
@@ -113,7 +113,7 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline):
# Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents
def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
if latents is None: if latents is None:
latents = torch_randn(shape, generator=generator, device=device, dtype=dtype) latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else: else:
if latents.shape != shape: if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
@@ -29,7 +29,7 @@ from transformers import (
from ...models import AutoencoderKL, DualTransformer2DModel, Transformer2DModel, UNet2DConditionModel from ...models import AutoencoderKL, DualTransformer2DModel, Transformer2DModel, UNet2DConditionModel
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...utils import is_accelerate_available, logging from ...utils import is_accelerate_available, logging, randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from .modeling_text_unet import UNetFlatConditionModel from .modeling_text_unet import UNetFlatConditionModel
@@ -382,20 +382,8 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
) )
if latents is None: if latents is None:
rand_device = "cpu" if device.type == "mps" else device latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
if isinstance(generator, list):
shape = (1,) + shape[1:]
latents = [
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
for i in range(batch_size)
]
latents = torch.cat(latents, dim=0).to(device)
else: else:
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
else:
if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
latents = latents.to(device) latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler # scale the initial noise by the standard deviation required by the scheduler
@@ -24,7 +24,7 @@ from transformers import CLIPFeatureExtractor, CLIPVisionModelWithProjection
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...utils import is_accelerate_available, logging from ...utils import is_accelerate_available, logging, randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
@@ -248,20 +248,8 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline):
) )
if latents is None: if latents is None:
rand_device = "cpu" if device.type == "mps" else device latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
if isinstance(generator, list):
shape = (1,) + shape[1:]
latents = [
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
for i in range(batch_size)
]
latents = torch.cat(latents, dim=0).to(device)
else: else:
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
else:
if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
latents = latents.to(device) latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler # scale the initial noise by the standard deviation required by the scheduler
@@ -22,7 +22,7 @@ from transformers import CLIPFeatureExtractor, CLIPTextModelWithProjection, CLIP
from ...models import AutoencoderKL, Transformer2DModel, UNet2DConditionModel from ...models import AutoencoderKL, Transformer2DModel, UNet2DConditionModel
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...utils import is_accelerate_available, logging from ...utils import is_accelerate_available, logging, randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from .modeling_text_unet import UNetFlatConditionModel from .modeling_text_unet import UNetFlatConditionModel
@@ -298,20 +298,8 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline):
) )
if latents is None: if latents is None:
rand_device = "cpu" if device.type == "mps" else device latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
if isinstance(generator, list):
shape = (1,) + shape[1:]
latents = [
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
for i in range(batch_size)
]
latents = torch.cat(latents, dim=0).to(device)
else: else:
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
else:
if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
latents = latents.to(device) latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler # scale the initial noise by the standard deviation required by the scheduler
+2 -7
View File
@@ -23,7 +23,7 @@ import numpy as np
import torch import torch
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, deprecate from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, deprecate, randn_tensor
from .scheduling_utils import SchedulerMixin from .scheduling_utils import SchedulerMixin
@@ -324,12 +324,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
) )
if variance_noise is None: if variance_noise is None:
if device.type == "mps": variance_noise = randn_tensor(
# randn does not work reproducibly on mps
variance_noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator)
variance_noise = variance_noise.to(device)
else:
variance_noise = torch.randn(
model_output.shape, generator=generator, device=device, dtype=model_output.dtype model_output.shape, generator=generator, device=device, dtype=model_output.dtype
) )
variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * variance_noise variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * variance_noise
+2 -7
View File
@@ -22,7 +22,7 @@ import numpy as np
import torch import torch
from ..configuration_utils import ConfigMixin, FrozenDict, register_to_config from ..configuration_utils import ConfigMixin, FrozenDict, register_to_config
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, deprecate from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, deprecate, randn_tensor
from .scheduling_utils import SchedulerMixin from .scheduling_utils import SchedulerMixin
@@ -313,12 +313,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
variance = 0 variance = 0
if t > 0: if t > 0:
device = model_output.device device = model_output.device
if device.type == "mps": variance_noise = randn_tensor(
# randn does not work reproducibly on mps
variance_noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator)
variance_noise = variance_noise.to(device)
else:
variance_noise = torch.randn(
model_output.shape, generator=generator, device=device, dtype=model_output.dtype model_output.shape, generator=generator, device=device, dtype=model_output.dtype
) )
if self.variance_type == "fixed_small_log": if self.variance_type == "fixed_small_log":
@@ -19,7 +19,7 @@ import numpy as np
import torch import torch
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, logging from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, logging, randn_tensor
from .scheduling_utils import SchedulerMixin from .scheduling_utils import SchedulerMixin
@@ -230,15 +230,7 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
prev_sample = sample + derivative * dt prev_sample = sample + derivative * dt
device = model_output.device device = model_output.device
if device.type == "mps": noise = randn_tensor(model_output.shape, dtype=model_output.dtype, device=device, generator=generator)
# randn does not work reproducibly on mps
noise = torch.randn(model_output.shape, dtype=model_output.dtype, device="cpu", generator=generator).to(
device
)
else:
noise = torch.randn(model_output.shape, dtype=model_output.dtype, device=device, generator=generator).to(
device
)
prev_sample = prev_sample + noise * sigma_up prev_sample = prev_sample + noise * sigma_up
@@ -19,7 +19,7 @@ import numpy as np
import torch import torch
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, logging from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, logging, randn_tensor
from .scheduling_utils import SchedulerMixin from .scheduling_utils import SchedulerMixin
@@ -217,15 +217,8 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0 gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0
device = model_output.device noise = randn_tensor(
if device.type == "mps": model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
# randn does not work reproducibly on mps
noise = torch.randn(model_output.shape, dtype=model_output.dtype, device="cpu", generator=generator).to(
device
)
else:
noise = torch.randn(model_output.shape, dtype=model_output.dtype, device=device, generator=generator).to(
device
) )
eps = noise * s_noise eps = noise * s_noise
@@ -18,7 +18,7 @@ import numpy as np
import torch import torch
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, randn_tensor
from .scheduling_utils import SchedulerMixin, SchedulerOutput from .scheduling_utils import SchedulerMixin, SchedulerOutput
@@ -243,15 +243,7 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
sigma_hat = sigma * (gamma + 1) # Note: sigma_hat == sigma for now sigma_hat = sigma * (gamma + 1) # Note: sigma_hat == sigma for now
device = model_output.device device = model_output.device
if device.type == "mps": noise = randn_tensor(model_output.shape, dtype=model_output.dtype, device=device, generator=generator)
# randn does not work reproducibly on mps
noise = torch.randn(model_output.shape, dtype=model_output.dtype, device="cpu", generator=generator).to(
device
)
else:
noise = torch.randn(model_output.shape, dtype=model_output.dtype, device=device, generator=generator).to(
device
)
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
if self.config.prediction_type == "epsilon": if self.config.prediction_type == "epsilon":
@@ -20,7 +20,7 @@ import numpy as np
import torch import torch
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput from ..utils import BaseOutput, randn_tensor
from .scheduling_utils import SchedulerMixin from .scheduling_utils import SchedulerMixin
@@ -147,7 +147,7 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
gamma = 0 gamma = 0
# sample eps ~ N(0, S_noise^2 * I) # sample eps ~ N(0, S_noise^2 * I)
eps = self.config.s_noise * torch.randn(sample.shape, generator=generator).to(sample.device) eps = self.config.s_noise * randn_tensor(sample.shape, generator=generator).to(sample.device)
sigma_hat = sigma + gamma * sigma sigma_hat = sigma + gamma * sigma
sample_hat = sample + ((sigma_hat**2 - sigma**2) ** 0.5 * eps) sample_hat = sample + ((sigma_hat**2 - sigma**2) ** 0.5 * eps)
@@ -20,7 +20,7 @@ import numpy as np
import torch import torch
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput from ..utils import BaseOutput, randn_tensor
from .scheduling_utils import SchedulerMixin from .scheduling_utils import SchedulerMixin
@@ -271,12 +271,7 @@ class RePaintScheduler(SchedulerMixin, ConfigMixin):
# 5. Add noise # 5. Add noise
device = model_output.device device = model_output.device
if device.type == "mps": noise = randn_tensor(model_output.shape, generator=generator, device=device, dtype=model_output.dtype)
# randn does not work reproducibly on mps
noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator)
noise = noise.to(device)
else:
noise = torch.randn(model_output.shape, generator=generator, device=device, dtype=model_output.dtype)
std_dev_t = self.eta * self._get_variance(timestep) ** 0.5 std_dev_t = self.eta * self._get_variance(timestep) ** 0.5
variance = 0 variance = 0
@@ -311,10 +306,10 @@ class RePaintScheduler(SchedulerMixin, ConfigMixin):
beta = self.betas[timestep + i] beta = self.betas[timestep + i]
if sample.device.type == "mps": if sample.device.type == "mps":
# randn does not work reproducibly on mps # randn does not work reproducibly on mps
noise = torch.randn(sample.shape, dtype=sample.dtype, generator=generator) noise = randn_tensor(sample.shape, dtype=sample.dtype, generator=generator)
noise = noise.to(sample.device) noise = noise.to(sample.device)
else: else:
noise = torch.randn(sample.shape, generator=generator, device=sample.device, dtype=sample.dtype) noise = randn_tensor(sample.shape, generator=generator, device=sample.device, dtype=sample.dtype)
# 10. Algorithm 1 Line 10 https://arxiv.org/pdf/2201.09865.pdf # 10. Algorithm 1 Line 10 https://arxiv.org/pdf/2201.09865.pdf
sample = (1 - beta) ** 0.5 * sample + beta**0.5 * noise sample = (1 - beta) ** 0.5 * sample + beta**0.5 * noise
@@ -21,7 +21,7 @@ from typing import Optional, Tuple, Union
import torch import torch
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput from ..utils import BaseOutput, randn_tensor
from .scheduling_utils import SchedulerMixin, SchedulerOutput from .scheduling_utils import SchedulerMixin, SchedulerOutput
@@ -201,7 +201,9 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
drift = drift - diffusion**2 * model_output drift = drift - diffusion**2 * model_output
# equation 6: sample noise for the diffusion term of # equation 6: sample noise for the diffusion term of
noise = torch.randn(sample.shape, layout=sample.layout, generator=generator).to(sample.device) noise = randn_tensor(
sample.shape, layout=sample.layout, generator=generator, device=sample.device, dtype=sample.dtype
)
prev_sample_mean = sample - drift # subtract because `dt` is a small negative timestep prev_sample_mean = sample - drift # subtract because `dt` is a small negative timestep
# TODO is the variable diffusion the correct scaling term for the noise? # TODO is the variable diffusion the correct scaling term for the noise?
prev_sample = prev_sample_mean + diffusion * noise # add impact of diffusion field g prev_sample = prev_sample_mean + diffusion * noise # add impact of diffusion field g
@@ -241,7 +243,7 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
# For small batch sizes, the paper "suggest replacing norm(z) with sqrt(d), where d is the dim. of z" # For small batch sizes, the paper "suggest replacing norm(z) with sqrt(d), where d is the dim. of z"
# sample noise for correction # sample noise for correction
noise = torch.randn(sample.shape, layout=sample.layout, generator=generator).to(sample.device) noise = randn_tensor(sample.shape, layout=sample.layout, generator=generator).to(sample.device)
# compute step size from the model_output, the noise, and the snr # compute step size from the model_output, the noise, and the snr
grad_norm = torch.norm(model_output.reshape(model_output.shape[0], -1), dim=-1).mean() grad_norm = torch.norm(model_output.reshape(model_output.shape[0], -1), dim=-1).mean()
@@ -20,6 +20,7 @@ from typing import Union
import torch import torch
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import randn_tensor
from .scheduling_utils import SchedulerMixin from .scheduling_utils import SchedulerMixin
@@ -80,7 +81,7 @@ class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin):
x_mean = x + drift * dt x_mean = x + drift * dt
# add noise # add noise
noise = torch.randn(x.shape, layout=x.layout, generator=generator).to(x.device) noise = randn_tensor(x.shape, layout=x.layout, generator=generator, device=x.device, dtype=x.dtype)
x = x_mean + diffusion * math.sqrt(-dt) * noise x = x_mean + diffusion * math.sqrt(-dt) * noise
return x, x_mean return x, x_mean
@@ -20,7 +20,7 @@ import numpy as np
import torch import torch
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput, torch_randn from ..utils import BaseOutput, randn_tensor
from .scheduling_utils import SchedulerMixin from .scheduling_utils import SchedulerMixin
@@ -273,7 +273,7 @@ class UnCLIPScheduler(SchedulerMixin, ConfigMixin):
# 6. Add noise # 6. Add noise
variance = 0 variance = 0
if t > 0: if t > 0:
variance_noise = torch_randn( variance_noise = randn_tensor(
model_output.shape, dtype=model_output.dtype, generator=generator, device=model_output.device model_output.shape, dtype=model_output.dtype, generator=generator, device=model_output.device
) )
+1 -1
View File
@@ -64,7 +64,7 @@ from .import_utils import (
from .logging import get_logger from .logging import get_logger
from .outputs import BaseOutput from .outputs import BaseOutput
from .pil_utils import PIL_INTERPOLATION from .pil_utils import PIL_INTERPOLATION
from .torch_utils import torch_randn from .torch_utils import randn_tensor
if is_torch_available(): if is_torch_available():
+12 -6
View File
@@ -26,11 +26,12 @@ if is_torch_available():
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def torch_randn( def randn_tensor(
shape: Union[Tuple, List], shape: Union[Tuple, List],
generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None, generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None,
device: Optional["torch.device"] = None, device: Optional["torch.device"] = None,
dtype: Optional["torch.dtype"] = None, dtype: Optional["torch.dtype"] = None,
layout: Optional["torch.layout"] = None,
): ):
"""This is a helper function that allows to create random tensors on the desired `device` with the desired `dtype`. When """This is a helper function that allows to create random tensors on the desired `device` with the desired `dtype`. When
passing a list of generators one can seed each batched size individually. If CPU generators are passed the tensor passing a list of generators one can seed each batched size individually. If CPU generators are passed the tensor
@@ -40,8 +41,12 @@ def torch_randn(
rand_device = device rand_device = device
batch_size = shape[0] batch_size = shape[0]
layout = layout or torch.strided
device = device or torch.device("cpu")
if generator is not None: if generator is not None:
if generator.device != device and generator.device.type == "cpu": gen_device_type = generator.device.type if not isinstance(generator, list) else generator[0].device.type
if gen_device_type != device.type and gen_device_type == "cpu":
rand_device = "cpu" rand_device = "cpu"
if device != "mps": if device != "mps":
logger.info( logger.info(
@@ -49,16 +54,17 @@ def torch_randn(
f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably" f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably"
f" slighly speed up this function by passing a generator that was created on the {device} device." f" slighly speed up this function by passing a generator that was created on the {device} device."
) )
elif generator.device.type != device.type and generator.device.type == "cuda": elif gen_device_type != device.type and gen_device_type == "cuda":
raise ValueError(f"Cannot generate a {device} tensor from a generator of type {generator.device.type}.") raise ValueError(f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.")
if isinstance(generator, list): if isinstance(generator, list):
shape = (1,) + shape[1:] shape = (1,) + shape[1:]
latents = [ latents = [
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) for i in range(batch_size) torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype, layout=layout)
for i in range(batch_size)
] ]
latents = torch.cat(latents, dim=0).to(device) latents = torch.cat(latents, dim=0).to(device)
else: else:
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device) latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device)
return latents return latents
@@ -25,44 +25,6 @@ from diffusers.utils.testing_utils import require_torch, slow, torch_device
torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cuda.matmul.allow_tf32 = False
class KarrasVePipelineFastTests(unittest.TestCase):
@property
def dummy_uncond_unet(self):
torch.manual_seed(0)
model = UNet2DModel(
block_out_channels=(32, 64),
layers_per_block=2,
sample_size=32,
in_channels=3,
out_channels=3,
down_block_types=("DownBlock2D", "AttnDownBlock2D"),
up_block_types=("AttnUpBlock2D", "UpBlock2D"),
)
return model
def test_inference(self):
unet = self.dummy_uncond_unet
scheduler = KarrasVeScheduler()
pipe = KarrasVePipeline(unet=unet, scheduler=scheduler)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0)
image = pipe(num_inference_steps=2, generator=generator, output_type="numpy").images
generator = torch.manual_seed(0)
image_from_tuple = pipe(num_inference_steps=2, generator=generator, output_type="numpy", return_dict=False)[0]
image_slice = image[0, -3:, -3:, -1]
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
assert image.shape == (1, 32, 32, 3)
expected_slice = np.array([0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
@slow @slow
@require_torch @require_torch
class KarrasVePipelineIntegrationTests(unittest.TestCase): class KarrasVePipelineIntegrationTests(unittest.TestCase):
@@ -132,7 +132,7 @@ class PaintByExamplePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
image_slice = image[0, -3:, -3:, -1] image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 64, 64, 3) assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.4397, 0.5553, 0.3802, 0.5222, 0.5811, 0.4342, 0.494, 0.4577, 0.4428]) expected_slice = np.array([0.4701, 0.5555, 0.3994, 0.5107, 0.5691, 0.4517, 0.5125, 0.4769, 0.4539])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+39 -4
View File
@@ -21,7 +21,7 @@ import torch
from diffusers import PriorTransformer, UnCLIPPipeline, UnCLIPScheduler, UNet2DConditionModel, UNet2DModel from diffusers import PriorTransformer, UnCLIPPipeline, UnCLIPScheduler, UNet2DConditionModel, UNet2DModel
from diffusers.pipelines.unclip.text_proj import UnCLIPTextProjModel from diffusers.pipelines.unclip.text_proj import UnCLIPTextProjModel
from diffusers.utils import load_numpy, slow, torch_device from diffusers.utils import load_numpy, nightly, slow, torch_device
from diffusers.utils.testing_utils import require_torch_gpu from diffusers.utils.testing_utils import require_torch_gpu
from transformers import CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer from transformers import CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer
@@ -363,6 +363,37 @@ class UnCLIPPipelineFastTests(unittest.TestCase):
assert np.abs(image - image_from_text).max() < 1e-4 assert np.abs(image - image_from_text).max() < 1e-4
@nightly
class UnCLIPPipelineCPUIntegrationTests(unittest.TestCase):
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def test_unclip_karlo_cpu_fp32(self):
expected_image = load_numpy(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/unclip/karlo_v1_alpha_horse_cpu.npy"
)
pipeline = UnCLIPPipeline.from_pretrained("kakaobrain/karlo-v1-alpha")
pipeline.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0)
output = pipeline(
"horse",
num_images_per_prompt=1,
generator=generator,
output_type="np",
)
image = output.images[0]
assert image.shape == (256, 256, 3)
assert np.abs(expected_image - image).max() < 1e-1
@slow @slow
@require_torch_gpu @require_torch_gpu
class UnCLIPPipelineIntegrationTests(unittest.TestCase): class UnCLIPPipelineIntegrationTests(unittest.TestCase):
@@ -385,15 +416,19 @@ class UnCLIPPipelineIntegrationTests(unittest.TestCase):
generator = torch.Generator(device="cpu").manual_seed(0) generator = torch.Generator(device="cpu").manual_seed(0)
output = pipeline( output = pipeline(
"horse", "horse",
num_images_per_prompt=1,
generator=generator, generator=generator,
output_type="np", output_type="np",
) )
image = output.images[0] image = np.asarray(pipeline.numpy_to_pil(output.images)[0], dtype=np.float32)
expected_image = np.asarray(pipeline.numpy_to_pil(expected_image)[0], dtype=np.float32)
# Karlo is extremely likely to strongly deviate depending on which hardware is used
# Here we just check that the image doesn't deviate more than 10 pixels from the reference image on average
avg_diff = np.abs(image - expected_image).mean()
assert avg_diff < 10, f"Error image deviates {avg_diff} pixels on average"
assert image.shape == (256, 256, 3) assert image.shape == (256, 256, 3)
assert np.abs(expected_image - image).max() < 1e-2
def test_unclip_pipeline_with_sequential_cpu_offloading(self): def test_unclip_pipeline_with_sequential_cpu_offloading(self):
torch.cuda.empty_cache() torch.cuda.empty_cache()
@@ -475,20 +475,25 @@ class UnCLIPImageVariationPipelineIntegrationTests(unittest.TestCase):
"/unclip/karlo_v1_alpha_cat_variation_fp16.npy" "/unclip/karlo_v1_alpha_cat_variation_fp16.npy"
) )
pipeline = UnCLIPImageVariationPipeline.from_pretrained("fusing/karlo-image-variations-diffusers") pipeline = UnCLIPImageVariationPipeline.from_pretrained(
"fusing/karlo-image-variations-diffusers", torch_dtype=torch.float16
)
pipeline = pipeline.to(torch_device) pipeline = pipeline.to(torch_device)
pipeline.set_progress_bar_config(disable=None) pipeline.set_progress_bar_config(disable=None)
pipeline.enable_sequential_cpu_offload()
generator = torch.Generator(device="cpu").manual_seed(0) generator = torch.Generator(device="cpu").manual_seed(0)
output = pipeline( output = pipeline(
input_image, input_image,
num_images_per_prompt=1,
generator=generator, generator=generator,
output_type="np", output_type="np",
) )
image = output.images[0] image = np.asarray(pipeline.numpy_to_pil(output.images)[0], dtype=np.float32)
expected_image = np.asarray(pipeline.numpy_to_pil(expected_image)[0], dtype=np.float32)
# Karlo is extremely likely to strongly deviate depending on which hardware is used
# Here we just check that the image doesn't deviate more than 10 pixels from the reference image on average
avg_diff = np.abs(image - expected_image).mean()
assert avg_diff < 10, f"Error image deviates {avg_diff} pixels on average"
assert image.shape == (256, 256, 3) assert image.shape == (256, 256, 3)
assert np.abs(expected_image - image).max() < 5e-2