Fix type mismatch error, add tests for negative prompts (#823)
This commit is contained in:
@@ -234,8 +234,8 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
|||||||
uncond_tokens = [""]
|
uncond_tokens = [""]
|
||||||
elif type(prompt) is not type(negative_prompt):
|
elif type(prompt) is not type(negative_prompt):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||||
" {type(prompt)}."
|
f" {type(prompt)}."
|
||||||
)
|
)
|
||||||
elif isinstance(negative_prompt, str):
|
elif isinstance(negative_prompt, str):
|
||||||
uncond_tokens = [negative_prompt]
|
uncond_tokens = [negative_prompt]
|
||||||
|
|||||||
@@ -195,7 +195,6 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
|||||||
"""
|
"""
|
||||||
if isinstance(prompt, str):
|
if isinstance(prompt, str):
|
||||||
batch_size = 1
|
batch_size = 1
|
||||||
prompt = [prompt]
|
|
||||||
elif isinstance(prompt, list):
|
elif isinstance(prompt, list):
|
||||||
batch_size = len(prompt)
|
batch_size = len(prompt)
|
||||||
else:
|
else:
|
||||||
@@ -250,8 +249,8 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
|||||||
uncond_tokens = [""]
|
uncond_tokens = [""]
|
||||||
elif type(prompt) is not type(negative_prompt):
|
elif type(prompt) is not type(negative_prompt):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||||
" {type(prompt)}."
|
f" {type(prompt)}."
|
||||||
)
|
)
|
||||||
elif isinstance(negative_prompt, str):
|
elif isinstance(negative_prompt, str):
|
||||||
uncond_tokens = [negative_prompt]
|
uncond_tokens = [negative_prompt]
|
||||||
@@ -285,6 +284,8 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
|||||||
init_latents = init_latent_dist.sample(generator=generator)
|
init_latents = init_latent_dist.sample(generator=generator)
|
||||||
init_latents = 0.18215 * init_latents
|
init_latents = 0.18215 * init_latents
|
||||||
|
|
||||||
|
if isinstance(prompt, str):
|
||||||
|
prompt = [prompt]
|
||||||
if len(prompt) > init_latents.shape[0] and len(prompt) % init_latents.shape[0] == 0:
|
if len(prompt) > init_latents.shape[0] and len(prompt) % init_latents.shape[0] == 0:
|
||||||
# expand init_latents for batch_size
|
# expand init_latents for batch_size
|
||||||
deprecation_message = (
|
deprecation_message = (
|
||||||
|
|||||||
@@ -266,8 +266,8 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
|||||||
uncond_tokens = [""]
|
uncond_tokens = [""]
|
||||||
elif type(prompt) is not type(negative_prompt):
|
elif type(prompt) is not type(negative_prompt):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||||
" {type(prompt)}."
|
f" {type(prompt)}."
|
||||||
)
|
)
|
||||||
elif isinstance(negative_prompt, str):
|
elif isinstance(negative_prompt, str):
|
||||||
uncond_tokens = [negative_prompt]
|
uncond_tokens = [negative_prompt]
|
||||||
|
|||||||
@@ -108,8 +108,8 @@ class StableDiffusionOnnxPipeline(DiffusionPipeline):
|
|||||||
uncond_tokens = [""] * batch_size
|
uncond_tokens = [""] * batch_size
|
||||||
elif type(prompt) is not type(negative_prompt):
|
elif type(prompt) is not type(negative_prompt):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||||
" {type(prompt)}."
|
f" {type(prompt)}."
|
||||||
)
|
)
|
||||||
elif isinstance(negative_prompt, str):
|
elif isinstance(negative_prompt, str):
|
||||||
uncond_tokens = [negative_prompt] * batch_size
|
uncond_tokens = [negative_prompt] * batch_size
|
||||||
|
|||||||
@@ -575,6 +575,46 @@ class PipelineFastTests(unittest.TestCase):
|
|||||||
|
|
||||||
assert np.abs(output_2.images.flatten() - output_1.images.flatten()).max() < 1e-4
|
assert np.abs(output_2.images.flatten() - output_1.images.flatten()).max() < 1e-4
|
||||||
|
|
||||||
|
def test_stable_diffusion_negative_prompt(self):
|
||||||
|
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||||
|
unet = self.dummy_cond_unet
|
||||||
|
scheduler = PNDMScheduler(skip_prk_steps=True)
|
||||||
|
vae = self.dummy_vae
|
||||||
|
bert = self.dummy_text_encoder
|
||||||
|
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||||
|
|
||||||
|
# make sure here that pndm scheduler skips prk
|
||||||
|
sd_pipe = StableDiffusionPipeline(
|
||||||
|
unet=unet,
|
||||||
|
scheduler=scheduler,
|
||||||
|
vae=vae,
|
||||||
|
text_encoder=bert,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
safety_checker=self.dummy_safety_checker,
|
||||||
|
feature_extractor=self.dummy_extractor,
|
||||||
|
)
|
||||||
|
sd_pipe = sd_pipe.to(device)
|
||||||
|
sd_pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
|
prompt = "A painting of a squirrel eating a burger"
|
||||||
|
negative_prompt = "french fries"
|
||||||
|
generator = torch.Generator(device=device).manual_seed(0)
|
||||||
|
output = sd_pipe(
|
||||||
|
prompt,
|
||||||
|
negative_prompt=negative_prompt,
|
||||||
|
generator=generator,
|
||||||
|
guidance_scale=6.0,
|
||||||
|
num_inference_steps=2,
|
||||||
|
output_type="np",
|
||||||
|
)
|
||||||
|
|
||||||
|
image = output.images
|
||||||
|
image_slice = image[0, -3:, -3:, -1]
|
||||||
|
|
||||||
|
assert image.shape == (1, 128, 128, 3)
|
||||||
|
expected_slice = np.array([0.4851, 0.4617, 0.4765, 0.5127, 0.4845, 0.5153, 0.5141, 0.4886, 0.4719])
|
||||||
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
|
|
||||||
def test_score_sde_ve_pipeline(self):
|
def test_score_sde_ve_pipeline(self):
|
||||||
unet = self.dummy_uncond_unet
|
unet = self.dummy_uncond_unet
|
||||||
scheduler = ScoreSdeVeScheduler()
|
scheduler = ScoreSdeVeScheduler()
|
||||||
@@ -704,6 +744,48 @@ class PipelineFastTests(unittest.TestCase):
|
|||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
|
|
||||||
|
def test_stable_diffusion_img2img_negative_prompt(self):
|
||||||
|
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||||
|
unet = self.dummy_cond_unet
|
||||||
|
scheduler = PNDMScheduler(skip_prk_steps=True)
|
||||||
|
vae = self.dummy_vae
|
||||||
|
bert = self.dummy_text_encoder
|
||||||
|
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||||
|
|
||||||
|
init_image = self.dummy_image.to(device)
|
||||||
|
|
||||||
|
# make sure here that pndm scheduler skips prk
|
||||||
|
sd_pipe = StableDiffusionImg2ImgPipeline(
|
||||||
|
unet=unet,
|
||||||
|
scheduler=scheduler,
|
||||||
|
vae=vae,
|
||||||
|
text_encoder=bert,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
safety_checker=self.dummy_safety_checker,
|
||||||
|
feature_extractor=self.dummy_extractor,
|
||||||
|
)
|
||||||
|
sd_pipe = sd_pipe.to(device)
|
||||||
|
sd_pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
|
prompt = "A painting of a squirrel eating a burger"
|
||||||
|
negative_prompt = "french fries"
|
||||||
|
generator = torch.Generator(device=device).manual_seed(0)
|
||||||
|
output = sd_pipe(
|
||||||
|
prompt,
|
||||||
|
negative_prompt=negative_prompt,
|
||||||
|
generator=generator,
|
||||||
|
guidance_scale=6.0,
|
||||||
|
num_inference_steps=2,
|
||||||
|
output_type="np",
|
||||||
|
init_image=init_image,
|
||||||
|
)
|
||||||
|
image = output.images
|
||||||
|
image_slice = image[0, -3:, -3:, -1]
|
||||||
|
|
||||||
|
assert image.shape == (1, 32, 32, 3)
|
||||||
|
expected_slice = np.array([0.4065, 0.3783, 0.4050, 0.5266, 0.4781, 0.4252, 0.4203, 0.4692, 0.4365])
|
||||||
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
|
|
||||||
def test_stable_diffusion_img2img_multiple_init_images(self):
|
def test_stable_diffusion_img2img_multiple_init_images(self):
|
||||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||||
unet = self.dummy_cond_unet
|
unet = self.dummy_cond_unet
|
||||||
@@ -861,6 +943,52 @@ class PipelineFastTests(unittest.TestCase):
|
|||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
|
|
||||||
|
def test_stable_diffusion_inpaint_negative_prompt(self):
|
||||||
|
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||||
|
unet = self.dummy_cond_unet
|
||||||
|
scheduler = PNDMScheduler(skip_prk_steps=True)
|
||||||
|
vae = self.dummy_vae
|
||||||
|
bert = self.dummy_text_encoder
|
||||||
|
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||||
|
|
||||||
|
image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0]
|
||||||
|
init_image = Image.fromarray(np.uint8(image)).convert("RGB")
|
||||||
|
mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((128, 128))
|
||||||
|
|
||||||
|
# make sure here that pndm scheduler skips prk
|
||||||
|
sd_pipe = StableDiffusionInpaintPipeline(
|
||||||
|
unet=unet,
|
||||||
|
scheduler=scheduler,
|
||||||
|
vae=vae,
|
||||||
|
text_encoder=bert,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
safety_checker=self.dummy_safety_checker,
|
||||||
|
feature_extractor=self.dummy_extractor,
|
||||||
|
)
|
||||||
|
sd_pipe = sd_pipe.to(device)
|
||||||
|
sd_pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
|
prompt = "A painting of a squirrel eating a burger"
|
||||||
|
negative_prompt = "french fries"
|
||||||
|
generator = torch.Generator(device=device).manual_seed(0)
|
||||||
|
output = sd_pipe(
|
||||||
|
prompt,
|
||||||
|
negative_prompt=negative_prompt,
|
||||||
|
generator=generator,
|
||||||
|
guidance_scale=6.0,
|
||||||
|
num_inference_steps=2,
|
||||||
|
output_type="np",
|
||||||
|
init_image=init_image,
|
||||||
|
mask_image=mask_image,
|
||||||
|
)
|
||||||
|
|
||||||
|
image = output.images
|
||||||
|
image_slice = image[0, -3:, -3:, -1]
|
||||||
|
|
||||||
|
assert image.shape == (1, 32, 32, 3)
|
||||||
|
expected_slice = np.array([0.4765, 0.5339, 0.4541, 0.6240, 0.5439, 0.4055, 0.5503, 0.5891, 0.5150])
|
||||||
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
|
|
||||||
def test_stable_diffusion_num_images_per_prompt(self):
|
def test_stable_diffusion_num_images_per_prompt(self):
|
||||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||||
unet = self.dummy_cond_unet
|
unet = self.dummy_cond_unet
|
||||||
|
|||||||
Reference in New Issue
Block a user