[ip-adapter] fix ip-adapter for StableDiffusionInstructPix2PixPipeline (#7820)
update prepare_ip_adapter_ for pix2pix
This commit is contained in:
+87
-8
@@ -172,6 +172,7 @@ class StableDiffusionInstructPix2PixPipeline(
|
|||||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||||
ip_adapter_image: Optional[PipelineImageInput] = None,
|
ip_adapter_image: Optional[PipelineImageInput] = None,
|
||||||
|
ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None,
|
||||||
output_type: Optional[str] = "pil",
|
output_type: Optional[str] = "pil",
|
||||||
return_dict: bool = True,
|
return_dict: bool = True,
|
||||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||||
@@ -296,6 +297,8 @@ class StableDiffusionInstructPix2PixPipeline(
|
|||||||
negative_prompt,
|
negative_prompt,
|
||||||
prompt_embeds,
|
prompt_embeds,
|
||||||
negative_prompt_embeds,
|
negative_prompt_embeds,
|
||||||
|
ip_adapter_image,
|
||||||
|
ip_adapter_image_embeds,
|
||||||
callback_on_step_end_tensor_inputs,
|
callback_on_step_end_tensor_inputs,
|
||||||
)
|
)
|
||||||
self._guidance_scale = guidance_scale
|
self._guidance_scale = guidance_scale
|
||||||
@@ -303,14 +306,6 @@ class StableDiffusionInstructPix2PixPipeline(
|
|||||||
|
|
||||||
device = self._execution_device
|
device = self._execution_device
|
||||||
|
|
||||||
if ip_adapter_image is not None:
|
|
||||||
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
|
|
||||||
image_embeds, negative_image_embeds = self.encode_image(
|
|
||||||
ip_adapter_image, device, num_images_per_prompt, output_hidden_state
|
|
||||||
)
|
|
||||||
if self.do_classifier_free_guidance:
|
|
||||||
image_embeds = torch.cat([image_embeds, negative_image_embeds, negative_image_embeds])
|
|
||||||
|
|
||||||
if image is None:
|
if image is None:
|
||||||
raise ValueError("`image` input cannot be undefined.")
|
raise ValueError("`image` input cannot be undefined.")
|
||||||
|
|
||||||
@@ -335,6 +330,14 @@ class StableDiffusionInstructPix2PixPipeline(
|
|||||||
negative_prompt_embeds=negative_prompt_embeds,
|
negative_prompt_embeds=negative_prompt_embeds,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
|
||||||
|
image_embeds = self.prepare_ip_adapter_image_embeds(
|
||||||
|
ip_adapter_image,
|
||||||
|
ip_adapter_image_embeds,
|
||||||
|
device,
|
||||||
|
batch_size * num_images_per_prompt,
|
||||||
|
self.do_classifier_free_guidance,
|
||||||
|
)
|
||||||
# 3. Preprocess image
|
# 3. Preprocess image
|
||||||
image = self.image_processor.preprocess(image)
|
image = self.image_processor.preprocess(image)
|
||||||
|
|
||||||
@@ -635,6 +638,65 @@ class StableDiffusionInstructPix2PixPipeline(
|
|||||||
|
|
||||||
return image_embeds, uncond_image_embeds
|
return image_embeds, uncond_image_embeds
|
||||||
|
|
||||||
|
def prepare_ip_adapter_image_embeds(
|
||||||
|
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
|
||||||
|
):
|
||||||
|
if ip_adapter_image_embeds is None:
|
||||||
|
if not isinstance(ip_adapter_image, list):
|
||||||
|
ip_adapter_image = [ip_adapter_image]
|
||||||
|
|
||||||
|
if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
|
||||||
|
raise ValueError(
|
||||||
|
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
|
||||||
|
)
|
||||||
|
|
||||||
|
image_embeds = []
|
||||||
|
for single_ip_adapter_image, image_proj_layer in zip(
|
||||||
|
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
|
||||||
|
):
|
||||||
|
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
|
||||||
|
single_image_embeds, single_negative_image_embeds = self.encode_image(
|
||||||
|
single_ip_adapter_image, device, 1, output_hidden_state
|
||||||
|
)
|
||||||
|
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
|
||||||
|
single_negative_image_embeds = torch.stack(
|
||||||
|
[single_negative_image_embeds] * num_images_per_prompt, dim=0
|
||||||
|
)
|
||||||
|
|
||||||
|
if do_classifier_free_guidance:
|
||||||
|
single_image_embeds = torch.cat(
|
||||||
|
[single_image_embeds, single_negative_image_embeds, single_negative_image_embeds]
|
||||||
|
)
|
||||||
|
single_image_embeds = single_image_embeds.to(device)
|
||||||
|
|
||||||
|
image_embeds.append(single_image_embeds)
|
||||||
|
else:
|
||||||
|
repeat_dims = [1]
|
||||||
|
image_embeds = []
|
||||||
|
for single_image_embeds in ip_adapter_image_embeds:
|
||||||
|
if do_classifier_free_guidance:
|
||||||
|
(
|
||||||
|
single_image_embeds,
|
||||||
|
single_negative_image_embeds,
|
||||||
|
single_negative_image_embeds,
|
||||||
|
) = single_image_embeds.chunk(3)
|
||||||
|
single_image_embeds = single_image_embeds.repeat(
|
||||||
|
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
|
||||||
|
)
|
||||||
|
single_negative_image_embeds = single_negative_image_embeds.repeat(
|
||||||
|
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
|
||||||
|
)
|
||||||
|
single_image_embeds = torch.cat(
|
||||||
|
[single_image_embeds, single_negative_image_embeds, single_negative_image_embeds]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
single_image_embeds = single_image_embeds.repeat(
|
||||||
|
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
|
||||||
|
)
|
||||||
|
image_embeds.append(single_image_embeds)
|
||||||
|
|
||||||
|
return image_embeds
|
||||||
|
|
||||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
|
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
|
||||||
def run_safety_checker(self, image, device, dtype):
|
def run_safety_checker(self, image, device, dtype):
|
||||||
if self.safety_checker is None:
|
if self.safety_checker is None:
|
||||||
@@ -687,6 +749,8 @@ class StableDiffusionInstructPix2PixPipeline(
|
|||||||
negative_prompt=None,
|
negative_prompt=None,
|
||||||
prompt_embeds=None,
|
prompt_embeds=None,
|
||||||
negative_prompt_embeds=None,
|
negative_prompt_embeds=None,
|
||||||
|
ip_adapter_image=None,
|
||||||
|
ip_adapter_image_embeds=None,
|
||||||
callback_on_step_end_tensor_inputs=None,
|
callback_on_step_end_tensor_inputs=None,
|
||||||
):
|
):
|
||||||
if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
|
if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
|
||||||
@@ -728,6 +792,21 @@ class StableDiffusionInstructPix2PixPipeline(
|
|||||||
f" {negative_prompt_embeds.shape}."
|
f" {negative_prompt_embeds.shape}."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
|
||||||
|
)
|
||||||
|
|
||||||
|
if ip_adapter_image_embeds is not None:
|
||||||
|
if not isinstance(ip_adapter_image_embeds, list):
|
||||||
|
raise ValueError(
|
||||||
|
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
|
||||||
|
)
|
||||||
|
elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
|
||||||
|
raise ValueError(
|
||||||
|
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
|
||||||
|
)
|
||||||
|
|
||||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
||||||
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
||||||
shape = (
|
shape = (
|
||||||
|
|||||||
-1
@@ -436,7 +436,6 @@ class StableDiffusionXLInstructPix2PixPipeline(
|
|||||||
extra_step_kwargs["generator"] = generator
|
extra_step_kwargs["generator"] = generator
|
||||||
return extra_step_kwargs
|
return extra_step_kwargs
|
||||||
|
|
||||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_instruct_pix2pix.StableDiffusionInstructPix2PixPipeline.check_inputs
|
|
||||||
def check_inputs(
|
def check_inputs(
|
||||||
self,
|
self,
|
||||||
prompt,
|
prompt,
|
||||||
|
|||||||
Reference in New Issue
Block a user