From 67e2f95cc4ff8c25f4d04f8bab46df02216527b2 Mon Sep 17 00:00:00 2001 From: Peter Willemsen Date: Wed, 4 Jan 2023 23:07:58 +0100 Subject: [PATCH] New Pipeline: Tiled-upscaling with depth perception to avoid blurry spots (#1615) * added first version of the tiled upscaling pipeline * reformatted to pass code quality tests --- examples/community/tiled_upscaling.py | 298 ++++++++++++++++++++++++++ 1 file changed, 298 insertions(+) create mode 100644 examples/community/tiled_upscaling.py diff --git a/examples/community/tiled_upscaling.py b/examples/community/tiled_upscaling.py new file mode 100644 index 0000000000..e873f2590c --- /dev/null +++ b/examples/community/tiled_upscaling.py @@ -0,0 +1,298 @@ +# Copyright 2022 Peter Willemsen . All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Callable, List, Optional, Union + +import numpy as np +import torch + +import PIL +from diffusers.models import AutoencoderKL, UNet2DConditionModel +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale import StableDiffusionUpscalePipeline +from diffusers.schedulers import DDIMScheduler, DDPMScheduler, LMSDiscreteScheduler, PNDMScheduler +from PIL import Image +from transformers import CLIPTextModel, CLIPTokenizer + + +def make_transparency_mask(size, overlap_pixels, remove_borders=[]): + size_x = size[0] - overlap_pixels * 2 + size_y = size[1] - overlap_pixels * 2 + for letter in ["l", "r"]: + if letter in remove_borders: + size_x += overlap_pixels + for letter in ["t", "b"]: + if letter in remove_borders: + size_y += overlap_pixels + mask = np.ones((size_y, size_x), dtype=np.uint8) * 255 + mask = np.pad(mask, mode="linear_ramp", pad_width=overlap_pixels, end_values=0) + + if "l" in remove_borders: + mask = mask[:, overlap_pixels : mask.shape[1]] + if "r" in remove_borders: + mask = mask[:, 0 : mask.shape[1] - overlap_pixels] + if "t" in remove_borders: + mask = mask[overlap_pixels : mask.shape[0], :] + if "b" in remove_borders: + mask = mask[0 : mask.shape[0] - overlap_pixels, :] + return mask + + +def clamp(n, smallest, largest): + return max(smallest, min(n, largest)) + + +def clamp_rect(rect: [int], min: [int], max: [int]): + return ( + clamp(rect[0], min[0], max[0]), + clamp(rect[1], min[1], max[1]), + clamp(rect[2], min[0], max[0]), + clamp(rect[3], min[1], max[1]), + ) + + +def add_overlap_rect(rect: [int], overlap: int, image_size: [int]): + rect = list(rect) + rect[0] -= overlap + rect[1] -= overlap + rect[2] += overlap + rect[3] += overlap + rect = clamp_rect(rect, [0, 0], [image_size[0], image_size[1]]) + return rect + + +def squeeze_tile(tile, original_image, original_slice, slice_x): + result = Image.new("RGB", (tile.size[0] + original_slice, tile.size[1])) + result.paste( + original_image.resize((tile.size[0], tile.size[1]), Image.BICUBIC).crop( + (slice_x, 0, slice_x + original_slice, tile.size[1]) + ), + (0, 0), + ) + result.paste(tile, (original_slice, 0)) + return result + + +def unsqueeze_tile(tile, original_image_slice): + crop_rect = (original_image_slice * 4, 0, tile.size[0], tile.size[1]) + tile = tile.crop(crop_rect) + return tile + + +def next_divisible(n, d): + divisor = n % d + return n - divisor + + +class StableDiffusionTiledUpscalePipeline(StableDiffusionUpscalePipeline): + r""" + Pipeline for tile-based text-guided image super-resolution using Stable Diffusion 2, trading memory for compute + to create gigantic images. + + This model inherits from [`StableDiffusionUpscalePipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + low_res_scheduler ([`SchedulerMixin`]): + A scheduler used to add initial noise to the low res conditioning image. It must be an instance of + [`DDPMScheduler`]. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + low_res_scheduler: DDPMScheduler, + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + max_noise_level: int = 350, + ): + super().__init__( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + low_res_scheduler=low_res_scheduler, + scheduler=scheduler, + max_noise_level=max_noise_level, + ) + + def _process_tile(self, original_image_slice, x, y, tile_size, tile_border, image, final_image, **kwargs): + torch.manual_seed(0) + crop_rect = ( + min(image.size[0] - (tile_size + original_image_slice), x * tile_size), + min(image.size[1] - (tile_size + original_image_slice), y * tile_size), + min(image.size[0], (x + 1) * tile_size), + min(image.size[1], (y + 1) * tile_size), + ) + crop_rect_with_overlap = add_overlap_rect(crop_rect, tile_border, image.size) + tile = image.crop(crop_rect_with_overlap) + translated_slice_x = ((crop_rect[0] + ((crop_rect[2] - crop_rect[0]) / 2)) / image.size[0]) * tile.size[0] + translated_slice_x = translated_slice_x - (original_image_slice / 2) + translated_slice_x = max(0, translated_slice_x) + to_input = squeeze_tile(tile, image, original_image_slice, translated_slice_x) + orig_input_size = to_input.size + to_input = to_input.resize((tile_size, tile_size), Image.BICUBIC) + upscaled_tile = super(StableDiffusionTiledUpscalePipeline, self).__call__(image=to_input, **kwargs).images[0] + upscaled_tile = upscaled_tile.resize((orig_input_size[0] * 4, orig_input_size[1] * 4), Image.BICUBIC) + upscaled_tile = unsqueeze_tile(upscaled_tile, original_image_slice) + upscaled_tile = upscaled_tile.resize((tile.size[0] * 4, tile.size[1] * 4), Image.BICUBIC) + remove_borders = [] + if x == 0: + remove_borders.append("l") + elif crop_rect[2] == image.size[0]: + remove_borders.append("r") + if y == 0: + remove_borders.append("t") + elif crop_rect[3] == image.size[1]: + remove_borders.append("b") + transparency_mask = Image.fromarray( + make_transparency_mask( + (upscaled_tile.size[0], upscaled_tile.size[1]), tile_border * 4, remove_borders=remove_borders + ), + mode="L", + ) + final_image.paste( + upscaled_tile, (crop_rect_with_overlap[0] * 4, crop_rect_with_overlap[1] * 4), transparency_mask + ) + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + image: Union[PIL.Image.Image, List[PIL.Image.Image]], + num_inference_steps: int = 75, + guidance_scale: float = 9.0, + noise_level: int = 50, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + tile_size: int = 128, + tile_border: int = 32, + original_image_slice: int = 32, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + image (`PIL.Image.Image` or List[`PIL.Image.Image`] or `torch.FloatTensor`): + `Image`, or tensor representing an image batch which will be upscaled. * + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + tile_size (`int`, *optional*): + The size of the tiles. Too big can result in an OOM-error. + tile_border (`int`, *optional*): + The number of pixels around a tile to consider (bigger means less seams, too big can lead to an OOM-error). + original_image_slice (`int`, *optional*): + The amount of pixels of the original image to calculate with the current tile (bigger means more depth + is preserved, less blur occurs in the final image, too big can lead to an OOM-error or loss in detail). + callback (`Callable`, *optional*): + A function that take a callback function with a single argument, a dict, + that contains the (partially) processed image under "image", + as well as the progress (0 to 1, where 1 is completed) under "progress". + + Returns: A PIL.Image that is 4 times larger than the original input image. + + """ + + final_image = Image.new("RGB", (image.size[0] * 4, image.size[1] * 4)) + tcx = math.ceil(image.size[0] / tile_size) + tcy = math.ceil(image.size[1] / tile_size) + total_tile_count = tcx * tcy + current_count = 0 + for y in range(tcy): + for x in range(tcx): + self._process_tile( + original_image_slice, + x, + y, + tile_size, + tile_border, + image, + final_image, + prompt=prompt, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + noise_level=noise_level, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + eta=eta, + generator=generator, + latents=latents, + ) + current_count += 1 + if callback is not None: + callback({"progress": current_count / total_tile_count, "image": final_image}) + return final_image + + +def main(): + # Run a demo + model_id = "stabilityai/stable-diffusion-x4-upscaler" + pipe = StableDiffusionTiledUpscalePipeline.from_pretrained(model_id, revision="fp16", torch_dtype=torch.float16) + pipe = pipe.to("cuda") + image = Image.open("../../docs/source/imgs/diffusers_library.jpg") + + def callback(obj): + print(f"progress: {obj['progress']:.4f}") + obj["image"].save("diffusers_library_progress.jpg") + + final_image = pipe(image=image, prompt="Black font, white background, vector", noise_level=40, callback=callback) + final_image.save("diffusers_library.jpg") + + +if __name__ == "__main__": + main()