Compare commits

...

163 Commits

Author SHA1 Message Date
sayakpaul 601c506918 up 2023-11-06 12:56:06 +05:30
sayakpaul f8a5e172cf up 2023-11-06 12:49:57 +05:30
sayakpaul 47e9219450 pop 2023-11-06 12:34:47 +05:30
sayakpaul 04d83c209d byte tensor 2023-11-06 12:20:38 +05:30
sayakpaul 0d81b2dab4 debug 2023-11-06 12:18:22 +05:30
sayakpaul e67ddf8d13 debug 2023-11-06 12:17:09 +05:30
sayakpaul cdbbc7d5b7 fix 2023-11-06 12:10:02 +05:30
sayakpaul 81fb265a08 also serialize the states. 2023-11-06 12:09:09 +05:30
sayakpaul 21c8c433a3 fix: scheduler assignment. 2023-11-06 11:58:35 +05:30
sayakpaul 3479e5311d workflow_copy 2023-11-06 11:57:22 +05:30
sayakpaul f8cad5dc4a debug 2023-11-06 11:50:24 +05:30
sayakpaul d75f8a537c debug 2023-11-06 11:36:01 +05:30
sayakpaul 4b94889652 generator 2023-11-06 11:33:00 +05:30
sayakpaul 35a6538343 debug generator 2023-11-06 11:30:25 +05:30
sayakpaul 1b65ff770c disable custom pop 2023-11-06 11:25:30 +05:30
sayakpaul 857c65bb56 make Workflow more lightweight. 2023-11-06 11:19:28 +05:30
sayakpaul b5752ec4bf resolve conflicts 2023-11-06 11:04:34 +05:30
sayakpaul 157405436b update doc. 2023-11-02 22:26:04 +05:30
sayakpaul d69f3079a8 Merge branch 'main' into feat/workflows 2023-11-02 22:20:45 +05:30
sayakpaul 407e669fca better error message 2023-11-02 22:16:42 +05:30
sayakpaul 5e00fcc153 generator 2023-11-02 13:20:17 +05:30
sayakpaul 1dd8cf5abe generator 2023-11-02 13:03:32 +05:30
sayakpaul 10739166e2 fix: tests. 2023-11-02 11:55:51 +05:30
sayakpaul c1d8b882ee use return_worflow at the end. 2023-11-02 10:59:51 +05:30
sayakpaul 6697144a7d fix: toc 2023-11-02 10:58:29 +05:30
sayakpaul 6af69f5639 Merge branch 'main' into feat/workflows 2023-11-02 10:40:28 +05:30
sayakpaul c28ea5e6c6 separate api page 2023-11-02 10:39:12 +05:30
sayakpaul cb8902394d clean docs. 2023-11-02 10:35:41 +05:30
Sayak Paul 5c8d5df564 Apply suggestions from code review
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
2023-11-02 10:31:30 +05:30
sayakpaul 474df7a6b1 clean 2023-11-02 10:27:28 +05:30
sayakpaul 047cede64a clean up tests 2023-10-29 17:28:03 +05:30
sayakpaul ce9e17c547 handle list of generators. 2023-10-29 17:20:48 +05:30
sayakpaul abdb4ffbc6 more fixes. 2023-10-29 17:11:39 +05:30
sayakpaul 87174fc208 load from a workflow directly 2023-10-29 17:09:36 +05:30
sayakpaul 425e75bc79 remove prompt_embeds. 2023-10-29 17:06:28 +05:30
sayakpaul f14cf8f4aa boom 2023-10-29 13:47:05 +05:30
sayakpaul 2f7725edf4 style 2023-10-29 11:21:40 +05:30
sayakpaul a11ce647de fix 2023-10-29 11:21:19 +05:30
sayakpaul 6f65f3ad3e fix 2023-10-29 11:07:02 +05:30
sayakpaul cbc835b870 correct output. 2023-10-29 10:59:31 +05:30
sayakpaul 8550a86a17 done 2023-10-29 10:57:52 +05:30
sayakpaul 477cc9a82a remove print 2023-10-29 10:56:23 +05:30
sayakpaul 28c8e93179 fix 2023-10-29 10:48:24 +05:30
sayakpaul 3f90e07228 fix: 2023-10-29 10:46:46 +05:30
sayakpaul 67f7757048 fix 2023-10-29 10:39:40 +05:30
sayakpaul 0e40d6ffd6 fix 2023-10-29 10:37:52 +05:30
sayakpaul 3eab48f883 stringify device. 2023-10-29 10:31:58 +05:30
sayakpaul b3a675288d corrections to doc. 2023-10-29 10:31:13 +05:30
sayakpaul 9099f51c5e fix? 2023-10-29 10:21:34 +05:30
sayakpaul 5dcd8c541e debug 2023-10-29 10:20:20 +05:30
sayakpaul ac73c86610 debug 2023-10-29 10:19:58 +05:30
sayakpaul 9836b61fa8 debug 2023-10-29 10:19:09 +05:30
sayakpaul 1829b9485c debug 2023-10-29 10:18:07 +05:30
sayakpaul 6e514c2b6a debug 2023-10-29 10:13:37 +05:30
sayakpaul b5888b4704 typo 2023-10-29 10:11:43 +05:30
sayakpaul 003220ba36 typo 2023-10-29 10:10:14 +05:30
sayakpaul f221631d3c typo 2023-10-29 10:08:55 +05:30
sayakpaul adcaba0a23 typo 2023-10-29 10:08:14 +05:30
sayakpaul bd1f78e6cd typo 2023-10-29 10:05:19 +05:30
sayakpaul ecfa79b673 fix: generator 2023-10-29 10:04:48 +05:30
sayakpaul ab1d58872b debug generator 2023-10-29 09:56:57 +05:30
sayakpaul 020b4a4ad7 fix: generator update 2023-10-29 09:38:47 +05:30
sayakpaul e510c3d4d5 fix import 2023-10-29 09:35:41 +05:30
sayakpaul f0418e8896 serialize scheduler info too. 2023-10-29 09:27:27 +05:30
sayakpaul c4eebd9c1a randomize seed. 2023-10-29 09:05:08 +05:30
sayakpaul 84de851116 change the path 2023-10-28 10:11:55 +05:30
sayakpaul eae28adf40 some more notes 2023-10-28 10:08:54 +05:30
Sayak Paul 450198061e Apply suggestions from code review
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
2023-10-28 10:04:17 +05:30
sayakpaul be3ff851c6 fix: test 2023-10-27 18:48:07 +05:30
sayakpaul c3792ba3e0 fix: 2023-10-27 15:52:44 +05:30
sayakpaul 7933f2ac18 fix: 2023-10-27 15:44:55 +05:30
sayakpaul 1b7a7c27d3 fix: 2023-10-27 15:39:31 +05:30
sayakpaul 4689d759fd fix test 2023-10-27 15:37:42 +05:30
sayakpaul d06cc7eb6f fix: test 2023-10-27 15:30:54 +05:30
sayakpaul b5780adf46 fix: test 2023-10-27 15:29:09 +05:30
sayakpaul d07d3f1642 fix: tokenizer 2023-10-27 15:28:30 +05:30
sayakpaul e5a69ff497 fix: import 2023-10-27 15:25:14 +05:30
sayakpaul 43846e14a1 fix: import 2023-10-27 15:22:57 +05:30
sayakpaul 7fa7259ded fix: doc 2023-10-27 15:17:38 +05:30
sayakpaul b576a1dc47 fix: doc 2023-10-27 15:06:27 +05:30
sayakpaul 9fc37d9dc7 add: hub related staging tests 2023-10-27 14:54:45 +05:30
sayakpaul 0a298f55fc fix: _name_or_path 2023-10-27 14:44:14 +05:30
sayakpaul 7382344fed add: test 2023-10-27 14:33:08 +05:30
sayakpaul 93ce75f4e5 add: entry to toctree 2023-10-27 14:25:29 +05:30
sayakpaul b42bcb86ea add: doc 2023-10-27 14:22:48 +05:30
sayakpaul 47fe2d0d2e fix pop call. 2023-10-27 13:42:41 +05:30
sayakpaul 6a59219e81 more rigorous 2023-10-27 13:41:14 +05:30
sayakpaul f8eff79b82 fxi 2023-10-27 13:31:13 +05:30
sayakpaul d01f2f678a rigid call arguments. 2023-10-27 13:30:10 +05:30
sayakpaul ae0a268f8e add to controlnet 2023-10-27 13:24:16 +05:30
sayakpaul fce889f19b make push_to_hub False by default 2023-10-27 13:00:27 +05:30
sayakpaul 01b3f64549 fix: save 2023-10-27 12:58:51 +05:30
sayakpaul 53c65e3d19 add: support for serializing generator device. 2023-10-27 12:12:34 +05:30
sayakpaul 6d48af1a46 add back the comment 2023-10-27 12:08:05 +05:30
sayakpaul 66d3fd6732 remove filename stuff from config utils. 2023-10-27 12:06:58 +05:30
sayakpaul 0cc943b757 Merge branch 'main' into feat/workflows 2023-10-27 12:00:28 +05:30
Sayak Paul 41a74f8474 Apply suggestions from code review
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
2023-10-27 12:00:01 +05:30
sayakpaul 9a8b5f7cd8 address patrick's comments./ 2023-10-19 10:59:58 +05:30
sayakpaul ebf2addb86 Merge branch 'main' into feat/workflows 2023-10-19 10:56:12 +05:30
sayakpaul 03bfdff59a Empty-Commit 2023-10-18 14:40:48 +05:30
sayakpaul ff5cd58aa1 support basic lora only for non-peft for now. 2023-10-18 12:39:41 +05:30
sayakpaul c1c11a6747 fix: lora 2023-10-18 12:25:15 +05:30
sayakpaul 49e06fdd2c fix: lora population. 2023-10-18 12:23:41 +05:30
sayakpaul f874578a4d fix: lora population. 2023-10-18 12:23:03 +05:30
Sayak Paul c0e1c6348f Merge branch 'main' into feat/workflows 2023-10-18 12:15:17 +05:30
sayakpaul 231e8314dd feat: support passing filename 2023-10-18 11:50:43 +05:30
sayakpaul 21e5bb65e5 replace the __call__ attribute of the class, not the instance 2023-10-18 11:40:09 +05:30
sayakpaul 74e766c0b8 quality 2023-10-18 11:20:36 +05:30
sayakpaul 452bf4fa05 hmm almost 2023-10-18 11:16:44 +05:30
sayakpaul 55c47bc751 let's see 2023-10-18 10:56:25 +05:30
sayakpaul eaae2df25c debugging. 2023-10-17 18:52:35 +05:30
sayakpaul d612b5435b debug 2023-10-17 18:49:03 +05:30
sayakpaul 1dc9854968 copying helps? 2023-10-17 18:44:59 +05:30
sayakpaul 18a756f0ad style. 2023-10-17 16:57:30 +05:30
Sayak Paul 4993c8ba63 Merge branch 'main' into feat/workflows 2023-10-17 16:54:09 +05:30
sayakpaul ed9acd6426 remove print 2023-10-17 16:36:05 +05:30
sayakpaul a69e3d15c1 debug. 2023-10-17 16:34:09 +05:30
sayakpaul 4731f65ed4 debug. 2023-10-17 16:29:14 +05:30
sayakpaul 9adaa1739a debug 2023-10-17 16:16:10 +05:30
sayakpaul af282b7f4b debug 2023-10-17 16:13:32 +05:30
sayakpaul 3d6637b65d partial 2023-10-17 16:10:43 +05:30
sayakpaul 800b7a0fda morr 2023-10-17 16:04:53 +05:30
sayakpaul 91c1c1f1f6 apply styling 2023-10-17 16:03:44 +05:30
sayakpaul 21d19bbc44 workflow_filename -> filename 2023-10-17 16:03:25 +05:30
sayakpaul 9d0bcd48f1 debug 2023-10-17 15:50:45 +05:30
sayakpaul 9ee8b0a070 debug 2023-10-17 15:49:34 +05:30
sayakpaul b5fd337875 debug 2023-10-17 15:47:41 +05:30
sayakpaul f08f40bde1 debug 2023-10-17 15:46:18 +05:30
sayakpaul c5ff8cd943 debug 2023-10-17 15:45:05 +05:30
sayakpaul 319456049a seed. 2023-10-17 15:28:03 +05:30
sayakpaul f6c0878fc6 callables should not be serialized too. 2023-10-17 15:24:26 +05:30
sayakpaul e590b73cc1 override pop too for feature compatibility 2023-10-17 15:22:59 +05:30
sayakpaul aa7839c1c7 pop from internal dict too. 2023-10-17 15:22:22 +05:30
sayakpaul fc609e308f more fix 2023-10-17 15:19:55 +05:30
sayakpaul 7b85bfe3e5 fix: signature 2023-10-17 15:18:31 +05:30
sayakpaul eff03fd054 override method 2023-10-17 15:16:21 +05:30
sayakpaul 73dcc17ff1 remove unneeded comment 2023-10-17 15:05:54 +05:30
sayakpaul b149800269 remove unneeded comment 2023-10-17 15:04:01 +05:30
sayakpaul 2d1cd20afe stronger check 2023-10-17 15:03:32 +05:30
sayakpaul 45c5656bad make config_name a part of the dict. 2023-10-17 14:41:25 +05:30
sayakpaul 2b48d8572d save_pretrained() to workflow so that it has push_to_hub 2023-10-17 14:31:15 +05:30
sayakpaul e710121a9b save_pretrained() to workflow so that it has push_to_hub 2023-10-17 14:28:42 +05:30
sayakpaul 50769e058b remove torch tensor warning as it might complicate things 2023-10-17 14:19:34 +05:30
sayakpaul 0bd97735dc debug 2023-10-17 14:16:43 +05:30
sayakpaul 930ca765f4 debug 2023-10-17 14:13:34 +05:30
sayakpaul 807c2ca13f debug 2023-10-17 14:04:22 +05:30
sayakpaul ad725977cc patch call. 2023-10-17 13:52:54 +05:30
sayakpaul 97ae043f8a update docstrings. 2023-10-17 13:46:49 +05:30
sayakpaul 1ab81a6db4 update progress. 2023-10-17 13:05:14 +05:30
sayakpaul a6a0277713 include pipeline name in the workflow 2023-10-17 12:28:45 +05:30
sayakpaul ba0b1e857c improve docstring 2023-10-17 12:24:35 +05:30
sayakpaul d8e6f38db4 change method desc. 2023-10-17 12:22:40 +05:30
sayakpaul ef94a008d2 handle torch.tensor. 2023-10-17 12:22:02 +05:30
sayakpaul 29d0aa887c remove components from workflows. 2023-10-17 12:14:43 +05:30
sayakpaul 5f19b66d5a resolve conflicts 2023-10-15 12:04:25 +05:30
sayakpaul 96c55d4c5a fix 2023-10-15 11:58:20 +05:30
sayakpaul e3611e325b properly set lora_info 2023-10-15 11:40:00 +05:30
sayakpaul d5d31e0ae3 add: support for lora. 2023-10-15 11:09:03 +05:30
sayakpaul a62b77ff6e include todos. 2023-10-15 10:52:11 +05:30
sayakpaul a8a1378987 fix 2023-10-15 10:48:40 +05:30
Sayak Paul e8e09e48ea Apply suggestions from code review
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
2023-08-30 14:14:26 +05:30
Sayak Paul ac295055ce add unifinished implementation of _update_call() 2023-08-29 14:45:09 +05:30
Sayak Paul 43e4e841f9 add: workflows. 2023-08-29 13:59:22 +05:30
10 changed files with 835 additions and 4 deletions
+4
View File
@@ -19,6 +19,8 @@
title: Train a diffusion model
- local: tutorials/using_peft_for_inference
title: Inference with PEFT
- local: tutorials/workflows
title: Working with workflows
title: Tutorials
- sections:
- sections:
@@ -178,6 +180,8 @@
title: Logging
- local: api/outputs
title: Outputs
- local: api/workflows
title: Shareable workflows
title: Main Classes
- sections:
- local: api/models/overview
+7
View File
@@ -0,0 +1,7 @@
# Shareable workflows
Workflows provide a simple mechanism to share your 🤗 Diffusers pipeline call arguments and scheduler configuration, making it easier to reproduce results.
## Workflow
[[autodoc]] workflow_utils.Workflow
+333
View File
@@ -0,0 +1,333 @@
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Working with workflows
<Tip warning={true}>
🧪 Workflow is experimental and its APIs can change in the future.
</Tip>
Workflows provide a simple mechanism to share your pipeline call arguments and scheduler configuration, making it easier to reproduce results.
## Serializing a workflow
A [`Workflow`] object provides all the argument values in the `__call__()` of a pipeline. Add `return_workflow=True` to return a `Workflow` object.
```python
from diffusers import DiffusionPipeline
import torch
pipeline = DiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, safety_checker=None
).to("cuda")
outputs = pipeline(
"A painting of a horse",
num_inference_steps=15,
generator=torch.manual_seed(0),
return_workflow=True
)
workflow = outputs.workflow
```
<Tip warning={true}>
It's mandatory to specify the `generator` when `return_workflow` is set to True.
</Tip>
If you look at this specific workflow, you'll see values like the number of inference steps, guidance scale, and height and width as well as the scheduler details:
```bash
{'prompt': 'A painting of a horse',
'height': None,
'width': None,
'num_inference_steps': 15,
'guidance_scale': 7.5,
'negative_prompt': None,
'eta': 0.0,
'latents': None,
'prompt_embeds': None,
'negative_prompt_embeds': None,
'output_type': 'pil',
'return_dict': True,
'callback': None,
'callback_steps': 1,
'cross_attention_kwargs': None,
'guidance_rescale': 0.0,
'clip_skip': None,
'generator_seed': 0,
'generator_device': device(type='cpu'),
'_name_or_path': 'runwayml/stable-diffusion-v1-5',
'scheduler_config': FrozenDict([('num_train_timesteps', 1000),
('beta_start', 0.00085),
('beta_end', 0.012),
('beta_schedule', 'scaled_linear'),
('trained_betas', None),
('skip_prk_steps', True),
('set_alpha_to_one', False),
('prediction_type', 'epsilon'),
('timestep_spacing', 'leading'),
('steps_offset', 1),
('_use_default_values', ['prediction_type', 'timestep_spacing']),
('_class_name', 'PNDMScheduler'),
('_diffusers_version', '0.6.0'),
('clip_sample', False)])}
```
Once you have generated a workflow object, you can serialize it with [`~Workflow.save_workflow`]:
```python
outputs.workflow.save_workflow("my-simple-workflow-sd")
```
By default, your workflows are saved as `diffusion_workflow.json`, but you can give them a specific name with the `filename` argument:
```python
outputs.workflow.save_workflow("my-simple-workflow-sd", filename="my_workflow.json")
```
You can also set `push_to_hub=True` in [`~Workflow.save_workflow`] to directly push the workflow object to the Hub.
## Loading a workflow
You can load a workflow in a pipeline with [`~DiffusionPipeline.load_workflow`]:
```python
from diffusers import DiffusionPipeline
import torch
pipeline = DiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16
).to("cuda")
pipeline.load_workflow("sayakpaul/my-simple-workflow-sd")
```
Once the pipeline is loaded with the desired workflow, it's ready to be called:
```python
image = pipeline().images[0]
```
By default, while loading a workflow, the scheduler of the underlying pipeline from the workflow isn't modified but you can change it by adding `load_scheduler=True`:
```
pipeline.load_workflow("sayakpaul/my-simple-workflow-sd", load_scheduler=True)
```
This is particularly useful if you have changed the scheduler after loading a pipeline.
You can also override the pipeline call arguments. For example, to add a `negative_prompt`:
```python
image = pipeline(negative_prompt="bad quality").images[0]
```
Loading from a workflow is possible by specifying the `filename` argument inside the [`DiffusionPipeline.load_workflow`] method.
A workflow doesn't necessarily have to be used with the same pipeline that generated it. You can use it with a different pipeline too:
```python
from diffusers import DiffusionPipeline
import torch
pipeline = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
).to("cuda")
pipeline.load_workflow("sayakpaul/my-simple-workflow-sd")
image = pipeline().images[0]
```
However, make sure to thoroughly inspect the values you are calling the pipeline with, in this case.
Loading from a local workflow is also possible:
```python
from diffusers import DiffusionPipeline
import torch
pipeline = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
).to("cuda")
pipeline.load_workflow("path_to_local_dir")
image = pipeline().images[0]
```
Alternatively, if you want to load a workflow file and populate the pipeline arguments manually:
```python
from diffusers import DiffusionPipeline
import json
import torch
pipeline = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
).to("cuda")
with open("path_to_workflow_file.json") as f:
workflow = json.load(f)
pipeline.load_workflow(workflow)
images = pipeline().images[0]
```
## Unsupported serialization types
Image-to-image pipelines like [`StableDiffusionControlNetPipeline`] accept one or more images in their `call` method. Currently, workflows don't support serializing `call` arguments that are of type `PIL.Image.Image` or `List[PIL.Image.Image]`. To make those pipelines work with workflows, you need to pass the images manually.
Let's say you generated the workflow below:
```python
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
from diffusers.utils import load_image
import numpy as np
import torch
import cv2
from PIL import Image
# download an image
image = load_image(
"https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png"
)
image = np.array(image)
# get canny image
image = cv2.Canny(image, 100, 200)
image = image[:, :, None]
image = np.concatenate([image, image, image], axis=2)
canny_image = Image.fromarray(image)
# load control net and stable diffusion v1-5
controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
pipe = StableDiffusionControlNetPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
)
pipe.enable_model_cpu_offload()
# generate image
generator = torch.manual_seed(0)
outputs = pipe(
prompt="futuristic-looking office",
image=canny_image,
num_inference_steps=20,
generator=generator,
return_workflow=True
)
workflow = outputs.workflow
```
If you look at the workflow, you'll see the image that was passed to the pipeline isn't included:
```bash
{'prompt': 'futuristic-looking office',
'height': None,
'width': None,
'num_inference_steps': 20,
'guidance_scale': 7.5,
'negative_prompt': None,
'eta': 0.0,
'latents': None,
'prompt_embeds': None,
'negative_prompt_embeds': None,
'output_type': 'pil',
'return_dict': True,
'callback': None,
'callback_steps': 1,
'cross_attention_kwargs': None,
'controlnet_conditioning_scale': 1.0,
'guess_mode': False,
'control_guidance_start': 0.0,
'control_guidance_end': 1.0,
'clip_skip': None,
'generator_seed': 0,
'generator_device': 'cpu',
'_name_or_path': 'runwayml/stable-diffusion-v1-5',
'scheduler_config': FrozenDict([('num_train_timesteps', 1000),
('beta_start', 0.00085),
('beta_end', 0.012),
('beta_schedule', 'scaled_linear'),
('trained_betas', None),
('solver_order', 2),
('prediction_type', 'epsilon'),
('thresholding', False),
('dynamic_thresholding_ratio', 0.995),
('sample_max_value', 1.0),
('predict_x0', True),
('solver_type', 'bh2'),
('lower_order_final', True),
('disable_corrector', []),
('solver_p', None),
('use_karras_sigmas', False),
('timestep_spacing', 'linspace'),
('steps_offset', 1),
('_use_default_values',
['lower_order_final',
'sample_max_value',
'solver_p',
'dynamic_thresholding_ratio',
'thresholding',
'solver_type',
'prediction_type',
'predict_x0',
'use_karras_sigmas',
'disable_corrector',
'timestep_spacing',
'solver_order']),
('skip_prk_steps', True),
('set_alpha_to_one', False),
('_class_name', 'PNDMScheduler'),
('_diffusers_version', '0.6.0'),
('clip_sample', False)])}
```
Let's serialize the workflow and reload the pipeline to see what happens when you try to use it.
```python
workflow.save_workflow("my-simple-workflow-sd", filename="controlnet_simple.json", push_to_hub=True)
```
Then load the workflow into [`StableDiffusionControlNetPipeline`]:
```python
# load control net and stable diffusion v1-5
controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
pipe = StableDiffusionControlNetPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
)
pipe.enable_model_cpu_offload()
pipe.load_workflow("sayakpaul/my-simple-workflow-sd", filename="controlnet_simple.json")
```
If you try to generate an image now, it'll return the following error:
```bash
TypeError: image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is <class 'NoneType'>
```
To resolve the error, manually pass the conditioning image `canny_image`:
```python
image = pipe(image=canny_image).images[0]
```
Other unsupported serialization types include:
* LoRA checkpoints: any information from LoRA checkpoints that might be loaded into a pipeline isn't serialized. Workflows generated from pipelines loaded with a LoRA checkpoint should be handled cautiously! You should ensure the LoRA checkpoint is loaded into the pipeline first before loading the corresponding workflow.
* Call arguments including the following types: `torch.Tensor`, `np.ndarray`, `Callable`, `PIL.Image.Image`, and `List[PIL.Image.Image]`.
@@ -752,6 +752,7 @@ class StableDiffusionControlNetPipeline(
guess_mode: bool = False,
control_guidance_start: Union[float, List[float]] = 0.0,
control_guidance_end: Union[float, List[float]] = 1.0,
return_workflow: bool = False,
clip_skip: Optional[int] = None,
):
r"""
@@ -824,6 +825,8 @@ class StableDiffusionControlNetPipeline(
The percentage of total steps at which the ControlNet starts applying.
control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
The percentage of total steps at which the ControlNet stops applying.
return_workflow (`bool`, *optional*, defaults to `False`):
Whether to return used pipeline call arguments.
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
@@ -837,6 +840,14 @@ class StableDiffusionControlNetPipeline(
second element is a list of `bool`s indicating whether the corresponding generated image contains
"not-safe-for-work" (nsfw) content.
"""
# We do this first to capture the "True" call values. If we do this at a later point in time,
# we cannot ensure that the call values weren't changed during the process.
workflow = None
if return_workflow:
if generator is None:
raise ValueError(f"`generator` cannot be None when `return_workflow` is {return_workflow}.")
workflow = self.populate_workflow_from_pipeline()
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
# align format for control guidance
@@ -1075,6 +1086,11 @@ class StableDiffusionControlNetPipeline(
self.maybe_free_model_hooks()
if not return_dict:
return (image, has_nsfw_concept)
outputs = (image, has_nsfw_concept)
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
if return_workflow:
outputs += (workflow,)
return outputs
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept, workflow=workflow)
+118
View File
@@ -22,6 +22,7 @@ import re
import sys
import warnings
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Union
@@ -54,7 +55,9 @@ from ..utils import (
logging,
numpy_to_pil,
)
from ..utils.constants import WORKFLOW_NAME
from ..utils.torch_utils import is_compiled_module
from ..workflow_utils import _NON_CALL_ARGUMENTS, Workflow
if is_transformers_available():
@@ -64,6 +67,7 @@ if is_transformers_available():
from transformers.utils import SAFE_WEIGHTS_NAME as TRANSFORMERS_SAFE_WEIGHTS_NAME
from transformers.utils import WEIGHTS_NAME as TRANSFORMERS_WEIGHTS_NAME
from ..utils import FLAX_WEIGHTS_NAME, ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME, PushToHubMixin
@@ -2075,3 +2079,117 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
for module in modules:
module.set_attention_slice(slice_size)
def populate_workflow_from_pipeline(self) -> Dict:
r"""Populates the call arguments in a dictionary.
Returns:
[`Workflow`]: A dictionary containing the details of the pipeline call arguments and (optionally) LoRA
checkpoint details.
"""
# A `Workflow` object is an extended Python dictionary. So, all regular dictionary methods
# apply to it.
workflow = Workflow()
signature = inspect.signature(self.__call__)
argument_names = [param.name for param in signature.parameters.values()]
call_arg_values = inspect.getargvalues(inspect.currentframe().f_back).locals
# Populate call arguments.
call_arguments = {
arg: call_arg_values[arg]
for arg in argument_names
if arg != "return_workflow"
and "image" not in arg
and not isinstance(call_arg_values[arg], (torch.Tensor, np.ndarray, Callable))
}
workflow.update(call_arguments)
# Handle generator device and seed.
generator = workflow["generator"]
if isinstance(generator, list):
for g in generator:
if "generator_seed" not in workflow:
workflow.update({"generator_seed": [g.initial_seed()]})
workflow.update({"generator_device": [str(g.device)]})
workflow.update({"generator_state": g.get_state().numpy().tolist()})
else:
workflow["generator_seed"].append(g.initial_seed())
workflow["generator_device"].append(g.device)
workflow["generator_state"].append(g.get_state().numpy().tolist())
else:
workflow.update({"generator_seed": generator.initial_seed()})
workflow.update({"generator_device": str(generator.device)})
workflow.update({"generator_state": generator.get_state().numpy().tolist()})
workflow.pop("generator")
# Handle pipeline-level things.
if hasattr(self, "config") and hasattr(self.config, "_name_or_path"):
pipeline_config_name_or_path = self.config._name_or_path
else:
pipeline_config_name_or_path = None
workflow["_name_or_path"] = pipeline_config_name_or_path
workflow["scheduler_config"] = self.scheduler.config
return workflow
def load_workflow(
self,
workflow_id_or_path: Union[str, dict],
filename: Optional[str] = None,
):
r"""Loads a workflow from the Hub or from a local path. Also patches the pipeline call arguments with values from the
workflow.
Args:
workflow_id_or_path (`str` or `dict`):
Can be either:
- A string, the workflow id (for example `sayakpaul/sdxl-workflow`) of a workflow hosted on the
Hub.
- A path to a directory (for example `./my_workflow_directory`) containing the workflow file with
[`Workflow.save_workflow`] or [`Workflow.push_to_hub`].
- A Python dictionary.
filename (`str`, *optional*):
Optional name of the workflow file to load. Especially useful when working with multiple workflow
files.
"""
filename = filename or WORKFLOW_NAME
# Load workflow.
if not isinstance(workflow_id_or_path, dict):
if os.path.isdir(workflow_id_or_path):
workflow_filepath = os.path.join(workflow_id_or_path, filename)
elif os.path.isfile(workflow_id_or_path):
workflow_filepath = workflow_id_or_path
else:
workflow_filepath = hf_hub_download(repo_id=workflow_id_or_path, filename=filename)
workflow = self._dict_from_json_file(workflow_filepath)
else:
workflow = workflow_id_or_path
# We make a copy of the original workflow and operate on it.
workflow_copy = dict(workflow.items())
# Handle generator.
seed = workflow_copy.pop("generator_seed")
device = workflow_copy.pop("generator_device", "cpu")
last_known_state = workflow_copy.pop("generator_state")
if isinstance(seed, list):
generator = [
torch.Generator(device=d).manual_seed(s).set_state(torch.from_numpy(np.array(lst)).byte())
for s, d, lst in zip(seed, device, last_known_state)
]
else:
last_known_state = torch.from_numpy(np.array(last_known_state)).byte()
generator = torch.Generator(device=device).manual_seed(seed).set_state(last_known_state)
workflow_copy.update({"generator": generator})
# Handle non-call arguments.
final_call_args = {k: v for k, v in workflow_copy.items() if k not in _NON_CALL_ARGUMENTS}
# Handle the call here.
partial_call = partial(self.__call__, **final_call_args)
setattr(self.__class__, "__call__", partial_call)
@@ -19,10 +19,13 @@ class StableDiffusionPipelineOutput(BaseOutput):
nsfw_content_detected (`List[bool]`)
List indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content or
`None` if safety checking could not be performed.
workflow (`dict`):
Dictionary containing pipeline component configurations and call arguments
"""
images: Union[List[PIL.Image.Image], np.ndarray]
nsfw_content_detected: Optional[List[bool]]
workflow: Optional[dict] = None
if is_flax_available():
@@ -623,6 +623,7 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
return_dict: bool = True,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guidance_rescale: float = 0.0,
return_workflow: bool = False,
clip_skip: Optional[int] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
@@ -677,6 +678,8 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
using zero terminal SNR.
return_workflow (`bool`, *optional*, defaults to `False`):
Whether to return used pipeline call arguments.
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
@@ -699,6 +702,13 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
second element is a list of `bool`s indicating whether the corresponding generated image contains
"not-safe-for-work" (nsfw) content.
"""
# We do this first to capture the "True" call values. If we do this at a later point in time,
# we cannot ensure that the call values weren't changed during the process.
workflow = None
if return_workflow:
if generator is None:
raise ValueError(f"`generator` cannot be None when `return_workflow` is {return_workflow}.")
workflow = self.populate_workflow_from_pipeline()
callback = kwargs.pop("callback", None)
callback_steps = kwargs.pop("callback_steps", None)
@@ -855,6 +865,11 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
self.maybe_free_model_hooks()
if not return_dict:
return (image, has_nsfw_concept)
outputs = (image, has_nsfw_concept)
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
if return_workflow:
outputs += (workflow,)
return outputs
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept, workflow=workflow)
+3
View File
@@ -14,6 +14,7 @@
import importlib
import os
import numpy as np
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE, hf_cache_home
from packaging import version
@@ -32,11 +33,13 @@ FLAX_WEIGHTS_NAME = "diffusion_flax_model.msgpack"
ONNX_WEIGHTS_NAME = "model.onnx"
SAFETENSORS_WEIGHTS_NAME = "diffusion_pytorch_model.safetensors"
ONNX_EXTERNAL_WEIGHTS_NAME = "weights.pb"
WORKFLOW_NAME = "diffusion_workflow.json"
HUGGINGFACE_CO_RESOLVE_ENDPOINT = os.environ.get("HF_ENDPOINT", "https://huggingface.co")
DIFFUSERS_CACHE = default_cache_path
DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules"))
DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"]
MAX_SEED = np.iinfo(np.int32).max
# Below should be `True` if the current version of `peft` and `transformers` are compatible with
# PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are
+161
View File
@@ -0,0 +1,161 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for managing workflows."""
import json
import os
from pathlib import PosixPath
from typing import Union
import numpy as np
from huggingface_hub import create_repo
from . import __version__
from .utils import PushToHubMixin, logging
from .utils.constants import WORKFLOW_NAME
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
_NON_CALL_ARGUMENTS = {"_name_or_path", "scheduler_config", "_class_name", "_diffusers_version"}
class Workflow(dict, PushToHubMixin):
"""Class sub-classing from native Python dict to have support for interacting with the Hub."""
config_name = None
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.config_name = WORKFLOW_NAME
self._internal_dict = {}
def __setitem__(self, __key, __value):
self._internal_dict[__key] = __value
return super().__setitem__(__key, __value)
def update(self, __m, **kwargs):
self._internal_dict.update(__m, **kwargs)
super().update(__m, **kwargs)
def pop(self, key, *args):
self._internal_dict.pop(key, *args)
return super().pop(key, *args)
# Copied from diffusers.configuration_utils.ConfigMixin.to_json_string
def to_json_string(self) -> str:
"""
Serializes the configuration instance to a JSON string.
Returns:
`str`:
String containing all the attributes that make up the configuration instance in JSON format.
"""
config_dict = self._internal_dict if hasattr(self, "_internal_dict") else {}
config_dict["_class_name"] = self.__class__.__name__
config_dict["_diffusers_version"] = __version__
def to_json_saveable(value):
if isinstance(value, np.ndarray):
value = value.tolist()
elif isinstance(value, PosixPath):
value = str(value)
return value
config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()}
# Don't save "_ignore_files" or "_use_default_values"
config_dict.pop("_ignore_files", None)
config_dict.pop("_use_default_values", None)
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
def save_workflow(
self,
save_directory: Union[str, os.PathLike],
push_to_hub: bool = False,
filename: str = WORKFLOW_NAME,
**kwargs,
):
"""
Saves a workflow to a directory.
Args:
save_directory (`str` or `os.PathLike`):
Directory where the workflow JSON file will be saved (will be created if it does not exist).
push_to_hub (`bool`, *optional*, defaults to `False`):
Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
namespace).
filename (`str`, *optional*, defaults to `workflow.json`):
Optional filename to use to serialize the workflow JSON.
kwargs (`Dict[str, Any]`, *optional*):
Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
"""
self.config_name = filename
if os.path.isfile(save_directory):
raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
os.makedirs(save_directory, exist_ok=True)
output_config_file = os.path.join(save_directory, self.config_name)
with open(output_config_file, "w", encoding="utf-8") as writer:
writer.write(self.to_json_string())
logger.info(f"Configuration saved in {output_config_file}")
if push_to_hub:
commit_message = kwargs.pop("commit_message", None)
private = kwargs.pop("private", False)
create_pr = kwargs.pop("create_pr", False)
token = kwargs.pop("token", None)
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id
self._upload_folder(
save_directory,
repo_id,
token=token,
commit_message=commit_message,
create_pr=create_pr,
)
def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
push_to_hub: bool = False,
filename: str = WORKFLOW_NAME,
**kwargs,
):
"""
Saves a workflow to a directory. This internally calls [`Workflow.save_workflow`], This method exists to have
feature parity with [`PushToHubMixin.push_to_hub`].
Args:
save_directory (`str` or `os.PathLike`):
Directory where the workflow JSON file will be saved (will be created if it does not exist).
push_to_hub (`bool`, *optional*, defaults to `False`):
Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
namespace).
filename (`str`, *optional*, defaults to `workflow.json`):
Optional filename to use to serialize the workflow JSON.
kwargs (`Dict[str, Any]`, *optional*):
Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
"""
self.save_workflow(
save_directory=save_directory,
push_to_hub=push_to_hub,
filename=filename,
**kwargs,
)
+171
View File
@@ -0,0 +1,171 @@
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import tempfile
import unittest
import uuid
import numpy as np
import torch
from huggingface_hub import delete_repo, hf_hub_download
from test_utils import TOKEN, USER, is_staging_test
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
from diffusers import (
AutoencoderKL,
DDIMScheduler,
StableDiffusionPipeline,
UNet2DConditionModel,
)
from diffusers.utils.constants import WORKFLOW_NAME
from diffusers.utils.testing_utils import torch_device
from diffusers.workflow_utils import Workflow
class WorkflowFastTests(unittest.TestCase):
def get_dummy_components(self):
torch.manual_seed(0)
unet = UNet2DConditionModel(
block_out_channels=(4, 8),
layers_per_block=1,
sample_size=32,
in_channels=4,
out_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
cross_attention_dim=32,
norm_num_groups=2,
)
scheduler = DDIMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
clip_sample=False,
set_alpha_to_one=False,
)
torch.manual_seed(0)
vae = AutoencoderKL(
block_out_channels=[4, 8],
in_channels=3,
out_channels=3,
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
latent_channels=4,
norm_num_groups=2,
)
torch.manual_seed(0)
text_encoder_config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=32,
intermediate_size=64,
layer_norm_eps=1e-05,
num_attention_heads=8,
num_hidden_layers=3,
pad_token_id=1,
vocab_size=1000,
)
text_encoder = CLIPTextModel(text_encoder_config)
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
components = {
"unet": unet,
"scheduler": scheduler,
"vae": vae,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"safety_checker": None,
"feature_extractor": None,
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
inputs = {
"prompt": "A painting of a squirrel eating a burger",
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 6.0,
"output_type": "np",
}
return inputs
def test_workflow_with_stable_diffusion(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
sd_pipe = StableDiffusionPipeline(**components)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
output = sd_pipe(**inputs, return_workflow=True)
image = output.images
image_slice = image[0, -3:, -3:, -1]
with tempfile.TemporaryDirectory() as tmpdirname:
output.workflow.save_pretrained(tmpdirname)
components = self.get_dummy_components()
sd_pipe = StableDiffusionPipeline(**components)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
sd_pipe.load_workflow(tmpdirname)
inputs = self.get_dummy_inputs(device)
output = sd_pipe(**inputs)
image = output.images
workflow_image_slice = image[0, -3:, -3:, -1]
self.assertTrue(np.allclose(image_slice, workflow_image_slice))
@is_staging_test
class WorkflowPushToHubTester(unittest.TestCase):
identifier = uuid.uuid4()
repo_id = f"test-workflow-{identifier}"
org_repo_id = f"valid_org/{repo_id}-org"
def compare_workflow_values(self, repo_id: str, actual_workflow: dict):
local_path = hf_hub_download(repo_id=repo_id, filename=WORKFLOW_NAME, token=TOKEN)
with open(local_path) as f:
locally_loaded_workflow = json.load(f)
for k in actual_workflow:
assert actual_workflow[k] == locally_loaded_workflow[k]
def test_push_to_hub(self):
workflow = Workflow()
workflow.update({"prompt": "hey", "num_inference_steps": 25})
workflow.push_to_hub(self.repo_id, token=TOKEN)
self.compare_workflow_values(repo_id=f"{USER}/{self.repo_id}", actual_workflow=workflow)
# Reset repo
delete_repo(token=TOKEN, repo_id=self.repo_id)
def test_push_to_hub_in_organization(self):
workflow = Workflow()
workflow.update({"prompt": "hey", "num_inference_steps": 25})
workflow.push_to_hub(self.org_repo_id, token=TOKEN)
self.compare_workflow_values(repo_id=self.org_repo_id, actual_workflow=workflow)
# Reset repo
delete_repo(token=TOKEN, repo_id=self.org_repo_id)