Compare commits
21 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| ccfaf0b75f | |||
| 7fb481f840 | |||
| 9f5ad1db41 | |||
| 464374fb87 | |||
| d43ce14e2d | |||
| cd0a4a82cf | |||
| 145522cbb7 | |||
| 23bc56a02d | |||
| 5b1dcd1584 | |||
| dbe0094e86 | |||
| f63d32233f | |||
| 5e8e6cb44f | |||
| 3e35f56b00 | |||
| 537891e693 | |||
| 9f28f1abba | |||
| 5d2d23986e | |||
| 1ae9b0595f | |||
| aad69ac2f3 | |||
| ea76880bd7 | |||
| 33f936154d | |||
| e6037e8275 |
@@ -461,12 +461,12 @@ Chain it to an upscaler pipeline to increase the image resolution:
|
||||
from diffusers import StableDiffusionLatentUpscalePipeline
|
||||
|
||||
upscaler = StableDiffusionLatentUpscalePipeline.from_pretrained(
|
||||
"stabilityai/sd-x2-latent-upscaler", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
|
||||
"stabilityai/sd-x2-latent-upscaler", torch_dtype=torch.float16, use_safetensors=True
|
||||
)
|
||||
upscaler.enable_model_cpu_offload()
|
||||
upscaler.enable_xformers_memory_efficient_attention()
|
||||
|
||||
image_2 = upscaler(prompt, image=image_1, output_type="latent").images[0]
|
||||
image_2 = upscaler(prompt, image=image_1).images[0]
|
||||
```
|
||||
|
||||
Finally, chain it to a super-resolution pipeline to further enhance the resolution:
|
||||
|
||||
@@ -106,7 +106,7 @@ Let's try it out!
|
||||
|
||||
## Deconstruct the Stable Diffusion pipeline
|
||||
|
||||
Stable Diffusion is a text-to-image *latent diffusion* model. It is called a latent diffusion model because it works with a lower-dimensional representation of the image instead of the actual pixel space, which makes it more memory efficient. The encoder compresses the image into a smaller representation, and a decoder to convert the compressed representation back into an image. For text-to-image models, you'll need a tokenizer and an encoder to generate text embeddings. From the previous example, you already know you need a UNet model and a scheduler.
|
||||
Stable Diffusion is a text-to-image *latent diffusion* model. It is called a latent diffusion model because it works with a lower-dimensional representation of the image instead of the actual pixel space, which makes it more memory efficient. The encoder compresses the image into a smaller representation, and a decoder converts the compressed representation back into an image. For text-to-image models, you'll need a tokenizer and an encoder to generate text embeddings. From the previous example, you already know you need a UNet model and a scheduler.
|
||||
|
||||
As you can see, this is already more complex than the DDPM pipeline which only contains a UNet model. The Stable Diffusion model has three separate pretrained models.
|
||||
|
||||
|
||||
+101
-28
@@ -24,8 +24,8 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif
|
||||
| Speech to Image | Using automatic-speech-recognition to transcribe text and Stable Diffusion to generate images | [Speech to Image](#speech-to-image) |[Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/speech_to_image.ipynb) | [Mikail Duzenli](https://github.com/MikailINTech)
|
||||
| Wild Card Stable Diffusion | Stable Diffusion Pipeline that supports prompts that contain wildcard terms (indicated by surrounding double underscores), with values instantiated randomly from a corresponding txt file or a dictionary of possible values | [Wildcard Stable Diffusion](#wildcard-stable-diffusion) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/wildcard_stable_diffusion.ipynb) | [Shyam Sudhakaran](https://github.com/shyamsn97) |
|
||||
| [Composable Stable Diffusion](https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/) | Stable Diffusion Pipeline that supports prompts that contain "|" in prompts (as an AND condition) and weights (separated by "|" as well) to positively / negatively weight prompts. | [Composable Stable Diffusion](#composable-stable-diffusion) | - | [Mark Rich](https://github.com/MarkRich) |
|
||||
| Seed Resizing Stable Diffusion | Stable Diffusion Pipeline that supports resizing an image and retaining the concepts of the 512 by 512 generation. | [Seed Resizing](#seed-resizing) | - | [Mark Rich](https://github.com/MarkRich) |
|
||||
| Imagic Stable Diffusion | Stable Diffusion Pipeline that enables writing a text prompt to edit an existing image | [Imagic Stable Diffusion](#imagic-stable-diffusion) | - | [Mark Rich](https://github.com/MarkRich) |
|
||||
| Seed Resizing Stable Diffusion | Stable Diffusion Pipeline that supports resizing an image and retaining the concepts of the 512 by 512 generation. | [Seed Resizing](#seed-resizing) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/seed_resizing.ipynb) | [Mark Rich](https://github.com/MarkRich) |
|
||||
| Imagic Stable Diffusion | Stable Diffusion Pipeline that enables writing a text prompt to edit an existing image | [Imagic Stable Diffusion](#imagic-stable-diffusion) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/imagic_stable_diffusion.ipynb) | [Mark Rich](https://github.com/MarkRich) |
|
||||
| Multilingual Stable Diffusion | Stable Diffusion Pipeline that supports prompts in 50 different languages. | [Multilingual Stable Diffusion](#multilingual-stable-diffusion-pipeline) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/multilingual_stable_diffusion.ipynb) | [Juan Carlos Piñeros](https://github.com/juancopi81) |
|
||||
| GlueGen Stable Diffusion | Stable Diffusion Pipeline that supports prompts in different languages using GlueGen adapter. | [GlueGen Stable Diffusion](#gluegen-stable-diffusion-pipeline) | - | [Phạm Hồng Vinh](https://github.com/rootonchair) |
|
||||
| Image to Image Inpainting Stable Diffusion | Stable Diffusion Pipeline that enables the overlaying of two images and subsequent inpainting | [Image to Image Inpainting Stable Diffusion](#image-to-image-inpainting-stable-diffusion) | - | [Alex McKinney](https://github.com/vvvm23) |
|
||||
@@ -37,7 +37,7 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif
|
||||
| MagicMix | Diffusion Pipeline for semantic mixing of an image and a text prompt | [MagicMix](#magic-mix) | - | [Partho Das](https://github.com/daspartho) |
|
||||
| Stable UnCLIP | Diffusion Pipeline for combining prior model (generate clip image embedding from text, UnCLIPPipeline `"kakaobrain/karlo-v1-alpha"`) and decoder pipeline (decode clip image embedding to image, StableDiffusionImageVariationPipeline `"lambdalabs/sd-image-variations-diffusers"` ). | [Stable UnCLIP](#stable-unclip) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/stable_unclip.ipynb) | [Ray Wang](https://wrong.wang) |
|
||||
| UnCLIP Text Interpolation Pipeline | Diffusion Pipeline that allows passing two prompts and produces images while interpolating between the text-embeddings of the two prompts | [UnCLIP Text Interpolation Pipeline](#unclip-text-interpolation-pipeline) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/unclip_text_interpolation.ipynb)| [Naga Sai Abhinay Devarinti](https://github.com/Abhinay1997/) |
|
||||
| UnCLIP Image Interpolation Pipeline | Diffusion Pipeline that allows passing two images/image_embeddings and produces images while interpolating between their image-embeddings | [UnCLIP Image Interpolation Pipeline](#unclip-image-interpolation-pipeline) | - | [Naga Sai Abhinay Devarinti](https://github.com/Abhinay1997/) |
|
||||
| UnCLIP Image Interpolation Pipeline | Diffusion Pipeline that allows passing two images/image_embeddings and produces images while interpolating between their image-embeddings | [UnCLIP Image Interpolation Pipeline](#unclip-image-interpolation-pipeline) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/unclip_image_interpolation.ipynb)| [Naga Sai Abhinay Devarinti](https://github.com/Abhinay1997/) |
|
||||
| DDIM Noise Comparative Analysis Pipeline | Investigating how the diffusion models learn visual concepts from each noise level (which is a contribution of [P2 weighting (CVPR 2022)](https://arxiv.org/abs/2204.00227)) | [DDIM Noise Comparative Analysis Pipeline](#ddim-noise-comparative-analysis-pipeline) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/ddim_noise_comparative_analysis.ipynb)| [Aengus (Duc-Anh)](https://github.com/aengusng8) |
|
||||
| CLIP Guided Img2Img Stable Diffusion Pipeline | Doing CLIP guidance for image to image generation with Stable Diffusion | [CLIP Guided Img2Img Stable Diffusion](#clip-guided-img2img-stable-diffusion) | - | [Nipun Jindal](https://github.com/nipunjindal/) |
|
||||
| TensorRT Stable Diffusion Text to Image Pipeline | Accelerates the Stable Diffusion Text2Image Pipeline using TensorRT | [TensorRT Stable Diffusion Text to Image Pipeline](#tensorrt-text2image-stable-diffusion-pipeline) | - | [Asfiya Baig](https://github.com/asfiyab-nvidia) |
|
||||
@@ -57,7 +57,7 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif
|
||||
| Latent Consistency Pipeline | Implementation of [Latent Consistency Models: Synthesizing High-Resolution Images with Few-Step Inference](https://arxiv.org/abs/2310.04378) | [Latent Consistency Pipeline](#latent-consistency-pipeline) | - | [Simian Luo](https://github.com/luosiallen) |
|
||||
| Latent Consistency Img2img Pipeline | Img2img pipeline for Latent Consistency Models | [Latent Consistency Img2Img Pipeline](#latent-consistency-img2img-pipeline) | - | [Logan Zoellner](https://github.com/nagolinc) |
|
||||
| Latent Consistency Interpolation Pipeline | Interpolate the latent space of Latent Consistency Models with multiple prompts | [Latent Consistency Interpolation Pipeline](#latent-consistency-interpolation-pipeline) | [](https://colab.research.google.com/drive/1pK3NrLWJSiJsBynLns1K1-IDTW9zbPvl?usp=sharing) | [Aryan V S](https://github.com/a-r-r-o-w) |
|
||||
| SDE Drag Pipeline | The pipeline supports drag editing of images using stochastic differential equations | [SDE Drag Pipeline](#sde-drag-pipeline) | - | [NieShen](https://github.com/NieShenRuc) [Fengqi Zhu](https://github.com/Monohydroxides) |
|
||||
| SDE Drag Pipeline | The pipeline supports drag editing of images using stochastic differential equations | [SDE Drag Pipeline](#sde-drag-pipeline) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/sde_drag.ipynb) | [NieShen](https://github.com/NieShenRuc) [Fengqi Zhu](https://github.com/Monohydroxides) |
|
||||
| Regional Prompting Pipeline | Assign multiple prompts for different regions | [Regional Prompting Pipeline](#regional-prompting-pipeline) | - | [hako-mikan](https://github.com/hako-mikan) |
|
||||
| LDM3D-sr (LDM3D upscaler) | Upscale low resolution RGB and depth inputs to high resolution | [StableDiffusionUpscaleLDM3D Pipeline](https://github.com/estelleafl/diffusers/tree/ldm3d_upscaler_community/examples/community#stablediffusionupscaleldm3d-pipeline) | - | [Estelle Aflalo](https://github.com/estelleafl) |
|
||||
| AnimateDiff ControlNet Pipeline | Combines AnimateDiff with precise motion control using ControlNets | [AnimateDiff ControlNet Pipeline](#animatediff-controlnet-pipeline) | [](https://colab.research.google.com/drive/1SKboYeGjEQmQPWoFC0aLYpBlYdHXkvAu?usp=sharing) | [Aryan V S](https://github.com/a-r-r-o-w) and [Edoardo Botta](https://github.com/EdoardoBotta) |
|
||||
@@ -948,10 +948,15 @@ image.save('./imagic/imagic_image_alpha_2.png')
|
||||
Test seed resizing. Originally generate an image in 512 by 512, then generate image with same seed at 512 by 592 using seed resizing. Finally, generate 512 by 592 using original stable diffusion pipeline.
|
||||
|
||||
```python
|
||||
import os
|
||||
import torch as th
|
||||
import numpy as np
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
# Ensure the save directory exists or create it
|
||||
save_dir = './seed_resize/'
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
has_cuda = th.cuda.is_available()
|
||||
device = th.device('cpu' if not has_cuda else 'cuda')
|
||||
|
||||
@@ -965,7 +970,6 @@ def dummy(images, **kwargs):
|
||||
|
||||
pipe.safety_checker = dummy
|
||||
|
||||
|
||||
images = []
|
||||
th.manual_seed(0)
|
||||
generator = th.Generator("cuda").manual_seed(0)
|
||||
@@ -984,15 +988,14 @@ res = pipe(
|
||||
width=width,
|
||||
generator=generator)
|
||||
image = res.images[0]
|
||||
image.save('./seed_resize/seed_resize_{w}_{h}_image.png'.format(w=width, h=height))
|
||||
|
||||
image.save(os.path.join(save_dir, 'seed_resize_{w}_{h}_image.png'.format(w=width, h=height)))
|
||||
|
||||
th.manual_seed(0)
|
||||
generator = th.Generator("cuda").manual_seed(0)
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
custom_pipeline="/home/mark/open_source/diffusers/examples/community/"
|
||||
custom_pipeline="seed_resize_stable_diffusion"
|
||||
).to(device)
|
||||
|
||||
width = 512
|
||||
@@ -1006,11 +1009,11 @@ res = pipe(
|
||||
width=width,
|
||||
generator=generator)
|
||||
image = res.images[0]
|
||||
image.save('./seed_resize/seed_resize_{w}_{h}_image.png'.format(w=width, h=height))
|
||||
image.save(os.path.join(save_dir, 'seed_resize_{w}_{h}_image.png'.format(w=width, h=height)))
|
||||
|
||||
pipe_compare = DiffusionPipeline.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
custom_pipeline="/home/mark/open_source/diffusers/examples/community/"
|
||||
custom_pipeline="seed_resize_stable_diffusion"
|
||||
).to(device)
|
||||
|
||||
res = pipe_compare(
|
||||
@@ -1023,7 +1026,7 @@ res = pipe_compare(
|
||||
)
|
||||
|
||||
image = res.images[0]
|
||||
image.save('./seed_resize/seed_resize_{w}_{h}_image_compare.png'.format(w=width, h=height))
|
||||
image.save(os.path.join(save_dir, 'seed_resize_{w}_{h}_image_compare.png'.format(w=width, h=height)))
|
||||
```
|
||||
|
||||
### Multilingual Stable Diffusion Pipeline
|
||||
@@ -1543,6 +1546,8 @@ This Diffusion Pipeline takes two images or an image_embeddings tensor of size 2
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline
|
||||
from PIL import Image
|
||||
import requests
|
||||
from io import BytesIO
|
||||
|
||||
device = torch.device("cpu" if not torch.cuda.is_available() else "cuda")
|
||||
dtype = torch.float16 if torch.cuda.is_available() else torch.bfloat16
|
||||
@@ -1554,13 +1559,25 @@ pipe = DiffusionPipeline.from_pretrained(
|
||||
)
|
||||
pipe.to(device)
|
||||
|
||||
images = [Image.open('./starry_night.jpg'), Image.open('./flowers.jpg')]
|
||||
# List of image URLs
|
||||
image_urls = [
|
||||
'https://camo.githubusercontent.com/ef13c8059b12947c0d5e8d3ea88900de6bf1cd76bbf61ace3928e824c491290e/68747470733a2f2f68756767696e67666163652e636f2f64617461736574732f4e616761536169416268696e61792f556e434c4950496d616765496e746572706f6c6174696f6e53616d706c65732f7265736f6c76652f6d61696e2f7374617272795f6e696768742e6a7067',
|
||||
'https://camo.githubusercontent.com/d1947ab7c49ae3f550c28409d5e8b120df48e456559cf4557306c0848337702c/68747470733a2f2f68756767696e67666163652e636f2f64617461736574732f4e616761536169416268696e61792f556e434c4950496d616765496e746572706f6c6174696f6e53616d706c65732f7265736f6c76652f6d61696e2f666c6f776572732e6a7067'
|
||||
]
|
||||
|
||||
# Open images from URLs
|
||||
images = []
|
||||
for url in image_urls:
|
||||
response = requests.get(url)
|
||||
img = Image.open(BytesIO(response.content))
|
||||
images.append(img)
|
||||
|
||||
# For best results keep the prompts close in length to each other. Of course, feel free to try out with differing lengths.
|
||||
generator = torch.Generator(device=device).manual_seed(42)
|
||||
|
||||
output = pipe(image=images, steps=6, generator=generator)
|
||||
|
||||
for i,image in enumerate(output.images):
|
||||
for i, image in enumerate(output.images):
|
||||
image.save('starry_to_flowers_%s.jpg' % i)
|
||||
```
|
||||
|
||||
@@ -3909,33 +3926,89 @@ This pipeline provides drag-and-drop image editing using stochastic differential
|
||||
See [paper](https://arxiv.org/abs/2311.01410), [paper page](https://ml-gsai.github.io/SDE-Drag-demo/), [original repo](https://github.com/ML-GSAI/SDE-Drag) for more information.
|
||||
|
||||
```py
|
||||
import PIL
|
||||
import torch
|
||||
from diffusers import DDIMScheduler, DiffusionPipeline
|
||||
from PIL import Image
|
||||
import requests
|
||||
from io import BytesIO
|
||||
import numpy as np
|
||||
|
||||
# Load the pipeline
|
||||
model_path = "stable-diffusion-v1-5/stable-diffusion-v1-5"
|
||||
scheduler = DDIMScheduler.from_pretrained(model_path, subfolder="scheduler")
|
||||
pipe = DiffusionPipeline.from_pretrained(model_path, scheduler=scheduler, custom_pipeline="sde_drag")
|
||||
pipe.to('cuda')
|
||||
|
||||
# To save GPU memory, torch.float16 can be used, but it may compromise image quality.
|
||||
# If not training LoRA, please avoid using torch.float16
|
||||
# pipe.to(torch.float16)
|
||||
# Ensure the model is moved to the GPU
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
pipe.to(device)
|
||||
|
||||
# Provide prompt, image, mask image, and the starting and target points for drag editing.
|
||||
prompt = "prompt of the image"
|
||||
image = PIL.Image.open('/path/to/image')
|
||||
mask_image = PIL.Image.open('/path/to/mask_image')
|
||||
source_points = [[123, 456]]
|
||||
target_points = [[234, 567]]
|
||||
# Function to load image from URL
|
||||
def load_image_from_url(url):
|
||||
response = requests.get(url)
|
||||
return Image.open(BytesIO(response.content)).convert("RGB")
|
||||
|
||||
# train_lora is optional, and in most cases, using train_lora can better preserve consistency with the original image.
|
||||
pipe.train_lora(prompt, image)
|
||||
# Function to prepare mask
|
||||
def prepare_mask(mask_image):
|
||||
# Convert to grayscale
|
||||
mask = mask_image.convert("L")
|
||||
return mask
|
||||
|
||||
output = pipe(prompt, image, mask_image, source_points, target_points)
|
||||
output_image = PIL.Image.fromarray(output)
|
||||
# Function to convert numpy array to PIL Image
|
||||
def array_to_pil(array):
|
||||
# Ensure the array is in uint8 format
|
||||
if array.dtype != np.uint8:
|
||||
if array.max() <= 1.0:
|
||||
array = (array * 255).astype(np.uint8)
|
||||
else:
|
||||
array = array.astype(np.uint8)
|
||||
|
||||
# Handle different array shapes
|
||||
if len(array.shape) == 3:
|
||||
if array.shape[0] == 3: # If channels first
|
||||
array = array.transpose(1, 2, 0)
|
||||
return Image.fromarray(array)
|
||||
elif len(array.shape) == 4: # If batch dimension
|
||||
array = array[0]
|
||||
if array.shape[0] == 3: # If channels first
|
||||
array = array.transpose(1, 2, 0)
|
||||
return Image.fromarray(array)
|
||||
else:
|
||||
raise ValueError(f"Unexpected array shape: {array.shape}")
|
||||
|
||||
# Image and mask URLs
|
||||
image_url = 'https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png'
|
||||
mask_url = 'https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png'
|
||||
|
||||
# Load the images
|
||||
image = load_image_from_url(image_url)
|
||||
mask_image = load_image_from_url(mask_url)
|
||||
|
||||
# Resize images to a size that's compatible with the model's latent space
|
||||
image = image.resize((512, 512))
|
||||
mask_image = mask_image.resize((512, 512))
|
||||
|
||||
# Prepare the mask (keep as PIL Image)
|
||||
mask = prepare_mask(mask_image)
|
||||
|
||||
# Provide the prompt and points for drag editing
|
||||
prompt = "A cute dog"
|
||||
source_points = [[32, 32]] # Adjusted for 512x512 image
|
||||
target_points = [[64, 64]] # Adjusted for 512x512 image
|
||||
|
||||
# Generate the output image
|
||||
output_array = pipe(
|
||||
prompt=prompt,
|
||||
image=image,
|
||||
mask_image=mask,
|
||||
source_points=source_points,
|
||||
target_points=target_points
|
||||
)
|
||||
|
||||
# Convert output array to PIL Image and save
|
||||
output_image = array_to_pil(output_array)
|
||||
output_image.save("./output.png")
|
||||
print("Output image saved as './output.png'")
|
||||
|
||||
```
|
||||
|
||||
### Instaflow Pipeline
|
||||
|
||||
@@ -995,7 +995,8 @@ def main(args):
|
||||
if args.enable_npu_flash_attention:
|
||||
if is_torch_npu_available():
|
||||
logger.info("npu flash attention enabled.")
|
||||
transformer.enable_npu_flash_attention()
|
||||
for block in transformer.transformer_blocks:
|
||||
block.attn2.set_use_npu_flash_attention(True)
|
||||
else:
|
||||
raise ValueError("npu flash attention requires torch_npu extensions and is supported only on npu device ")
|
||||
|
||||
|
||||
@@ -695,7 +695,7 @@ def main():
|
||||
)
|
||||
# We need to ensure that the original and the edited images undergo the same
|
||||
# augmentation transforms.
|
||||
images = np.concatenate([original_images, edited_images])
|
||||
images = np.stack([original_images, edited_images])
|
||||
images = torch.tensor(images)
|
||||
images = 2 * (images / 255) - 1
|
||||
return train_transforms(images)
|
||||
@@ -706,7 +706,7 @@ def main():
|
||||
# Since the original and edited images were concatenated before
|
||||
# applying the transformations, we need to separate them and reshape
|
||||
# them accordingly.
|
||||
original_images, edited_images = preprocessed_images.chunk(2)
|
||||
original_images, edited_images = preprocessed_images
|
||||
original_images = original_images.reshape(-1, 3, args.resolution, args.resolution)
|
||||
edited_images = edited_images.reshape(-1, 3, args.resolution, args.resolution)
|
||||
|
||||
|
||||
@@ -766,7 +766,7 @@ def main():
|
||||
)
|
||||
# We need to ensure that the original and the edited images undergo the same
|
||||
# augmentation transforms.
|
||||
images = np.concatenate([original_images, edited_images])
|
||||
images = np.stack([original_images, edited_images])
|
||||
images = torch.tensor(images)
|
||||
images = 2 * (images / 255) - 1
|
||||
return train_transforms(images)
|
||||
@@ -906,7 +906,7 @@ def main():
|
||||
# Since the original and edited images were concatenated before
|
||||
# applying the transformations, we need to separate them and reshape
|
||||
# them accordingly.
|
||||
original_images, edited_images = preprocessed_images.chunk(2)
|
||||
original_images, edited_images = preprocessed_images
|
||||
original_images = original_images.reshape(-1, 3, args.resolution, args.resolution)
|
||||
edited_images = edited_images.reshape(-1, 3, args.resolution, args.resolution)
|
||||
|
||||
|
||||
@@ -82,31 +82,11 @@ pipeline = EasyPipelineForInpainting.from_huggingface(
|
||||
## Search Civitai and Huggingface
|
||||
|
||||
```python
|
||||
from pipeline_easy import (
|
||||
search_huggingface,
|
||||
search_civitai,
|
||||
)
|
||||
|
||||
# Search Lora
|
||||
Lora = search_civitai(
|
||||
"Keyword_to_search_Lora",
|
||||
model_type="LORA",
|
||||
base_model = "SD 1.5",
|
||||
download=True,
|
||||
)
|
||||
# Load Lora into the pipeline.
|
||||
pipeline.load_lora_weights(Lora)
|
||||
pipeline.auto_load_lora_weights("Detail Tweaker")
|
||||
|
||||
|
||||
# Search TextualInversion
|
||||
TextualInversion = search_civitai(
|
||||
"EasyNegative",
|
||||
model_type="TextualInversion",
|
||||
base_model = "SD 1.5",
|
||||
download=True
|
||||
)
|
||||
# Load TextualInversion into the pipeline.
|
||||
pipeline.load_textual_inversion(TextualInversion, token="EasyNegative")
|
||||
pipeline.auto_load_textual_inversion("EasyNegative", token="EasyNegative")
|
||||
```
|
||||
|
||||
### Search Civitai
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 suzukimain
|
||||
# Copyright 2025 suzukimain
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -15,11 +15,13 @@
|
||||
|
||||
import os
|
||||
import re
|
||||
import types
|
||||
from collections import OrderedDict
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Union
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import requests
|
||||
import torch
|
||||
from huggingface_hub import hf_api, hf_hub_download
|
||||
from huggingface_hub.file_download import http_get
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
@@ -30,6 +32,7 @@ from diffusers.loaders.single_file_utils import (
|
||||
infer_diffusers_model_type,
|
||||
load_single_file_checkpoint,
|
||||
)
|
||||
from diffusers.pipelines.animatediff import AnimateDiffPipeline, AnimateDiffSDXLPipeline
|
||||
from diffusers.pipelines.auto_pipeline import (
|
||||
AutoPipelineForImage2Image,
|
||||
AutoPipelineForInpainting,
|
||||
@@ -39,13 +42,18 @@ from diffusers.pipelines.controlnet import (
|
||||
StableDiffusionControlNetImg2ImgPipeline,
|
||||
StableDiffusionControlNetInpaintPipeline,
|
||||
StableDiffusionControlNetPipeline,
|
||||
StableDiffusionXLControlNetImg2ImgPipeline,
|
||||
StableDiffusionXLControlNetPipeline,
|
||||
)
|
||||
from diffusers.pipelines.flux import FluxImg2ImgPipeline, FluxPipeline
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.pipelines.stable_diffusion import (
|
||||
StableDiffusionImg2ImgPipeline,
|
||||
StableDiffusionInpaintPipeline,
|
||||
StableDiffusionPipeline,
|
||||
StableDiffusionUpscalePipeline,
|
||||
)
|
||||
from diffusers.pipelines.stable_diffusion_3 import StableDiffusion3Img2ImgPipeline, StableDiffusion3Pipeline
|
||||
from diffusers.pipelines.stable_diffusion_xl import (
|
||||
StableDiffusionXLImg2ImgPipeline,
|
||||
StableDiffusionXLInpaintPipeline,
|
||||
@@ -59,46 +67,133 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
SINGLE_FILE_CHECKPOINT_TEXT2IMAGE_PIPELINE_MAPPING = OrderedDict(
|
||||
[
|
||||
("xl_base", StableDiffusionXLPipeline),
|
||||
("xl_refiner", StableDiffusionXLPipeline),
|
||||
("xl_inpaint", None),
|
||||
("playground-v2-5", StableDiffusionXLPipeline),
|
||||
("upscale", None),
|
||||
("animatediff_rgb", AnimateDiffPipeline),
|
||||
("animatediff_scribble", AnimateDiffPipeline),
|
||||
("animatediff_sdxl_beta", AnimateDiffSDXLPipeline),
|
||||
("animatediff_v1", AnimateDiffPipeline),
|
||||
("animatediff_v2", AnimateDiffPipeline),
|
||||
("animatediff_v3", AnimateDiffPipeline),
|
||||
("autoencoder-dc-f128c512", None),
|
||||
("autoencoder-dc-f32c32", None),
|
||||
("autoencoder-dc-f32c32-sana", None),
|
||||
("autoencoder-dc-f64c128", None),
|
||||
("controlnet", StableDiffusionControlNetPipeline),
|
||||
("controlnet_xl", StableDiffusionXLControlNetPipeline),
|
||||
("controlnet_xl_large", StableDiffusionXLControlNetPipeline),
|
||||
("controlnet_xl_mid", StableDiffusionXLControlNetPipeline),
|
||||
("controlnet_xl_small", StableDiffusionXLControlNetPipeline),
|
||||
("flux-depth", FluxPipeline),
|
||||
("flux-dev", FluxPipeline),
|
||||
("flux-fill", FluxPipeline),
|
||||
("flux-schnell", FluxPipeline),
|
||||
("hunyuan-video", None),
|
||||
("inpainting", None),
|
||||
("inpainting_v2", None),
|
||||
("controlnet", StableDiffusionControlNetPipeline),
|
||||
("v2", StableDiffusionPipeline),
|
||||
("ltx-video", None),
|
||||
("ltx-video-0.9.1", None),
|
||||
("mochi-1-preview", None),
|
||||
("playground-v2-5", StableDiffusionXLPipeline),
|
||||
("sd3", StableDiffusion3Pipeline),
|
||||
("sd35_large", StableDiffusion3Pipeline),
|
||||
("sd35_medium", StableDiffusion3Pipeline),
|
||||
("stable_cascade_stage_b", None),
|
||||
("stable_cascade_stage_b_lite", None),
|
||||
("stable_cascade_stage_c", None),
|
||||
("stable_cascade_stage_c_lite", None),
|
||||
("upscale", StableDiffusionUpscalePipeline),
|
||||
("v1", StableDiffusionPipeline),
|
||||
("v2", StableDiffusionPipeline),
|
||||
("xl_base", StableDiffusionXLPipeline),
|
||||
("xl_inpaint", None),
|
||||
("xl_refiner", StableDiffusionXLPipeline),
|
||||
]
|
||||
)
|
||||
|
||||
SINGLE_FILE_CHECKPOINT_IMAGE2IMAGE_PIPELINE_MAPPING = OrderedDict(
|
||||
[
|
||||
("xl_base", StableDiffusionXLImg2ImgPipeline),
|
||||
("xl_refiner", StableDiffusionXLImg2ImgPipeline),
|
||||
("xl_inpaint", None),
|
||||
("playground-v2-5", StableDiffusionXLImg2ImgPipeline),
|
||||
("upscale", None),
|
||||
("animatediff_rgb", AnimateDiffPipeline),
|
||||
("animatediff_scribble", AnimateDiffPipeline),
|
||||
("animatediff_sdxl_beta", AnimateDiffSDXLPipeline),
|
||||
("animatediff_v1", AnimateDiffPipeline),
|
||||
("animatediff_v2", AnimateDiffPipeline),
|
||||
("animatediff_v3", AnimateDiffPipeline),
|
||||
("autoencoder-dc-f128c512", None),
|
||||
("autoencoder-dc-f32c32", None),
|
||||
("autoencoder-dc-f32c32-sana", None),
|
||||
("autoencoder-dc-f64c128", None),
|
||||
("controlnet", StableDiffusionControlNetImg2ImgPipeline),
|
||||
("controlnet_xl", StableDiffusionXLControlNetImg2ImgPipeline),
|
||||
("controlnet_xl_large", StableDiffusionXLControlNetImg2ImgPipeline),
|
||||
("controlnet_xl_mid", StableDiffusionXLControlNetImg2ImgPipeline),
|
||||
("controlnet_xl_small", StableDiffusionXLControlNetImg2ImgPipeline),
|
||||
("flux-depth", FluxImg2ImgPipeline),
|
||||
("flux-dev", FluxImg2ImgPipeline),
|
||||
("flux-fill", FluxImg2ImgPipeline),
|
||||
("flux-schnell", FluxImg2ImgPipeline),
|
||||
("hunyuan-video", None),
|
||||
("inpainting", None),
|
||||
("inpainting_v2", None),
|
||||
("controlnet", StableDiffusionControlNetImg2ImgPipeline),
|
||||
("v2", StableDiffusionImg2ImgPipeline),
|
||||
("ltx-video", None),
|
||||
("ltx-video-0.9.1", None),
|
||||
("mochi-1-preview", None),
|
||||
("playground-v2-5", StableDiffusionXLImg2ImgPipeline),
|
||||
("sd3", StableDiffusion3Img2ImgPipeline),
|
||||
("sd35_large", StableDiffusion3Img2ImgPipeline),
|
||||
("sd35_medium", StableDiffusion3Img2ImgPipeline),
|
||||
("stable_cascade_stage_b", None),
|
||||
("stable_cascade_stage_b_lite", None),
|
||||
("stable_cascade_stage_c", None),
|
||||
("stable_cascade_stage_c_lite", None),
|
||||
("upscale", StableDiffusionUpscalePipeline),
|
||||
("v1", StableDiffusionImg2ImgPipeline),
|
||||
("v2", StableDiffusionImg2ImgPipeline),
|
||||
("xl_base", StableDiffusionXLImg2ImgPipeline),
|
||||
("xl_inpaint", None),
|
||||
("xl_refiner", StableDiffusionXLImg2ImgPipeline),
|
||||
]
|
||||
)
|
||||
|
||||
SINGLE_FILE_CHECKPOINT_INPAINT_PIPELINE_MAPPING = OrderedDict(
|
||||
[
|
||||
("xl_base", None),
|
||||
("xl_refiner", None),
|
||||
("xl_inpaint", StableDiffusionXLInpaintPipeline),
|
||||
("playground-v2-5", None),
|
||||
("upscale", None),
|
||||
("animatediff_rgb", None),
|
||||
("animatediff_scribble", None),
|
||||
("animatediff_sdxl_beta", None),
|
||||
("animatediff_v1", None),
|
||||
("animatediff_v2", None),
|
||||
("animatediff_v3", None),
|
||||
("autoencoder-dc-f128c512", None),
|
||||
("autoencoder-dc-f32c32", None),
|
||||
("autoencoder-dc-f32c32-sana", None),
|
||||
("autoencoder-dc-f64c128", None),
|
||||
("controlnet", StableDiffusionControlNetInpaintPipeline),
|
||||
("controlnet_xl", None),
|
||||
("controlnet_xl_large", None),
|
||||
("controlnet_xl_mid", None),
|
||||
("controlnet_xl_small", None),
|
||||
("flux-depth", None),
|
||||
("flux-dev", None),
|
||||
("flux-fill", None),
|
||||
("flux-schnell", None),
|
||||
("hunyuan-video", None),
|
||||
("inpainting", StableDiffusionInpaintPipeline),
|
||||
("inpainting_v2", StableDiffusionInpaintPipeline),
|
||||
("controlnet", StableDiffusionControlNetInpaintPipeline),
|
||||
("v2", None),
|
||||
("ltx-video", None),
|
||||
("ltx-video-0.9.1", None),
|
||||
("mochi-1-preview", None),
|
||||
("playground-v2-5", None),
|
||||
("sd3", None),
|
||||
("sd35_large", None),
|
||||
("sd35_medium", None),
|
||||
("stable_cascade_stage_b", None),
|
||||
("stable_cascade_stage_b_lite", None),
|
||||
("stable_cascade_stage_c", None),
|
||||
("stable_cascade_stage_c_lite", None),
|
||||
("upscale", StableDiffusionUpscalePipeline),
|
||||
("v1", None),
|
||||
("v2", None),
|
||||
("xl_base", None),
|
||||
("xl_inpaint", StableDiffusionXLInpaintPipeline),
|
||||
("xl_refiner", None),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -116,14 +211,33 @@ CONFIG_FILE_LIST = [
|
||||
"diffusion_pytorch_model.non_ema.safetensors",
|
||||
]
|
||||
|
||||
DIFFUSERS_CONFIG_DIR = ["safety_checker", "unet", "vae", "text_encoder", "text_encoder_2"]
|
||||
|
||||
INPAINT_PIPELINE_KEYS = [
|
||||
"xl_inpaint",
|
||||
"inpainting",
|
||||
"inpainting_v2",
|
||||
DIFFUSERS_CONFIG_DIR = [
|
||||
"safety_checker",
|
||||
"unet",
|
||||
"vae",
|
||||
"text_encoder",
|
||||
"text_encoder_2",
|
||||
]
|
||||
|
||||
TOKENIZER_SHAPE_MAP = {
|
||||
768: [
|
||||
"SD 1.4",
|
||||
"SD 1.5",
|
||||
"SD 1.5 LCM",
|
||||
"SDXL 0.9",
|
||||
"SDXL 1.0",
|
||||
"SDXL 1.0 LCM",
|
||||
"SDXL Distilled",
|
||||
"SDXL Turbo",
|
||||
"SDXL Lightning",
|
||||
"PixArt a",
|
||||
"Playground v2",
|
||||
"Pony",
|
||||
],
|
||||
1024: ["SD 2.0", "SD 2.0 768", "SD 2.1", "SD 2.1 768", "SD 2.1 Unclip"],
|
||||
}
|
||||
|
||||
|
||||
EXTENSION = [".safetensors", ".ckpt", ".bin"]
|
||||
|
||||
CACHE_HOME = os.path.expanduser("~/.cache")
|
||||
@@ -162,12 +276,28 @@ class ModelStatus:
|
||||
The name of the model file.
|
||||
local (`bool`):
|
||||
Whether the model exists locally
|
||||
site_url (`str`):
|
||||
The URL of the site where the model is hosted.
|
||||
"""
|
||||
|
||||
search_word: str = ""
|
||||
download_url: str = ""
|
||||
file_name: str = ""
|
||||
local: bool = False
|
||||
site_url: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtraStatus:
|
||||
r"""
|
||||
Data class for storing extra status information.
|
||||
|
||||
Attributes:
|
||||
trained_words (`str`):
|
||||
The words used to trigger the model
|
||||
"""
|
||||
|
||||
trained_words: Union[List[str], None] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -191,8 +321,9 @@ class SearchResult:
|
||||
model_path: str = ""
|
||||
loading_method: Union[str, None] = None
|
||||
checkpoint_format: Union[str, None] = None
|
||||
repo_status: RepoStatus = RepoStatus()
|
||||
model_status: ModelStatus = ModelStatus()
|
||||
repo_status: RepoStatus = field(default_factory=RepoStatus)
|
||||
model_status: ModelStatus = field(default_factory=ModelStatus)
|
||||
extra_status: ExtraStatus = field(default_factory=ExtraStatus)
|
||||
|
||||
|
||||
@validate_hf_hub_args
|
||||
@@ -385,6 +516,7 @@ def file_downloader(
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
displayed_filename = kwargs.pop("displayed_filename", None)
|
||||
|
||||
# Default mode for file writing and initial file size
|
||||
mode = "wb"
|
||||
file_size = 0
|
||||
@@ -396,7 +528,7 @@ def file_downloader(
|
||||
if os.path.exists(save_path):
|
||||
if not force_download:
|
||||
# If the file exists and force_download is False, skip the download
|
||||
logger.warning(f"File already exists: {save_path}, skipping download.")
|
||||
logger.info(f"File already exists: {save_path}, skipping download.")
|
||||
return None
|
||||
elif resume:
|
||||
# If resuming, set mode to append binary and get current file size
|
||||
@@ -457,10 +589,18 @@ def search_huggingface(search_word: str, **kwargs) -> Union[str, SearchResult, N
|
||||
gated = kwargs.pop("gated", False)
|
||||
skip_error = kwargs.pop("skip_error", False)
|
||||
|
||||
file_list = []
|
||||
hf_repo_info = {}
|
||||
hf_security_info = {}
|
||||
model_path = ""
|
||||
repo_id, file_name = "", ""
|
||||
diffusers_model_exists = False
|
||||
|
||||
# Get the type and loading method for the keyword
|
||||
search_word_status = get_keyword_types(search_word)
|
||||
|
||||
if search_word_status["type"]["hf_repo"]:
|
||||
hf_repo_info = hf_api.model_info(repo_id=search_word, securityStatus=True)
|
||||
if download:
|
||||
model_path = DiffusionPipeline.download(
|
||||
search_word,
|
||||
@@ -503,13 +643,6 @@ def search_huggingface(search_word: str, **kwargs) -> Union[str, SearchResult, N
|
||||
)
|
||||
model_dicts = [asdict(value) for value in list(hf_models)]
|
||||
|
||||
file_list = []
|
||||
hf_repo_info = {}
|
||||
hf_security_info = {}
|
||||
model_path = ""
|
||||
repo_id, file_name = "", ""
|
||||
diffusers_model_exists = False
|
||||
|
||||
# Loop through models to find a suitable candidate
|
||||
for repo_info in model_dicts:
|
||||
repo_id = repo_info["id"]
|
||||
@@ -523,7 +656,10 @@ def search_huggingface(search_word: str, **kwargs) -> Union[str, SearchResult, N
|
||||
if hf_security_info["scansDone"]:
|
||||
for info in repo_info["siblings"]:
|
||||
file_path = info["rfilename"]
|
||||
if "model_index.json" == file_path and checkpoint_format in ["diffusers", "all"]:
|
||||
if "model_index.json" == file_path and checkpoint_format in [
|
||||
"diffusers",
|
||||
"all",
|
||||
]:
|
||||
diffusers_model_exists = True
|
||||
break
|
||||
|
||||
@@ -571,6 +707,10 @@ def search_huggingface(search_word: str, **kwargs) -> Union[str, SearchResult, N
|
||||
force_download=force_download,
|
||||
)
|
||||
|
||||
# `pathlib.PosixPath` may be returned
|
||||
if model_path:
|
||||
model_path = str(model_path)
|
||||
|
||||
if file_name:
|
||||
download_url = f"https://huggingface.co/{repo_id}/blob/main/{file_name}"
|
||||
else:
|
||||
@@ -586,10 +726,12 @@ def search_huggingface(search_word: str, **kwargs) -> Union[str, SearchResult, N
|
||||
repo_status=RepoStatus(repo_id=repo_id, repo_hash=hf_repo_info.sha, version=revision),
|
||||
model_status=ModelStatus(
|
||||
search_word=search_word,
|
||||
site_url=download_url,
|
||||
download_url=download_url,
|
||||
file_name=file_name,
|
||||
local=download,
|
||||
),
|
||||
extra_status=ExtraStatus(trained_words=None),
|
||||
)
|
||||
|
||||
else:
|
||||
@@ -605,6 +747,8 @@ def search_civitai(search_word: str, **kwargs) -> Union[str, SearchResult, None]
|
||||
The search query string.
|
||||
model_type (`str`, *optional*, defaults to `Checkpoint`):
|
||||
The type of model to search for.
|
||||
sort (`str`, *optional*):
|
||||
The order in which you wish to sort the results(for example, `Highest Rated`, `Most Downloaded`, `Newest`).
|
||||
base_model (`str`, *optional*):
|
||||
The base model to filter by.
|
||||
download (`bool`, *optional*, defaults to `False`):
|
||||
@@ -628,6 +772,7 @@ def search_civitai(search_word: str, **kwargs) -> Union[str, SearchResult, None]
|
||||
|
||||
# Extract additional parameters from kwargs
|
||||
model_type = kwargs.pop("model_type", "Checkpoint")
|
||||
sort = kwargs.pop("sort", None)
|
||||
download = kwargs.pop("download", False)
|
||||
base_model = kwargs.pop("base_model", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
@@ -642,6 +787,7 @@ def search_civitai(search_word: str, **kwargs) -> Union[str, SearchResult, None]
|
||||
repo_name = ""
|
||||
repo_id = ""
|
||||
version_id = ""
|
||||
trainedWords = ""
|
||||
models_list = []
|
||||
selected_repo = {}
|
||||
selected_model = {}
|
||||
@@ -652,12 +798,16 @@ def search_civitai(search_word: str, **kwargs) -> Union[str, SearchResult, None]
|
||||
params = {
|
||||
"query": search_word,
|
||||
"types": model_type,
|
||||
"sort": "Most Downloaded",
|
||||
"limit": 20,
|
||||
}
|
||||
if base_model is not None:
|
||||
if not isinstance(base_model, list):
|
||||
base_model = [base_model]
|
||||
params["baseModel"] = base_model
|
||||
|
||||
if sort is not None:
|
||||
params["sort"] = sort
|
||||
|
||||
headers = {}
|
||||
if token:
|
||||
headers["Authorization"] = f"Bearer {token}"
|
||||
@@ -686,25 +836,30 @@ def search_civitai(search_word: str, **kwargs) -> Union[str, SearchResult, None]
|
||||
|
||||
# Sort versions within the selected repo by download count
|
||||
sorted_versions = sorted(
|
||||
selected_repo["modelVersions"], key=lambda x: x["stats"]["downloadCount"], reverse=True
|
||||
selected_repo["modelVersions"],
|
||||
key=lambda x: x["stats"]["downloadCount"],
|
||||
reverse=True,
|
||||
)
|
||||
for selected_version in sorted_versions:
|
||||
version_id = selected_version["id"]
|
||||
trainedWords = selected_version["trainedWords"]
|
||||
models_list = []
|
||||
for model_data in selected_version["files"]:
|
||||
# Check if the file passes security scans and has a valid extension
|
||||
file_name = model_data["name"]
|
||||
if (
|
||||
model_data["pickleScanResult"] == "Success"
|
||||
and model_data["virusScanResult"] == "Success"
|
||||
and any(file_name.endswith(ext) for ext in EXTENSION)
|
||||
and os.path.basename(os.path.dirname(file_name)) not in DIFFUSERS_CONFIG_DIR
|
||||
):
|
||||
file_status = {
|
||||
"filename": file_name,
|
||||
"download_url": model_data["downloadUrl"],
|
||||
}
|
||||
models_list.append(file_status)
|
||||
# When searching for textual inversion, results other than the values entered for the base model may come up, so check again.
|
||||
if base_model is None or selected_version["baseModel"] in base_model:
|
||||
for model_data in selected_version["files"]:
|
||||
# Check if the file passes security scans and has a valid extension
|
||||
file_name = model_data["name"]
|
||||
if (
|
||||
model_data["pickleScanResult"] == "Success"
|
||||
and model_data["virusScanResult"] == "Success"
|
||||
and any(file_name.endswith(ext) for ext in EXTENSION)
|
||||
and os.path.basename(os.path.dirname(file_name)) not in DIFFUSERS_CONFIG_DIR
|
||||
):
|
||||
file_status = {
|
||||
"filename": file_name,
|
||||
"download_url": model_data["downloadUrl"],
|
||||
}
|
||||
models_list.append(file_status)
|
||||
|
||||
if models_list:
|
||||
# Sort the models list by filename and find the safest model
|
||||
@@ -764,19 +919,229 @@ def search_civitai(search_word: str, **kwargs) -> Union[str, SearchResult, None]
|
||||
repo_status=RepoStatus(repo_id=repo_name, repo_hash=repo_id, version=version_id),
|
||||
model_status=ModelStatus(
|
||||
search_word=search_word,
|
||||
site_url=f"https://civitai.com/models/{repo_id}?modelVersionId={version_id}",
|
||||
download_url=download_url,
|
||||
file_name=file_name,
|
||||
local=output_info["type"]["local"],
|
||||
),
|
||||
extra_status=ExtraStatus(trained_words=trainedWords or None),
|
||||
)
|
||||
|
||||
|
||||
def add_methods(pipeline):
|
||||
r"""
|
||||
Add methods from `AutoConfig` to the pipeline.
|
||||
|
||||
Parameters:
|
||||
pipeline (`Pipeline`):
|
||||
The pipeline to which the methods will be added.
|
||||
"""
|
||||
for attr_name in dir(AutoConfig):
|
||||
attr_value = getattr(AutoConfig, attr_name)
|
||||
if callable(attr_value) and not attr_name.startswith("__"):
|
||||
setattr(pipeline, attr_name, types.MethodType(attr_value, pipeline))
|
||||
return pipeline
|
||||
|
||||
|
||||
class AutoConfig:
|
||||
def auto_load_textual_inversion(
|
||||
self,
|
||||
pretrained_model_name_or_path: Union[str, List[str]],
|
||||
token: Optional[Union[str, List[str]]] = None,
|
||||
base_model: Optional[Union[str, List[str]]] = None,
|
||||
tokenizer=None,
|
||||
text_encoder=None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Load Textual Inversion embeddings into the text encoder of [`StableDiffusionPipeline`] (both 🤗 Diffusers and
|
||||
Automatic1111 formats are supported).
|
||||
|
||||
Parameters:
|
||||
pretrained_model_name_or_path (`str` or `os.PathLike` or `List[str or os.PathLike]` or `Dict` or `List[Dict]`):
|
||||
Can be either one of the following or a list of them:
|
||||
|
||||
- Search keywords for pretrained model (for example `EasyNegative`).
|
||||
- A string, the *model id* (for example `sd-concepts-library/low-poly-hd-logos-icons`) of a
|
||||
pretrained model hosted on the Hub.
|
||||
- A path to a *directory* (for example `./my_text_inversion_directory/`) containing the textual
|
||||
inversion weights.
|
||||
- A path to a *file* (for example `./my_text_inversions.pt`) containing textual inversion weights.
|
||||
- A [torch state
|
||||
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
||||
|
||||
token (`str` or `List[str]`, *optional*):
|
||||
Override the token to use for the textual inversion weights. If `pretrained_model_name_or_path` is a
|
||||
list, then `token` must also be a list of equal length.
|
||||
text_encoder ([`~transformers.CLIPTextModel`], *optional*):
|
||||
Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
|
||||
If not specified, function will take self.tokenizer.
|
||||
tokenizer ([`~transformers.CLIPTokenizer`], *optional*):
|
||||
A `CLIPTokenizer` to tokenize text. If not specified, function will take self.tokenizer.
|
||||
weight_name (`str`, *optional*):
|
||||
Name of a custom weight file. This should be used when:
|
||||
|
||||
- The saved textual inversion file is in 🤗 Diffusers format, but was saved under a specific weight
|
||||
name such as `text_inv.bin`.
|
||||
- The saved textual inversion file is in the Automatic1111 format.
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
||||
won't be downloaded from the Hub.
|
||||
token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
||||
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
||||
allowed by Git.
|
||||
subfolder (`str`, *optional*, defaults to `""`):
|
||||
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
||||
mirror (`str`, *optional*):
|
||||
Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
|
||||
guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
|
||||
information.
|
||||
|
||||
Examples:
|
||||
|
||||
```py
|
||||
>>> from auto_diffusers import EasyPipelineForText2Image
|
||||
|
||||
>>> pipeline = EasyPipelineForText2Image.from_huggingface("stable-diffusion-v1-5")
|
||||
|
||||
>>> pipeline.auto_load_textual_inversion("EasyNegative", token="EasyNegative")
|
||||
|
||||
>>> image = pipeline(prompt).images[0]
|
||||
```
|
||||
|
||||
"""
|
||||
# 1. Set tokenizer and text encoder
|
||||
tokenizer = tokenizer or getattr(self, "tokenizer", None)
|
||||
text_encoder = text_encoder or getattr(self, "text_encoder", None)
|
||||
|
||||
# Check if tokenizer and text encoder are provided
|
||||
if tokenizer is None or text_encoder is None:
|
||||
raise ValueError("Tokenizer and text encoder must be provided.")
|
||||
|
||||
# 2. Normalize inputs
|
||||
pretrained_model_name_or_paths = (
|
||||
[pretrained_model_name_or_path]
|
||||
if not isinstance(pretrained_model_name_or_path, list)
|
||||
else pretrained_model_name_or_path
|
||||
)
|
||||
|
||||
# 2.1 Normalize tokens
|
||||
tokens = [token] if not isinstance(token, list) else token
|
||||
if tokens[0] is None:
|
||||
tokens = tokens * len(pretrained_model_name_or_paths)
|
||||
|
||||
for check_token in tokens:
|
||||
# Check if token is already in tokenizer vocabulary
|
||||
if check_token in tokenizer.get_vocab():
|
||||
raise ValueError(
|
||||
f"Token {token} already in tokenizer vocabulary. Please choose a different token name or remove {token} and embedding from the tokenizer and text encoder."
|
||||
)
|
||||
|
||||
expected_shape = text_encoder.get_input_embeddings().weight.shape[-1] # Expected shape of tokenizer
|
||||
|
||||
for search_word in pretrained_model_name_or_paths:
|
||||
if isinstance(search_word, str):
|
||||
# Update kwargs to ensure the model is downloaded and parameters are included
|
||||
_status = {
|
||||
"download": True,
|
||||
"include_params": True,
|
||||
"skip_error": False,
|
||||
"model_type": "TextualInversion",
|
||||
}
|
||||
# Get tags for the base model of textual inversion compatible with tokenizer.
|
||||
# If the tokenizer is 768-dimensional, set tags for SD 1.x and SDXL.
|
||||
# If the tokenizer is 1024-dimensional, set tags for SD 2.x.
|
||||
if expected_shape in TOKENIZER_SHAPE_MAP:
|
||||
# Retrieve the appropriate tags from the TOKENIZER_SHAPE_MAP based on the expected shape
|
||||
tags = TOKENIZER_SHAPE_MAP[expected_shape]
|
||||
if base_model is not None:
|
||||
if isinstance(base_model, list):
|
||||
tags.extend(base_model)
|
||||
else:
|
||||
tags.append(base_model)
|
||||
_status["base_model"] = tags
|
||||
|
||||
kwargs.update(_status)
|
||||
# Search for the model on Civitai and get the model status
|
||||
textual_inversion_path = search_civitai(search_word, **kwargs)
|
||||
logger.warning(
|
||||
f"textual_inversion_path: {search_word} -> {textual_inversion_path.model_status.site_url}"
|
||||
)
|
||||
|
||||
pretrained_model_name_or_paths[
|
||||
pretrained_model_name_or_paths.index(search_word)
|
||||
] = textual_inversion_path.model_path
|
||||
|
||||
self.load_textual_inversion(
|
||||
pretrained_model_name_or_paths, token=tokens, tokenizer=tokenizer, text_encoder=text_encoder, **kwargs
|
||||
)
|
||||
|
||||
def auto_load_lora_weights(
|
||||
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
|
||||
):
|
||||
r"""
|
||||
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
|
||||
`self.text_encoder`.
|
||||
|
||||
All kwargs are forwarded to `self.lora_state_dict`.
|
||||
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is
|
||||
loaded.
|
||||
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details on how the state dict is
|
||||
loaded into `self.unet`.
|
||||
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder`] for more details on how the state
|
||||
dict is loaded into `self.text_encoder`.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
kwargs (`dict`, *optional*):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
||||
"""
|
||||
if isinstance(pretrained_model_name_or_path_or_dict, str):
|
||||
# Update kwargs to ensure the model is downloaded and parameters are included
|
||||
_status = {
|
||||
"download": True,
|
||||
"include_params": True,
|
||||
"skip_error": False,
|
||||
"model_type": "LORA",
|
||||
}
|
||||
kwargs.update(_status)
|
||||
# Search for the model on Civitai and get the model status
|
||||
lora_path = search_civitai(pretrained_model_name_or_path_or_dict, **kwargs)
|
||||
logger.warning(f"lora_path: {lora_path.model_status.site_url}")
|
||||
logger.warning(f"trained_words: {lora_path.extra_status.trained_words}")
|
||||
pretrained_model_name_or_path_or_dict = lora_path.model_path
|
||||
|
||||
self.load_lora_weights(pretrained_model_name_or_path_or_dict, adapter_name=adapter_name, **kwargs)
|
||||
|
||||
|
||||
class EasyPipelineForText2Image(AutoPipelineForText2Image):
|
||||
r"""
|
||||
|
||||
[`AutoPipelineForText2Image`] is a generic pipeline class that instantiates a text-to-image pipeline class. The
|
||||
[`EasyPipelineForText2Image`] is a generic pipeline class that instantiates a text-to-image pipeline class. The
|
||||
specific underlying pipeline class is automatically selected from either the
|
||||
[`~AutoPipelineForText2Image.from_pretrained`] or [`~AutoPipelineForText2Image.from_pipe`] methods.
|
||||
[`~EasyPipelineForText2Image.from_pretrained`], [`~EasyPipelineForText2Image.from_pipe`], [`~EasyPipelineForText2Image.from_huggingface`] or [`~EasyPipelineForText2Image.from_civitai`] methods.
|
||||
|
||||
This class cannot be instantiated using `__init__()` (throws an error).
|
||||
|
||||
@@ -891,9 +1256,9 @@ class EasyPipelineForText2Image(AutoPipelineForText2Image):
|
||||
Examples:
|
||||
|
||||
```py
|
||||
>>> from diffusers import AutoPipelineForText2Image
|
||||
>>> from auto_diffusers import EasyPipelineForText2Image
|
||||
|
||||
>>> pipeline = AutoPipelineForText2Image.from_huggingface("stable-diffusion-v1-5")
|
||||
>>> pipeline = EasyPipelineForText2Image.from_huggingface("stable-diffusion-v1-5")
|
||||
>>> image = pipeline(prompt).images[0]
|
||||
```
|
||||
"""
|
||||
@@ -907,20 +1272,21 @@ class EasyPipelineForText2Image(AutoPipelineForText2Image):
|
||||
kwargs.update(_status)
|
||||
|
||||
# Search for the model on Hugging Face and get the model status
|
||||
hf_model_status = search_huggingface(pretrained_model_link_or_path, **kwargs)
|
||||
logger.warning(f"checkpoint_path: {hf_model_status.model_status.download_url}")
|
||||
checkpoint_path = hf_model_status.model_path
|
||||
hf_checkpoint_status = search_huggingface(pretrained_model_link_or_path, **kwargs)
|
||||
logger.warning(f"checkpoint_path: {hf_checkpoint_status.model_status.download_url}")
|
||||
checkpoint_path = hf_checkpoint_status.model_path
|
||||
|
||||
# Check the format of the model checkpoint
|
||||
if hf_model_status.checkpoint_format == "single_file":
|
||||
if hf_checkpoint_status.loading_method == "from_single_file":
|
||||
# Load the pipeline from a single file checkpoint
|
||||
return load_pipeline_from_single_file(
|
||||
pipeline = load_pipeline_from_single_file(
|
||||
pretrained_model_or_path=checkpoint_path,
|
||||
pipeline_mapping=SINGLE_FILE_CHECKPOINT_TEXT2IMAGE_PIPELINE_MAPPING,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
return cls.from_pretrained(checkpoint_path, **kwargs)
|
||||
pipeline = cls.from_pretrained(checkpoint_path, **kwargs)
|
||||
return add_methods(pipeline)
|
||||
|
||||
@classmethod
|
||||
def from_civitai(cls, pretrained_model_link_or_path, **kwargs):
|
||||
@@ -999,9 +1365,9 @@ class EasyPipelineForText2Image(AutoPipelineForText2Image):
|
||||
Examples:
|
||||
|
||||
```py
|
||||
>>> from diffusers import AutoPipelineForText2Image
|
||||
>>> from auto_diffusers import EasyPipelineForText2Image
|
||||
|
||||
>>> pipeline = AutoPipelineForText2Image.from_huggingface("stable-diffusion-v1-5")
|
||||
>>> pipeline = EasyPipelineForText2Image.from_huggingface("stable-diffusion-v1-5")
|
||||
>>> image = pipeline(prompt).images[0]
|
||||
```
|
||||
"""
|
||||
@@ -1015,24 +1381,25 @@ class EasyPipelineForText2Image(AutoPipelineForText2Image):
|
||||
kwargs.update(_status)
|
||||
|
||||
# Search for the model on Civitai and get the model status
|
||||
model_status = search_civitai(pretrained_model_link_or_path, **kwargs)
|
||||
logger.warning(f"checkpoint_path: {model_status.model_status.download_url}")
|
||||
checkpoint_path = model_status.model_path
|
||||
checkpoint_status = search_civitai(pretrained_model_link_or_path, **kwargs)
|
||||
logger.warning(f"checkpoint_path: {checkpoint_status.model_status.site_url}")
|
||||
checkpoint_path = checkpoint_status.model_path
|
||||
|
||||
# Load the pipeline from a single file checkpoint
|
||||
return load_pipeline_from_single_file(
|
||||
pipeline = load_pipeline_from_single_file(
|
||||
pretrained_model_or_path=checkpoint_path,
|
||||
pipeline_mapping=SINGLE_FILE_CHECKPOINT_TEXT2IMAGE_PIPELINE_MAPPING,
|
||||
**kwargs,
|
||||
)
|
||||
return add_methods(pipeline)
|
||||
|
||||
|
||||
class EasyPipelineForImage2Image(AutoPipelineForImage2Image):
|
||||
r"""
|
||||
|
||||
[`AutoPipelineForImage2Image`] is a generic pipeline class that instantiates an image-to-image pipeline class. The
|
||||
[`EasyPipelineForImage2Image`] is a generic pipeline class that instantiates an image-to-image pipeline class. The
|
||||
specific underlying pipeline class is automatically selected from either the
|
||||
[`~AutoPipelineForImage2Image.from_pretrained`] or [`~AutoPipelineForImage2Image.from_pipe`] methods.
|
||||
[`~EasyPipelineForImage2Image.from_pretrained`], [`~EasyPipelineForImage2Image.from_pipe`], [`~EasyPipelineForImage2Image.from_huggingface`] or [`~EasyPipelineForImage2Image.from_civitai`] methods.
|
||||
|
||||
This class cannot be instantiated using `__init__()` (throws an error).
|
||||
|
||||
@@ -1147,10 +1514,10 @@ class EasyPipelineForImage2Image(AutoPipelineForImage2Image):
|
||||
Examples:
|
||||
|
||||
```py
|
||||
>>> from diffusers import AutoPipelineForText2Image
|
||||
>>> from auto_diffusers import EasyPipelineForImage2Image
|
||||
|
||||
>>> pipeline = AutoPipelineForText2Image.from_huggingface("stable-diffusion-v1-5")
|
||||
>>> image = pipeline(prompt).images[0]
|
||||
>>> pipeline = EasyPipelineForImage2Image.from_huggingface("stable-diffusion-v1-5")
|
||||
>>> image = pipeline(prompt, image).images[0]
|
||||
```
|
||||
"""
|
||||
# Update kwargs to ensure the model is downloaded and parameters are included
|
||||
@@ -1163,20 +1530,22 @@ class EasyPipelineForImage2Image(AutoPipelineForImage2Image):
|
||||
kwargs.update(_parmas)
|
||||
|
||||
# Search for the model on Hugging Face and get the model status
|
||||
model_status = search_huggingface(pretrained_model_link_or_path, **kwargs)
|
||||
logger.warning(f"checkpoint_path: {model_status.model_status.download_url}")
|
||||
checkpoint_path = model_status.model_path
|
||||
hf_checkpoint_status = search_huggingface(pretrained_model_link_or_path, **kwargs)
|
||||
logger.warning(f"checkpoint_path: {hf_checkpoint_status.model_status.download_url}")
|
||||
checkpoint_path = hf_checkpoint_status.model_path
|
||||
|
||||
# Check the format of the model checkpoint
|
||||
if model_status.checkpoint_format == "single_file":
|
||||
if hf_checkpoint_status.loading_method == "from_single_file":
|
||||
# Load the pipeline from a single file checkpoint
|
||||
return load_pipeline_from_single_file(
|
||||
pipeline = load_pipeline_from_single_file(
|
||||
pretrained_model_or_path=checkpoint_path,
|
||||
pipeline_mapping=SINGLE_FILE_CHECKPOINT_IMAGE2IMAGE_PIPELINE_MAPPING,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
return cls.from_pretrained(checkpoint_path, **kwargs)
|
||||
pipeline = cls.from_pretrained(checkpoint_path, **kwargs)
|
||||
|
||||
return add_methods(pipeline)
|
||||
|
||||
@classmethod
|
||||
def from_civitai(cls, pretrained_model_link_or_path, **kwargs):
|
||||
@@ -1255,10 +1624,10 @@ class EasyPipelineForImage2Image(AutoPipelineForImage2Image):
|
||||
Examples:
|
||||
|
||||
```py
|
||||
>>> from diffusers import AutoPipelineForText2Image
|
||||
>>> from auto_diffusers import EasyPipelineForImage2Image
|
||||
|
||||
>>> pipeline = AutoPipelineForText2Image.from_huggingface("stable-diffusion-v1-5")
|
||||
>>> image = pipeline(prompt).images[0]
|
||||
>>> pipeline = EasyPipelineForImage2Image.from_huggingface("stable-diffusion-v1-5")
|
||||
>>> image = pipeline(prompt, image).images[0]
|
||||
```
|
||||
"""
|
||||
# Update kwargs to ensure the model is downloaded and parameters are included
|
||||
@@ -1271,24 +1640,25 @@ class EasyPipelineForImage2Image(AutoPipelineForImage2Image):
|
||||
kwargs.update(_status)
|
||||
|
||||
# Search for the model on Civitai and get the model status
|
||||
model_status = search_civitai(pretrained_model_link_or_path, **kwargs)
|
||||
logger.warning(f"checkpoint_path: {model_status.model_status.download_url}")
|
||||
checkpoint_path = model_status.model_path
|
||||
checkpoint_status = search_civitai(pretrained_model_link_or_path, **kwargs)
|
||||
logger.warning(f"checkpoint_path: {checkpoint_status.model_status.site_url}")
|
||||
checkpoint_path = checkpoint_status.model_path
|
||||
|
||||
# Load the pipeline from a single file checkpoint
|
||||
return load_pipeline_from_single_file(
|
||||
pipeline = load_pipeline_from_single_file(
|
||||
pretrained_model_or_path=checkpoint_path,
|
||||
pipeline_mapping=SINGLE_FILE_CHECKPOINT_IMAGE2IMAGE_PIPELINE_MAPPING,
|
||||
**kwargs,
|
||||
)
|
||||
return add_methods(pipeline)
|
||||
|
||||
|
||||
class EasyPipelineForInpainting(AutoPipelineForInpainting):
|
||||
r"""
|
||||
|
||||
[`AutoPipelineForInpainting`] is a generic pipeline class that instantiates an inpainting pipeline class. The
|
||||
[`EasyPipelineForInpainting`] is a generic pipeline class that instantiates an inpainting pipeline class. The
|
||||
specific underlying pipeline class is automatically selected from either the
|
||||
[`~AutoPipelineForInpainting.from_pretrained`] or [`~AutoPipelineForInpainting.from_pipe`] methods.
|
||||
[`~EasyPipelineForInpainting.from_pretrained`], [`~EasyPipelineForInpainting.from_pipe`], [`~EasyPipelineForInpainting.from_huggingface`] or [`~EasyPipelineForInpainting.from_civitai`] methods.
|
||||
|
||||
This class cannot be instantiated using `__init__()` (throws an error).
|
||||
|
||||
@@ -1403,10 +1773,10 @@ class EasyPipelineForInpainting(AutoPipelineForInpainting):
|
||||
Examples:
|
||||
|
||||
```py
|
||||
>>> from diffusers import AutoPipelineForText2Image
|
||||
>>> from auto_diffusers import EasyPipelineForInpainting
|
||||
|
||||
>>> pipeline = AutoPipelineForText2Image.from_huggingface("stable-diffusion-v1-5")
|
||||
>>> image = pipeline(prompt).images[0]
|
||||
>>> pipeline = EasyPipelineForInpainting.from_huggingface("stable-diffusion-2-inpainting")
|
||||
>>> image = pipeline(prompt, image=init_image, mask_image=mask_image).images[0]
|
||||
```
|
||||
"""
|
||||
# Update kwargs to ensure the model is downloaded and parameters are included
|
||||
@@ -1419,20 +1789,21 @@ class EasyPipelineForInpainting(AutoPipelineForInpainting):
|
||||
kwargs.update(_status)
|
||||
|
||||
# Search for the model on Hugging Face and get the model status
|
||||
model_status = search_huggingface(pretrained_model_link_or_path, **kwargs)
|
||||
logger.warning(f"checkpoint_path: {model_status.model_status.download_url}")
|
||||
checkpoint_path = model_status.model_path
|
||||
hf_checkpoint_status = search_huggingface(pretrained_model_link_or_path, **kwargs)
|
||||
logger.warning(f"checkpoint_path: {hf_checkpoint_status.model_status.download_url}")
|
||||
checkpoint_path = hf_checkpoint_status.model_path
|
||||
|
||||
# Check the format of the model checkpoint
|
||||
if model_status.checkpoint_format == "single_file":
|
||||
if hf_checkpoint_status.loading_method == "from_single_file":
|
||||
# Load the pipeline from a single file checkpoint
|
||||
return load_pipeline_from_single_file(
|
||||
pipeline = load_pipeline_from_single_file(
|
||||
pretrained_model_or_path=checkpoint_path,
|
||||
pipeline_mapping=SINGLE_FILE_CHECKPOINT_INPAINT_PIPELINE_MAPPING,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
return cls.from_pretrained(checkpoint_path, **kwargs)
|
||||
pipeline = cls.from_pretrained(checkpoint_path, **kwargs)
|
||||
return add_methods(pipeline)
|
||||
|
||||
@classmethod
|
||||
def from_civitai(cls, pretrained_model_link_or_path, **kwargs):
|
||||
@@ -1511,10 +1882,10 @@ class EasyPipelineForInpainting(AutoPipelineForInpainting):
|
||||
Examples:
|
||||
|
||||
```py
|
||||
>>> from diffusers import AutoPipelineForText2Image
|
||||
>>> from auto_diffusers import EasyPipelineForInpainting
|
||||
|
||||
>>> pipeline = AutoPipelineForText2Image.from_huggingface("stable-diffusion-v1-5")
|
||||
>>> image = pipeline(prompt).images[0]
|
||||
>>> pipeline = EasyPipelineForInpainting.from_huggingface("stable-diffusion-2-inpainting")
|
||||
>>> image = pipeline(prompt, image=init_image, mask_image=mask_image).images[0]
|
||||
```
|
||||
"""
|
||||
# Update kwargs to ensure the model is downloaded and parameters are included
|
||||
@@ -1527,13 +1898,14 @@ class EasyPipelineForInpainting(AutoPipelineForInpainting):
|
||||
kwargs.update(_status)
|
||||
|
||||
# Search for the model on Civitai and get the model status
|
||||
model_status = search_civitai(pretrained_model_link_or_path, **kwargs)
|
||||
logger.warning(f"checkpoint_path: {model_status.model_status.download_url}")
|
||||
checkpoint_path = model_status.model_path
|
||||
checkpoint_status = search_civitai(pretrained_model_link_or_path, **kwargs)
|
||||
logger.warning(f"checkpoint_path: {checkpoint_status.model_status.site_url}")
|
||||
checkpoint_path = checkpoint_status.model_path
|
||||
|
||||
# Load the pipeline from a single file checkpoint
|
||||
return load_pipeline_from_single_file(
|
||||
pipeline = load_pipeline_from_single_file(
|
||||
pretrained_model_or_path=checkpoint_path,
|
||||
pipeline_mapping=SINGLE_FILE_CHECKPOINT_INPAINT_PIPELINE_MAPPING,
|
||||
**kwargs,
|
||||
)
|
||||
return add_methods(pipeline)
|
||||
|
||||
@@ -0,0 +1,30 @@
|
||||
# Diffusion Model Alignment Using GRPO
|
||||
|
||||
|
||||
This directory provides LoRA implementations of Diffusion [GRPO](https://arxiv.org/abs/2402.03300) an RL based alignment method which is a variant of Proximal Policy Optimization (PPO) in the diffusion model setting.
|
||||
|
||||
## SDXL training command
|
||||
|
||||
```bash
|
||||
accelerate launch train_diffusion_grpo_sdxl.py \
|
||||
--pretrained_model_name_or_path=stabilityai/stable-diffusion-xl-base-1.0 \
|
||||
--pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix \
|
||||
--output_dir="diffusion-sdxl-dpo" \
|
||||
--mixed_precision="fp16" \
|
||||
--dataset_name=kashif/pickascore \
|
||||
--train_batch_size=8 \
|
||||
--gradient_accumulation_steps=2 \
|
||||
--gradient_checkpointing \
|
||||
--use_8bit_adam \
|
||||
--rank=8 \
|
||||
--learning_rate=1e-5 \
|
||||
--report_to="wandb" \
|
||||
--lr_scheduler="constant" \
|
||||
--lr_warmup_steps=0 \
|
||||
--max_train_steps=2000 \
|
||||
--checkpointing_steps=500 \
|
||||
--run_validation --validation_steps=50 \
|
||||
--seed="0" \
|
||||
--report_to="wandb" \
|
||||
--push_to_hub
|
||||
```
|
||||
@@ -0,0 +1,8 @@
|
||||
accelerate>=0.16.0
|
||||
torchvision
|
||||
transformers>=4.25.1
|
||||
ftfy
|
||||
tensorboard
|
||||
Jinja2
|
||||
peft
|
||||
wandb
|
||||
File diff suppressed because it is too large
Load Diff
@@ -365,8 +365,8 @@ def parse_args():
|
||||
"--dream_training",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Use the DREAM training method, which makes training more efficient and accurate at the ",
|
||||
"expense of doing an extra forward pass. See: https://arxiv.org/abs/2312.00210",
|
||||
"Use the DREAM training method, which makes training more efficient and accurate at the "
|
||||
"expense of doing an extra forward pass. See: https://arxiv.org/abs/2312.00210"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
|
||||
@@ -519,7 +519,7 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
|
||||
remaining_keys = list(sds_sd.keys())
|
||||
te_state_dict = {}
|
||||
if remaining_keys:
|
||||
if not all(k.startswith("lora_te1") for k in remaining_keys):
|
||||
if not all(k.startswith("lora_te") for k in remaining_keys):
|
||||
raise ValueError(f"Incompatible keys detected: \n\n {', '.join(remaining_keys)}")
|
||||
for key in remaining_keys:
|
||||
if not key.endswith("lora_down.weight"):
|
||||
@@ -558,6 +558,88 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
|
||||
new_state_dict = {**ait_sd, **te_state_dict}
|
||||
return new_state_dict
|
||||
|
||||
def _convert_mixture_state_dict_to_diffusers(state_dict):
|
||||
new_state_dict = {}
|
||||
|
||||
def _convert(original_key, diffusers_key, state_dict, new_state_dict):
|
||||
down_key = f"{original_key}.lora_down.weight"
|
||||
down_weight = state_dict.pop(down_key)
|
||||
lora_rank = down_weight.shape[0]
|
||||
|
||||
up_weight_key = f"{original_key}.lora_up.weight"
|
||||
up_weight = state_dict.pop(up_weight_key)
|
||||
|
||||
alpha_key = f"{original_key}.alpha"
|
||||
alpha = state_dict.pop(alpha_key)
|
||||
|
||||
# scale weight by alpha and dim
|
||||
scale = alpha / lora_rank
|
||||
# calculate scale_down and scale_up
|
||||
scale_down = scale
|
||||
scale_up = 1.0
|
||||
while scale_down * 2 < scale_up:
|
||||
scale_down *= 2
|
||||
scale_up /= 2
|
||||
down_weight = down_weight * scale_down
|
||||
up_weight = up_weight * scale_up
|
||||
|
||||
diffusers_down_key = f"{diffusers_key}.lora_A.weight"
|
||||
new_state_dict[diffusers_down_key] = down_weight
|
||||
new_state_dict[diffusers_down_key.replace(".lora_A.", ".lora_B.")] = up_weight
|
||||
|
||||
all_unique_keys = {
|
||||
k.replace(".lora_down.weight", "").replace(".lora_up.weight", "").replace(".alpha", "") for k in state_dict
|
||||
}
|
||||
all_unique_keys = sorted(all_unique_keys)
|
||||
assert all("lora_transformer_" in k for k in all_unique_keys), f"{all_unique_keys=}"
|
||||
|
||||
for k in all_unique_keys:
|
||||
if k.startswith("lora_transformer_single_transformer_blocks_"):
|
||||
i = int(k.split("lora_transformer_single_transformer_blocks_")[-1].split("_")[0])
|
||||
diffusers_key = f"single_transformer_blocks.{i}"
|
||||
elif k.startswith("lora_transformer_transformer_blocks_"):
|
||||
i = int(k.split("lora_transformer_transformer_blocks_")[-1].split("_")[0])
|
||||
diffusers_key = f"transformer_blocks.{i}"
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
if "attn_" in k:
|
||||
if "_to_out_0" in k:
|
||||
diffusers_key += ".attn.to_out.0"
|
||||
elif "_to_add_out" in k:
|
||||
diffusers_key += ".attn.to_add_out"
|
||||
elif any(qkv in k for qkv in ["to_q", "to_k", "to_v"]):
|
||||
remaining = k.split("attn_")[-1]
|
||||
diffusers_key += f".attn.{remaining}"
|
||||
elif any(add_qkv in k for add_qkv in ["add_q_proj", "add_k_proj", "add_v_proj"]):
|
||||
remaining = k.split("attn_")[-1]
|
||||
diffusers_key += f".attn.{remaining}"
|
||||
|
||||
if diffusers_key == f"transformer_blocks.{i}":
|
||||
print(k, diffusers_key)
|
||||
_convert(k, diffusers_key, state_dict, new_state_dict)
|
||||
|
||||
if len(state_dict) > 0:
|
||||
raise ValueError(
|
||||
f"Expected an empty state dict at this point but its has these keys which couldn't be parsed: {list(state_dict.keys())}."
|
||||
)
|
||||
|
||||
new_state_dict = {f"transformer.{k}": v for k, v in new_state_dict.items()}
|
||||
return new_state_dict
|
||||
|
||||
# This is weird.
|
||||
# https://huggingface.co/sayakpaul/different-lora-from-civitai/tree/main?show_file_info=sharp_detailed_foot.safetensors
|
||||
# has both `peft` and non-peft state dict.
|
||||
has_peft_state_dict = any(k.startswith("transformer.") for k in state_dict)
|
||||
if has_peft_state_dict:
|
||||
state_dict = {k: v for k, v in state_dict.items() if k.startswith("transformer.")}
|
||||
return state_dict
|
||||
# Another weird one.
|
||||
has_mixture = any(
|
||||
k.startswith("lora_transformer_") and ("lora_down" in k or "lora_up" in k or "alpha" in k) for k in state_dict
|
||||
)
|
||||
if has_mixture:
|
||||
return _convert_mixture_state_dict_to_diffusers(state_dict)
|
||||
return _convert_sd_scripts_to_ai_toolkit(state_dict)
|
||||
|
||||
|
||||
|
||||
@@ -177,5 +177,3 @@ class FluxTransformer2DLoadersMixin:
|
||||
|
||||
self.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers)
|
||||
self.config.encoder_hid_dim_type = "ip_image_proj"
|
||||
|
||||
self.to(dtype=self.dtype, device=self.device)
|
||||
|
||||
@@ -405,11 +405,12 @@ class Attention(nn.Module):
|
||||
else:
|
||||
try:
|
||||
# Make sure we can run the memory efficient attention
|
||||
_ = xformers.ops.memory_efficient_attention(
|
||||
torch.randn((1, 2, 40), device="cuda"),
|
||||
torch.randn((1, 2, 40), device="cuda"),
|
||||
torch.randn((1, 2, 40), device="cuda"),
|
||||
)
|
||||
dtype = None
|
||||
if attention_op is not None:
|
||||
op_fw, op_bw = attention_op
|
||||
dtype, *_ = op_fw.SUPPORTED_DTYPES
|
||||
q = torch.randn((1, 2, 40), device="cuda", dtype=dtype)
|
||||
_ = xformers.ops.memory_efficient_attention(q, q, q)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
@@ -31,6 +31,7 @@ import torch.utils.checkpoint
|
||||
from huggingface_hub import DDUFEntry, create_repo, split_torch_state_dict_into_shards
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
from torch import Tensor, nn
|
||||
from typing_extensions import Self
|
||||
|
||||
from .. import __version__
|
||||
from ..hooks import apply_layerwise_casting
|
||||
@@ -605,7 +606,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
|
||||
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs) -> Self:
|
||||
r"""
|
||||
Instantiate a pretrained PyTorch model from a pretrained model configuration.
|
||||
|
||||
|
||||
@@ -160,8 +160,10 @@ class AuraFlowPipeline(DiffusionPipeline):
|
||||
prompt_attention_mask=None,
|
||||
negative_prompt_attention_mask=None,
|
||||
):
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
|
||||
raise ValueError(
|
||||
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}."
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
|
||||
@@ -348,7 +348,7 @@ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
|
||||
prompt_template=None,
|
||||
):
|
||||
if height % 16 != 0 or width % 16 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
|
||||
@@ -630,6 +630,7 @@ def load_sub_model(
|
||||
cached_folder: Union[str, os.PathLike],
|
||||
use_safetensors: bool,
|
||||
dduf_entries: Optional[Dict[str, DDUFEntry]],
|
||||
provider_options: Any,
|
||||
):
|
||||
"""Helper method to load the module `name` from `library_name` and `class_name`"""
|
||||
|
||||
@@ -676,6 +677,7 @@ def load_sub_model(
|
||||
if issubclass(class_obj, diffusers_module.OnnxRuntimeModel):
|
||||
loading_kwargs["provider"] = provider
|
||||
loading_kwargs["sess_options"] = sess_options
|
||||
loading_kwargs["provider_options"] = provider_options
|
||||
|
||||
is_diffusers_model = issubclass(class_obj, diffusers_module.ModelMixin)
|
||||
|
||||
|
||||
@@ -41,6 +41,7 @@ from huggingface_hub.utils import OfflineModeIsEnabled, validate_hf_hub_args
|
||||
from packaging import version
|
||||
from requests.exceptions import HTTPError
|
||||
from tqdm.auto import tqdm
|
||||
from typing_extensions import Self
|
||||
|
||||
from .. import __version__
|
||||
from ..configuration_utils import ConfigMixin
|
||||
@@ -513,7 +514,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
|
||||
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs) -> Self:
|
||||
r"""
|
||||
Instantiate a PyTorch diffusion pipeline from pretrained pipeline weights.
|
||||
|
||||
@@ -676,6 +677,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
custom_revision = kwargs.pop("custom_revision", None)
|
||||
provider = kwargs.pop("provider", None)
|
||||
sess_options = kwargs.pop("sess_options", None)
|
||||
provider_options = kwargs.pop("provider_options", None)
|
||||
device_map = kwargs.pop("device_map", None)
|
||||
max_memory = kwargs.pop("max_memory", None)
|
||||
offload_folder = kwargs.pop("offload_folder", None)
|
||||
@@ -970,6 +972,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
cached_folder=cached_folder,
|
||||
use_safetensors=use_safetensors,
|
||||
dduf_entries=dduf_entries,
|
||||
provider_options=provider_options,
|
||||
)
|
||||
logger.info(
|
||||
f"Loaded {name} as {class_name} from `{name}` subfolder of {pretrained_model_name_or_path}."
|
||||
|
||||
@@ -22,7 +22,7 @@ import torch
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...loaders import FromSingleFileMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, MultiAdapter, T2IAdapter, UNet2DConditionModel
|
||||
from ...models.lora import adjust_lora_scale_text_encoder
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
@@ -188,7 +188,7 @@ def retrieve_timesteps(
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class StableDiffusionAdapterPipeline(DiffusionPipeline, StableDiffusionMixin):
|
||||
class StableDiffusionAdapterPipeline(DiffusionPipeline, StableDiffusionMixin, FromSingleFileMixin):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using Stable Diffusion augmented with T2I-Adapter
|
||||
https://arxiv.org/abs/2302.08453
|
||||
|
||||
@@ -153,8 +153,8 @@ def replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name
|
||||
return model
|
||||
|
||||
|
||||
# Copied from PEFT: https://github.com/huggingface/peft/blob/47b3712898539569c02ec5b3ed4a6c36811331a1/src/peft/utils/integrations.py#L41
|
||||
def dequantize_bnb_weight(weight: "torch.nn.Parameter", state=None):
|
||||
# Adapted from PEFT: https://github.com/huggingface/peft/blob/6d458b300fc2ed82e19f796b53af4c97d03ea604/src/peft/utils/integrations.py#L81
|
||||
def dequantize_bnb_weight(weight: "torch.nn.Parameter", state=None, dtype: "torch.dtype" = None):
|
||||
"""
|
||||
Helper function to dequantize 4bit or 8bit bnb weights.
|
||||
|
||||
@@ -177,13 +177,16 @@ def dequantize_bnb_weight(weight: "torch.nn.Parameter", state=None):
|
||||
if state.SCB is None:
|
||||
state.SCB = weight.SCB
|
||||
|
||||
im = torch.eye(weight.data.shape[-1]).contiguous().half().to(weight.device)
|
||||
im, imt, SCim, SCimt, coo_tensorim = bnb.functional.double_quant(im)
|
||||
im, Sim = bnb.functional.transform(im, "col32")
|
||||
if state.CxB is None:
|
||||
state.CxB, state.SB = bnb.functional.transform(weight.data, to_order=state.formatB)
|
||||
out32, Sout32 = bnb.functional.igemmlt(im, state.CxB, Sim, state.SB)
|
||||
return bnb.functional.mm_dequant(out32, Sout32, SCim, state.SCB, bias=None).t()
|
||||
if hasattr(bnb.functional, "int8_vectorwise_dequant"):
|
||||
# Use bitsandbytes API if available (requires v0.45.0+)
|
||||
dequantized = bnb.functional.int8_vectorwise_dequant(weight.data, state.SCB)
|
||||
else:
|
||||
# Multiply by (scale/127) to dequantize.
|
||||
dequantized = weight.data * state.SCB.view(-1, 1) * 7.874015718698502e-3
|
||||
|
||||
if dtype:
|
||||
dequantized = dequantized.to(dtype)
|
||||
return dequantized
|
||||
|
||||
|
||||
def _create_accelerate_new_hook(old_hook):
|
||||
@@ -205,6 +208,7 @@ def _create_accelerate_new_hook(old_hook):
|
||||
|
||||
def _dequantize_and_replace(
|
||||
model,
|
||||
dtype,
|
||||
modules_to_not_convert=None,
|
||||
current_key_name=None,
|
||||
quantization_config=None,
|
||||
@@ -244,7 +248,7 @@ def _dequantize_and_replace(
|
||||
else:
|
||||
state = None
|
||||
|
||||
new_module.weight = torch.nn.Parameter(dequantize_bnb_weight(module.weight, state))
|
||||
new_module.weight = torch.nn.Parameter(dequantize_bnb_weight(module.weight, state, dtype))
|
||||
|
||||
if bias is not None:
|
||||
new_module.bias = bias
|
||||
@@ -263,9 +267,10 @@ def _dequantize_and_replace(
|
||||
if len(list(module.children())) > 0:
|
||||
_, has_been_replaced = _dequantize_and_replace(
|
||||
module,
|
||||
modules_to_not_convert,
|
||||
current_key_name,
|
||||
quantization_config,
|
||||
dtype=dtype,
|
||||
modules_to_not_convert=modules_to_not_convert,
|
||||
current_key_name=current_key_name,
|
||||
quantization_config=quantization_config,
|
||||
has_been_replaced=has_been_replaced,
|
||||
)
|
||||
# Remove the last key for recursion
|
||||
@@ -280,6 +285,7 @@ def dequantize_and_replace(
|
||||
):
|
||||
model, has_been_replaced = _dequantize_and_replace(
|
||||
model,
|
||||
dtype=model.dtype,
|
||||
modules_to_not_convert=modules_to_not_convert,
|
||||
quantization_config=quantization_config,
|
||||
)
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -77,6 +77,9 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
|
||||
Video](https://imagen.research.google/video/paper.pdf) paper).
|
||||
rho (`float`, *optional*, defaults to 7.0):
|
||||
The rho parameter used for calculating the Karras sigma schedule, which is set to 7.0 in the EDM paper [1].
|
||||
final_sigmas_type (`str`, defaults to `"zero"`):
|
||||
The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
|
||||
sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
|
||||
"""
|
||||
|
||||
_compatibles = []
|
||||
@@ -92,6 +95,7 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
|
||||
num_train_timesteps: int = 1000,
|
||||
prediction_type: str = "epsilon",
|
||||
rho: float = 7.0,
|
||||
final_sigmas_type: str = "zero", # can be "zero" or "sigma_min"
|
||||
):
|
||||
if sigma_schedule not in ["karras", "exponential"]:
|
||||
raise ValueError(f"Wrong value for provided for `{sigma_schedule=}`.`")
|
||||
@@ -99,15 +103,24 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
|
||||
# setable values
|
||||
self.num_inference_steps = None
|
||||
|
||||
ramp = torch.linspace(0, 1, num_train_timesteps)
|
||||
sigmas = torch.arange(num_train_timesteps + 1) / num_train_timesteps
|
||||
if sigma_schedule == "karras":
|
||||
sigmas = self._compute_karras_sigmas(ramp)
|
||||
sigmas = self._compute_karras_sigmas(sigmas)
|
||||
elif sigma_schedule == "exponential":
|
||||
sigmas = self._compute_exponential_sigmas(ramp)
|
||||
sigmas = self._compute_exponential_sigmas(sigmas)
|
||||
|
||||
self.timesteps = self.precondition_noise(sigmas)
|
||||
|
||||
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
|
||||
if self.config.final_sigmas_type == "sigma_min":
|
||||
sigma_last = sigmas[-1]
|
||||
elif self.config.final_sigmas_type == "zero":
|
||||
sigma_last = 0
|
||||
else:
|
||||
raise ValueError(
|
||||
f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
|
||||
)
|
||||
|
||||
self.sigmas = torch.cat([sigmas, torch.full((1,), fill_value=sigma_last, device=sigmas.device)])
|
||||
|
||||
self.is_scale_input_called = False
|
||||
|
||||
@@ -197,7 +210,12 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.is_scale_input_called = True
|
||||
return sample
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
||||
def set_timesteps(
|
||||
self,
|
||||
num_inference_steps: int = None,
|
||||
device: Union[str, torch.device] = None,
|
||||
sigmas: Optional[Union[torch.Tensor, List[float]]] = None,
|
||||
):
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||
|
||||
@@ -206,19 +224,36 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
sigmas (`Union[torch.Tensor, List[float]]`, *optional*):
|
||||
Custom sigmas to use for the denoising process. If not defined, the default behavior when
|
||||
`num_inference_steps` is passed will be used.
|
||||
"""
|
||||
self.num_inference_steps = num_inference_steps
|
||||
|
||||
ramp = torch.linspace(0, 1, self.num_inference_steps)
|
||||
if sigmas is None:
|
||||
sigmas = torch.linspace(0, 1, self.num_inference_steps)
|
||||
elif isinstance(sigmas, float):
|
||||
sigmas = torch.tensor(sigmas, dtype=torch.float32)
|
||||
else:
|
||||
sigmas = sigmas
|
||||
if self.config.sigma_schedule == "karras":
|
||||
sigmas = self._compute_karras_sigmas(ramp)
|
||||
sigmas = self._compute_karras_sigmas(sigmas)
|
||||
elif self.config.sigma_schedule == "exponential":
|
||||
sigmas = self._compute_exponential_sigmas(ramp)
|
||||
sigmas = self._compute_exponential_sigmas(sigmas)
|
||||
|
||||
sigmas = sigmas.to(dtype=torch.float32, device=device)
|
||||
self.timesteps = self.precondition_noise(sigmas)
|
||||
|
||||
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
|
||||
if self.config.final_sigmas_type == "sigma_min":
|
||||
sigma_last = sigmas[-1]
|
||||
elif self.config.final_sigmas_type == "zero":
|
||||
sigma_last = 0
|
||||
else:
|
||||
raise ValueError(
|
||||
f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
|
||||
)
|
||||
|
||||
self.sigmas = torch.cat([sigmas, torch.full((1,), fill_value=sigma_last, device=sigmas.device)])
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
@@ -248,7 +248,13 @@ def _set_state_dict_into_text_encoder(
|
||||
|
||||
|
||||
def compute_density_for_timestep_sampling(
|
||||
weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None
|
||||
weighting_scheme: str,
|
||||
batch_size: int,
|
||||
logit_mean: float = None,
|
||||
logit_std: float = None,
|
||||
mode_scale: float = None,
|
||||
device: Union[torch.device, str] = "cpu",
|
||||
generator: Optional[torch.Generator] = None,
|
||||
):
|
||||
"""
|
||||
Compute the density for sampling the timesteps when doing SD3 training.
|
||||
@@ -258,14 +264,13 @@ def compute_density_for_timestep_sampling(
|
||||
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
|
||||
"""
|
||||
if weighting_scheme == "logit_normal":
|
||||
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
|
||||
u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu")
|
||||
u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device=device, generator=generator)
|
||||
u = torch.nn.functional.sigmoid(u)
|
||||
elif weighting_scheme == "mode":
|
||||
u = torch.rand(size=(batch_size,), device="cpu")
|
||||
u = torch.rand(size=(batch_size,), device=device, generator=generator)
|
||||
u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
|
||||
else:
|
||||
u = torch.rand(size=(batch_size,), device="cpu")
|
||||
u = torch.rand(size=(batch_size,), device=device, generator=generator)
|
||||
return u
|
||||
|
||||
|
||||
|
||||
@@ -132,7 +132,7 @@ class HunyuanVideoPipelineFastTests(PipelineTesterMixin, PyramidAttentionBroadca
|
||||
|
||||
torch.manual_seed(0)
|
||||
text_encoder = LlamaModel(llama_text_encoder_config)
|
||||
tokenizer = LlamaTokenizer.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM")
|
||||
tokenizer = LlamaTokenizer.from_pretrained("finetrainers/dummy-hunyaunvideo", subfolder="tokenizer")
|
||||
|
||||
torch.manual_seed(0)
|
||||
text_encoder_2 = CLIPTextModel(clip_text_encoder_config)
|
||||
@@ -155,10 +155,8 @@ class HunyuanVideoPipelineFastTests(PipelineTesterMixin, PyramidAttentionBroadca
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
|
||||
# Cannot test with dummy prompt because tokenizers are not configured correctly.
|
||||
# TODO(aryan): create dummy tokenizers and using from hub
|
||||
inputs = {
|
||||
"prompt": "",
|
||||
"prompt": "dance monkey",
|
||||
"prompt_template": {
|
||||
"template": "{}",
|
||||
"crop_start": 0,
|
||||
|
||||
Reference in New Issue
Block a user