Improve reproduceability 2/3 (#1906)
* [Repro] Correct reproducability * up * up * uP * up * need better image * allow conversion from no state dict checkpoints * up * up * up * up * check tensors * check tensors * check tensors * check tensors * next try * up * up * better name * up * up * Apply suggestions from code review * correct more * up * replace all torch randn * fix * correct * correct * finish * fix more * up
This commit is contained in:
committed by
GitHub
parent
67e2f95cc4
commit
9b63854886
@@ -19,6 +19,7 @@ import tqdm
|
|||||||
|
|
||||||
from ...models.unet_1d import UNet1DModel
|
from ...models.unet_1d import UNet1DModel
|
||||||
from ...pipelines import DiffusionPipeline
|
from ...pipelines import DiffusionPipeline
|
||||||
|
from ...utils import randn_tensor
|
||||||
from ...utils.dummy_pt_objects import DDPMScheduler
|
from ...utils.dummy_pt_objects import DDPMScheduler
|
||||||
|
|
||||||
|
|
||||||
@@ -127,7 +128,7 @@ class ValueGuidedRLPipeline(DiffusionPipeline):
|
|||||||
shape = (batch_size, planning_horizon, self.state_dim + self.action_dim)
|
shape = (batch_size, planning_horizon, self.state_dim + self.action_dim)
|
||||||
|
|
||||||
# generate initial noise and apply our conditions (to make the trajectories start at current state)
|
# generate initial noise and apply our conditions (to make the trajectories start at current state)
|
||||||
x1 = torch.randn(shape, device=self.unet.device)
|
x1 = randn_tensor(shape, device=self.unet.device)
|
||||||
x = self.reset_x0(x1, conditions, self.action_dim)
|
x = self.reset_x0(x1, conditions, self.action_dim)
|
||||||
x = self.to_torch(x)
|
x = self.to_torch(x)
|
||||||
|
|
||||||
|
|||||||
@@ -95,7 +95,7 @@ class PriorTransformer(ModelMixin, ConfigMixin):
|
|||||||
self.proj_to_clip_embeddings = nn.Linear(inner_dim, embedding_dim)
|
self.proj_to_clip_embeddings = nn.Linear(inner_dim, embedding_dim)
|
||||||
|
|
||||||
causal_attention_mask = torch.full(
|
causal_attention_mask = torch.full(
|
||||||
[num_embeddings + additional_embeddings, num_embeddings + additional_embeddings], float("-inf")
|
[num_embeddings + additional_embeddings, num_embeddings + additional_embeddings], -10000.0
|
||||||
)
|
)
|
||||||
causal_attention_mask.triu_(1)
|
causal_attention_mask.triu_(1)
|
||||||
causal_attention_mask = causal_attention_mask[None, ...]
|
causal_attention_mask = causal_attention_mask[None, ...]
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from ..utils import BaseOutput
|
from ..utils import BaseOutput, randn_tensor
|
||||||
from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
|
from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
|
||||||
|
|
||||||
|
|
||||||
@@ -323,11 +323,10 @@ class DiagonalGaussianDistribution(object):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
|
def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
|
||||||
device = self.parameters.device
|
|
||||||
sample_device = "cpu" if device.type == "mps" else device
|
|
||||||
sample = torch.randn(self.mean.shape, generator=generator, device=sample_device)
|
|
||||||
# make sure sample is on the same device as the parameters and has same dtype
|
# make sure sample is on the same device as the parameters and has same dtype
|
||||||
sample = sample.to(device=device, dtype=self.parameters.dtype)
|
sample = randn_tensor(
|
||||||
|
self.mean.shape, generator=generator, device=self.parameters.device, dtype=self.parameters.dtype
|
||||||
|
)
|
||||||
x = self.mean + self.std * sample
|
x = self.mean + self.std * sample
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ from ...schedulers import (
|
|||||||
LMSDiscreteScheduler,
|
LMSDiscreteScheduler,
|
||||||
PNDMScheduler,
|
PNDMScheduler,
|
||||||
)
|
)
|
||||||
from ...utils import deprecate, logging, replace_example_docstring
|
from ...utils import deprecate, logging, randn_tensor, replace_example_docstring
|
||||||
from ..pipeline_utils import DiffusionPipeline
|
from ..pipeline_utils import DiffusionPipeline
|
||||||
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||||
from . import AltDiffusionPipelineOutput, RobertaSeriesModelWithTransformation
|
from . import AltDiffusionPipelineOutput, RobertaSeriesModelWithTransformation
|
||||||
@@ -401,20 +401,8 @@ class AltDiffusionPipeline(DiffusionPipeline):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if latents is None:
|
if latents is None:
|
||||||
rand_device = "cpu" if device.type == "mps" else device
|
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||||
|
|
||||||
if isinstance(generator, list):
|
|
||||||
shape = (1,) + shape[1:]
|
|
||||||
latents = [
|
|
||||||
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
|
|
||||||
for i in range(batch_size)
|
|
||||||
]
|
|
||||||
latents = torch.cat(latents, dim=0).to(device)
|
|
||||||
else:
|
|
||||||
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
|
|
||||||
else:
|
else:
|
||||||
if latents.shape != shape:
|
|
||||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
|
||||||
latents = latents.to(device)
|
latents = latents.to(device)
|
||||||
|
|
||||||
# scale the initial noise by the standard deviation required by the scheduler
|
# scale the initial noise by the standard deviation required by the scheduler
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ from ...schedulers import (
|
|||||||
LMSDiscreteScheduler,
|
LMSDiscreteScheduler,
|
||||||
PNDMScheduler,
|
PNDMScheduler,
|
||||||
)
|
)
|
||||||
from ...utils import PIL_INTERPOLATION, deprecate, logging, replace_example_docstring
|
from ...utils import PIL_INTERPOLATION, deprecate, logging, randn_tensor, replace_example_docstring
|
||||||
from ..pipeline_utils import DiffusionPipeline
|
from ..pipeline_utils import DiffusionPipeline
|
||||||
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||||
from . import AltDiffusionPipelineOutput, RobertaSeriesModelWithTransformation
|
from . import AltDiffusionPipelineOutput, RobertaSeriesModelWithTransformation
|
||||||
@@ -461,16 +461,8 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
|
|||||||
else:
|
else:
|
||||||
init_latents = torch.cat([init_latents], dim=0)
|
init_latents = torch.cat([init_latents], dim=0)
|
||||||
|
|
||||||
rand_device = "cpu" if device.type == "mps" else device
|
|
||||||
shape = init_latents.shape
|
shape = init_latents.shape
|
||||||
if isinstance(generator, list):
|
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||||
shape = (1,) + shape[1:]
|
|
||||||
noise = [
|
|
||||||
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) for i in range(batch_size)
|
|
||||||
]
|
|
||||||
noise = torch.cat(noise, dim=0).to(device)
|
|
||||||
else:
|
|
||||||
noise = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
|
|
||||||
|
|
||||||
# get latents
|
# get latents
|
||||||
init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
|
init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ from PIL import Image
|
|||||||
|
|
||||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||||
from ...schedulers import DDIMScheduler, DDPMScheduler
|
from ...schedulers import DDIMScheduler, DDPMScheduler
|
||||||
|
from ...utils import randn_tensor
|
||||||
from ..pipeline_utils import AudioPipelineOutput, BaseOutput, DiffusionPipeline, ImagePipelineOutput
|
from ..pipeline_utils import AudioPipelineOutput, BaseOutput, DiffusionPipeline, ImagePipelineOutput
|
||||||
from .mel import Mel
|
from .mel import Mel
|
||||||
|
|
||||||
@@ -126,7 +127,7 @@ class AudioDiffusionPipeline(DiffusionPipeline):
|
|||||||
input_dims = self.get_input_dims()
|
input_dims = self.get_input_dims()
|
||||||
self.mel.set_resolution(x_res=input_dims[1], y_res=input_dims[0])
|
self.mel.set_resolution(x_res=input_dims[1], y_res=input_dims[0])
|
||||||
if noise is None:
|
if noise is None:
|
||||||
noise = torch.randn(
|
noise = randn_tensor(
|
||||||
(
|
(
|
||||||
batch_size,
|
batch_size,
|
||||||
self.unet.in_channels,
|
self.unet.in_channels,
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ from typing import List, Optional, Tuple, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ...utils import logging
|
from ...utils import logging, randn_tensor
|
||||||
from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline
|
from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline
|
||||||
|
|
||||||
|
|
||||||
@@ -100,16 +100,7 @@ class DanceDiffusionPipeline(DiffusionPipeline):
|
|||||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||||
)
|
)
|
||||||
|
|
||||||
rand_device = "cpu" if self.device.type == "mps" else self.device
|
audio = randn_tensor(shape, generator=generator, device=self.device, dtype=dtype)
|
||||||
if isinstance(generator, list):
|
|
||||||
shape = (1,) + shape[1:]
|
|
||||||
audio = [
|
|
||||||
torch.randn(shape, generator=generator[i], device=rand_device, dtype=self.unet.dtype)
|
|
||||||
for i in range(batch_size)
|
|
||||||
]
|
|
||||||
audio = torch.cat(audio, dim=0).to(self.device)
|
|
||||||
else:
|
|
||||||
audio = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(self.device)
|
|
||||||
|
|
||||||
# set step values
|
# set step values
|
||||||
self.scheduler.set_timesteps(num_inference_steps, device=audio.device)
|
self.scheduler.set_timesteps(num_inference_steps, device=audio.device)
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ from typing import List, Optional, Tuple, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ...utils import deprecate
|
from ...utils import deprecate, randn_tensor
|
||||||
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||||
|
|
||||||
|
|
||||||
@@ -103,17 +103,7 @@ class DDIMPipeline(DiffusionPipeline):
|
|||||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||||
)
|
)
|
||||||
|
|
||||||
rand_device = "cpu" if self.device.type == "mps" else self.device
|
image = randn_tensor(image_shape, generator=generator, device=self.device, dtype=self.unet.dtype)
|
||||||
if isinstance(generator, list):
|
|
||||||
shape = (1,) + image_shape[1:]
|
|
||||||
image = [
|
|
||||||
torch.randn(shape, generator=generator[i], device=rand_device, dtype=self.unet.dtype)
|
|
||||||
for i in range(batch_size)
|
|
||||||
]
|
|
||||||
image = torch.cat(image, dim=0).to(self.device)
|
|
||||||
else:
|
|
||||||
image = torch.randn(image_shape, generator=generator, device=rand_device, dtype=self.unet.dtype)
|
|
||||||
image = image.to(self.device)
|
|
||||||
|
|
||||||
# set step values
|
# set step values
|
||||||
self.scheduler.set_timesteps(num_inference_steps)
|
self.scheduler.set_timesteps(num_inference_steps)
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ from typing import List, Optional, Tuple, Union
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ...configuration_utils import FrozenDict
|
from ...configuration_utils import FrozenDict
|
||||||
from ...utils import deprecate
|
from ...utils import deprecate, randn_tensor
|
||||||
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||||
|
|
||||||
|
|
||||||
@@ -100,10 +100,10 @@ class DDPMPipeline(DiffusionPipeline):
|
|||||||
|
|
||||||
if self.device.type == "mps":
|
if self.device.type == "mps":
|
||||||
# randn does not work reproducibly on mps
|
# randn does not work reproducibly on mps
|
||||||
image = torch.randn(image_shape, generator=generator)
|
image = randn_tensor(image_shape, generator=generator)
|
||||||
image = image.to(self.device)
|
image = image.to(self.device)
|
||||||
else:
|
else:
|
||||||
image = torch.randn(image_shape, generator=generator, device=self.device)
|
image = randn_tensor(image_shape, generator=generator, device=self.device)
|
||||||
|
|
||||||
# set step values
|
# set step values
|
||||||
self.scheduler.set_timesteps(num_inference_steps)
|
self.scheduler.set_timesteps(num_inference_steps)
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ from transformers.utils import logging
|
|||||||
|
|
||||||
from ...models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel
|
from ...models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel
|
||||||
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||||
|
from ...utils import randn_tensor
|
||||||
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||||
|
|
||||||
|
|
||||||
@@ -143,20 +144,7 @@ class LDMTextToImagePipeline(DiffusionPipeline):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if latents is None:
|
if latents is None:
|
||||||
rand_device = "cpu" if self.device.type == "mps" else self.device
|
latents = randn_tensor(latents_shape, generator=generator, device=self.device, dtype=text_embeddings.dtype)
|
||||||
|
|
||||||
if isinstance(generator, list):
|
|
||||||
latents_shape = (1,) + latents_shape[1:]
|
|
||||||
latents = [
|
|
||||||
torch.randn(latents_shape, generator=generator[i], device=rand_device, dtype=text_embeddings.dtype)
|
|
||||||
for i in range(batch_size)
|
|
||||||
]
|
|
||||||
latents = torch.cat(latents, dim=0)
|
|
||||||
else:
|
|
||||||
latents = torch.randn(
|
|
||||||
latents_shape, generator=generator, device=rand_device, dtype=text_embeddings.dtype
|
|
||||||
)
|
|
||||||
latents = latents.to(self.device)
|
|
||||||
else:
|
else:
|
||||||
if latents.shape != latents_shape:
|
if latents.shape != latents_shape:
|
||||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
|
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
|
||||||
|
|||||||
+2
-7
@@ -16,7 +16,7 @@ from ...schedulers import (
|
|||||||
LMSDiscreteScheduler,
|
LMSDiscreteScheduler,
|
||||||
PNDMScheduler,
|
PNDMScheduler,
|
||||||
)
|
)
|
||||||
from ...utils import PIL_INTERPOLATION, deprecate
|
from ...utils import PIL_INTERPOLATION, deprecate, randn_tensor
|
||||||
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||||
|
|
||||||
|
|
||||||
@@ -121,12 +121,7 @@ class LDMSuperResolutionPipeline(DiffusionPipeline):
|
|||||||
latents_shape = (batch_size, self.unet.in_channels // 2, height, width)
|
latents_shape = (batch_size, self.unet.in_channels // 2, height, width)
|
||||||
latents_dtype = next(self.unet.parameters()).dtype
|
latents_dtype = next(self.unet.parameters()).dtype
|
||||||
|
|
||||||
if self.device.type == "mps":
|
latents = randn_tensor(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
|
||||||
# randn does not work reproducibly on mps
|
|
||||||
latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype)
|
|
||||||
latents = latents.to(self.device)
|
|
||||||
else:
|
|
||||||
latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
|
|
||||||
|
|
||||||
image = image.to(device=self.device, dtype=latents_dtype)
|
image = image.to(device=self.device, dtype=latents_dtype)
|
||||||
|
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ import torch
|
|||||||
|
|
||||||
from ...models import UNet2DModel, VQModel
|
from ...models import UNet2DModel, VQModel
|
||||||
from ...schedulers import DDIMScheduler
|
from ...schedulers import DDIMScheduler
|
||||||
|
from ...utils import randn_tensor
|
||||||
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||||
|
|
||||||
|
|
||||||
@@ -71,7 +72,7 @@ class LDMPipeline(DiffusionPipeline):
|
|||||||
True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images.
|
True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
latents = torch.randn(
|
latents = randn_tensor(
|
||||||
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
|
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
|
||||||
generator=generator,
|
generator=generator,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ class PaintByExampleImageEncoder(CLIPPreTrainedModel):
|
|||||||
self.proj_out = nn.Linear(config.hidden_size, self.proj_size)
|
self.proj_out = nn.Linear(config.hidden_size, self.proj_size)
|
||||||
|
|
||||||
# uncondition for scaling
|
# uncondition for scaling
|
||||||
self.uncond_vector = nn.Parameter(torch.rand((1, 1, self.proj_size)))
|
self.uncond_vector = nn.Parameter(torch.randn((1, 1, self.proj_size)))
|
||||||
|
|
||||||
def forward(self, pixel_values):
|
def forward(self, pixel_values):
|
||||||
clip_output = self.model(pixel_values=pixel_values)
|
clip_output = self.model(pixel_values=pixel_values)
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ from transformers import CLIPFeatureExtractor
|
|||||||
|
|
||||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||||
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||||
from ...utils import logging
|
from ...utils import logging, randn_tensor
|
||||||
from ..pipeline_utils import DiffusionPipeline
|
from ..pipeline_utils import DiffusionPipeline
|
||||||
from ..stable_diffusion import StableDiffusionPipelineOutput
|
from ..stable_diffusion import StableDiffusionPipelineOutput
|
||||||
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||||
@@ -300,20 +300,8 @@ class PaintByExamplePipeline(DiffusionPipeline):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if latents is None:
|
if latents is None:
|
||||||
rand_device = "cpu" if device.type == "mps" else device
|
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||||
|
|
||||||
if isinstance(generator, list):
|
|
||||||
shape = (1,) + shape[1:]
|
|
||||||
latents = [
|
|
||||||
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
|
|
||||||
for i in range(batch_size)
|
|
||||||
]
|
|
||||||
latents = torch.cat(latents, dim=0).to(device)
|
|
||||||
else:
|
|
||||||
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
|
|
||||||
else:
|
else:
|
||||||
if latents.shape != shape:
|
|
||||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
|
||||||
latents = latents.to(device)
|
latents = latents.to(device)
|
||||||
|
|
||||||
# scale the initial noise by the standard deviation required by the scheduler
|
# scale the initial noise by the standard deviation required by the scheduler
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ import torch
|
|||||||
|
|
||||||
from ...models import UNet2DModel
|
from ...models import UNet2DModel
|
||||||
from ...schedulers import PNDMScheduler
|
from ...schedulers import PNDMScheduler
|
||||||
|
from ...utils import randn_tensor
|
||||||
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||||
|
|
||||||
|
|
||||||
@@ -72,11 +73,11 @@ class PNDMPipeline(DiffusionPipeline):
|
|||||||
# the official paper: https://arxiv.org/pdf/2202.09778.pdf
|
# the official paper: https://arxiv.org/pdf/2202.09778.pdf
|
||||||
|
|
||||||
# Sample gaussian noise to begin loop
|
# Sample gaussian noise to begin loop
|
||||||
image = torch.randn(
|
image = randn_tensor(
|
||||||
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
|
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
|
||||||
generator=generator,
|
generator=generator,
|
||||||
|
device=self.device,
|
||||||
)
|
)
|
||||||
image = image.to(self.device)
|
|
||||||
|
|
||||||
self.scheduler.set_timesteps(num_inference_steps)
|
self.scheduler.set_timesteps(num_inference_steps)
|
||||||
for t in self.progress_bar(self.scheduler.timesteps):
|
for t in self.progress_bar(self.scheduler.timesteps):
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ import PIL
|
|||||||
|
|
||||||
from ...models import UNet2DModel
|
from ...models import UNet2DModel
|
||||||
from ...schedulers import RePaintScheduler
|
from ...schedulers import RePaintScheduler
|
||||||
from ...utils import PIL_INTERPOLATION, deprecate, logging
|
from ...utils import PIL_INTERPOLATION, deprecate, logging, randn_tensor
|
||||||
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||||
|
|
||||||
|
|
||||||
@@ -143,18 +143,8 @@ class RePaintPipeline(DiffusionPipeline):
|
|||||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||||
)
|
)
|
||||||
|
|
||||||
rand_device = "cpu" if self.device.type == "mps" else self.device
|
|
||||||
image_shape = original_image.shape
|
image_shape = original_image.shape
|
||||||
if isinstance(generator, list):
|
image = randn_tensor(image_shape, generator=generator, device=self.device, dtype=self.unet.dtype)
|
||||||
shape = (1,) + image_shape[1:]
|
|
||||||
image = [
|
|
||||||
torch.randn(shape, generator=generator[i], device=rand_device, dtype=self.unet.dtype)
|
|
||||||
for i in range(batch_size)
|
|
||||||
]
|
|
||||||
image = torch.cat(image, dim=0).to(self.device)
|
|
||||||
else:
|
|
||||||
image = torch.randn(image_shape, generator=generator, device=rand_device, dtype=self.unet.dtype)
|
|
||||||
image = image.to(self.device)
|
|
||||||
|
|
||||||
# set step values
|
# set step values
|
||||||
self.scheduler.set_timesteps(num_inference_steps, jump_length, jump_n_sample, self.device)
|
self.scheduler.set_timesteps(num_inference_steps, jump_length, jump_n_sample, self.device)
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ import torch
|
|||||||
|
|
||||||
from ...models import UNet2DModel
|
from ...models import UNet2DModel
|
||||||
from ...schedulers import ScoreSdeVeScheduler
|
from ...schedulers import ScoreSdeVeScheduler
|
||||||
|
from ...utils import randn_tensor
|
||||||
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||||
|
|
||||||
|
|
||||||
@@ -69,7 +70,7 @@ class ScoreSdeVePipeline(DiffusionPipeline):
|
|||||||
|
|
||||||
model = self.unet
|
model = self.unet
|
||||||
|
|
||||||
sample = torch.randn(*shape, generator=generator) * self.scheduler.init_noise_sigma
|
sample = randn_tensor(shape, generator=generator) * self.scheduler.init_noise_sigma
|
||||||
sample = sample.to(self.device)
|
sample = sample.to(self.device)
|
||||||
|
|
||||||
self.scheduler.set_timesteps(num_inference_steps)
|
self.scheduler.set_timesteps(num_inference_steps)
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
|||||||
from ...configuration_utils import FrozenDict
|
from ...configuration_utils import FrozenDict
|
||||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||||
from ...schedulers import DDIMScheduler
|
from ...schedulers import DDIMScheduler
|
||||||
from ...utils import PIL_INTERPOLATION, deprecate, logging
|
from ...utils import PIL_INTERPOLATION, deprecate, logging, randn_tensor
|
||||||
from ..pipeline_utils import DiffusionPipeline
|
from ..pipeline_utils import DiffusionPipeline
|
||||||
from . import StableDiffusionPipelineOutput
|
from . import StableDiffusionPipelineOutput
|
||||||
from .safety_checker import StableDiffusionSafetyChecker
|
from .safety_checker import StableDiffusionSafetyChecker
|
||||||
@@ -76,7 +76,7 @@ def posterior_sample(scheduler, latents, timestep, clean_latents, generator, eta
|
|||||||
# direction pointing to x_t
|
# direction pointing to x_t
|
||||||
e_t = (latents - alpha_prod_t ** (0.5) * clean_latents) / (1 - alpha_prod_t) ** (0.5)
|
e_t = (latents - alpha_prod_t ** (0.5) * clean_latents) / (1 - alpha_prod_t) ** (0.5)
|
||||||
dir_xt = (1.0 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * e_t
|
dir_xt = (1.0 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * e_t
|
||||||
noise = std_dev_t * torch.randn(
|
noise = std_dev_t * randn_tensor(
|
||||||
clean_latents.shape, dtype=clean_latents.dtype, device=clean_latents.device, generator=generator
|
clean_latents.shape, dtype=clean_latents.dtype, device=clean_latents.device, generator=generator
|
||||||
)
|
)
|
||||||
prev_latents = alpha_prod_t_prev ** (0.5) * clean_latents + dir_xt + noise
|
prev_latents = alpha_prod_t_prev ** (0.5) * clean_latents + dir_xt + noise
|
||||||
@@ -472,16 +472,8 @@ class CycleDiffusionPipeline(DiffusionPipeline):
|
|||||||
init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0)
|
init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0)
|
||||||
|
|
||||||
# add noise to latents using the timestep
|
# add noise to latents using the timestep
|
||||||
rand_device = "cpu" if device.type == "mps" else device
|
|
||||||
shape = init_latents.shape
|
shape = init_latents.shape
|
||||||
if isinstance(generator, list):
|
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||||
shape = (1,) + shape[1:]
|
|
||||||
noise = [
|
|
||||||
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) for i in range(batch_size)
|
|
||||||
]
|
|
||||||
noise = torch.cat(noise, dim=0).to(device)
|
|
||||||
else:
|
|
||||||
noise = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
|
|
||||||
|
|
||||||
# get latents
|
# get latents
|
||||||
clean_latents = init_latents
|
clean_latents = init_latents
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ from ...schedulers import (
|
|||||||
LMSDiscreteScheduler,
|
LMSDiscreteScheduler,
|
||||||
PNDMScheduler,
|
PNDMScheduler,
|
||||||
)
|
)
|
||||||
from ...utils import deprecate, is_accelerate_available, logging, replace_example_docstring
|
from ...utils import deprecate, is_accelerate_available, logging, randn_tensor, replace_example_docstring
|
||||||
from ..pipeline_utils import DiffusionPipeline
|
from ..pipeline_utils import DiffusionPipeline
|
||||||
from . import StableDiffusionPipelineOutput
|
from . import StableDiffusionPipelineOutput
|
||||||
from .safety_checker import StableDiffusionSafetyChecker
|
from .safety_checker import StableDiffusionSafetyChecker
|
||||||
@@ -398,20 +398,8 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if latents is None:
|
if latents is None:
|
||||||
rand_device = "cpu" if device.type == "mps" else device
|
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||||
|
|
||||||
if isinstance(generator, list):
|
|
||||||
shape = (1,) + shape[1:]
|
|
||||||
latents = [
|
|
||||||
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
|
|
||||||
for i in range(batch_size)
|
|
||||||
]
|
|
||||||
latents = torch.cat(latents, dim=0).to(device)
|
|
||||||
else:
|
|
||||||
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
|
|
||||||
else:
|
else:
|
||||||
if latents.shape != shape:
|
|
||||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
|
||||||
latents = latents.to(device)
|
latents = latents.to(device)
|
||||||
|
|
||||||
# scale the initial noise by the standard deviation required by the scheduler
|
# scale the initial noise by the standard deviation required by the scheduler
|
||||||
|
|||||||
@@ -20,7 +20,6 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
import PIL
|
import PIL
|
||||||
from diffusers.utils import is_accelerate_available
|
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from transformers import CLIPTextModel, CLIPTokenizer, DPTFeatureExtractor, DPTForDepthEstimation
|
from transformers import CLIPTextModel, CLIPTokenizer, DPTFeatureExtractor, DPTForDepthEstimation
|
||||||
|
|
||||||
@@ -34,7 +33,7 @@ from ...schedulers import (
|
|||||||
LMSDiscreteScheduler,
|
LMSDiscreteScheduler,
|
||||||
PNDMScheduler,
|
PNDMScheduler,
|
||||||
)
|
)
|
||||||
from ...utils import PIL_INTERPOLATION, deprecate, logging
|
from ...utils import PIL_INTERPOLATION, deprecate, is_accelerate_available, logging, randn_tensor
|
||||||
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||||
|
|
||||||
|
|
||||||
@@ -381,16 +380,8 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline):
|
|||||||
else:
|
else:
|
||||||
init_latents = torch.cat([init_latents], dim=0)
|
init_latents = torch.cat([init_latents], dim=0)
|
||||||
|
|
||||||
rand_device = "cpu" if device.type == "mps" else device
|
|
||||||
shape = init_latents.shape
|
shape = init_latents.shape
|
||||||
if isinstance(generator, list):
|
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||||
shape = (1,) + shape[1:]
|
|
||||||
noise = [
|
|
||||||
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) for i in range(batch_size)
|
|
||||||
]
|
|
||||||
noise = torch.cat(noise, dim=0).to(device)
|
|
||||||
else:
|
|
||||||
noise = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
|
|
||||||
|
|
||||||
# get latents
|
# get latents
|
||||||
init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
|
init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
|
||||||
|
|||||||
+2
-15
@@ -18,7 +18,6 @@ from typing import Callable, List, Optional, Union
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
import PIL
|
import PIL
|
||||||
from diffusers.utils import is_accelerate_available
|
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from transformers import CLIPFeatureExtractor, CLIPVisionModelWithProjection
|
from transformers import CLIPFeatureExtractor, CLIPVisionModelWithProjection
|
||||||
|
|
||||||
@@ -32,7 +31,7 @@ from ...schedulers import (
|
|||||||
LMSDiscreteScheduler,
|
LMSDiscreteScheduler,
|
||||||
PNDMScheduler,
|
PNDMScheduler,
|
||||||
)
|
)
|
||||||
from ...utils import deprecate, logging
|
from ...utils import deprecate, is_accelerate_available, logging, randn_tensor
|
||||||
from ..pipeline_utils import DiffusionPipeline
|
from ..pipeline_utils import DiffusionPipeline
|
||||||
from . import StableDiffusionPipelineOutput
|
from . import StableDiffusionPipelineOutput
|
||||||
from .safety_checker import StableDiffusionSafetyChecker
|
from .safety_checker import StableDiffusionSafetyChecker
|
||||||
@@ -267,20 +266,8 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if latents is None:
|
if latents is None:
|
||||||
rand_device = "cpu" if device.type == "mps" else device
|
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||||
|
|
||||||
if isinstance(generator, list):
|
|
||||||
shape = (1,) + shape[1:]
|
|
||||||
latents = [
|
|
||||||
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
|
|
||||||
for i in range(batch_size)
|
|
||||||
]
|
|
||||||
latents = torch.cat(latents, dim=0).to(device)
|
|
||||||
else:
|
|
||||||
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
|
|
||||||
else:
|
else:
|
||||||
if latents.shape != shape:
|
|
||||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
|
||||||
latents = latents.to(device)
|
latents = latents.to(device)
|
||||||
|
|
||||||
# scale the initial noise by the standard deviation required by the scheduler
|
# scale the initial noise by the standard deviation required by the scheduler
|
||||||
|
|||||||
@@ -32,7 +32,14 @@ from ...schedulers import (
|
|||||||
LMSDiscreteScheduler,
|
LMSDiscreteScheduler,
|
||||||
PNDMScheduler,
|
PNDMScheduler,
|
||||||
)
|
)
|
||||||
from ...utils import PIL_INTERPOLATION, deprecate, is_accelerate_available, logging, replace_example_docstring
|
from ...utils import (
|
||||||
|
PIL_INTERPOLATION,
|
||||||
|
deprecate,
|
||||||
|
is_accelerate_available,
|
||||||
|
logging,
|
||||||
|
randn_tensor,
|
||||||
|
replace_example_docstring,
|
||||||
|
)
|
||||||
from ..pipeline_utils import DiffusionPipeline
|
from ..pipeline_utils import DiffusionPipeline
|
||||||
from . import StableDiffusionPipelineOutput
|
from . import StableDiffusionPipelineOutput
|
||||||
from .safety_checker import StableDiffusionSafetyChecker
|
from .safety_checker import StableDiffusionSafetyChecker
|
||||||
@@ -464,16 +471,8 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
|||||||
else:
|
else:
|
||||||
init_latents = torch.cat([init_latents], dim=0)
|
init_latents = torch.cat([init_latents], dim=0)
|
||||||
|
|
||||||
rand_device = "cpu" if device.type == "mps" else device
|
|
||||||
shape = init_latents.shape
|
shape = init_latents.shape
|
||||||
if isinstance(generator, list):
|
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||||
shape = (1,) + shape[1:]
|
|
||||||
noise = [
|
|
||||||
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) for i in range(batch_size)
|
|
||||||
]
|
|
||||||
noise = torch.cat(noise, dim=0).to(device)
|
|
||||||
else:
|
|
||||||
noise = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
|
|
||||||
|
|
||||||
# get latents
|
# get latents
|
||||||
init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
|
init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
|
||||||
|
|||||||
@@ -19,14 +19,13 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
import PIL
|
import PIL
|
||||||
from diffusers.utils import is_accelerate_available
|
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||||
|
|
||||||
from ...configuration_utils import FrozenDict
|
from ...configuration_utils import FrozenDict
|
||||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||||
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||||
from ...utils import deprecate, logging
|
from ...utils import deprecate, is_accelerate_available, logging, randn_tensor
|
||||||
from ..pipeline_utils import DiffusionPipeline
|
from ..pipeline_utils import DiffusionPipeline
|
||||||
from . import StableDiffusionPipelineOutput
|
from . import StableDiffusionPipelineOutput
|
||||||
from .safety_checker import StableDiffusionSafetyChecker
|
from .safety_checker import StableDiffusionSafetyChecker
|
||||||
@@ -470,20 +469,8 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if latents is None:
|
if latents is None:
|
||||||
rand_device = "cpu" if device.type == "mps" else device
|
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||||
|
|
||||||
if isinstance(generator, list):
|
|
||||||
shape = (1,) + shape[1:]
|
|
||||||
latents = [
|
|
||||||
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
|
|
||||||
for i in range(batch_size)
|
|
||||||
]
|
|
||||||
latents = torch.cat(latents, dim=0).to(device)
|
|
||||||
else:
|
|
||||||
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
|
|
||||||
else:
|
else:
|
||||||
if latents.shape != shape:
|
|
||||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
|
||||||
latents = latents.to(device)
|
latents = latents.to(device)
|
||||||
|
|
||||||
# scale the initial noise by the standard deviation required by the scheduler
|
# scale the initial noise by the standard deviation required by the scheduler
|
||||||
|
|||||||
+2
-3
@@ -19,7 +19,6 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
import PIL
|
import PIL
|
||||||
from diffusers.utils import is_accelerate_available
|
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||||
|
|
||||||
@@ -33,7 +32,7 @@ from ...schedulers import (
|
|||||||
LMSDiscreteScheduler,
|
LMSDiscreteScheduler,
|
||||||
PNDMScheduler,
|
PNDMScheduler,
|
||||||
)
|
)
|
||||||
from ...utils import PIL_INTERPOLATION, deprecate, logging
|
from ...utils import PIL_INTERPOLATION, deprecate, is_accelerate_available, logging, randn_tensor
|
||||||
from ..pipeline_utils import DiffusionPipeline
|
from ..pipeline_utils import DiffusionPipeline
|
||||||
from . import StableDiffusionPipelineOutput
|
from . import StableDiffusionPipelineOutput
|
||||||
from .safety_checker import StableDiffusionSafetyChecker
|
from .safety_checker import StableDiffusionSafetyChecker
|
||||||
@@ -414,7 +413,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
|
|||||||
init_latents_orig = init_latents
|
init_latents_orig = init_latents
|
||||||
|
|
||||||
# add noise to latents using the timesteps
|
# add noise to latents using the timesteps
|
||||||
noise = torch.randn(init_latents.shape, generator=generator, device=self.device, dtype=dtype)
|
noise = randn_tensor(init_latents.shape, generator=generator, device=self.device, dtype=dtype)
|
||||||
init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
|
init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
|
||||||
latents = init_latents
|
latents = init_latents
|
||||||
return latents, init_latents_orig, noise
|
return latents, init_latents_orig, noise
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser
|
|||||||
|
|
||||||
from ...pipelines import DiffusionPipeline
|
from ...pipelines import DiffusionPipeline
|
||||||
from ...schedulers import LMSDiscreteScheduler
|
from ...schedulers import LMSDiscreteScheduler
|
||||||
from ...utils import is_accelerate_available, logging
|
from ...utils import is_accelerate_available, logging, randn_tensor
|
||||||
from . import StableDiffusionPipelineOutput
|
from . import StableDiffusionPipelineOutput
|
||||||
|
|
||||||
|
|
||||||
@@ -308,11 +308,7 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline):
|
|||||||
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
||||||
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
||||||
if latents is None:
|
if latents is None:
|
||||||
if device.type == "mps":
|
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||||
# randn does not work reproducibly on mps
|
|
||||||
latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
|
|
||||||
else:
|
|
||||||
latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
|
|
||||||
else:
|
else:
|
||||||
if latents.shape != shape:
|
if latents.shape != shape:
|
||||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
||||||
|
|||||||
@@ -19,12 +19,11 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
import PIL
|
import PIL
|
||||||
from diffusers.utils import is_accelerate_available
|
|
||||||
from transformers import CLIPTextModel, CLIPTokenizer
|
from transformers import CLIPTextModel, CLIPTokenizer
|
||||||
|
|
||||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||||
from ...schedulers import DDIMScheduler, DDPMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
from ...schedulers import DDIMScheduler, DDPMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||||
from ...utils import logging
|
from ...utils import is_accelerate_available, logging, randn_tensor
|
||||||
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||||
|
|
||||||
|
|
||||||
@@ -313,11 +312,7 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline):
|
|||||||
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
||||||
shape = (batch_size, num_channels_latents, height, width)
|
shape = (batch_size, num_channels_latents, height, width)
|
||||||
if latents is None:
|
if latents is None:
|
||||||
if device.type == "mps":
|
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||||
# randn does not work reproducibly on mps
|
|
||||||
latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
|
|
||||||
else:
|
|
||||||
latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
|
|
||||||
else:
|
else:
|
||||||
if latents.shape != shape:
|
if latents.shape != shape:
|
||||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
||||||
@@ -450,11 +445,7 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline):
|
|||||||
|
|
||||||
# 5. Add noise to image
|
# 5. Add noise to image
|
||||||
noise_level = torch.tensor([noise_level], dtype=torch.long, device=device)
|
noise_level = torch.tensor([noise_level], dtype=torch.long, device=device)
|
||||||
if device.type == "mps":
|
noise = randn_tensor(image.shape, generator=generator, device=device, dtype=text_embeddings.dtype)
|
||||||
# randn does not work reproducibly on mps
|
|
||||||
noise = torch.randn(image.shape, generator=generator, device="cpu", dtype=text_embeddings.dtype).to(device)
|
|
||||||
else:
|
|
||||||
noise = torch.randn(image.shape, generator=generator, device=device, dtype=text_embeddings.dtype)
|
|
||||||
image = self.low_res_scheduler.add_noise(image, noise, noise_level)
|
image = self.low_res_scheduler.add_noise(image, noise, noise_level)
|
||||||
|
|
||||||
batch_multiplier = 2 if do_classifier_free_guidance else 1
|
batch_multiplier = 2 if do_classifier_free_guidance else 1
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ from ...schedulers import (
|
|||||||
LMSDiscreteScheduler,
|
LMSDiscreteScheduler,
|
||||||
PNDMScheduler,
|
PNDMScheduler,
|
||||||
)
|
)
|
||||||
from ...utils import deprecate, is_accelerate_available, logging
|
from ...utils import deprecate, is_accelerate_available, logging, randn_tensor
|
||||||
from ..pipeline_utils import DiffusionPipeline
|
from ..pipeline_utils import DiffusionPipeline
|
||||||
from . import StableDiffusionSafePipelineOutput
|
from . import StableDiffusionSafePipelineOutput
|
||||||
from .safety_checker import SafeStableDiffusionSafetyChecker
|
from .safety_checker import SafeStableDiffusionSafetyChecker
|
||||||
@@ -429,20 +429,8 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if latents is None:
|
if latents is None:
|
||||||
rand_device = "cpu" if device.type == "mps" else device
|
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||||
|
|
||||||
if isinstance(generator, list):
|
|
||||||
shape = (1,) + shape[1:]
|
|
||||||
latents = [
|
|
||||||
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
|
|
||||||
for i in range(batch_size)
|
|
||||||
]
|
|
||||||
latents = torch.cat(latents, dim=0).to(device)
|
|
||||||
else:
|
|
||||||
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
|
|
||||||
else:
|
else:
|
||||||
if latents.shape != shape:
|
|
||||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
|
||||||
latents = latents.to(device)
|
latents = latents.to(device)
|
||||||
|
|
||||||
# scale the initial noise by the standard deviation required by the scheduler
|
# scale the initial noise by the standard deviation required by the scheduler
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ import torch
|
|||||||
|
|
||||||
from ...models import UNet2DModel
|
from ...models import UNet2DModel
|
||||||
from ...schedulers import KarrasVeScheduler
|
from ...schedulers import KarrasVeScheduler
|
||||||
|
from ...utils import randn_tensor
|
||||||
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||||
|
|
||||||
|
|
||||||
@@ -81,8 +82,7 @@ class KarrasVePipeline(DiffusionPipeline):
|
|||||||
model = self.unet
|
model = self.unet
|
||||||
|
|
||||||
# sample x_0 ~ N(0, sigma_0^2 * I)
|
# sample x_0 ~ N(0, sigma_0^2 * I)
|
||||||
sample = torch.randn(*shape) * self.scheduler.init_noise_sigma
|
sample = randn_tensor(shape, device=self.device) * self.scheduler.init_noise_sigma
|
||||||
sample = sample.to(self.device)
|
|
||||||
|
|
||||||
self.scheduler.set_timesteps(num_inference_steps)
|
self.scheduler.set_timesteps(num_inference_steps)
|
||||||
|
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ from transformers.models.clip.modeling_clip import CLIPTextModelOutput
|
|||||||
from ...models import PriorTransformer, UNet2DConditionModel, UNet2DModel
|
from ...models import PriorTransformer, UNet2DConditionModel, UNet2DModel
|
||||||
from ...pipelines import DiffusionPipeline, ImagePipelineOutput
|
from ...pipelines import DiffusionPipeline, ImagePipelineOutput
|
||||||
from ...schedulers import UnCLIPScheduler
|
from ...schedulers import UnCLIPScheduler
|
||||||
from ...utils import is_accelerate_available, logging, torch_randn
|
from ...utils import is_accelerate_available, logging, randn_tensor
|
||||||
from .text_proj import UnCLIPTextProjModel
|
from .text_proj import UnCLIPTextProjModel
|
||||||
|
|
||||||
|
|
||||||
@@ -105,7 +105,7 @@ class UnCLIPPipeline(DiffusionPipeline):
|
|||||||
|
|
||||||
def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
|
def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
|
||||||
if latents is None:
|
if latents is None:
|
||||||
latents = torch_randn(shape, generator=generator, device=device, dtype=dtype)
|
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||||
else:
|
else:
|
||||||
if latents.shape != shape:
|
if latents.shape != shape:
|
||||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
||||||
@@ -499,7 +499,6 @@ class UnCLIPPipeline(DiffusionPipeline):
|
|||||||
).prev_sample
|
).prev_sample
|
||||||
|
|
||||||
image = super_res_latents
|
image = super_res_latents
|
||||||
|
|
||||||
# done super res
|
# done super res
|
||||||
|
|
||||||
# post processing
|
# post processing
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ from transformers import (
|
|||||||
from ...models import UNet2DConditionModel, UNet2DModel
|
from ...models import UNet2DConditionModel, UNet2DModel
|
||||||
from ...pipelines import DiffusionPipeline, ImagePipelineOutput
|
from ...pipelines import DiffusionPipeline, ImagePipelineOutput
|
||||||
from ...schedulers import UnCLIPScheduler
|
from ...schedulers import UnCLIPScheduler
|
||||||
from ...utils import is_accelerate_available, logging, torch_randn
|
from ...utils import is_accelerate_available, logging, randn_tensor
|
||||||
from .text_proj import UnCLIPTextProjModel
|
from .text_proj import UnCLIPTextProjModel
|
||||||
|
|
||||||
|
|
||||||
@@ -113,7 +113,7 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline):
|
|||||||
# Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents
|
# Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents
|
||||||
def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
|
def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
|
||||||
if latents is None:
|
if latents is None:
|
||||||
latents = torch_randn(shape, generator=generator, device=device, dtype=dtype)
|
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||||
else:
|
else:
|
||||||
if latents.shape != shape:
|
if latents.shape != shape:
|
||||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
||||||
|
|||||||
+2
-14
@@ -29,7 +29,7 @@ from transformers import (
|
|||||||
|
|
||||||
from ...models import AutoencoderKL, DualTransformer2DModel, Transformer2DModel, UNet2DConditionModel
|
from ...models import AutoencoderKL, DualTransformer2DModel, Transformer2DModel, UNet2DConditionModel
|
||||||
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||||
from ...utils import is_accelerate_available, logging
|
from ...utils import is_accelerate_available, logging, randn_tensor
|
||||||
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||||
from .modeling_text_unet import UNetFlatConditionModel
|
from .modeling_text_unet import UNetFlatConditionModel
|
||||||
|
|
||||||
@@ -382,20 +382,8 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if latents is None:
|
if latents is None:
|
||||||
rand_device = "cpu" if device.type == "mps" else device
|
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||||
|
|
||||||
if isinstance(generator, list):
|
|
||||||
shape = (1,) + shape[1:]
|
|
||||||
latents = [
|
|
||||||
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
|
|
||||||
for i in range(batch_size)
|
|
||||||
]
|
|
||||||
latents = torch.cat(latents, dim=0).to(device)
|
|
||||||
else:
|
|
||||||
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
|
|
||||||
else:
|
else:
|
||||||
if latents.shape != shape:
|
|
||||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
|
||||||
latents = latents.to(device)
|
latents = latents.to(device)
|
||||||
|
|
||||||
# scale the initial noise by the standard deviation required by the scheduler
|
# scale the initial noise by the standard deviation required by the scheduler
|
||||||
|
|||||||
+2
-14
@@ -24,7 +24,7 @@ from transformers import CLIPFeatureExtractor, CLIPVisionModelWithProjection
|
|||||||
|
|
||||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||||
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||||
from ...utils import is_accelerate_available, logging
|
from ...utils import is_accelerate_available, logging, randn_tensor
|
||||||
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||||
|
|
||||||
|
|
||||||
@@ -248,20 +248,8 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if latents is None:
|
if latents is None:
|
||||||
rand_device = "cpu" if device.type == "mps" else device
|
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||||
|
|
||||||
if isinstance(generator, list):
|
|
||||||
shape = (1,) + shape[1:]
|
|
||||||
latents = [
|
|
||||||
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
|
|
||||||
for i in range(batch_size)
|
|
||||||
]
|
|
||||||
latents = torch.cat(latents, dim=0).to(device)
|
|
||||||
else:
|
|
||||||
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
|
|
||||||
else:
|
else:
|
||||||
if latents.shape != shape:
|
|
||||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
|
||||||
latents = latents.to(device)
|
latents = latents.to(device)
|
||||||
|
|
||||||
# scale the initial noise by the standard deviation required by the scheduler
|
# scale the initial noise by the standard deviation required by the scheduler
|
||||||
|
|||||||
+2
-14
@@ -22,7 +22,7 @@ from transformers import CLIPFeatureExtractor, CLIPTextModelWithProjection, CLIP
|
|||||||
|
|
||||||
from ...models import AutoencoderKL, Transformer2DModel, UNet2DConditionModel
|
from ...models import AutoencoderKL, Transformer2DModel, UNet2DConditionModel
|
||||||
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||||
from ...utils import is_accelerate_available, logging
|
from ...utils import is_accelerate_available, logging, randn_tensor
|
||||||
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||||
from .modeling_text_unet import UNetFlatConditionModel
|
from .modeling_text_unet import UNetFlatConditionModel
|
||||||
|
|
||||||
@@ -298,20 +298,8 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if latents is None:
|
if latents is None:
|
||||||
rand_device = "cpu" if device.type == "mps" else device
|
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||||
|
|
||||||
if isinstance(generator, list):
|
|
||||||
shape = (1,) + shape[1:]
|
|
||||||
latents = [
|
|
||||||
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
|
|
||||||
for i in range(batch_size)
|
|
||||||
]
|
|
||||||
latents = torch.cat(latents, dim=0).to(device)
|
|
||||||
else:
|
|
||||||
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
|
|
||||||
else:
|
else:
|
||||||
if latents.shape != shape:
|
|
||||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
|
||||||
latents = latents.to(device)
|
latents = latents.to(device)
|
||||||
|
|
||||||
# scale the initial noise by the standard deviation required by the scheduler
|
# scale the initial noise by the standard deviation required by the scheduler
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ..configuration_utils import ConfigMixin, register_to_config
|
from ..configuration_utils import ConfigMixin, register_to_config
|
||||||
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, deprecate
|
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, deprecate, randn_tensor
|
||||||
from .scheduling_utils import SchedulerMixin
|
from .scheduling_utils import SchedulerMixin
|
||||||
|
|
||||||
|
|
||||||
@@ -324,14 +324,9 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if variance_noise is None:
|
if variance_noise is None:
|
||||||
if device.type == "mps":
|
variance_noise = randn_tensor(
|
||||||
# randn does not work reproducibly on mps
|
model_output.shape, generator=generator, device=device, dtype=model_output.dtype
|
||||||
variance_noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator)
|
)
|
||||||
variance_noise = variance_noise.to(device)
|
|
||||||
else:
|
|
||||||
variance_noise = torch.randn(
|
|
||||||
model_output.shape, generator=generator, device=device, dtype=model_output.dtype
|
|
||||||
)
|
|
||||||
variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * variance_noise
|
variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * variance_noise
|
||||||
|
|
||||||
prev_sample = prev_sample + variance
|
prev_sample = prev_sample + variance
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ..configuration_utils import ConfigMixin, FrozenDict, register_to_config
|
from ..configuration_utils import ConfigMixin, FrozenDict, register_to_config
|
||||||
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, deprecate
|
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, deprecate, randn_tensor
|
||||||
from .scheduling_utils import SchedulerMixin
|
from .scheduling_utils import SchedulerMixin
|
||||||
|
|
||||||
|
|
||||||
@@ -313,14 +313,9 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
variance = 0
|
variance = 0
|
||||||
if t > 0:
|
if t > 0:
|
||||||
device = model_output.device
|
device = model_output.device
|
||||||
if device.type == "mps":
|
variance_noise = randn_tensor(
|
||||||
# randn does not work reproducibly on mps
|
model_output.shape, generator=generator, device=device, dtype=model_output.dtype
|
||||||
variance_noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator)
|
)
|
||||||
variance_noise = variance_noise.to(device)
|
|
||||||
else:
|
|
||||||
variance_noise = torch.randn(
|
|
||||||
model_output.shape, generator=generator, device=device, dtype=model_output.dtype
|
|
||||||
)
|
|
||||||
if self.variance_type == "fixed_small_log":
|
if self.variance_type == "fixed_small_log":
|
||||||
variance = self._get_variance(t, predicted_variance=predicted_variance) * variance_noise
|
variance = self._get_variance(t, predicted_variance=predicted_variance) * variance_noise
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ..configuration_utils import ConfigMixin, register_to_config
|
from ..configuration_utils import ConfigMixin, register_to_config
|
||||||
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, logging
|
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, logging, randn_tensor
|
||||||
from .scheduling_utils import SchedulerMixin
|
from .scheduling_utils import SchedulerMixin
|
||||||
|
|
||||||
|
|
||||||
@@ -230,15 +230,7 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
prev_sample = sample + derivative * dt
|
prev_sample = sample + derivative * dt
|
||||||
|
|
||||||
device = model_output.device
|
device = model_output.device
|
||||||
if device.type == "mps":
|
noise = randn_tensor(model_output.shape, dtype=model_output.dtype, device=device, generator=generator)
|
||||||
# randn does not work reproducibly on mps
|
|
||||||
noise = torch.randn(model_output.shape, dtype=model_output.dtype, device="cpu", generator=generator).to(
|
|
||||||
device
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
noise = torch.randn(model_output.shape, dtype=model_output.dtype, device=device, generator=generator).to(
|
|
||||||
device
|
|
||||||
)
|
|
||||||
|
|
||||||
prev_sample = prev_sample + noise * sigma_up
|
prev_sample = prev_sample + noise * sigma_up
|
||||||
|
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ..configuration_utils import ConfigMixin, register_to_config
|
from ..configuration_utils import ConfigMixin, register_to_config
|
||||||
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, logging
|
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, logging, randn_tensor
|
||||||
from .scheduling_utils import SchedulerMixin
|
from .scheduling_utils import SchedulerMixin
|
||||||
|
|
||||||
|
|
||||||
@@ -217,16 +217,9 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
|
|
||||||
gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0
|
gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0
|
||||||
|
|
||||||
device = model_output.device
|
noise = randn_tensor(
|
||||||
if device.type == "mps":
|
model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
|
||||||
# randn does not work reproducibly on mps
|
)
|
||||||
noise = torch.randn(model_output.shape, dtype=model_output.dtype, device="cpu", generator=generator).to(
|
|
||||||
device
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
noise = torch.randn(model_output.shape, dtype=model_output.dtype, device=device, generator=generator).to(
|
|
||||||
device
|
|
||||||
)
|
|
||||||
|
|
||||||
eps = noise * s_noise
|
eps = noise * s_noise
|
||||||
sigma_hat = sigma * (gamma + 1)
|
sigma_hat = sigma * (gamma + 1)
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ..configuration_utils import ConfigMixin, register_to_config
|
from ..configuration_utils import ConfigMixin, register_to_config
|
||||||
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
|
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, randn_tensor
|
||||||
from .scheduling_utils import SchedulerMixin, SchedulerOutput
|
from .scheduling_utils import SchedulerMixin, SchedulerOutput
|
||||||
|
|
||||||
|
|
||||||
@@ -243,15 +243,7 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
sigma_hat = sigma * (gamma + 1) # Note: sigma_hat == sigma for now
|
sigma_hat = sigma * (gamma + 1) # Note: sigma_hat == sigma for now
|
||||||
|
|
||||||
device = model_output.device
|
device = model_output.device
|
||||||
if device.type == "mps":
|
noise = randn_tensor(model_output.shape, dtype=model_output.dtype, device=device, generator=generator)
|
||||||
# randn does not work reproducibly on mps
|
|
||||||
noise = torch.randn(model_output.shape, dtype=model_output.dtype, device="cpu", generator=generator).to(
|
|
||||||
device
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
noise = torch.randn(model_output.shape, dtype=model_output.dtype, device=device, generator=generator).to(
|
|
||||||
device
|
|
||||||
)
|
|
||||||
|
|
||||||
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
|
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
|
||||||
if self.config.prediction_type == "epsilon":
|
if self.config.prediction_type == "epsilon":
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ..configuration_utils import ConfigMixin, register_to_config
|
from ..configuration_utils import ConfigMixin, register_to_config
|
||||||
from ..utils import BaseOutput
|
from ..utils import BaseOutput, randn_tensor
|
||||||
from .scheduling_utils import SchedulerMixin
|
from .scheduling_utils import SchedulerMixin
|
||||||
|
|
||||||
|
|
||||||
@@ -147,7 +147,7 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
gamma = 0
|
gamma = 0
|
||||||
|
|
||||||
# sample eps ~ N(0, S_noise^2 * I)
|
# sample eps ~ N(0, S_noise^2 * I)
|
||||||
eps = self.config.s_noise * torch.randn(sample.shape, generator=generator).to(sample.device)
|
eps = self.config.s_noise * randn_tensor(sample.shape, generator=generator).to(sample.device)
|
||||||
sigma_hat = sigma + gamma * sigma
|
sigma_hat = sigma + gamma * sigma
|
||||||
sample_hat = sample + ((sigma_hat**2 - sigma**2) ** 0.5 * eps)
|
sample_hat = sample + ((sigma_hat**2 - sigma**2) ** 0.5 * eps)
|
||||||
|
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ..configuration_utils import ConfigMixin, register_to_config
|
from ..configuration_utils import ConfigMixin, register_to_config
|
||||||
from ..utils import BaseOutput
|
from ..utils import BaseOutput, randn_tensor
|
||||||
from .scheduling_utils import SchedulerMixin
|
from .scheduling_utils import SchedulerMixin
|
||||||
|
|
||||||
|
|
||||||
@@ -271,12 +271,7 @@ class RePaintScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
|
|
||||||
# 5. Add noise
|
# 5. Add noise
|
||||||
device = model_output.device
|
device = model_output.device
|
||||||
if device.type == "mps":
|
noise = randn_tensor(model_output.shape, generator=generator, device=device, dtype=model_output.dtype)
|
||||||
# randn does not work reproducibly on mps
|
|
||||||
noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator)
|
|
||||||
noise = noise.to(device)
|
|
||||||
else:
|
|
||||||
noise = torch.randn(model_output.shape, generator=generator, device=device, dtype=model_output.dtype)
|
|
||||||
std_dev_t = self.eta * self._get_variance(timestep) ** 0.5
|
std_dev_t = self.eta * self._get_variance(timestep) ** 0.5
|
||||||
|
|
||||||
variance = 0
|
variance = 0
|
||||||
@@ -311,10 +306,10 @@ class RePaintScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
beta = self.betas[timestep + i]
|
beta = self.betas[timestep + i]
|
||||||
if sample.device.type == "mps":
|
if sample.device.type == "mps":
|
||||||
# randn does not work reproducibly on mps
|
# randn does not work reproducibly on mps
|
||||||
noise = torch.randn(sample.shape, dtype=sample.dtype, generator=generator)
|
noise = randn_tensor(sample.shape, dtype=sample.dtype, generator=generator)
|
||||||
noise = noise.to(sample.device)
|
noise = noise.to(sample.device)
|
||||||
else:
|
else:
|
||||||
noise = torch.randn(sample.shape, generator=generator, device=sample.device, dtype=sample.dtype)
|
noise = randn_tensor(sample.shape, generator=generator, device=sample.device, dtype=sample.dtype)
|
||||||
|
|
||||||
# 10. Algorithm 1 Line 10 https://arxiv.org/pdf/2201.09865.pdf
|
# 10. Algorithm 1 Line 10 https://arxiv.org/pdf/2201.09865.pdf
|
||||||
sample = (1 - beta) ** 0.5 * sample + beta**0.5 * noise
|
sample = (1 - beta) ** 0.5 * sample + beta**0.5 * noise
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ from typing import Optional, Tuple, Union
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ..configuration_utils import ConfigMixin, register_to_config
|
from ..configuration_utils import ConfigMixin, register_to_config
|
||||||
from ..utils import BaseOutput
|
from ..utils import BaseOutput, randn_tensor
|
||||||
from .scheduling_utils import SchedulerMixin, SchedulerOutput
|
from .scheduling_utils import SchedulerMixin, SchedulerOutput
|
||||||
|
|
||||||
|
|
||||||
@@ -201,7 +201,9 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
drift = drift - diffusion**2 * model_output
|
drift = drift - diffusion**2 * model_output
|
||||||
|
|
||||||
# equation 6: sample noise for the diffusion term of
|
# equation 6: sample noise for the diffusion term of
|
||||||
noise = torch.randn(sample.shape, layout=sample.layout, generator=generator).to(sample.device)
|
noise = randn_tensor(
|
||||||
|
sample.shape, layout=sample.layout, generator=generator, device=sample.device, dtype=sample.dtype
|
||||||
|
)
|
||||||
prev_sample_mean = sample - drift # subtract because `dt` is a small negative timestep
|
prev_sample_mean = sample - drift # subtract because `dt` is a small negative timestep
|
||||||
# TODO is the variable diffusion the correct scaling term for the noise?
|
# TODO is the variable diffusion the correct scaling term for the noise?
|
||||||
prev_sample = prev_sample_mean + diffusion * noise # add impact of diffusion field g
|
prev_sample = prev_sample_mean + diffusion * noise # add impact of diffusion field g
|
||||||
@@ -241,7 +243,7 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
|
|
||||||
# For small batch sizes, the paper "suggest replacing norm(z) with sqrt(d), where d is the dim. of z"
|
# For small batch sizes, the paper "suggest replacing norm(z) with sqrt(d), where d is the dim. of z"
|
||||||
# sample noise for correction
|
# sample noise for correction
|
||||||
noise = torch.randn(sample.shape, layout=sample.layout, generator=generator).to(sample.device)
|
noise = randn_tensor(sample.shape, layout=sample.layout, generator=generator).to(sample.device)
|
||||||
|
|
||||||
# compute step size from the model_output, the noise, and the snr
|
# compute step size from the model_output, the noise, and the snr
|
||||||
grad_norm = torch.norm(model_output.reshape(model_output.shape[0], -1), dim=-1).mean()
|
grad_norm = torch.norm(model_output.reshape(model_output.shape[0], -1), dim=-1).mean()
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ from typing import Union
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ..configuration_utils import ConfigMixin, register_to_config
|
from ..configuration_utils import ConfigMixin, register_to_config
|
||||||
|
from ..utils import randn_tensor
|
||||||
from .scheduling_utils import SchedulerMixin
|
from .scheduling_utils import SchedulerMixin
|
||||||
|
|
||||||
|
|
||||||
@@ -80,7 +81,7 @@ class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
x_mean = x + drift * dt
|
x_mean = x + drift * dt
|
||||||
|
|
||||||
# add noise
|
# add noise
|
||||||
noise = torch.randn(x.shape, layout=x.layout, generator=generator).to(x.device)
|
noise = randn_tensor(x.shape, layout=x.layout, generator=generator, device=x.device, dtype=x.dtype)
|
||||||
x = x_mean + diffusion * math.sqrt(-dt) * noise
|
x = x_mean + diffusion * math.sqrt(-dt) * noise
|
||||||
|
|
||||||
return x, x_mean
|
return x, x_mean
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ..configuration_utils import ConfigMixin, register_to_config
|
from ..configuration_utils import ConfigMixin, register_to_config
|
||||||
from ..utils import BaseOutput, torch_randn
|
from ..utils import BaseOutput, randn_tensor
|
||||||
from .scheduling_utils import SchedulerMixin
|
from .scheduling_utils import SchedulerMixin
|
||||||
|
|
||||||
|
|
||||||
@@ -273,7 +273,7 @@ class UnCLIPScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
# 6. Add noise
|
# 6. Add noise
|
||||||
variance = 0
|
variance = 0
|
||||||
if t > 0:
|
if t > 0:
|
||||||
variance_noise = torch_randn(
|
variance_noise = randn_tensor(
|
||||||
model_output.shape, dtype=model_output.dtype, generator=generator, device=model_output.device
|
model_output.shape, dtype=model_output.dtype, generator=generator, device=model_output.device
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -64,7 +64,7 @@ from .import_utils import (
|
|||||||
from .logging import get_logger
|
from .logging import get_logger
|
||||||
from .outputs import BaseOutput
|
from .outputs import BaseOutput
|
||||||
from .pil_utils import PIL_INTERPOLATION
|
from .pil_utils import PIL_INTERPOLATION
|
||||||
from .torch_utils import torch_randn
|
from .torch_utils import randn_tensor
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
|
|||||||
@@ -26,11 +26,12 @@ if is_torch_available():
|
|||||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
def torch_randn(
|
def randn_tensor(
|
||||||
shape: Union[Tuple, List],
|
shape: Union[Tuple, List],
|
||||||
generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None,
|
generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None,
|
||||||
device: Optional["torch.device"] = None,
|
device: Optional["torch.device"] = None,
|
||||||
dtype: Optional["torch.dtype"] = None,
|
dtype: Optional["torch.dtype"] = None,
|
||||||
|
layout: Optional["torch.layout"] = None,
|
||||||
):
|
):
|
||||||
"""This is a helper function that allows to create random tensors on the desired `device` with the desired `dtype`. When
|
"""This is a helper function that allows to create random tensors on the desired `device` with the desired `dtype`. When
|
||||||
passing a list of generators one can seed each batched size individually. If CPU generators are passed the tensor
|
passing a list of generators one can seed each batched size individually. If CPU generators are passed the tensor
|
||||||
@@ -40,8 +41,12 @@ def torch_randn(
|
|||||||
rand_device = device
|
rand_device = device
|
||||||
batch_size = shape[0]
|
batch_size = shape[0]
|
||||||
|
|
||||||
|
layout = layout or torch.strided
|
||||||
|
device = device or torch.device("cpu")
|
||||||
|
|
||||||
if generator is not None:
|
if generator is not None:
|
||||||
if generator.device != device and generator.device.type == "cpu":
|
gen_device_type = generator.device.type if not isinstance(generator, list) else generator[0].device.type
|
||||||
|
if gen_device_type != device.type and gen_device_type == "cpu":
|
||||||
rand_device = "cpu"
|
rand_device = "cpu"
|
||||||
if device != "mps":
|
if device != "mps":
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -49,16 +54,17 @@ def torch_randn(
|
|||||||
f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably"
|
f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably"
|
||||||
f" slighly speed up this function by passing a generator that was created on the {device} device."
|
f" slighly speed up this function by passing a generator that was created on the {device} device."
|
||||||
)
|
)
|
||||||
elif generator.device.type != device.type and generator.device.type == "cuda":
|
elif gen_device_type != device.type and gen_device_type == "cuda":
|
||||||
raise ValueError(f"Cannot generate a {device} tensor from a generator of type {generator.device.type}.")
|
raise ValueError(f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.")
|
||||||
|
|
||||||
if isinstance(generator, list):
|
if isinstance(generator, list):
|
||||||
shape = (1,) + shape[1:]
|
shape = (1,) + shape[1:]
|
||||||
latents = [
|
latents = [
|
||||||
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) for i in range(batch_size)
|
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype, layout=layout)
|
||||||
|
for i in range(batch_size)
|
||||||
]
|
]
|
||||||
latents = torch.cat(latents, dim=0).to(device)
|
latents = torch.cat(latents, dim=0).to(device)
|
||||||
else:
|
else:
|
||||||
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
|
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device)
|
||||||
|
|
||||||
return latents
|
return latents
|
||||||
|
|||||||
@@ -25,44 +25,6 @@ from diffusers.utils.testing_utils import require_torch, slow, torch_device
|
|||||||
torch.backends.cuda.matmul.allow_tf32 = False
|
torch.backends.cuda.matmul.allow_tf32 = False
|
||||||
|
|
||||||
|
|
||||||
class KarrasVePipelineFastTests(unittest.TestCase):
|
|
||||||
@property
|
|
||||||
def dummy_uncond_unet(self):
|
|
||||||
torch.manual_seed(0)
|
|
||||||
model = UNet2DModel(
|
|
||||||
block_out_channels=(32, 64),
|
|
||||||
layers_per_block=2,
|
|
||||||
sample_size=32,
|
|
||||||
in_channels=3,
|
|
||||||
out_channels=3,
|
|
||||||
down_block_types=("DownBlock2D", "AttnDownBlock2D"),
|
|
||||||
up_block_types=("AttnUpBlock2D", "UpBlock2D"),
|
|
||||||
)
|
|
||||||
return model
|
|
||||||
|
|
||||||
def test_inference(self):
|
|
||||||
unet = self.dummy_uncond_unet
|
|
||||||
scheduler = KarrasVeScheduler()
|
|
||||||
|
|
||||||
pipe = KarrasVePipeline(unet=unet, scheduler=scheduler)
|
|
||||||
pipe.to(torch_device)
|
|
||||||
pipe.set_progress_bar_config(disable=None)
|
|
||||||
|
|
||||||
generator = torch.manual_seed(0)
|
|
||||||
image = pipe(num_inference_steps=2, generator=generator, output_type="numpy").images
|
|
||||||
|
|
||||||
generator = torch.manual_seed(0)
|
|
||||||
image_from_tuple = pipe(num_inference_steps=2, generator=generator, output_type="numpy", return_dict=False)[0]
|
|
||||||
|
|
||||||
image_slice = image[0, -3:, -3:, -1]
|
|
||||||
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
|
|
||||||
|
|
||||||
assert image.shape == (1, 32, 32, 3)
|
|
||||||
expected_slice = np.array([0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0])
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
|
||||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
|
||||||
|
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch
|
@require_torch
|
||||||
class KarrasVePipelineIntegrationTests(unittest.TestCase):
|
class KarrasVePipelineIntegrationTests(unittest.TestCase):
|
||||||
|
|||||||
@@ -132,7 +132,7 @@ class PaintByExamplePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
|||||||
image_slice = image[0, -3:, -3:, -1]
|
image_slice = image[0, -3:, -3:, -1]
|
||||||
|
|
||||||
assert image.shape == (1, 64, 64, 3)
|
assert image.shape == (1, 64, 64, 3)
|
||||||
expected_slice = np.array([0.4397, 0.5553, 0.3802, 0.5222, 0.5811, 0.4342, 0.494, 0.4577, 0.4428])
|
expected_slice = np.array([0.4701, 0.5555, 0.3994, 0.5107, 0.5691, 0.4517, 0.5125, 0.4769, 0.4539])
|
||||||
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
|
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ import torch
|
|||||||
|
|
||||||
from diffusers import PriorTransformer, UnCLIPPipeline, UnCLIPScheduler, UNet2DConditionModel, UNet2DModel
|
from diffusers import PriorTransformer, UnCLIPPipeline, UnCLIPScheduler, UNet2DConditionModel, UNet2DModel
|
||||||
from diffusers.pipelines.unclip.text_proj import UnCLIPTextProjModel
|
from diffusers.pipelines.unclip.text_proj import UnCLIPTextProjModel
|
||||||
from diffusers.utils import load_numpy, slow, torch_device
|
from diffusers.utils import load_numpy, nightly, slow, torch_device
|
||||||
from diffusers.utils.testing_utils import require_torch_gpu
|
from diffusers.utils.testing_utils import require_torch_gpu
|
||||||
from transformers import CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer
|
from transformers import CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer
|
||||||
|
|
||||||
@@ -363,6 +363,37 @@ class UnCLIPPipelineFastTests(unittest.TestCase):
|
|||||||
assert np.abs(image - image_from_text).max() < 1e-4
|
assert np.abs(image - image_from_text).max() < 1e-4
|
||||||
|
|
||||||
|
|
||||||
|
@nightly
|
||||||
|
class UnCLIPPipelineCPUIntegrationTests(unittest.TestCase):
|
||||||
|
def tearDown(self):
|
||||||
|
# clean up the VRAM after each test
|
||||||
|
super().tearDown()
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
def test_unclip_karlo_cpu_fp32(self):
|
||||||
|
expected_image = load_numpy(
|
||||||
|
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||||
|
"/unclip/karlo_v1_alpha_horse_cpu.npy"
|
||||||
|
)
|
||||||
|
|
||||||
|
pipeline = UnCLIPPipeline.from_pretrained("kakaobrain/karlo-v1-alpha")
|
||||||
|
pipeline.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
|
generator = torch.manual_seed(0)
|
||||||
|
output = pipeline(
|
||||||
|
"horse",
|
||||||
|
num_images_per_prompt=1,
|
||||||
|
generator=generator,
|
||||||
|
output_type="np",
|
||||||
|
)
|
||||||
|
|
||||||
|
image = output.images[0]
|
||||||
|
|
||||||
|
assert image.shape == (256, 256, 3)
|
||||||
|
assert np.abs(expected_image - image).max() < 1e-1
|
||||||
|
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
class UnCLIPPipelineIntegrationTests(unittest.TestCase):
|
class UnCLIPPipelineIntegrationTests(unittest.TestCase):
|
||||||
@@ -385,15 +416,19 @@ class UnCLIPPipelineIntegrationTests(unittest.TestCase):
|
|||||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||||
output = pipeline(
|
output = pipeline(
|
||||||
"horse",
|
"horse",
|
||||||
num_images_per_prompt=1,
|
|
||||||
generator=generator,
|
generator=generator,
|
||||||
output_type="np",
|
output_type="np",
|
||||||
)
|
)
|
||||||
|
|
||||||
image = output.images[0]
|
image = np.asarray(pipeline.numpy_to_pil(output.images)[0], dtype=np.float32)
|
||||||
|
expected_image = np.asarray(pipeline.numpy_to_pil(expected_image)[0], dtype=np.float32)
|
||||||
|
|
||||||
|
# Karlo is extremely likely to strongly deviate depending on which hardware is used
|
||||||
|
# Here we just check that the image doesn't deviate more than 10 pixels from the reference image on average
|
||||||
|
avg_diff = np.abs(image - expected_image).mean()
|
||||||
|
|
||||||
|
assert avg_diff < 10, f"Error image deviates {avg_diff} pixels on average"
|
||||||
assert image.shape == (256, 256, 3)
|
assert image.shape == (256, 256, 3)
|
||||||
assert np.abs(expected_image - image).max() < 1e-2
|
|
||||||
|
|
||||||
def test_unclip_pipeline_with_sequential_cpu_offloading(self):
|
def test_unclip_pipeline_with_sequential_cpu_offloading(self):
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|||||||
@@ -475,20 +475,25 @@ class UnCLIPImageVariationPipelineIntegrationTests(unittest.TestCase):
|
|||||||
"/unclip/karlo_v1_alpha_cat_variation_fp16.npy"
|
"/unclip/karlo_v1_alpha_cat_variation_fp16.npy"
|
||||||
)
|
)
|
||||||
|
|
||||||
pipeline = UnCLIPImageVariationPipeline.from_pretrained("fusing/karlo-image-variations-diffusers")
|
pipeline = UnCLIPImageVariationPipeline.from_pretrained(
|
||||||
|
"fusing/karlo-image-variations-diffusers", torch_dtype=torch.float16
|
||||||
|
)
|
||||||
pipeline = pipeline.to(torch_device)
|
pipeline = pipeline.to(torch_device)
|
||||||
pipeline.set_progress_bar_config(disable=None)
|
pipeline.set_progress_bar_config(disable=None)
|
||||||
pipeline.enable_sequential_cpu_offload()
|
|
||||||
|
|
||||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||||
output = pipeline(
|
output = pipeline(
|
||||||
input_image,
|
input_image,
|
||||||
num_images_per_prompt=1,
|
|
||||||
generator=generator,
|
generator=generator,
|
||||||
output_type="np",
|
output_type="np",
|
||||||
)
|
)
|
||||||
|
|
||||||
image = output.images[0]
|
image = np.asarray(pipeline.numpy_to_pil(output.images)[0], dtype=np.float32)
|
||||||
|
expected_image = np.asarray(pipeline.numpy_to_pil(expected_image)[0], dtype=np.float32)
|
||||||
|
|
||||||
|
# Karlo is extremely likely to strongly deviate depending on which hardware is used
|
||||||
|
# Here we just check that the image doesn't deviate more than 10 pixels from the reference image on average
|
||||||
|
avg_diff = np.abs(image - expected_image).mean()
|
||||||
|
|
||||||
|
assert avg_diff < 10, f"Error image deviates {avg_diff} pixels on average"
|
||||||
assert image.shape == (256, 256, 3)
|
assert image.shape == (256, 256, 3)
|
||||||
assert np.abs(expected_image - image).max() < 5e-2
|
|
||||||
|
|||||||
Reference in New Issue
Block a user