Compare commits

..

1 Commits

Author SHA1 Message Date
Patrick von Platen 6b2d9e6acd [Draft] MultiControlNet 2023-03-09 10:46:24 +00:00
@@ -14,17 +14,14 @@
import inspect
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import PIL.Image
import torch
from torch import device, nn
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from ...models.controlnet import ControlNetOutput
from ...models.modeling_utils import get_parameter_device, get_parameter_dtype
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
PIL_INTERPOLATION,
@@ -88,60 +85,6 @@ EXAMPLE_DOC_STRING = """
"""
class MultiControlNet(nn.Module):
def __init__(self, controlnets: List[ControlNetModel]):
super().__init__()
self.nets = nn.ModuleList(controlnets)
@property
def device(self) -> device:
"""
`torch.device`: The device on which the module is (assuming that all the module parameters are on the same
device).
"""
return get_parameter_device(self)
@property
def dtype(self) -> torch.dtype:
"""
`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
"""
return get_parameter_dtype(self)
def forward(
self,
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
controlnet_cond: torch.FloatTensor,
class_labels: Optional[torch.Tensor] = None,
timestep_cond: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[ControlNetOutput, Tuple]:
num_images_per_net = controlnet_cond.shape[0] // len(self.nets)
conds = controlnet_cond[None, :].reshape((num_images_per_net, -1) + controlnet_cond.shape[1:])
down_block_res_samples, mid_block_res_sample = 0
for cond, controlnet in zip(conds, self.nets):
down, mid = self.controlnet(
sample,
timestep,
encoder_hidden_states,
cond,
class_labels,
timestep_cond,
attention_mask,
cross_attention_kwargs,
return_dict,
)
down_block_res_samples += down
mid_block_res_sample += mid
return down_block_res_samples, mid_block_res_sample
class StableDiffusionControlNetPipeline(DiffusionPipeline):
r"""
Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance.
@@ -203,8 +146,8 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline):
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
if isinstance(controlnet, (list, tuple)):
controlnet = MultiControlNet(controlnet)
if isinstance(controlnet, list):
controlnet = torch.nn.ModuleList(controlnet)
self.register_modules(
vae=vae,
@@ -577,11 +520,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline):
elif isinstance(image[0], torch.Tensor):
image = torch.cat(image, dim=0)
num_controlnets = 1 if isinstance(self.controlnet, ControlNetModel) else len(self.controlnet.nets)
image_batch_size = image.shape[0] // num_controlnets
if image_batch_size != image.shape[0] * num_controlnets:
raise ValueError("TODO: Good error message here")
image_batch_size = image.shape[0]
if image_batch_size == 1:
repeat_by = batch_size
@@ -780,12 +719,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline):
)
if do_classifier_free_guidance:
num_control = 1 if isinstance(self.controlnet, ControlNetModel) else len(self.controlnet.nets)
image = image[None, :].reshape(num_control, -1, *image.shape[1:])
# only repeat batch size, but not controlnet dim
image = image.repeat(1, 2, 1, 1, 1)
image = image.reshape((image.shape[:2].numel(),) + image.shape[2:])
image = torch.cat([image] * 2)
# 5. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
@@ -807,7 +741,16 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline):
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 8. Denoising loop
# 8. Prepare controlnets for potential multi-controlnet case
controlnets = self.controlnet if isinstance(self.controlnet, torch.nn.ModuleList) else [self.controlnet]
images_per_controlnet = image.shape[0] // len(controlnets)
if images_per_controlnet * len(controlnets) != image.shape[0]:
raise ValueError(f"You have passed {len(controlnets)} ControlNet models, but {image.shape[0]} conditioned images. Please make sure to pass `n` x {len(controlnets)} images to generate `n` output images.")
control_images = image[None, :].reshape(len(controlnets), images_per_controlnet, *image.shape[1:])
# 9. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
@@ -815,19 +758,23 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline):
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
down_block_res_samples, mid_block_res_sample = self.controlnet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
controlnet_cond=image,
return_dict=False,
)
control_down_block_res = control_mid_block_res = 0
for image, controlnet in zip(control_images, controlnets):
down_res, mid_res = self.controlnet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
controlnet_cond=image,
return_dict=False,
)
control_down_block_res += down_res
control_mid_block_res += mid_res
down_block_res_samples = [
control_down_block_res = [
down_block_res_sample * controlnet_conditioning_scale
for down_block_res_sample in down_block_res_samples
for down_block_res_sample in control_down_block_res
]
mid_block_res_sample *= controlnet_conditioning_scale
control_mid_block_res *= controlnet_conditioning_scale
# predict the noise residual
noise_pred = self.unet(
@@ -835,8 +782,8 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline):
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
down_block_additional_residuals=down_block_res_samples,
mid_block_additional_residual=mid_block_res_sample,
down_block_additional_residuals=control_down_block_res,
mid_block_additional_residual=control_mid_block_res,
).sample
# perform guidance