pix2pix tests no write to fs (#2497)

* attend and excite batch test causing timeouts

* pix2pix tests, no write to fs
This commit is contained in:
Will Berman
2023-02-27 15:26:28 -08:00
committed by GitHub
parent 42beaf1d23
commit 1586186eea
2 changed files with 44 additions and 39 deletions
+7
View File
@@ -209,6 +209,13 @@ def load_numpy(arry: Union[str, np.ndarray], local_path: Optional[str] = None) -
return arry return arry
def load_pt(url: str):
response = requests.get(url)
response.raise_for_status()
arry = torch.load(BytesIO(response.content))
return arry
def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image: def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image:
""" """
Args: Args:
@@ -17,9 +17,7 @@ import gc
import unittest import unittest
import numpy as np import numpy as np
import requests
import torch import torch
from PIL import Image
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
from diffusers import ( from diffusers import (
@@ -33,7 +31,7 @@ from diffusers import (
UNet2DConditionModel, UNet2DConditionModel,
) )
from diffusers.utils import load_numpy, slow, torch_device from diffusers.utils import load_numpy, slow, torch_device
from diffusers.utils.testing_utils import require_torch_gpu, skip_mps from diffusers.utils.testing_utils import load_image, load_pt, require_torch_gpu, skip_mps
from ...test_pipelines_common import PipelineTesterMixin from ...test_pipelines_common import PipelineTesterMixin
@@ -41,16 +39,20 @@ from ...test_pipelines_common import PipelineTesterMixin
torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cuda.matmul.allow_tf32 = False
def download_from_url(embedding_url, local_filepath):
r = requests.get(embedding_url)
with open(local_filepath, "wb") as f:
f.write(r.content)
@skip_mps @skip_mps
class StableDiffusionPix2PixZeroPipelineFastTests(PipelineTesterMixin, unittest.TestCase): class StableDiffusionPix2PixZeroPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = StableDiffusionPix2PixZeroPipeline pipeline_class = StableDiffusionPix2PixZeroPipeline
@classmethod
def setUpClass(cls):
cls.source_embeds = load_pt(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/pix2pix/src_emb_0.pt"
)
cls.target_embeds = load_pt(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/pix2pix/tgt_emb_0.pt"
)
def get_dummy_components(self): def get_dummy_components(self):
torch.manual_seed(0) torch.manual_seed(0)
unet = UNet2DConditionModel( unet = UNet2DConditionModel(
@@ -103,15 +105,6 @@ class StableDiffusionPix2PixZeroPipelineFastTests(PipelineTesterMixin, unittest.
return components return components
def get_dummy_inputs(self, device, seed=0): def get_dummy_inputs(self, device, seed=0):
src_emb_url = "https://hf.co/datasets/sayakpaul/sample-datasets/resolve/main/src_emb_0.pt"
tgt_emb_url = "https://hf.co/datasets/sayakpaul/sample-datasets/resolve/main/tgt_emb_0.pt"
for url in [src_emb_url, tgt_emb_url]:
download_from_url(url, url.split("/")[-1])
src_embeds = torch.load(src_emb_url.split("/")[-1])
target_embeds = torch.load(tgt_emb_url.split("/")[-1])
generator = torch.manual_seed(seed) generator = torch.manual_seed(seed)
inputs = { inputs = {
@@ -120,8 +113,8 @@ class StableDiffusionPix2PixZeroPipelineFastTests(PipelineTesterMixin, unittest.
"num_inference_steps": 2, "num_inference_steps": 2,
"guidance_scale": 6.0, "guidance_scale": 6.0,
"cross_attention_guidance_amount": 0.15, "cross_attention_guidance_amount": 0.15,
"source_embeds": src_embeds, "source_embeds": self.source_embeds,
"target_embeds": target_embeds, "target_embeds": self.target_embeds,
"output_type": "numpy", "output_type": "numpy",
} }
return inputs return inputs
@@ -237,26 +230,27 @@ class StableDiffusionPix2PixZeroPipelineSlowTests(unittest.TestCase):
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
@classmethod
def setUpClass(cls):
cls.source_embeds = load_pt(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/pix2pix/cat.pt"
)
cls.target_embeds = load_pt(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/pix2pix/dog.pt"
)
def get_inputs(self, seed=0): def get_inputs(self, seed=0):
generator = torch.manual_seed(seed) generator = torch.manual_seed(seed)
src_emb_url = "https://hf.co/datasets/sayakpaul/sample-datasets/resolve/main/cat.pt"
tgt_emb_url = "https://hf.co/datasets/sayakpaul/sample-datasets/resolve/main/dog.pt"
for url in [src_emb_url, tgt_emb_url]:
download_from_url(url, url.split("/")[-1])
src_embeds = torch.load(src_emb_url.split("/")[-1])
target_embeds = torch.load(tgt_emb_url.split("/")[-1])
inputs = { inputs = {
"prompt": "turn him into a cyborg", "prompt": "turn him into a cyborg",
"generator": generator, "generator": generator,
"num_inference_steps": 3, "num_inference_steps": 3,
"guidance_scale": 7.5, "guidance_scale": 7.5,
"cross_attention_guidance_amount": 0.15, "cross_attention_guidance_amount": 0.15,
"source_embeds": src_embeds, "source_embeds": self.source_embeds,
"target_embeds": target_embeds, "target_embeds": self.target_embeds,
"output_type": "numpy", "output_type": "numpy",
} }
return inputs return inputs
@@ -364,10 +358,17 @@ class InversionPipelineSlowTests(unittest.TestCase):
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
def test_stable_diffusion_pix2pix_inversion(self): @classmethod
img_url = "https://github.com/pix2pixzero/pix2pix-zero/raw/main/assets/test_images/cats/cat_6.png" def setUpClass(cls):
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB").resize((512, 512)) raw_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/pix2pix/cat_6.png"
)
raw_image = raw_image.convert("RGB").resize((512, 512))
cls.raw_image = raw_image
def test_stable_diffusion_pix2pix_inversion(self):
pipe = StableDiffusionPix2PixZeroPipeline.from_pretrained( pipe = StableDiffusionPix2PixZeroPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", safety_checker=None, torch_dtype=torch.float16 "CompVis/stable-diffusion-v1-4", safety_checker=None, torch_dtype=torch.float16
) )
@@ -380,7 +381,7 @@ class InversionPipelineSlowTests(unittest.TestCase):
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
output = pipe.invert(caption, image=raw_image, generator=generator, num_inference_steps=10) output = pipe.invert(caption, image=self.raw_image, generator=generator, num_inference_steps=10)
inv_latents = output[0] inv_latents = output[0]
image_slice = inv_latents[0, -3:, -3:, -1].flatten() image_slice = inv_latents[0, -3:, -3:, -1].flatten()
@@ -391,9 +392,6 @@ class InversionPipelineSlowTests(unittest.TestCase):
assert np.abs(expected_slice - image_slice.cpu().numpy()).max() < 1e-3 assert np.abs(expected_slice - image_slice.cpu().numpy()).max() < 1e-3
def test_stable_diffusion_pix2pix_full(self): def test_stable_diffusion_pix2pix_full(self):
img_url = "https://github.com/pix2pixzero/pix2pix-zero/raw/main/assets/test_images/cats/cat_6.png"
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB").resize((512, 512))
# numpy array of https://huggingface.co/datasets/hf-internal-testing/diffusers-images/blob/main/pix2pix/dog.png # numpy array of https://huggingface.co/datasets/hf-internal-testing/diffusers-images/blob/main/pix2pix/dog.png
expected_image = load_numpy( expected_image = load_numpy(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/pix2pix/dog.npy" "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/pix2pix/dog.npy"
@@ -411,7 +409,7 @@ class InversionPipelineSlowTests(unittest.TestCase):
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
output = pipe.invert(caption, image=raw_image, generator=generator) output = pipe.invert(caption, image=self.raw_image, generator=generator)
inv_latents = output[0] inv_latents = output[0]
source_prompts = 4 * ["a cat sitting on the street", "a cat playing in the field", "a face of a cat"] source_prompts = 4 * ["a cat sitting on the street", "a cat playing in the field", "a face of a cat"]