[ip-adapter] fix ip-adapter for StableDiffusionInstructPix2PixPipeline (#7820)
update prepare_ip_adapter_ for pix2pix
This commit is contained in:
+87
-8
@@ -172,6 +172,7 @@ class StableDiffusionInstructPix2PixPipeline(
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
ip_adapter_image: Optional[PipelineImageInput] = None,
|
||||
ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
@@ -296,6 +297,8 @@ class StableDiffusionInstructPix2PixPipeline(
|
||||
negative_prompt,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
ip_adapter_image,
|
||||
ip_adapter_image_embeds,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
)
|
||||
self._guidance_scale = guidance_scale
|
||||
@@ -303,14 +306,6 @@ class StableDiffusionInstructPix2PixPipeline(
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
if ip_adapter_image is not None:
|
||||
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
|
||||
image_embeds, negative_image_embeds = self.encode_image(
|
||||
ip_adapter_image, device, num_images_per_prompt, output_hidden_state
|
||||
)
|
||||
if self.do_classifier_free_guidance:
|
||||
image_embeds = torch.cat([image_embeds, negative_image_embeds, negative_image_embeds])
|
||||
|
||||
if image is None:
|
||||
raise ValueError("`image` input cannot be undefined.")
|
||||
|
||||
@@ -335,6 +330,14 @@ class StableDiffusionInstructPix2PixPipeline(
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
)
|
||||
|
||||
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
|
||||
image_embeds = self.prepare_ip_adapter_image_embeds(
|
||||
ip_adapter_image,
|
||||
ip_adapter_image_embeds,
|
||||
device,
|
||||
batch_size * num_images_per_prompt,
|
||||
self.do_classifier_free_guidance,
|
||||
)
|
||||
# 3. Preprocess image
|
||||
image = self.image_processor.preprocess(image)
|
||||
|
||||
@@ -635,6 +638,65 @@ class StableDiffusionInstructPix2PixPipeline(
|
||||
|
||||
return image_embeds, uncond_image_embeds
|
||||
|
||||
def prepare_ip_adapter_image_embeds(
|
||||
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
|
||||
):
|
||||
if ip_adapter_image_embeds is None:
|
||||
if not isinstance(ip_adapter_image, list):
|
||||
ip_adapter_image = [ip_adapter_image]
|
||||
|
||||
if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
|
||||
raise ValueError(
|
||||
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
|
||||
)
|
||||
|
||||
image_embeds = []
|
||||
for single_ip_adapter_image, image_proj_layer in zip(
|
||||
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
|
||||
):
|
||||
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
|
||||
single_image_embeds, single_negative_image_embeds = self.encode_image(
|
||||
single_ip_adapter_image, device, 1, output_hidden_state
|
||||
)
|
||||
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
|
||||
single_negative_image_embeds = torch.stack(
|
||||
[single_negative_image_embeds] * num_images_per_prompt, dim=0
|
||||
)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
single_image_embeds = torch.cat(
|
||||
[single_image_embeds, single_negative_image_embeds, single_negative_image_embeds]
|
||||
)
|
||||
single_image_embeds = single_image_embeds.to(device)
|
||||
|
||||
image_embeds.append(single_image_embeds)
|
||||
else:
|
||||
repeat_dims = [1]
|
||||
image_embeds = []
|
||||
for single_image_embeds in ip_adapter_image_embeds:
|
||||
if do_classifier_free_guidance:
|
||||
(
|
||||
single_image_embeds,
|
||||
single_negative_image_embeds,
|
||||
single_negative_image_embeds,
|
||||
) = single_image_embeds.chunk(3)
|
||||
single_image_embeds = single_image_embeds.repeat(
|
||||
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
|
||||
)
|
||||
single_negative_image_embeds = single_negative_image_embeds.repeat(
|
||||
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
|
||||
)
|
||||
single_image_embeds = torch.cat(
|
||||
[single_image_embeds, single_negative_image_embeds, single_negative_image_embeds]
|
||||
)
|
||||
else:
|
||||
single_image_embeds = single_image_embeds.repeat(
|
||||
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
|
||||
)
|
||||
image_embeds.append(single_image_embeds)
|
||||
|
||||
return image_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
|
||||
def run_safety_checker(self, image, device, dtype):
|
||||
if self.safety_checker is None:
|
||||
@@ -687,6 +749,8 @@ class StableDiffusionInstructPix2PixPipeline(
|
||||
negative_prompt=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
ip_adapter_image=None,
|
||||
ip_adapter_image_embeds=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
):
|
||||
if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
|
||||
@@ -728,6 +792,21 @@ class StableDiffusionInstructPix2PixPipeline(
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
|
||||
if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
|
||||
raise ValueError(
|
||||
"Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
|
||||
)
|
||||
|
||||
if ip_adapter_image_embeds is not None:
|
||||
if not isinstance(ip_adapter_image_embeds, list):
|
||||
raise ValueError(
|
||||
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
|
||||
)
|
||||
elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
|
||||
raise ValueError(
|
||||
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
||||
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
||||
shape = (
|
||||
|
||||
-1
@@ -436,7 +436,6 @@ class StableDiffusionXLInstructPix2PixPipeline(
|
||||
extra_step_kwargs["generator"] = generator
|
||||
return extra_step_kwargs
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_instruct_pix2pix.StableDiffusionInstructPix2PixPipeline.check_inputs
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
|
||||
Reference in New Issue
Block a user