pix2pix tests no write to fs (#2497)
* attend and excite batch test causing timeouts * pix2pix tests, no write to fs
This commit is contained in:
@@ -209,6 +209,13 @@ def load_numpy(arry: Union[str, np.ndarray], local_path: Optional[str] = None) -
|
||||
return arry
|
||||
|
||||
|
||||
def load_pt(url: str):
|
||||
response = requests.get(url)
|
||||
response.raise_for_status()
|
||||
arry = torch.load(BytesIO(response.content))
|
||||
return arry
|
||||
|
||||
|
||||
def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image:
|
||||
"""
|
||||
Args:
|
||||
|
||||
@@ -17,9 +17,7 @@ import gc
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from diffusers import (
|
||||
@@ -33,7 +31,7 @@ from diffusers import (
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.utils import load_numpy, slow, torch_device
|
||||
from diffusers.utils.testing_utils import require_torch_gpu, skip_mps
|
||||
from diffusers.utils.testing_utils import load_image, load_pt, require_torch_gpu, skip_mps
|
||||
|
||||
from ...test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
@@ -41,16 +39,20 @@ from ...test_pipelines_common import PipelineTesterMixin
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
|
||||
|
||||
def download_from_url(embedding_url, local_filepath):
|
||||
r = requests.get(embedding_url)
|
||||
with open(local_filepath, "wb") as f:
|
||||
f.write(r.content)
|
||||
|
||||
|
||||
@skip_mps
|
||||
class StableDiffusionPix2PixZeroPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = StableDiffusionPix2PixZeroPipeline
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.source_embeds = load_pt(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/pix2pix/src_emb_0.pt"
|
||||
)
|
||||
|
||||
cls.target_embeds = load_pt(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/pix2pix/tgt_emb_0.pt"
|
||||
)
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
unet = UNet2DConditionModel(
|
||||
@@ -103,15 +105,6 @@ class StableDiffusionPix2PixZeroPipelineFastTests(PipelineTesterMixin, unittest.
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
src_emb_url = "https://hf.co/datasets/sayakpaul/sample-datasets/resolve/main/src_emb_0.pt"
|
||||
tgt_emb_url = "https://hf.co/datasets/sayakpaul/sample-datasets/resolve/main/tgt_emb_0.pt"
|
||||
|
||||
for url in [src_emb_url, tgt_emb_url]:
|
||||
download_from_url(url, url.split("/")[-1])
|
||||
|
||||
src_embeds = torch.load(src_emb_url.split("/")[-1])
|
||||
target_embeds = torch.load(tgt_emb_url.split("/")[-1])
|
||||
|
||||
generator = torch.manual_seed(seed)
|
||||
|
||||
inputs = {
|
||||
@@ -120,8 +113,8 @@ class StableDiffusionPix2PixZeroPipelineFastTests(PipelineTesterMixin, unittest.
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 6.0,
|
||||
"cross_attention_guidance_amount": 0.15,
|
||||
"source_embeds": src_embeds,
|
||||
"target_embeds": target_embeds,
|
||||
"source_embeds": self.source_embeds,
|
||||
"target_embeds": self.target_embeds,
|
||||
"output_type": "numpy",
|
||||
}
|
||||
return inputs
|
||||
@@ -237,26 +230,27 @@ class StableDiffusionPix2PixZeroPipelineSlowTests(unittest.TestCase):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.source_embeds = load_pt(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/pix2pix/cat.pt"
|
||||
)
|
||||
|
||||
cls.target_embeds = load_pt(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/pix2pix/dog.pt"
|
||||
)
|
||||
|
||||
def get_inputs(self, seed=0):
|
||||
generator = torch.manual_seed(seed)
|
||||
|
||||
src_emb_url = "https://hf.co/datasets/sayakpaul/sample-datasets/resolve/main/cat.pt"
|
||||
tgt_emb_url = "https://hf.co/datasets/sayakpaul/sample-datasets/resolve/main/dog.pt"
|
||||
|
||||
for url in [src_emb_url, tgt_emb_url]:
|
||||
download_from_url(url, url.split("/")[-1])
|
||||
|
||||
src_embeds = torch.load(src_emb_url.split("/")[-1])
|
||||
target_embeds = torch.load(tgt_emb_url.split("/")[-1])
|
||||
|
||||
inputs = {
|
||||
"prompt": "turn him into a cyborg",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 3,
|
||||
"guidance_scale": 7.5,
|
||||
"cross_attention_guidance_amount": 0.15,
|
||||
"source_embeds": src_embeds,
|
||||
"target_embeds": target_embeds,
|
||||
"source_embeds": self.source_embeds,
|
||||
"target_embeds": self.target_embeds,
|
||||
"output_type": "numpy",
|
||||
}
|
||||
return inputs
|
||||
@@ -364,10 +358,17 @@ class InversionPipelineSlowTests(unittest.TestCase):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def test_stable_diffusion_pix2pix_inversion(self):
|
||||
img_url = "https://github.com/pix2pixzero/pix2pix-zero/raw/main/assets/test_images/cats/cat_6.png"
|
||||
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB").resize((512, 512))
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
raw_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/pix2pix/cat_6.png"
|
||||
)
|
||||
|
||||
raw_image = raw_image.convert("RGB").resize((512, 512))
|
||||
|
||||
cls.raw_image = raw_image
|
||||
|
||||
def test_stable_diffusion_pix2pix_inversion(self):
|
||||
pipe = StableDiffusionPix2PixZeroPipeline.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4", safety_checker=None, torch_dtype=torch.float16
|
||||
)
|
||||
@@ -380,7 +381,7 @@ class InversionPipelineSlowTests(unittest.TestCase):
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
output = pipe.invert(caption, image=raw_image, generator=generator, num_inference_steps=10)
|
||||
output = pipe.invert(caption, image=self.raw_image, generator=generator, num_inference_steps=10)
|
||||
inv_latents = output[0]
|
||||
|
||||
image_slice = inv_latents[0, -3:, -3:, -1].flatten()
|
||||
@@ -391,9 +392,6 @@ class InversionPipelineSlowTests(unittest.TestCase):
|
||||
assert np.abs(expected_slice - image_slice.cpu().numpy()).max() < 1e-3
|
||||
|
||||
def test_stable_diffusion_pix2pix_full(self):
|
||||
img_url = "https://github.com/pix2pixzero/pix2pix-zero/raw/main/assets/test_images/cats/cat_6.png"
|
||||
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB").resize((512, 512))
|
||||
|
||||
# numpy array of https://huggingface.co/datasets/hf-internal-testing/diffusers-images/blob/main/pix2pix/dog.png
|
||||
expected_image = load_numpy(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/pix2pix/dog.npy"
|
||||
@@ -411,7 +409,7 @@ class InversionPipelineSlowTests(unittest.TestCase):
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
output = pipe.invert(caption, image=raw_image, generator=generator)
|
||||
output = pipe.invert(caption, image=self.raw_image, generator=generator)
|
||||
inv_latents = output[0]
|
||||
|
||||
source_prompts = 4 * ["a cat sitting on the street", "a cat playing in the field", "a face of a cat"]
|
||||
|
||||
Reference in New Issue
Block a user