Compare commits

..

4 Commits

Author SHA1 Message Date
Dhruv Nair 840344b817 update 2024-03-12 10:36:41 +00:00
Dhruv Nair 7739271db3 update 2024-03-12 10:31:29 +00:00
Dhruv Nair 7f1ea22c07 update 2024-03-11 14:56:12 +00:00
Dhruv Nair 0de7e023fd update 2024-03-11 13:18:13 +00:00
9 changed files with 25 additions and 175 deletions
+6 -1
View File
@@ -56,6 +56,8 @@ def build_sub_model_components(
if component_name == "unet":
num_in_channels = kwargs.pop("num_in_channels", None)
upcast_attention = kwargs.pop("upcast_attention", None)
unet_components = create_diffusers_unet_model_from_ldm(
pipeline_class_name,
original_config,
@@ -64,6 +66,7 @@ def build_sub_model_components(
image_size=image_size,
torch_dtype=torch_dtype,
model_type=model_type,
upcast_attention=upcast_attention,
)
return unet_components
@@ -300,7 +303,9 @@ class FromSingleFileMixin:
continue
init_kwargs.update(components)
additional_components = set_additional_components(class_name, original_config, model_type=model_type)
additional_components = set_additional_components(
class_name, original_config, checkpoint=checkpoint, model_type=model_type
)
if additional_components:
init_kwargs.update(additional_components)
+4 -3
View File
@@ -307,7 +307,7 @@ def fetch_original_config(pipeline_class_name, checkpoint, original_config_file=
return original_config
def infer_model_type(original_config, checkpoint=None, model_type=None):
def infer_model_type(original_config, checkpoint, model_type=None):
if model_type is not None:
return model_type
@@ -1176,7 +1176,7 @@ def create_diffusers_unet_model_from_ldm(
original_config,
checkpoint,
num_in_channels=None,
upcast_attention=False,
upcast_attention=None,
extract_ema=False,
image_size=None,
torch_dtype=None,
@@ -1204,7 +1204,8 @@ def create_diffusers_unet_model_from_ldm(
)
unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
unet_config["in_channels"] = num_in_channels
unet_config["upcast_attention"] = upcast_attention
if upcast_attention is not None:
unet_config["upcast_attention"] = upcast_attention
diffusers_format_unet_checkpoint = convert_ldm_unet_checkpoint(checkpoint, unet_config, extract_ema=extract_ema)
ctx = init_empty_weights if is_accelerate_available() else nullcontext
@@ -289,9 +289,7 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
guidance_scale: float = 0.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
prompt_embeds_pooled: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds_pooled: Optional[torch.FloatTensor] = None,
num_images_per_prompt: int = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
@@ -323,17 +321,10 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
prompt_embeds_pooled (`torch.FloatTensor`, *optional*):
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
If not provided, pooled text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
negative_prompt_embeds_pooled (`torch.FloatTensor`, *optional*):
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds_pooled will be generated from `negative_prompt` input
argument.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -387,7 +378,7 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
# 2. Encode caption
if prompt_embeds is None and negative_prompt_embeds is None:
_, prompt_embeds_pooled, _, negative_prompt_embeds_pooled = self.encode_prompt(
prompt_embeds, _, negative_prompt_embeds, _ = self.encode_prompt(
prompt=prompt,
device=device,
batch_size=batch_size,
@@ -395,16 +386,10 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
do_classifier_free_guidance=self.do_classifier_free_guidance,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
prompt_embeds_pooled=prompt_embeds_pooled,
negative_prompt_embeds=negative_prompt_embeds,
negative_prompt_embeds_pooled=negative_prompt_embeds_pooled,
)
# The pooled embeds from the prior are pooled again before being passed to the decoder
prompt_embeds_pooled = (
torch.cat([prompt_embeds_pooled, negative_prompt_embeds_pooled])
if self.do_classifier_free_guidance
else prompt_embeds_pooled
torch.cat([prompt_embeds, negative_prompt_embeds]) if self.do_classifier_free_guidance else prompt_embeds
)
effnet = (
torch.cat([image_embeddings, torch.zeros_like(image_embeddings)])
@@ -155,14 +155,14 @@ class StableCascadeCombinedPipeline(DiffusionPipeline):
height: int = 512,
width: int = 512,
prior_num_inference_steps: int = 60,
prior_timesteps: Optional[List[float]] = None,
prior_guidance_scale: float = 4.0,
num_inference_steps: int = 12,
decoder_timesteps: Optional[List[float]] = None,
decoder_guidance_scale: float = 0.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
prompt_embeds_pooled: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds_pooled: Optional[torch.FloatTensor] = None,
num_images_per_prompt: int = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
@@ -187,17 +187,10 @@ class StableCascadeCombinedPipeline(DiffusionPipeline):
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings for the prior. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, text embeddings will be generated from `prompt` input argument.
prompt_embeds_pooled (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings for the prior. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings for the prior. Can be used to easily tweak text inputs, *e.g.*
prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt`
input argument.
negative_prompt_embeds_pooled (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings for the prior. Can be used to easily tweak text inputs, *e.g.*
prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt`
input argument.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
height (`int`, *optional*, defaults to 512):
@@ -260,6 +253,7 @@ class StableCascadeCombinedPipeline(DiffusionPipeline):
[`~pipelines.ImagePipelineOutput`] or `tuple` [`~pipelines.ImagePipelineOutput`] if `return_dict` is True,
otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images.
"""
prior_outputs = self.prior_pipe(
prompt=prompt if prompt_embeds is None else None,
images=images,
@@ -269,9 +263,7 @@ class StableCascadeCombinedPipeline(DiffusionPipeline):
guidance_scale=prior_guidance_scale,
negative_prompt=negative_prompt if negative_prompt_embeds is None else None,
prompt_embeds=prompt_embeds,
prompt_embeds_pooled=prompt_embeds_pooled,
negative_prompt_embeds=negative_prompt_embeds,
negative_prompt_embeds_pooled=negative_prompt_embeds_pooled,
num_images_per_prompt=num_images_per_prompt,
generator=generator,
latents=latents,
@@ -282,9 +274,7 @@ class StableCascadeCombinedPipeline(DiffusionPipeline):
)
image_embeddings = prior_outputs.image_embeddings
prompt_embeds = prior_outputs.get("prompt_embeds", None)
prompt_embeds_pooled = prior_outputs.get("prompt_embeds_pooled", None)
negative_prompt_embeds = prior_outputs.get("negative_prompt_embeds", None)
negative_prompt_embeds_pooled = prior_outputs.get("negative_prompt_embeds_pooled", None)
outputs = self.decoder_pipe(
image_embeddings=image_embeddings,
@@ -293,9 +283,7 @@ class StableCascadeCombinedPipeline(DiffusionPipeline):
guidance_scale=decoder_guidance_scale,
negative_prompt=negative_prompt if negative_prompt_embeds is None else None,
prompt_embeds=prompt_embeds,
prompt_embeds_pooled=prompt_embeds_pooled,
negative_prompt_embeds=negative_prompt_embeds,
negative_prompt_embeds_pooled=negative_prompt_embeds_pooled,
generator=generator,
output_type=output_type,
return_dict=return_dict,
@@ -64,9 +64,7 @@ class StableCascadePriorPipelineOutput(BaseOutput):
image_embeddings: Union[torch.FloatTensor, np.ndarray]
prompt_embeds: Union[torch.FloatTensor, np.ndarray]
prompt_embeds_pooled: Union[torch.FloatTensor, np.ndarray]
negative_prompt_embeds: Union[torch.FloatTensor, np.ndarray]
negative_prompt_embeds_pooled: Union[torch.FloatTensor, np.ndarray]
class StableCascadePriorPipeline(DiffusionPipeline):
@@ -307,16 +305,6 @@ class StableCascadePriorPipeline(DiffusionPipeline):
f" {negative_prompt_embeds.shape}."
)
if prompt_embeds is not None and prompt_embeds_pooled is None:
raise ValueError(
"If `prompt_embeds` are provided, `prompt_embeds_pooled` must also be provided. Make sure to generate `prompt_embeds_pooled` from the same text encoder that was used to generate `prompt_embeds`"
)
if negative_prompt_embeds is not None and negative_prompt_embeds_pooled is None:
raise ValueError(
"If `negative_prompt_embeds` are provided, `negative_prompt_embeds_pooled` must also be provided. Make sure to generate `prompt_embeds_pooled` from the same text encoder that was used to generate `prompt_embeds`"
)
if prompt_embeds_pooled is not None and negative_prompt_embeds_pooled is not None:
if prompt_embeds_pooled.shape != negative_prompt_embeds_pooled.shape:
raise ValueError(
@@ -351,7 +339,7 @@ class StableCascadePriorPipeline(DiffusionPipeline):
def num_timesteps(self):
return self._num_timesteps
def get_timestep_ratio_conditioning(self, t, alphas_cumprod):
def get_t_condioning(self, t, alphas_cumprod):
s = torch.tensor([0.003])
clamp_range = [0, 1]
min_var = torch.cos(s / (1 + s) * torch.pi * 0.5) ** 2
@@ -570,7 +558,7 @@ class StableCascadePriorPipeline(DiffusionPipeline):
for i, t in enumerate(self.progress_bar(timesteps)):
if not isinstance(self.scheduler, DDPMWuerstchenScheduler):
if len(alphas_cumprod) > 0:
timestep_ratio = self.get_timestep_ratio_conditioning(t.long().cpu(), alphas_cumprod)
timestep_ratio = self.get_t_condioning(t.long().cpu(), alphas_cumprod)
timestep_ratio = timestep_ratio.expand(latents.size(0)).to(dtype).to(device)
else:
timestep_ratio = t.float().div(self.scheduler.timesteps[-1]).expand(latents.size(0)).to(dtype)
@@ -621,18 +609,6 @@ class StableCascadePriorPipeline(DiffusionPipeline):
) # float() as bfloat16-> numpy doesnt work
if not return_dict:
return (
latents,
prompt_embeds,
prompt_embeds_pooled,
negative_prompt_embeds,
negative_prompt_embeds_pooled,
)
return (latents, prompt_embeds, negative_prompt_embeds)
return StableCascadePriorPipelineOutput(
image_embeddings=latents,
prompt_embeds=prompt_embeds,
prompt_embeds_pooled=prompt_embeds_pooled,
negative_prompt_embeds=negative_prompt_embeds,
negative_prompt_embeds_pooled=negative_prompt_embeds_pooled,
)
return StableCascadePriorPipelineOutput(latents, prompt_embeds, negative_prompt_embeds)
@@ -241,39 +241,6 @@ class StableCascadeCombinedPipelineFastTests(PipelineTesterMixin, unittest.TestC
def test_callback_inputs(self):
super().test_callback_inputs()
def test_stable_cascade_combined_prompt_embeds(self):
device = "cpu"
components = self.get_dummy_components()
pipe = StableCascadeCombinedPipeline(**components)
pipe.set_progress_bar_config(disable=None)
prompt = "A photograph of a shiba inu, wearing a hat"
(
prompt_embeds,
prompt_embeds_pooled,
negative_prompt_embeds,
negative_prompt_embeds_pooled,
) = pipe.prior_pipe.encode_prompt(device, 1, 1, False, prompt=prompt)
generator = torch.Generator(device=device)
output_prompt = pipe(
prompt=prompt,
num_inference_steps=1,
prior_num_inference_steps=1,
output_type="np",
generator=generator.manual_seed(0),
)
output_prompt_embeds = pipe(
prompt=None,
prompt_embeds=prompt_embeds,
prompt_embeds_pooled=prompt_embeds_pooled,
negative_prompt_embeds=negative_prompt_embeds,
negative_prompt_embeds_pooled=negative_prompt_embeds_pooled,
num_inference_steps=1,
prior_num_inference_steps=1,
output_type="np",
generator=generator.manual_seed(0),
)
assert np.abs(output_prompt.images - output_prompt_embeds.images).max() < 1e-5
# def test_callback_cfg(self):
# pass
# pass
@@ -207,45 +207,6 @@ class StableCascadeDecoderPipelineFastTests(PipelineTesterMixin, unittest.TestCa
def test_float16_inference(self):
super().test_float16_inference()
def test_stable_cascade_decoder_prompt_embeds(self):
device = "cpu"
components = self.get_dummy_components()
pipe = StableCascadeDecoderPipeline(**components)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
image_embeddings = inputs["image_embeddings"]
prompt = "A photograph of a shiba inu, wearing a hat"
(
prompt_embeds,
prompt_embeds_pooled,
negative_prompt_embeds,
negative_prompt_embeds_pooled,
) = pipe.encode_prompt(device, 1, 1, False, prompt=prompt)
generator = torch.Generator(device=device)
decoder_output_prompt = pipe(
image_embeddings=image_embeddings,
prompt=prompt,
num_inference_steps=1,
output_type="np",
generator=generator.manual_seed(0),
)
decoder_output_prompt_embeds = pipe(
image_embeddings=image_embeddings,
prompt=None,
prompt_embeds=prompt_embeds,
prompt_embeds_pooled=prompt_embeds_pooled,
negative_prompt_embeds=negative_prompt_embeds,
negative_prompt_embeds_pooled=negative_prompt_embeds_pooled,
num_inference_steps=1,
output_type="np",
generator=generator.manual_seed(0),
)
assert np.abs(decoder_output_prompt.images - decoder_output_prompt_embeds.images).max() < 1e-5
@slow
@require_torch_gpu
@@ -273,41 +273,6 @@ class StableCascadePriorPipelineFastTests(PipelineTesterMixin, unittest.TestCase
self.assertTrue(image_embed.shape == lora_image_embed.shape)
def test_stable_cascade_decoder_prompt_embeds(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None)
prompt = "A photograph of a shiba inu, wearing a hat"
(
prompt_embeds,
prompt_embeds_pooled,
negative_prompt_embeds,
negative_prompt_embeds_pooled,
) = pipe.encode_prompt(device, 1, 1, False, prompt=prompt)
generator = torch.Generator(device=device)
output_prompt = pipe(
prompt=prompt,
num_inference_steps=1,
output_type="np",
generator=generator.manual_seed(0),
)
output_prompt_embeds = pipe(
prompt=None,
prompt_embeds=prompt_embeds,
prompt_embeds_pooled=prompt_embeds_pooled,
negative_prompt_embeds=negative_prompt_embeds,
negative_prompt_embeds_pooled=negative_prompt_embeds_pooled,
num_inference_steps=1,
output_type="np",
generator=generator.manual_seed(0),
)
assert np.abs(output_prompt.image_embeddings - output_prompt_embeds.image_embeddings).max() < 1e-5
@slow
@require_torch_gpu
@@ -838,9 +838,11 @@ class StableDiffusionXLImg2ImgIntegrationTests(unittest.TestCase):
for param_name, param_value in single_file_pipe.unet.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
if param_name == "upcast_attention" and pipe.unet.config[param_name] is None:
pipe.unet.config[param_name] = False
assert (
pipe.unet.config[param_name] == param_value
), f"{param_name} differs between single file loading and pretrained loading"
), f"{param_name} is differs between single file loading and pretrained loading"
for param_name, param_value in single_file_pipe.vae.config.items():
if param_name in PARAMS_TO_IGNORE: