pix2pix tests no write to fs (#2497)
* attend and excite batch test causing timeouts * pix2pix tests, no write to fs
This commit is contained in:
@@ -209,6 +209,13 @@ def load_numpy(arry: Union[str, np.ndarray], local_path: Optional[str] = None) -
|
|||||||
return arry
|
return arry
|
||||||
|
|
||||||
|
|
||||||
|
def load_pt(url: str):
|
||||||
|
response = requests.get(url)
|
||||||
|
response.raise_for_status()
|
||||||
|
arry = torch.load(BytesIO(response.content))
|
||||||
|
return arry
|
||||||
|
|
||||||
|
|
||||||
def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image:
|
def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@@ -17,9 +17,7 @@ import gc
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import requests
|
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
|
||||||
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||||
|
|
||||||
from diffusers import (
|
from diffusers import (
|
||||||
@@ -33,7 +31,7 @@ from diffusers import (
|
|||||||
UNet2DConditionModel,
|
UNet2DConditionModel,
|
||||||
)
|
)
|
||||||
from diffusers.utils import load_numpy, slow, torch_device
|
from diffusers.utils import load_numpy, slow, torch_device
|
||||||
from diffusers.utils.testing_utils import require_torch_gpu, skip_mps
|
from diffusers.utils.testing_utils import load_image, load_pt, require_torch_gpu, skip_mps
|
||||||
|
|
||||||
from ...test_pipelines_common import PipelineTesterMixin
|
from ...test_pipelines_common import PipelineTesterMixin
|
||||||
|
|
||||||
@@ -41,16 +39,20 @@ from ...test_pipelines_common import PipelineTesterMixin
|
|||||||
torch.backends.cuda.matmul.allow_tf32 = False
|
torch.backends.cuda.matmul.allow_tf32 = False
|
||||||
|
|
||||||
|
|
||||||
def download_from_url(embedding_url, local_filepath):
|
|
||||||
r = requests.get(embedding_url)
|
|
||||||
with open(local_filepath, "wb") as f:
|
|
||||||
f.write(r.content)
|
|
||||||
|
|
||||||
|
|
||||||
@skip_mps
|
@skip_mps
|
||||||
class StableDiffusionPix2PixZeroPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
class StableDiffusionPix2PixZeroPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||||
pipeline_class = StableDiffusionPix2PixZeroPipeline
|
pipeline_class = StableDiffusionPix2PixZeroPipeline
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.source_embeds = load_pt(
|
||||||
|
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/pix2pix/src_emb_0.pt"
|
||||||
|
)
|
||||||
|
|
||||||
|
cls.target_embeds = load_pt(
|
||||||
|
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/pix2pix/tgt_emb_0.pt"
|
||||||
|
)
|
||||||
|
|
||||||
def get_dummy_components(self):
|
def get_dummy_components(self):
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
unet = UNet2DConditionModel(
|
unet = UNet2DConditionModel(
|
||||||
@@ -103,15 +105,6 @@ class StableDiffusionPix2PixZeroPipelineFastTests(PipelineTesterMixin, unittest.
|
|||||||
return components
|
return components
|
||||||
|
|
||||||
def get_dummy_inputs(self, device, seed=0):
|
def get_dummy_inputs(self, device, seed=0):
|
||||||
src_emb_url = "https://hf.co/datasets/sayakpaul/sample-datasets/resolve/main/src_emb_0.pt"
|
|
||||||
tgt_emb_url = "https://hf.co/datasets/sayakpaul/sample-datasets/resolve/main/tgt_emb_0.pt"
|
|
||||||
|
|
||||||
for url in [src_emb_url, tgt_emb_url]:
|
|
||||||
download_from_url(url, url.split("/")[-1])
|
|
||||||
|
|
||||||
src_embeds = torch.load(src_emb_url.split("/")[-1])
|
|
||||||
target_embeds = torch.load(tgt_emb_url.split("/")[-1])
|
|
||||||
|
|
||||||
generator = torch.manual_seed(seed)
|
generator = torch.manual_seed(seed)
|
||||||
|
|
||||||
inputs = {
|
inputs = {
|
||||||
@@ -120,8 +113,8 @@ class StableDiffusionPix2PixZeroPipelineFastTests(PipelineTesterMixin, unittest.
|
|||||||
"num_inference_steps": 2,
|
"num_inference_steps": 2,
|
||||||
"guidance_scale": 6.0,
|
"guidance_scale": 6.0,
|
||||||
"cross_attention_guidance_amount": 0.15,
|
"cross_attention_guidance_amount": 0.15,
|
||||||
"source_embeds": src_embeds,
|
"source_embeds": self.source_embeds,
|
||||||
"target_embeds": target_embeds,
|
"target_embeds": self.target_embeds,
|
||||||
"output_type": "numpy",
|
"output_type": "numpy",
|
||||||
}
|
}
|
||||||
return inputs
|
return inputs
|
||||||
@@ -237,26 +230,27 @@ class StableDiffusionPix2PixZeroPipelineSlowTests(unittest.TestCase):
|
|||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.source_embeds = load_pt(
|
||||||
|
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/pix2pix/cat.pt"
|
||||||
|
)
|
||||||
|
|
||||||
|
cls.target_embeds = load_pt(
|
||||||
|
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/pix2pix/dog.pt"
|
||||||
|
)
|
||||||
|
|
||||||
def get_inputs(self, seed=0):
|
def get_inputs(self, seed=0):
|
||||||
generator = torch.manual_seed(seed)
|
generator = torch.manual_seed(seed)
|
||||||
|
|
||||||
src_emb_url = "https://hf.co/datasets/sayakpaul/sample-datasets/resolve/main/cat.pt"
|
|
||||||
tgt_emb_url = "https://hf.co/datasets/sayakpaul/sample-datasets/resolve/main/dog.pt"
|
|
||||||
|
|
||||||
for url in [src_emb_url, tgt_emb_url]:
|
|
||||||
download_from_url(url, url.split("/")[-1])
|
|
||||||
|
|
||||||
src_embeds = torch.load(src_emb_url.split("/")[-1])
|
|
||||||
target_embeds = torch.load(tgt_emb_url.split("/")[-1])
|
|
||||||
|
|
||||||
inputs = {
|
inputs = {
|
||||||
"prompt": "turn him into a cyborg",
|
"prompt": "turn him into a cyborg",
|
||||||
"generator": generator,
|
"generator": generator,
|
||||||
"num_inference_steps": 3,
|
"num_inference_steps": 3,
|
||||||
"guidance_scale": 7.5,
|
"guidance_scale": 7.5,
|
||||||
"cross_attention_guidance_amount": 0.15,
|
"cross_attention_guidance_amount": 0.15,
|
||||||
"source_embeds": src_embeds,
|
"source_embeds": self.source_embeds,
|
||||||
"target_embeds": target_embeds,
|
"target_embeds": self.target_embeds,
|
||||||
"output_type": "numpy",
|
"output_type": "numpy",
|
||||||
}
|
}
|
||||||
return inputs
|
return inputs
|
||||||
@@ -364,10 +358,17 @@ class InversionPipelineSlowTests(unittest.TestCase):
|
|||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
def test_stable_diffusion_pix2pix_inversion(self):
|
@classmethod
|
||||||
img_url = "https://github.com/pix2pixzero/pix2pix-zero/raw/main/assets/test_images/cats/cat_6.png"
|
def setUpClass(cls):
|
||||||
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB").resize((512, 512))
|
raw_image = load_image(
|
||||||
|
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/pix2pix/cat_6.png"
|
||||||
|
)
|
||||||
|
|
||||||
|
raw_image = raw_image.convert("RGB").resize((512, 512))
|
||||||
|
|
||||||
|
cls.raw_image = raw_image
|
||||||
|
|
||||||
|
def test_stable_diffusion_pix2pix_inversion(self):
|
||||||
pipe = StableDiffusionPix2PixZeroPipeline.from_pretrained(
|
pipe = StableDiffusionPix2PixZeroPipeline.from_pretrained(
|
||||||
"CompVis/stable-diffusion-v1-4", safety_checker=None, torch_dtype=torch.float16
|
"CompVis/stable-diffusion-v1-4", safety_checker=None, torch_dtype=torch.float16
|
||||||
)
|
)
|
||||||
@@ -380,7 +381,7 @@ class InversionPipelineSlowTests(unittest.TestCase):
|
|||||||
pipe.set_progress_bar_config(disable=None)
|
pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
generator = torch.manual_seed(0)
|
generator = torch.manual_seed(0)
|
||||||
output = pipe.invert(caption, image=raw_image, generator=generator, num_inference_steps=10)
|
output = pipe.invert(caption, image=self.raw_image, generator=generator, num_inference_steps=10)
|
||||||
inv_latents = output[0]
|
inv_latents = output[0]
|
||||||
|
|
||||||
image_slice = inv_latents[0, -3:, -3:, -1].flatten()
|
image_slice = inv_latents[0, -3:, -3:, -1].flatten()
|
||||||
@@ -391,9 +392,6 @@ class InversionPipelineSlowTests(unittest.TestCase):
|
|||||||
assert np.abs(expected_slice - image_slice.cpu().numpy()).max() < 1e-3
|
assert np.abs(expected_slice - image_slice.cpu().numpy()).max() < 1e-3
|
||||||
|
|
||||||
def test_stable_diffusion_pix2pix_full(self):
|
def test_stable_diffusion_pix2pix_full(self):
|
||||||
img_url = "https://github.com/pix2pixzero/pix2pix-zero/raw/main/assets/test_images/cats/cat_6.png"
|
|
||||||
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB").resize((512, 512))
|
|
||||||
|
|
||||||
# numpy array of https://huggingface.co/datasets/hf-internal-testing/diffusers-images/blob/main/pix2pix/dog.png
|
# numpy array of https://huggingface.co/datasets/hf-internal-testing/diffusers-images/blob/main/pix2pix/dog.png
|
||||||
expected_image = load_numpy(
|
expected_image = load_numpy(
|
||||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/pix2pix/dog.npy"
|
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/pix2pix/dog.npy"
|
||||||
@@ -411,7 +409,7 @@ class InversionPipelineSlowTests(unittest.TestCase):
|
|||||||
pipe.set_progress_bar_config(disable=None)
|
pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
generator = torch.manual_seed(0)
|
generator = torch.manual_seed(0)
|
||||||
output = pipe.invert(caption, image=raw_image, generator=generator)
|
output = pipe.invert(caption, image=self.raw_image, generator=generator)
|
||||||
inv_latents = output[0]
|
inv_latents = output[0]
|
||||||
|
|
||||||
source_prompts = 4 * ["a cat sitting on the street", "a cat playing in the field", "a face of a cat"]
|
source_prompts = 4 * ["a cat sitting on the street", "a cat playing in the field", "a face of a cat"]
|
||||||
|
|||||||
Reference in New Issue
Block a user