Compare commits
110 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| b08fc2d9d5 | |||
| c46711e895 | |||
| a031abdc89 | |||
| 1d686bac81 | |||
| 0a401b95b7 | |||
| 664e931bcb | |||
| 88bdd97ccd | |||
| 08b453e382 | |||
| 2a111bc9fe | |||
| 16e6997f0d | |||
| 3b9b98656e | |||
| b65928b556 | |||
| 6bf1ca2c79 | |||
| 978dec9014 | |||
| 79a7ab92d1 | |||
| c2717317f0 | |||
| bf7f9b49a2 | |||
| e192ae08d3 | |||
| 87a09d66f3 | |||
| 75ada25048 | |||
| 2243a59483 | |||
| 466d32c442 | |||
| 20ba1fdbbd | |||
| ab6672fecd | |||
| f90a5139a2 | |||
| a2bc2e14b9 | |||
| f427345ab1 | |||
| 6e221334cd | |||
| 53bc30dd45 | |||
| eacf5e34eb | |||
| 4c05f7856a | |||
| bbd3572044 | |||
| f948778322 | |||
| 4684ea2fe8 | |||
| b64f835ea7 | |||
| 880c0fdd36 | |||
| c36f1c3160 | |||
| 0a08d41961 | |||
| e185084a5d | |||
| b21729225a | |||
| 8a812e4e14 | |||
| bf92e746c0 | |||
| b785a155d6 | |||
| d486f0e846 | |||
| 3351270627 | |||
| 4520e1221a | |||
| 618260409f | |||
| dadd55fb36 | |||
| 1b6c7ea74e | |||
| b41f809a4e | |||
| 0f55c17e17 | |||
| 5058d27f12 | |||
| 748c1b3ec7 | |||
| 523507034f | |||
| 46c751e970 | |||
| bc1d28c888 | |||
| af378c1dd1 | |||
| 6ba4c5395f | |||
| c1e4529541 | |||
| d29d97b616 | |||
| 7d4a257c7f | |||
| 141cd52d56 | |||
| f72b28c75b | |||
| ada8109d5b | |||
| b34acbdcbc | |||
| 63f767ef15 | |||
| d1b2a1a957 | |||
| 01782c220e | |||
| d63a498c3b | |||
| 6a4aad43dc | |||
| ddd8bd53ed | |||
| 9f7b2cf2dc | |||
| 895c4b704b | |||
| 636feba552 | |||
| 79dc7df03e | |||
| 6031ecbd23 | |||
| fdd003d8e2 | |||
| 172acc98b9 | |||
| 5ae3c3a56b | |||
| 21bc59ab24 | |||
| 50a749e909 | |||
| d9075be494 | |||
| b135b6e905 | |||
| 14a0d21d2e | |||
| ebf581e85f | |||
| e550163b9f | |||
| 20f0cbc88f | |||
| d72a24b790 | |||
| d3cda804e7 | |||
| 07eac4d65a | |||
| c079cae3d4 | |||
| c7bfb8b22a | |||
| 67d070749a | |||
| 9c357bda3f | |||
| 3f7c3511dc | |||
| 7d6f30e89b | |||
| 6d2e19f746 | |||
| 2a7f43a73b | |||
| b978334d71 | |||
| e5f232f76b | |||
| 3003ff4947 | |||
| 5ffa603244 | |||
| 0eeee618cf | |||
| 93f1a14cab | |||
| 13d73d9303 | |||
| ba352aea29 | |||
| 6fac1369d0 | |||
| 1093f9d615 | |||
| 81780882b8 | |||
| ebc7bedeb7 |
@@ -0,0 +1,52 @@
|
||||
name: Benchmarking tests
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: "30 1 1,15 * *" # every 2 weeks on the 1st and the 15th of every month at 1:30 AM
|
||||
|
||||
env:
|
||||
DIFFUSERS_IS_CI: yes
|
||||
HF_HOME: /mnt/cache
|
||||
OMP_NUM_THREADS: 8
|
||||
MKL_NUM_THREADS: 8
|
||||
|
||||
jobs:
|
||||
torch_pipelines_cuda_benchmark_tests:
|
||||
name: Torch Core Pipelines CUDA Benchmarking Tests
|
||||
strategy:
|
||||
fail-fast: false
|
||||
max-parallel: 1
|
||||
runs-on: [single-gpu, nvidia-gpu, a10, ci]
|
||||
container:
|
||||
image: diffusers/diffusers-pytorch-cuda
|
||||
options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/ --gpus 0
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
- name: NVIDIA-SMI
|
||||
run: |
|
||||
nvidia-smi
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
apt-get update && apt-get install libsndfile1-dev libgl1 -y
|
||||
python -m pip install -e .[quality,test]
|
||||
python -m pip install pandas
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
- name: Diffusers Benchmarking
|
||||
env:
|
||||
HUGGING_FACE_HUB_TOKEN: ${{ secrets.DIFFUSERS_BOT_TOKEN }}
|
||||
BASE_PATH: benchmark_outputs
|
||||
run: |
|
||||
export TOTAL_GPU_MEMORY=$(python -c "import torch; print(torch.cuda.get_device_properties(0).total_memory / (1024**3))")
|
||||
cd benchmarks && mkdir ${BASE_PATH} && python run_all.py && python push_results.py
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v2
|
||||
with:
|
||||
name: benchmark_test_reports
|
||||
path: benchmarks/benchmark_outputs
|
||||
@@ -0,0 +1,170 @@
|
||||
name: Fast tests for PRs - Test Fetcher
|
||||
|
||||
on: workflow_dispatch
|
||||
|
||||
env:
|
||||
DIFFUSERS_IS_CI: yes
|
||||
OMP_NUM_THREADS: 4
|
||||
MKL_NUM_THREADS: 4
|
||||
PYTEST_TIMEOUT: 60
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
setup_pr_tests:
|
||||
name: Setup PR Tests
|
||||
runs-on: docker-cpu
|
||||
container:
|
||||
image: diffusers/diffusers-pytorch-cpu
|
||||
options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
outputs:
|
||||
matrix: ${{ steps.set_matrix.outputs.matrix }}
|
||||
test_map: ${{ steps.set_matrix.outputs.test_map }}
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
apt-get update && apt-get install libsndfile1-dev libgl1 -y
|
||||
python -m pip install -e .[quality,test]
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
echo $(git --version)
|
||||
- name: Fetch Tests
|
||||
run: |
|
||||
python utils/tests_fetcher.py | tee test_preparation.txt
|
||||
- name: Report fetched tests
|
||||
uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: test_fetched
|
||||
path: test_preparation.txt
|
||||
- id: set_matrix
|
||||
name: Create Test Matrix
|
||||
# The `keys` is used as GitHub actions matrix for jobs, i.e. `models`, `pipelines`, etc.
|
||||
# The `test_map` is used to get the actual identified test files under each key.
|
||||
# If no test to run (so no `test_map.json` file), create a dummy map (empty matrix will fail)
|
||||
run: |
|
||||
if [ -f test_map.json ]; then
|
||||
keys=$(python3 -c 'import json; fp = open("test_map.json"); test_map = json.load(fp); fp.close(); d = list(test_map.keys()); print(json.dumps(d))')
|
||||
test_map=$(python3 -c 'import json; fp = open("test_map.json"); test_map = json.load(fp); fp.close(); print(json.dumps(test_map))')
|
||||
else
|
||||
keys=$(python3 -c 'keys = ["dummy"]; print(keys)')
|
||||
test_map=$(python3 -c 'test_map = {"dummy": []}; print(test_map)')
|
||||
fi
|
||||
echo $keys
|
||||
echo $test_map
|
||||
echo "matrix=$keys" >> $GITHUB_OUTPUT
|
||||
echo "test_map=$test_map" >> $GITHUB_OUTPUT
|
||||
|
||||
run_pr_tests:
|
||||
name: Run PR Tests
|
||||
needs: setup_pr_tests
|
||||
if: contains(fromJson(needs.setup_pr_tests.outputs.matrix), 'dummy') != true
|
||||
strategy:
|
||||
fail-fast: false
|
||||
max-parallel: 2
|
||||
matrix:
|
||||
modules: ${{ fromJson(needs.setup_pr_tests.outputs.matrix) }}
|
||||
runs-on: docker-cpu
|
||||
container:
|
||||
image: diffusers/diffusers-pytorch-cpu
|
||||
options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
apt-get update && apt-get install libsndfile1-dev libgl1 -y
|
||||
python -m pip install -e .[quality,test]
|
||||
python -m pip install accelerate
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
|
||||
- name: Run all selected tests on CPU
|
||||
run: |
|
||||
python -m pytest -n 2 --dist=loadfile -v --make-reports=${{ matrix.modules }}_tests_cpu ${{ fromJson(needs.setup_pr_tests.outputs.test_map)[matrix.modules] }}
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
continue-on-error: true
|
||||
run: |
|
||||
cat reports/${{ matrix.modules }}_tests_cpu_stats.txt
|
||||
cat reports/${{ matrix.modules }}_tests_cpu_failures_short.txt
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: ${{ matrix.modules }}_test_reports
|
||||
path: reports
|
||||
|
||||
run_staging_tests:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
config:
|
||||
- name: Hub tests for models, schedulers, and pipelines
|
||||
framework: hub_tests_pytorch
|
||||
runner: docker-cpu
|
||||
image: diffusers/diffusers-pytorch-cpu
|
||||
report: torch_hub
|
||||
|
||||
name: ${{ matrix.config.name }}
|
||||
runs-on: ${{ matrix.config.runner }}
|
||||
container:
|
||||
image: ${{ matrix.config.image }}
|
||||
options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/
|
||||
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
apt-get update && apt-get install libsndfile1-dev libgl1 -y
|
||||
python -m pip install -e .[quality,test]
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
|
||||
- name: Run Hub tests for models, schedulers, and pipelines on a staging env
|
||||
if: ${{ matrix.config.framework == 'hub_tests_pytorch' }}
|
||||
run: |
|
||||
HUGGINGFACE_CO_STAGING=true python -m pytest \
|
||||
-m "is_staging_test" \
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
tests
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
run: cat reports/tests_${{ matrix.config.report }}_failures_short.txt
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v2
|
||||
with:
|
||||
name: pr_${{ matrix.config.report }}_test_reports
|
||||
path: reports
|
||||
@@ -113,9 +113,10 @@ jobs:
|
||||
- name: Run example PyTorch CPU tests
|
||||
if: ${{ matrix.config.framework == 'pytorch_examples' }}
|
||||
run: |
|
||||
python -m pip install peft
|
||||
python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile \
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
examples/test_examples.py
|
||||
examples
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
|
||||
@@ -5,6 +5,10 @@ on:
|
||||
branches:
|
||||
- main
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
env:
|
||||
DIFFUSERS_IS_CI: yes
|
||||
HF_HOME: /mnt/cache
|
||||
@@ -96,7 +100,7 @@ jobs:
|
||||
run: |
|
||||
python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile \
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
examples/test_examples.py
|
||||
examples
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
|
||||
@@ -13,6 +13,10 @@ env:
|
||||
PYTEST_TIMEOUT: 600
|
||||
RUN_SLOW: no
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
run_fast_tests_apple_m1:
|
||||
name: Fast PyTorch MPS tests on MacOS
|
||||
|
||||
+1
-1
@@ -355,7 +355,7 @@ You will need basic `git` proficiency to be able to contribute to
|
||||
manual. Type `git --help` in a shell and enjoy. If you prefer books, [Pro
|
||||
Git](https://git-scm.com/book/en/v2) is a very good reference.
|
||||
|
||||
Follow these steps to start contributing ([supported Python versions](https://github.com/huggingface/diffusers/blob/main/setup.py#L244)):
|
||||
Follow these steps to start contributing ([supported Python versions](https://github.com/huggingface/diffusers/blob/main/setup.py#L265)):
|
||||
|
||||
1. Fork the [repository](https://github.com/huggingface/diffusers) by
|
||||
clicking on the 'Fork' button on the repository's page. This creates a copy of the code
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
# make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!)
|
||||
export PYTHONPATH = src
|
||||
|
||||
check_dirs := examples scripts src tests utils
|
||||
check_dirs := examples scripts src tests utils benchmarks
|
||||
|
||||
modified_only_fixup:
|
||||
$(eval modified_py_files := $(shell python utils/get_modified_files.py $(check_dirs)))
|
||||
@@ -41,7 +41,7 @@ repo-consistency:
|
||||
|
||||
quality:
|
||||
ruff check $(check_dirs) setup.py
|
||||
ruff format --check $(check_dirs) setup.py
|
||||
ruff format --check $(check_dirs) setup.py
|
||||
python utils/check_doc_toc.py
|
||||
|
||||
# Format source code automatically and check is there are any problems left that need manual fixing
|
||||
|
||||
+1
-1
@@ -82,7 +82,7 @@ Models are designed as configurable toolboxes that are natural extensions of [Py
|
||||
The following design principles are followed:
|
||||
- Models correspond to **a type of model architecture**. *E.g.* the [`UNet2DConditionModel`] class is used for all UNet variations that expect 2D image inputs and are conditioned on some context.
|
||||
- All models can be found in [`src/diffusers/models`](https://github.com/huggingface/diffusers/tree/main/src/diffusers/models) and every model architecture shall be defined in its file, e.g. [`unet_2d_condition.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py), [`transformer_2d.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformer_2d.py), etc...
|
||||
- Models **do not** follow the single-file policy and should make use of smaller model building blocks, such as [`attention.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py), [`resnet.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py), [`embeddings.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/embeddings.py), etc... **Note**: This is in stark contrast to Transformers' modelling files and shows that models do not really follow the single-file policy.
|
||||
- Models **do not** follow the single-file policy and should make use of smaller model building blocks, such as [`attention.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py), [`resnet.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py), [`embeddings.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/embeddings.py), etc... **Note**: This is in stark contrast to Transformers' modeling files and shows that models do not really follow the single-file policy.
|
||||
- Models intend to expose complexity, just like PyTorch's `Module` class, and give clear error messages.
|
||||
- Models all inherit from `ModelMixin` and `ConfigMixin`.
|
||||
- Models can be optimized for performance when it doesn’t demand major code changes, keep backward compatibility, and give significant memory or compute gain.
|
||||
|
||||
@@ -77,7 +77,7 @@ Please refer to the [How to use Stable Diffusion in Apple Silicon](https://huggi
|
||||
|
||||
## Quickstart
|
||||
|
||||
Generating outputs is super easy with 🤗 Diffusers. To generate an image from text, use the `from_pretrained` method to load any pretrained diffusion model (browse the [Hub](https://huggingface.co/models?library=diffusers&sort=downloads) for 15000+ checkpoints):
|
||||
Generating outputs is super easy with 🤗 Diffusers. To generate an image from text, use the `from_pretrained` method to load any pretrained diffusion model (browse the [Hub](https://huggingface.co/models?library=diffusers&sort=downloads) for 16000+ checkpoints):
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline
|
||||
@@ -219,7 +219,7 @@ Also, say 👋 in our public Discord channel <a href="https://discord.gg/G7tWnz9
|
||||
- https://github.com/deep-floyd/IF
|
||||
- https://github.com/bentoml/BentoML
|
||||
- https://github.com/bmaltais/kohya_ss
|
||||
- +6000 other amazing GitHub repositories 💪
|
||||
- +7000 other amazing GitHub repositories 💪
|
||||
|
||||
Thank you for using us ❤️.
|
||||
|
||||
|
||||
@@ -0,0 +1,297 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import (
|
||||
AutoPipelineForImage2Image,
|
||||
AutoPipelineForInpainting,
|
||||
AutoPipelineForText2Image,
|
||||
ControlNetModel,
|
||||
LCMScheduler,
|
||||
StableDiffusionAdapterPipeline,
|
||||
StableDiffusionControlNetPipeline,
|
||||
StableDiffusionXLAdapterPipeline,
|
||||
StableDiffusionXLControlNetPipeline,
|
||||
T2IAdapter,
|
||||
WuerstchenCombinedPipeline,
|
||||
)
|
||||
from diffusers.utils import load_image
|
||||
|
||||
|
||||
sys.path.append(".")
|
||||
|
||||
from utils import ( # noqa: E402
|
||||
BASE_PATH,
|
||||
PROMPT,
|
||||
BenchmarkInfo,
|
||||
benchmark_fn,
|
||||
bytes_to_giga_bytes,
|
||||
flush,
|
||||
generate_csv_dict,
|
||||
write_to_csv,
|
||||
)
|
||||
|
||||
|
||||
RESOLUTION_MAPPING = {
|
||||
"runwayml/stable-diffusion-v1-5": (512, 512),
|
||||
"lllyasviel/sd-controlnet-canny": (512, 512),
|
||||
"diffusers/controlnet-canny-sdxl-1.0": (1024, 1024),
|
||||
"TencentARC/t2iadapter_canny_sd14v1": (512, 512),
|
||||
"TencentARC/t2i-adapter-canny-sdxl-1.0": (1024, 1024),
|
||||
"stabilityai/stable-diffusion-2-1": (768, 768),
|
||||
"stabilityai/stable-diffusion-xl-base-1.0": (1024, 1024),
|
||||
"stabilityai/stable-diffusion-xl-refiner-1.0": (1024, 1024),
|
||||
"stabilityai/sdxl-turbo": (512, 512),
|
||||
}
|
||||
|
||||
|
||||
class BaseBenchmak:
|
||||
pipeline_class = None
|
||||
|
||||
def __init__(self, args):
|
||||
super().__init__()
|
||||
|
||||
def run_inference(self, args):
|
||||
raise NotImplementedError
|
||||
|
||||
def benchmark(self, args):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_result_filepath(self, args):
|
||||
pipeline_class_name = str(self.pipe.__class__.__name__)
|
||||
name = (
|
||||
args.ckpt.replace("/", "_")
|
||||
+ "_"
|
||||
+ pipeline_class_name
|
||||
+ f"-bs@{args.batch_size}-steps@{args.num_inference_steps}-mco@{args.model_cpu_offload}-compile@{args.run_compile}.csv"
|
||||
)
|
||||
filepath = os.path.join(BASE_PATH, name)
|
||||
return filepath
|
||||
|
||||
|
||||
class TextToImageBenchmark(BaseBenchmak):
|
||||
pipeline_class = AutoPipelineForText2Image
|
||||
|
||||
def __init__(self, args):
|
||||
pipe = self.pipeline_class.from_pretrained(args.ckpt, torch_dtype=torch.float16)
|
||||
pipe = pipe.to("cuda")
|
||||
|
||||
if args.run_compile:
|
||||
if not isinstance(pipe, WuerstchenCombinedPipeline):
|
||||
pipe.unet.to(memory_format=torch.channels_last)
|
||||
print("Run torch compile")
|
||||
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
|
||||
|
||||
if hasattr(pipe, "movq") and getattr(pipe, "movq", None) is not None:
|
||||
pipe.movq.to(memory_format=torch.channels_last)
|
||||
pipe.movq = torch.compile(pipe.movq, mode="reduce-overhead", fullgraph=True)
|
||||
else:
|
||||
print("Run torch compile")
|
||||
pipe.decoder = torch.compile(pipe.decoder, mode="reduce-overhead", fullgraph=True)
|
||||
pipe.vqgan = torch.compile(pipe.vqgan, mode="reduce-overhead", fullgraph=True)
|
||||
|
||||
pipe.set_progress_bar_config(disable=True)
|
||||
self.pipe = pipe
|
||||
|
||||
def run_inference(self, pipe, args):
|
||||
_ = pipe(
|
||||
prompt=PROMPT,
|
||||
num_inference_steps=args.num_inference_steps,
|
||||
num_images_per_prompt=args.batch_size,
|
||||
)
|
||||
|
||||
def benchmark(self, args):
|
||||
flush()
|
||||
|
||||
print(f"[INFO] {self.pipe.__class__.__name__}: Running benchmark with: {vars(args)}\n")
|
||||
|
||||
time = benchmark_fn(self.run_inference, self.pipe, args) # in seconds.
|
||||
memory = bytes_to_giga_bytes(torch.cuda.max_memory_allocated()) # in GBs.
|
||||
benchmark_info = BenchmarkInfo(time=time, memory=memory)
|
||||
|
||||
pipeline_class_name = str(self.pipe.__class__.__name__)
|
||||
flush()
|
||||
csv_dict = generate_csv_dict(
|
||||
pipeline_cls=pipeline_class_name, ckpt=args.ckpt, args=args, benchmark_info=benchmark_info
|
||||
)
|
||||
filepath = self.get_result_filepath(args)
|
||||
write_to_csv(filepath, csv_dict)
|
||||
print(f"Logs written to: {filepath}")
|
||||
flush()
|
||||
|
||||
|
||||
class TurboTextToImageBenchmark(TextToImageBenchmark):
|
||||
def __init__(self, args):
|
||||
super().__init__(args)
|
||||
|
||||
def run_inference(self, pipe, args):
|
||||
_ = pipe(
|
||||
prompt=PROMPT,
|
||||
num_inference_steps=args.num_inference_steps,
|
||||
num_images_per_prompt=args.batch_size,
|
||||
guidance_scale=0.0,
|
||||
)
|
||||
|
||||
|
||||
class LCMLoRATextToImageBenchmark(TextToImageBenchmark):
|
||||
lora_id = "latent-consistency/lcm-lora-sdxl"
|
||||
|
||||
def __init__(self, args):
|
||||
super().__init__(args)
|
||||
self.pipe.load_lora_weights(self.lora_id)
|
||||
self.pipe.fuse_lora()
|
||||
self.pipe.scheduler = LCMScheduler.from_config(self.pipe.scheduler.config)
|
||||
|
||||
def get_result_filepath(self, args):
|
||||
pipeline_class_name = str(self.pipe.__class__.__name__)
|
||||
name = (
|
||||
self.lora_id.replace("/", "_")
|
||||
+ "_"
|
||||
+ pipeline_class_name
|
||||
+ f"-bs@{args.batch_size}-steps@{args.num_inference_steps}-mco@{args.model_cpu_offload}-compile@{args.run_compile}.csv"
|
||||
)
|
||||
filepath = os.path.join(BASE_PATH, name)
|
||||
return filepath
|
||||
|
||||
def run_inference(self, pipe, args):
|
||||
_ = pipe(
|
||||
prompt=PROMPT,
|
||||
num_inference_steps=args.num_inference_steps,
|
||||
num_images_per_prompt=args.batch_size,
|
||||
guidance_scale=1.0,
|
||||
)
|
||||
|
||||
|
||||
class ImageToImageBenchmark(TextToImageBenchmark):
|
||||
pipeline_class = AutoPipelineForImage2Image
|
||||
url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/benchmarking/1665_Girl_with_a_Pearl_Earring.jpg"
|
||||
image = load_image(url).convert("RGB")
|
||||
|
||||
def __init__(self, args):
|
||||
super().__init__(args)
|
||||
self.image = self.image.resize(RESOLUTION_MAPPING[args.ckpt])
|
||||
|
||||
def run_inference(self, pipe, args):
|
||||
_ = pipe(
|
||||
prompt=PROMPT,
|
||||
image=self.image,
|
||||
num_inference_steps=args.num_inference_steps,
|
||||
num_images_per_prompt=args.batch_size,
|
||||
)
|
||||
|
||||
|
||||
class TurboImageToImageBenchmark(ImageToImageBenchmark):
|
||||
def __init__(self, args):
|
||||
super().__init__(args)
|
||||
|
||||
def run_inference(self, pipe, args):
|
||||
_ = pipe(
|
||||
prompt=PROMPT,
|
||||
image=self.image,
|
||||
num_inference_steps=args.num_inference_steps,
|
||||
num_images_per_prompt=args.batch_size,
|
||||
guidance_scale=0.0,
|
||||
strength=0.5,
|
||||
)
|
||||
|
||||
|
||||
class InpaintingBenchmark(ImageToImageBenchmark):
|
||||
pipeline_class = AutoPipelineForInpainting
|
||||
mask_url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/benchmarking/overture-creations-5sI6fQgYIuo_mask.png"
|
||||
mask = load_image(mask_url).convert("RGB")
|
||||
|
||||
def __init__(self, args):
|
||||
super().__init__(args)
|
||||
self.image = self.image.resize(RESOLUTION_MAPPING[args.ckpt])
|
||||
self.mask = self.mask.resize(RESOLUTION_MAPPING[args.ckpt])
|
||||
|
||||
def run_inference(self, pipe, args):
|
||||
_ = pipe(
|
||||
prompt=PROMPT,
|
||||
image=self.image,
|
||||
mask_image=self.mask,
|
||||
num_inference_steps=args.num_inference_steps,
|
||||
num_images_per_prompt=args.batch_size,
|
||||
)
|
||||
|
||||
|
||||
class ControlNetBenchmark(TextToImageBenchmark):
|
||||
pipeline_class = StableDiffusionControlNetPipeline
|
||||
aux_network_class = ControlNetModel
|
||||
root_ckpt = "runwayml/stable-diffusion-v1-5"
|
||||
|
||||
url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/benchmarking/canny_image_condition.png"
|
||||
image = load_image(url).convert("RGB")
|
||||
|
||||
def __init__(self, args):
|
||||
aux_network = self.aux_network_class.from_pretrained(args.ckpt, torch_dtype=torch.float16)
|
||||
pipe = self.pipeline_class.from_pretrained(self.root_ckpt, controlnet=aux_network, torch_dtype=torch.float16)
|
||||
pipe = pipe.to("cuda")
|
||||
|
||||
pipe.set_progress_bar_config(disable=True)
|
||||
self.pipe = pipe
|
||||
|
||||
if args.run_compile:
|
||||
pipe.unet.to(memory_format=torch.channels_last)
|
||||
pipe.controlnet.to(memory_format=torch.channels_last)
|
||||
|
||||
print("Run torch compile")
|
||||
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
|
||||
pipe.controlnet = torch.compile(pipe.controlnet, mode="reduce-overhead", fullgraph=True)
|
||||
|
||||
self.image = self.image.resize(RESOLUTION_MAPPING[args.ckpt])
|
||||
|
||||
def run_inference(self, pipe, args):
|
||||
_ = pipe(
|
||||
prompt=PROMPT,
|
||||
image=self.image,
|
||||
num_inference_steps=args.num_inference_steps,
|
||||
num_images_per_prompt=args.batch_size,
|
||||
)
|
||||
|
||||
|
||||
class ControlNetSDXLBenchmark(ControlNetBenchmark):
|
||||
pipeline_class = StableDiffusionXLControlNetPipeline
|
||||
root_ckpt = "stabilityai/stable-diffusion-xl-base-1.0"
|
||||
|
||||
def __init__(self, args):
|
||||
super().__init__(args)
|
||||
|
||||
|
||||
class T2IAdapterBenchmark(ControlNetBenchmark):
|
||||
pipeline_class = StableDiffusionAdapterPipeline
|
||||
aux_network_class = T2IAdapter
|
||||
root_ckpt = "CompVis/stable-diffusion-v1-4"
|
||||
|
||||
url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/benchmarking/canny_for_adapter.png"
|
||||
image = load_image(url).convert("L")
|
||||
|
||||
def __init__(self, args):
|
||||
aux_network = self.aux_network_class.from_pretrained(args.ckpt, torch_dtype=torch.float16)
|
||||
pipe = self.pipeline_class.from_pretrained(self.root_ckpt, adapter=aux_network, torch_dtype=torch.float16)
|
||||
pipe = pipe.to("cuda")
|
||||
|
||||
pipe.set_progress_bar_config(disable=True)
|
||||
self.pipe = pipe
|
||||
|
||||
if args.run_compile:
|
||||
pipe.unet.to(memory_format=torch.channels_last)
|
||||
pipe.adapter.to(memory_format=torch.channels_last)
|
||||
|
||||
print("Run torch compile")
|
||||
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
|
||||
pipe.adapter = torch.compile(pipe.adapter, mode="reduce-overhead", fullgraph=True)
|
||||
|
||||
self.image = self.image.resize(RESOLUTION_MAPPING[args.ckpt])
|
||||
|
||||
|
||||
class T2IAdapterSDXLBenchmark(T2IAdapterBenchmark):
|
||||
pipeline_class = StableDiffusionXLAdapterPipeline
|
||||
root_ckpt = "stabilityai/stable-diffusion-xl-base-1.0"
|
||||
|
||||
url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/benchmarking/canny_for_adapter_sdxl.png"
|
||||
image = load_image(url)
|
||||
|
||||
def __init__(self, args):
|
||||
super().__init__(args)
|
||||
@@ -0,0 +1,26 @@
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
|
||||
sys.path.append(".")
|
||||
from base_classes import ControlNetBenchmark, ControlNetSDXLBenchmark # noqa: E402
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--ckpt",
|
||||
type=str,
|
||||
default="lllyasviel/sd-controlnet-canny",
|
||||
choices=["lllyasviel/sd-controlnet-canny", "diffusers/controlnet-canny-sdxl-1.0"],
|
||||
)
|
||||
parser.add_argument("--batch_size", type=int, default=1)
|
||||
parser.add_argument("--num_inference_steps", type=int, default=50)
|
||||
parser.add_argument("--model_cpu_offload", action="store_true")
|
||||
parser.add_argument("--run_compile", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
benchmark_pipe = (
|
||||
ControlNetBenchmark(args) if args.ckpt == "lllyasviel/sd-controlnet-canny" else ControlNetSDXLBenchmark(args)
|
||||
)
|
||||
benchmark_pipe.benchmark(args)
|
||||
@@ -0,0 +1,29 @@
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
|
||||
sys.path.append(".")
|
||||
from base_classes import ImageToImageBenchmark, TurboImageToImageBenchmark # noqa: E402
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--ckpt",
|
||||
type=str,
|
||||
default="runwayml/stable-diffusion-v1-5",
|
||||
choices=[
|
||||
"runwayml/stable-diffusion-v1-5",
|
||||
"stabilityai/stable-diffusion-2-1",
|
||||
"stabilityai/stable-diffusion-xl-refiner-1.0",
|
||||
"stabilityai/sdxl-turbo",
|
||||
],
|
||||
)
|
||||
parser.add_argument("--batch_size", type=int, default=1)
|
||||
parser.add_argument("--num_inference_steps", type=int, default=50)
|
||||
parser.add_argument("--model_cpu_offload", action="store_true")
|
||||
parser.add_argument("--run_compile", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
benchmark_pipe = ImageToImageBenchmark(args) if "turbo" not in args.ckpt else TurboImageToImageBenchmark(args)
|
||||
benchmark_pipe.benchmark(args)
|
||||
@@ -0,0 +1,28 @@
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
|
||||
sys.path.append(".")
|
||||
from base_classes import InpaintingBenchmark # noqa: E402
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--ckpt",
|
||||
type=str,
|
||||
default="runwayml/stable-diffusion-v1-5",
|
||||
choices=[
|
||||
"runwayml/stable-diffusion-v1-5",
|
||||
"stabilityai/stable-diffusion-2-1",
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
],
|
||||
)
|
||||
parser.add_argument("--batch_size", type=int, default=1)
|
||||
parser.add_argument("--num_inference_steps", type=int, default=50)
|
||||
parser.add_argument("--model_cpu_offload", action="store_true")
|
||||
parser.add_argument("--run_compile", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
benchmark_pipe = InpaintingBenchmark(args)
|
||||
benchmark_pipe.benchmark(args)
|
||||
@@ -0,0 +1,28 @@
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
|
||||
sys.path.append(".")
|
||||
from base_classes import T2IAdapterBenchmark, T2IAdapterSDXLBenchmark # noqa: E402
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--ckpt",
|
||||
type=str,
|
||||
default="TencentARC/t2iadapter_canny_sd14v1",
|
||||
choices=["TencentARC/t2iadapter_canny_sd14v1", "TencentARC/t2i-adapter-canny-sdxl-1.0"],
|
||||
)
|
||||
parser.add_argument("--batch_size", type=int, default=1)
|
||||
parser.add_argument("--num_inference_steps", type=int, default=50)
|
||||
parser.add_argument("--model_cpu_offload", action="store_true")
|
||||
parser.add_argument("--run_compile", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
benchmark_pipe = (
|
||||
T2IAdapterBenchmark(args)
|
||||
if args.ckpt == "TencentARC/t2iadapter_canny_sd14v1"
|
||||
else T2IAdapterSDXLBenchmark(args)
|
||||
)
|
||||
benchmark_pipe.benchmark(args)
|
||||
@@ -0,0 +1,23 @@
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
|
||||
sys.path.append(".")
|
||||
from base_classes import LCMLoRATextToImageBenchmark # noqa: E402
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--ckpt",
|
||||
type=str,
|
||||
default="stabilityai/stable-diffusion-xl-base-1.0",
|
||||
)
|
||||
parser.add_argument("--batch_size", type=int, default=1)
|
||||
parser.add_argument("--num_inference_steps", type=int, default=4)
|
||||
parser.add_argument("--model_cpu_offload", action="store_true")
|
||||
parser.add_argument("--run_compile", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
benchmark_pipe = LCMLoRATextToImageBenchmark(args)
|
||||
benchmark_pipe.benchmark(args)
|
||||
@@ -0,0 +1,40 @@
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
|
||||
sys.path.append(".")
|
||||
from base_classes import TextToImageBenchmark, TurboTextToImageBenchmark # noqa: E402
|
||||
|
||||
|
||||
ALL_T2I_CKPTS = [
|
||||
"runwayml/stable-diffusion-v1-5",
|
||||
"segmind/SSD-1B",
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
"kandinsky-community/kandinsky-2-2-decoder",
|
||||
"warp-ai/wuerstchen",
|
||||
"stabilityai/sdxl-turbo",
|
||||
]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--ckpt",
|
||||
type=str,
|
||||
default="runwayml/stable-diffusion-v1-5",
|
||||
choices=ALL_T2I_CKPTS,
|
||||
)
|
||||
parser.add_argument("--batch_size", type=int, default=1)
|
||||
parser.add_argument("--num_inference_steps", type=int, default=50)
|
||||
parser.add_argument("--model_cpu_offload", action="store_true")
|
||||
parser.add_argument("--run_compile", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
benchmark_cls = None
|
||||
if "turbo" in args.ckpt:
|
||||
benchmark_cls = TurboTextToImageBenchmark
|
||||
else:
|
||||
benchmark_cls = TextToImageBenchmark
|
||||
|
||||
benchmark_pipe = benchmark_cls(args)
|
||||
benchmark_pipe.benchmark(args)
|
||||
@@ -0,0 +1,72 @@
|
||||
import glob
|
||||
import sys
|
||||
|
||||
import pandas as pd
|
||||
from huggingface_hub import hf_hub_download, upload_file
|
||||
from huggingface_hub.utils._errors import EntryNotFoundError
|
||||
|
||||
|
||||
sys.path.append(".")
|
||||
from utils import BASE_PATH, FINAL_CSV_FILE, GITHUB_SHA, REPO_ID, collate_csv # noqa: E402
|
||||
|
||||
|
||||
def has_previous_benchmark() -> str:
|
||||
csv_path = None
|
||||
try:
|
||||
csv_path = hf_hub_download(repo_id=REPO_ID, repo_type="dataset", filename=FINAL_CSV_FILE)
|
||||
except EntryNotFoundError:
|
||||
csv_path = None
|
||||
return csv_path
|
||||
|
||||
|
||||
def filter_float(value):
|
||||
if isinstance(value, str):
|
||||
return float(value.split()[0])
|
||||
return value
|
||||
|
||||
|
||||
def push_to_hf_dataset():
|
||||
all_csvs = sorted(glob.glob(f"{BASE_PATH}/*.csv"))
|
||||
collate_csv(all_csvs, FINAL_CSV_FILE)
|
||||
|
||||
# If there's an existing benchmark file, we should report the changes.
|
||||
csv_path = has_previous_benchmark()
|
||||
if csv_path is not None:
|
||||
current_results = pd.read_csv(FINAL_CSV_FILE)
|
||||
previous_results = pd.read_csv(csv_path)
|
||||
|
||||
numeric_columns = current_results.select_dtypes(include=["float64", "int64"]).columns
|
||||
numeric_columns = [
|
||||
c for c in numeric_columns if c not in ["batch_size", "num_inference_steps", "actual_gpu_memory (gbs)"]
|
||||
]
|
||||
|
||||
for column in numeric_columns:
|
||||
previous_results[column] = previous_results[column].map(lambda x: filter_float(x))
|
||||
|
||||
# Calculate the percentage change
|
||||
current_results[column] = current_results[column].astype(float)
|
||||
previous_results[column] = previous_results[column].astype(float)
|
||||
percent_change = ((current_results[column] - previous_results[column]) / previous_results[column]) * 100
|
||||
|
||||
# Format the values with '+' or '-' sign and append to original values
|
||||
current_results[column] = current_results[column].map(str) + percent_change.map(
|
||||
lambda x: f" ({'+' if x > 0 else ''}{x:.2f}%)"
|
||||
)
|
||||
# There might be newly added rows. So, filter out the NaNs.
|
||||
current_results[column] = current_results[column].map(lambda x: x.replace(" (nan%)", ""))
|
||||
|
||||
# Overwrite the current result file.
|
||||
current_results.to_csv(FINAL_CSV_FILE, index=False)
|
||||
|
||||
commit_message = f"upload from sha: {GITHUB_SHA}" if GITHUB_SHA is not None else "upload benchmark results"
|
||||
upload_file(
|
||||
repo_id=REPO_ID,
|
||||
path_in_repo=FINAL_CSV_FILE,
|
||||
path_or_fileobj=FINAL_CSV_FILE,
|
||||
repo_type="dataset",
|
||||
commit_message=commit_message,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
push_to_hf_dataset()
|
||||
@@ -0,0 +1,97 @@
|
||||
import glob
|
||||
import subprocess
|
||||
import sys
|
||||
from typing import List
|
||||
|
||||
|
||||
sys.path.append(".")
|
||||
from benchmark_text_to_image import ALL_T2I_CKPTS # noqa: E402
|
||||
|
||||
|
||||
PATTERN = "benchmark_*.py"
|
||||
|
||||
|
||||
class SubprocessCallException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
# Taken from `test_examples_utils.py`
|
||||
def run_command(command: List[str], return_stdout=False):
|
||||
"""
|
||||
Runs `command` with `subprocess.check_output` and will potentially return the `stdout`. Will also properly capture
|
||||
if an error occurred while running `command`
|
||||
"""
|
||||
try:
|
||||
output = subprocess.check_output(command, stderr=subprocess.STDOUT)
|
||||
if return_stdout:
|
||||
if hasattr(output, "decode"):
|
||||
output = output.decode("utf-8")
|
||||
return output
|
||||
except subprocess.CalledProcessError as e:
|
||||
raise SubprocessCallException(
|
||||
f"Command `{' '.join(command)}` failed with the following error:\n\n{e.output.decode()}"
|
||||
) from e
|
||||
|
||||
|
||||
def main():
|
||||
python_files = glob.glob(PATTERN)
|
||||
|
||||
for file in python_files:
|
||||
print(f"****** Running file: {file} ******")
|
||||
|
||||
# Run with canonical settings.
|
||||
if file != "benchmark_text_to_image.py":
|
||||
command = f"python {file}"
|
||||
run_command(command.split())
|
||||
|
||||
command += " --run_compile"
|
||||
run_command(command.split())
|
||||
|
||||
# Run variants.
|
||||
for file in python_files:
|
||||
if file == "benchmark_text_to_image.py":
|
||||
for ckpt in ALL_T2I_CKPTS:
|
||||
command = f"python {file} --ckpt {ckpt}"
|
||||
|
||||
if "turbo" in ckpt:
|
||||
command += " --num_inference_steps 1"
|
||||
|
||||
run_command(command.split())
|
||||
|
||||
command += " --run_compile"
|
||||
run_command(command.split())
|
||||
|
||||
elif file == "benchmark_sd_img.py":
|
||||
for ckpt in ["stabilityai/stable-diffusion-xl-refiner-1.0", "stabilityai/sdxl-turbo"]:
|
||||
command = f"python {file} --ckpt {ckpt}"
|
||||
|
||||
if ckpt == "stabilityai/sdxl-turbo":
|
||||
command += " --num_inference_steps 2"
|
||||
|
||||
run_command(command.split())
|
||||
command += " --run_compile"
|
||||
run_command(command.split())
|
||||
|
||||
elif file == "benchmark_sd_inpainting.py":
|
||||
sdxl_ckpt = "stabilityai/stable-diffusion-xl-base-1.0"
|
||||
command = f"python {file} --ckpt {sdxl_ckpt}"
|
||||
run_command(command.split())
|
||||
|
||||
command += " --run_compile"
|
||||
run_command(command.split())
|
||||
|
||||
elif file in ["benchmark_controlnet.py", "benchmark_t2i_adapter.py"]:
|
||||
sdxl_ckpt = (
|
||||
"diffusers/controlnet-canny-sdxl-1.0"
|
||||
if "controlnet" in file
|
||||
else "TencentARC/t2i-adapter-canny-sdxl-1.0"
|
||||
)
|
||||
command = f"python {file} --ckpt {sdxl_ckpt}"
|
||||
run_command(command.split())
|
||||
|
||||
command += " --run_compile"
|
||||
run_command(command.split())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,98 @@
|
||||
import argparse
|
||||
import csv
|
||||
import gc
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.benchmark as benchmark
|
||||
|
||||
|
||||
GITHUB_SHA = os.getenv("GITHUB_SHA", None)
|
||||
BENCHMARK_FIELDS = [
|
||||
"pipeline_cls",
|
||||
"ckpt_id",
|
||||
"batch_size",
|
||||
"num_inference_steps",
|
||||
"model_cpu_offload",
|
||||
"run_compile",
|
||||
"time (secs)",
|
||||
"memory (gbs)",
|
||||
"actual_gpu_memory (gbs)",
|
||||
"github_sha",
|
||||
]
|
||||
|
||||
PROMPT = "ghibli style, a fantasy landscape with castles"
|
||||
BASE_PATH = os.getenv("BASE_PATH", ".")
|
||||
TOTAL_GPU_MEMORY = float(os.getenv("TOTAL_GPU_MEMORY", torch.cuda.get_device_properties(0).total_memory / (1024**3)))
|
||||
|
||||
REPO_ID = "diffusers/benchmarks"
|
||||
FINAL_CSV_FILE = "collated_results.csv"
|
||||
|
||||
|
||||
@dataclass
|
||||
class BenchmarkInfo:
|
||||
time: float
|
||||
memory: float
|
||||
|
||||
|
||||
def flush():
|
||||
"""Wipes off memory."""
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_max_memory_allocated()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
|
||||
def bytes_to_giga_bytes(bytes):
|
||||
return f"{(bytes / 1024 / 1024 / 1024):.3f}"
|
||||
|
||||
|
||||
def benchmark_fn(f, *args, **kwargs):
|
||||
t0 = benchmark.Timer(
|
||||
stmt="f(*args, **kwargs)",
|
||||
globals={"args": args, "kwargs": kwargs, "f": f},
|
||||
num_threads=torch.get_num_threads(),
|
||||
)
|
||||
return f"{(t0.blocked_autorange().mean):.3f}"
|
||||
|
||||
|
||||
def generate_csv_dict(
|
||||
pipeline_cls: str, ckpt: str, args: argparse.Namespace, benchmark_info: BenchmarkInfo
|
||||
) -> Dict[str, Union[str, bool, float]]:
|
||||
"""Packs benchmarking data into a dictionary for latter serialization."""
|
||||
data_dict = {
|
||||
"pipeline_cls": pipeline_cls,
|
||||
"ckpt_id": ckpt,
|
||||
"batch_size": args.batch_size,
|
||||
"num_inference_steps": args.num_inference_steps,
|
||||
"model_cpu_offload": args.model_cpu_offload,
|
||||
"run_compile": args.run_compile,
|
||||
"time (secs)": benchmark_info.time,
|
||||
"memory (gbs)": benchmark_info.memory,
|
||||
"actual_gpu_memory (gbs)": f"{(TOTAL_GPU_MEMORY):.3f}",
|
||||
"github_sha": GITHUB_SHA,
|
||||
}
|
||||
return data_dict
|
||||
|
||||
|
||||
def write_to_csv(file_name: str, data_dict: Dict[str, Union[str, bool, float]]):
|
||||
"""Serializes a dictionary into a CSV file."""
|
||||
with open(file_name, mode="w", newline="") as csvfile:
|
||||
writer = csv.DictWriter(csvfile, fieldnames=BENCHMARK_FIELDS)
|
||||
writer.writeheader()
|
||||
writer.writerow(data_dict)
|
||||
|
||||
|
||||
def collate_csv(input_files: List[str], output_file: str):
|
||||
"""Collates multiple identically structured CSVs into a single CSV file."""
|
||||
with open(output_file, mode="w", newline="") as outfile:
|
||||
writer = csv.DictWriter(outfile, fieldnames=BENCHMARK_FIELDS)
|
||||
writer.writeheader()
|
||||
|
||||
for file in input_files:
|
||||
with open(file, mode="r") as infile:
|
||||
reader = csv.DictReader(infile)
|
||||
for row in reader:
|
||||
writer.writerow(row)
|
||||
@@ -72,6 +72,8 @@
|
||||
title: Overview
|
||||
- local: using-diffusers/sdxl
|
||||
title: Stable Diffusion XL
|
||||
- local: using-diffusers/sdxl_turbo
|
||||
title: SDXL Turbo
|
||||
- local: using-diffusers/kandinsky
|
||||
title: Kandinsky
|
||||
- local: using-diffusers/controlnet
|
||||
@@ -94,6 +96,8 @@
|
||||
title: Latent Consistency Model-LoRA
|
||||
- local: using-diffusers/inference_with_lcm
|
||||
title: Latent Consistency Model
|
||||
- local: using-diffusers/svd
|
||||
title: Stable Video Diffusion
|
||||
title: Specific pipeline examples
|
||||
- sections:
|
||||
- local: training/overview
|
||||
@@ -129,6 +133,8 @@
|
||||
title: LoRA
|
||||
- local: training/custom_diffusion
|
||||
title: Custom Diffusion
|
||||
- local: training/lcm_distill
|
||||
title: Latent Consistency Distillation
|
||||
- local: training/ddpo
|
||||
title: Reinforcement learning training with DDPO
|
||||
title: Methods
|
||||
@@ -258,6 +264,10 @@
|
||||
title: ControlNet
|
||||
- local: api/pipelines/controlnet_sdxl
|
||||
title: ControlNet with Stable Diffusion XL
|
||||
- local: api/pipelines/controlnetxs
|
||||
title: ControlNet-XS
|
||||
- local: api/pipelines/controlnetxs_sdxl
|
||||
title: ControlNet-XS with Stable Diffusion XL
|
||||
- local: api/pipelines/cycle_diffusion
|
||||
title: Cycle Diffusion
|
||||
- local: api/pipelines/dance_diffusion
|
||||
@@ -278,6 +288,8 @@
|
||||
title: Kandinsky 2.1
|
||||
- local: api/pipelines/kandinsky_v22
|
||||
title: Kandinsky 2.2
|
||||
- local: api/pipelines/kandinsky3
|
||||
title: Kandinsky 3
|
||||
- local: api/pipelines/latent_consistency_models
|
||||
title: Latent Consistency Models
|
||||
- local: api/pipelines/latent_diffusion
|
||||
@@ -327,12 +339,14 @@
|
||||
title: Stable Diffusion 2
|
||||
- local: api/pipelines/stable_diffusion/stable_diffusion_xl
|
||||
title: Stable Diffusion XL
|
||||
- local: api/pipelines/stable_diffusion/sdxl_turbo
|
||||
title: SDXL Turbo
|
||||
- local: api/pipelines/stable_diffusion/latent_upscale
|
||||
title: Latent upscaler
|
||||
- local: api/pipelines/stable_diffusion/upscale
|
||||
title: Super-resolution
|
||||
- local: api/pipelines/stable_diffusion/ldm3d_diffusion
|
||||
title: LDM3D Text-to-(RGB, Depth)
|
||||
title: LDM3D Text-to-(RGB, Depth), Text-to-(RGB-pano, Depth-pano), LDM3D Upscaler
|
||||
- local: api/pipelines/stable_diffusion/adapter
|
||||
title: Stable Diffusion T2I-Adapter
|
||||
- local: api/pipelines/stable_diffusion/gligen
|
||||
|
||||
@@ -20,6 +20,9 @@ An attention processor is a class for applying different types of attention mech
|
||||
## AttnProcessor2_0
|
||||
[[autodoc]] models.attention_processor.AttnProcessor2_0
|
||||
|
||||
## FusedAttnProcessor2_0
|
||||
[[autodoc]] models.attention_processor.FusedAttnProcessor2_0
|
||||
|
||||
## LoRAAttnProcessor
|
||||
[[autodoc]] models.attention_processor.LoRAAttnProcessor
|
||||
|
||||
|
||||
@@ -0,0 +1,39 @@
|
||||
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# ControlNet-XS
|
||||
|
||||
ControlNet-XS was introduced in [ControlNet-XS](https://vislearn.github.io/ControlNet-XS/) by Denis Zavadski and Carsten Rother. It is based on the observation that the control model in the [original ControlNet](https://huggingface.co/papers/2302.05543) can be made much smaller and still produce good results.
|
||||
|
||||
Like the original ControlNet model, you can provide an additional control image to condition and control Stable Diffusion generation. For example, if you provide a depth map, the ControlNet model generates an image that'll preserve the spatial information from the depth map. It is a more flexible and accurate way to control the image generation process.
|
||||
|
||||
ControlNet-XS generates images with comparable quality to a regular ControlNet, but it is 20-25% faster ([see benchmark](https://github.com/UmerHA/controlnet-xs-benchmark/blob/main/Speed%20Benchmark.ipynb) with StableDiffusion-XL) and uses ~45% less memory.
|
||||
|
||||
Here's the overview from the [project page](https://vislearn.github.io/ControlNet-XS/):
|
||||
|
||||
*With increasing computing capabilities, current model architectures appear to follow the trend of simply upscaling all components without validating the necessity for doing so. In this project we investigate the size and architectural design of ControlNet [Zhang et al., 2023] for controlling the image generation process with stable diffusion-based models. We show that a new architecture with as little as 1% of the parameters of the base model achieves state-of-the art results, considerably better than ControlNet in terms of FID score. Hence we call it ControlNet-XS. We provide the code for controlling StableDiffusion-XL [Podell et al., 2023] (Model B, 48M Parameters) and StableDiffusion 2.1 [Rombach et al. 2022] (Model B, 14M Parameters), all under openrail license.*
|
||||
|
||||
This model was contributed by [UmerHA](https://twitter.com/UmerHAdil). ❤️
|
||||
|
||||
<Tip>
|
||||
|
||||
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
|
||||
</Tip>
|
||||
|
||||
## StableDiffusionControlNetXSPipeline
|
||||
[[autodoc]] StableDiffusionControlNetXSPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## StableDiffusionPipelineOutput
|
||||
[[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput
|
||||
@@ -0,0 +1,45 @@
|
||||
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# ControlNet-XS with Stable Diffusion XL
|
||||
|
||||
ControlNet-XS was introduced in [ControlNet-XS](https://vislearn.github.io/ControlNet-XS/) by Denis Zavadski and Carsten Rother. It is based on the observation that the control model in the [original ControlNet](https://huggingface.co/papers/2302.05543) can be made much smaller and still produce good results.
|
||||
|
||||
Like the original ControlNet model, you can provide an additional control image to condition and control Stable Diffusion generation. For example, if you provide a depth map, the ControlNet model generates an image that'll preserve the spatial information from the depth map. It is a more flexible and accurate way to control the image generation process.
|
||||
|
||||
ControlNet-XS generates images with comparable quality to a regular ControlNet, but it is 20-25% faster ([see benchmark](https://github.com/UmerHA/controlnet-xs-benchmark/blob/main/Speed%20Benchmark.ipynb)) and uses ~45% less memory.
|
||||
|
||||
Here's the overview from the [project page](https://vislearn.github.io/ControlNet-XS/):
|
||||
|
||||
*With increasing computing capabilities, current model architectures appear to follow the trend of simply upscaling all components without validating the necessity for doing so. In this project we investigate the size and architectural design of ControlNet [Zhang et al., 2023] for controlling the image generation process with stable diffusion-based models. We show that a new architecture with as little as 1% of the parameters of the base model achieves state-of-the art results, considerably better than ControlNet in terms of FID score. Hence we call it ControlNet-XS. We provide the code for controlling StableDiffusion-XL [Podell et al., 2023] (Model B, 48M Parameters) and StableDiffusion 2.1 [Rombach et al. 2022] (Model B, 14M Parameters), all under openrail license.*
|
||||
|
||||
This model was contributed by [UmerHA](https://twitter.com/UmerHAdil). ❤️
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
🧪 Many of the SDXL ControlNet checkpoints are experimental, and there is a lot of room for improvement. Feel free to open an [Issue](https://github.com/huggingface/diffusers/issues/new/choose) and leave us feedback on how we can improve!
|
||||
|
||||
</Tip>
|
||||
|
||||
<Tip>
|
||||
|
||||
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
|
||||
</Tip>
|
||||
|
||||
## StableDiffusionXLControlNetXSPipeline
|
||||
[[autodoc]] StableDiffusionXLControlNetXSPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## StableDiffusionPipelineOutput
|
||||
[[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput
|
||||
@@ -0,0 +1,49 @@
|
||||
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Kandinsky 3
|
||||
|
||||
Kandinsky 3 is created by [Vladimir Arkhipkin](https://github.com/oriBetelgeuse),[Anastasia Maltseva](https://github.com/NastyaMittseva),[Igor Pavlov](https://github.com/boomb0om),[Andrei Filatov](https://github.com/anvilarth),[Arseniy Shakhmatov](https://github.com/cene555),[Andrey Kuznetsov](https://github.com/kuznetsoffandrey),[Denis Dimitrov](https://github.com/denndimitrov), [Zein Shaheen](https://github.com/zeinsh)
|
||||
|
||||
The description from it's Github page:
|
||||
|
||||
*Kandinsky 3.0 is an open-source text-to-image diffusion model built upon the Kandinsky2-x model family. In comparison to its predecessors, enhancements have been made to the text understanding and visual quality of the model, achieved by increasing the size of the text encoder and Diffusion U-Net models, respectively.*
|
||||
|
||||
Its architecture includes 3 main components:
|
||||
1. [FLAN-UL2](https://huggingface.co/google/flan-ul2), which is an encoder decoder model based on the T5 architecture.
|
||||
2. New U-Net architecture featuring BigGAN-deep blocks doubles depth while maintaining the same number of parameters.
|
||||
3. Sber-MoVQGAN is a decoder proven to have superior results in image restoration.
|
||||
|
||||
|
||||
|
||||
The original codebase can be found at [ai-forever/Kandinsky-3](https://github.com/ai-forever/Kandinsky-3).
|
||||
|
||||
<Tip>
|
||||
|
||||
Check out the [Kandinsky Community](https://huggingface.co/kandinsky-community) organization on the Hub for the official model checkpoints for tasks like text-to-image, image-to-image, and inpainting.
|
||||
|
||||
</Tip>
|
||||
|
||||
<Tip>
|
||||
|
||||
Make sure to check out the schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
|
||||
</Tip>
|
||||
|
||||
## Kandinsky3Pipeline
|
||||
|
||||
[[autodoc]] Kandinsky3Pipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## Kandinsky3Img2ImgPipeline
|
||||
|
||||
[[autodoc]] Kandinsky3Img2ImgPipeline
|
||||
- all
|
||||
- __call__
|
||||
@@ -40,6 +40,8 @@ The table below lists all the pipelines currently available in 🤗 Diffusers an
|
||||
| [Consistency Models](consistency_models) | unconditional image generation |
|
||||
| [ControlNet](controlnet) | text2image, image2image, inpainting |
|
||||
| [ControlNet with Stable Diffusion XL](controlnet_sdxl) | text2image |
|
||||
| [ControlNet-XS](controlnetxs) | text2image |
|
||||
| [ControlNet-XS with Stable Diffusion XL](controlnetxs_sdxl) | text2image |
|
||||
| [Cycle Diffusion](cycle_diffusion) | image2image |
|
||||
| [Dance Diffusion](dance_diffusion) | unconditional audio generation |
|
||||
| [DDIM](ddim) | unconditional image generation |
|
||||
@@ -51,9 +53,10 @@ The table below lists all the pipelines currently available in 🤗 Diffusers an
|
||||
| [InstructPix2Pix](pix2pix) | image editing |
|
||||
| [Kandinsky 2.1](kandinsky) | text2image, image2image, inpainting, interpolation |
|
||||
| [Kandinsky 2.2](kandinsky_v22) | text2image, image2image, inpainting |
|
||||
| [Kandinsky 3](kandinsky3) | text2image, image2image |
|
||||
| [Latent Consistency Models](latent_consistency_models) | text2image |
|
||||
| [Latent Diffusion](latent_diffusion) | text2image, super-resolution |
|
||||
| [LDM3D](stable_diffusion/ldm3d_diffusion) | text2image, text-to-3D |
|
||||
| [LDM3D](stable_diffusion/ldm3d_diffusion) | text2image, text-to-3D, text-to-pano, upscaling |
|
||||
| [MultiDiffusion](panorama) | text2image |
|
||||
| [MusicLDM](musicldm) | text2audio |
|
||||
| [Paint by Example](paint_by_example) | inpainting |
|
||||
@@ -70,6 +73,7 @@ The table below lists all the pipelines currently available in 🤗 Diffusers an
|
||||
| [Stable Diffusion](stable_diffusion/overview) | text2image, image2image, depth2image, inpainting, image variation, latent upscaler, super-resolution |
|
||||
| [Stable Diffusion Model Editing](model_editing) | model editing |
|
||||
| [Stable Diffusion XL](stable_diffusion/stable_diffusion_xl) | text2image, image2image, inpainting |
|
||||
| [Stable Diffusion XL Turbo](stable_diffusion/sdxl_turbo) | text2image, image2image, inpainting |
|
||||
| [Stable unCLIP](stable_unclip) | text2image, image variation |
|
||||
| [Stochastic Karras VE](stochastic_karras_ve) | unconditional image generation |
|
||||
| [T2I-Adapter](stable_diffusion/adapter) | text2image |
|
||||
|
||||
@@ -35,6 +35,112 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers)
|
||||
|
||||
</Tip>
|
||||
|
||||
## Inference with under 8GB GPU VRAM
|
||||
|
||||
Run the [`PixArtAlphaPipeline`] with under 8GB GPU VRAM by loading the text encoder in 8-bit precision. Let's walk through a full-fledged example.
|
||||
|
||||
First, install the [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) library:
|
||||
|
||||
```bash
|
||||
pip install -U bitsandbytes
|
||||
```
|
||||
|
||||
Then load the text encoder in 8-bit:
|
||||
|
||||
```python
|
||||
from transformers import T5EncoderModel
|
||||
from diffusers import PixArtAlphaPipeline
|
||||
import torch
|
||||
|
||||
text_encoder = T5EncoderModel.from_pretrained(
|
||||
"PixArt-alpha/PixArt-XL-2-1024-MS",
|
||||
subfolder="text_encoder",
|
||||
load_in_8bit=True,
|
||||
device_map="auto",
|
||||
|
||||
)
|
||||
pipe = PixArtAlphaPipeline.from_pretrained(
|
||||
"PixArt-alpha/PixArt-XL-2-1024-MS",
|
||||
text_encoder=text_encoder,
|
||||
transformer=None,
|
||||
device_map="auto"
|
||||
)
|
||||
```
|
||||
|
||||
Now, use the `pipe` to encode a prompt:
|
||||
|
||||
```python
|
||||
with torch.no_grad():
|
||||
prompt = "cute cat"
|
||||
prompt_embeds, prompt_attention_mask, negative_embeds, negative_prompt_attention_mask = pipe.encode_prompt(prompt)
|
||||
```
|
||||
|
||||
Since text embeddings have been computed, remove the `text_encoder` and `pipe` from the memory, and free up som GPU VRAM:
|
||||
|
||||
```python
|
||||
import gc
|
||||
|
||||
def flush():
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
del text_encoder
|
||||
del pipe
|
||||
flush()
|
||||
```
|
||||
|
||||
Then compute the latents with the prompt embeddings as inputs:
|
||||
|
||||
```python
|
||||
pipe = PixArtAlphaPipeline.from_pretrained(
|
||||
"PixArt-alpha/PixArt-XL-2-1024-MS",
|
||||
text_encoder=None,
|
||||
torch_dtype=torch.float16,
|
||||
).to("cuda")
|
||||
|
||||
latents = pipe(
|
||||
negative_prompt=None,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_embeds,
|
||||
prompt_attention_mask=prompt_attention_mask,
|
||||
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
||||
num_images_per_prompt=1,
|
||||
output_type="latent",
|
||||
).images
|
||||
|
||||
del pipe.transformer
|
||||
flush()
|
||||
```
|
||||
|
||||
<Tip>
|
||||
|
||||
Notice that while initializing `pipe`, you're setting `text_encoder` to `None` so that it's not loaded.
|
||||
|
||||
</Tip>
|
||||
|
||||
Once the latents are computed, pass it off to the VAE to decode into a real image:
|
||||
|
||||
```python
|
||||
with torch.no_grad():
|
||||
image = pipe.vae.decode(latents / pipe.vae.config.scaling_factor, return_dict=False)[0]
|
||||
image = pipe.image_processor.postprocess(image, output_type="pil")[0]
|
||||
image.save("cat.png")
|
||||
```
|
||||
|
||||
By deleting components you aren't using and flushing the GPU VRAM, you should be able to run [`PixArtAlphaPipeline`] with under 8GB GPU VRAM.
|
||||
|
||||

|
||||
|
||||
If you want a report of your memory-usage, run this [script](https://gist.github.com/sayakpaul/3ae0f847001d342af27018a96f467e4e).
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Text embeddings computed in 8-bit can impact the quality of the generated images because of the information loss in the representation space caused by the reduced precision. It's recommended to compare the outputs with and without 8-bit.
|
||||
|
||||
</Tip>
|
||||
|
||||
While loading the `text_encoder`, you set `load_in_8bit` to `True`. You could also specify `load_in_4bit` to bring your memory requirements down even further to under 7GB.
|
||||
|
||||
## PixArtAlphaPipeline
|
||||
|
||||
[[autodoc]] PixArtAlphaPipeline
|
||||
|
||||
@@ -14,6 +14,11 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
LDM3D was proposed in [LDM3D: Latent Diffusion Model for 3D](https://huggingface.co/papers/2305.10853) by Gabriela Ben Melech Stan, Diana Wofk, Scottie Fox, Alex Redden, Will Saxton, Jean Yu, Estelle Aflalo, Shao-Yen Tseng, Fabio Nonato, Matthias Muller, and Vasudev Lal. LDM3D generates an image and a depth map from a given text prompt unlike the existing text-to-image diffusion models such as [Stable Diffusion](./overview) which only generates an image. With almost the same number of parameters, LDM3D achieves to create a latent space that can compress both the RGB images and the depth maps.
|
||||
|
||||
Two checkpoints are available for use:
|
||||
- [ldm3d-original](https://huggingface.co/Intel/ldm3d). The original checkpoint used in the [paper](https://arxiv.org/pdf/2305.10853.pdf)
|
||||
- [ldm3d-4c](https://huggingface.co/Intel/ldm3d-4c). The new version of LDM3D using 4 channels inputs instead of 6-channels inputs and finetuned on higher resolution images.
|
||||
|
||||
|
||||
The abstract from the paper is:
|
||||
|
||||
*This research paper proposes a Latent Diffusion Model for 3D (LDM3D) that generates both image and depth map data from a given text prompt, allowing users to generate RGBD images from text prompts. The LDM3D model is fine-tuned on a dataset of tuples containing an RGB image, depth map and caption, and validated through extensive experiments. We also develop an application called DepthFusion, which uses the generated RGB images and depth maps to create immersive and interactive 360-degree-view experiences using TouchDesigner. This technology has the potential to transform a wide range of industries, from entertainment and gaming to architecture and design. Overall, this paper presents a significant contribution to the field of generative AI and computer vision, and showcases the potential of LDM3D and DepthFusion to revolutionize content creation and digital experiences. A short video summarizing the approach can be found at [this url](https://t.ly/tdi2).*
|
||||
@@ -26,12 +31,25 @@ Make sure to check out the Stable Diffusion [Tips](overview#tips) section to lea
|
||||
|
||||
## StableDiffusionLDM3DPipeline
|
||||
|
||||
[[autodoc]] StableDiffusionLDM3DPipeline
|
||||
[[autodoc]] pipelines.stable_diffusion.pipeline_stable_diffusion_ldm3d.StableDiffusionLDM3DPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
|
||||
## LDM3DPipelineOutput
|
||||
|
||||
[[autodoc]] pipelines.stable_diffusion.pipeline_stable_diffusion_ldm3d.LDM3DPipelineOutput
|
||||
- all
|
||||
- __call__
|
||||
|
||||
# Upscaler
|
||||
|
||||
[LDM3D-VR](https://arxiv.org/pdf/2311.03226.pdf) is an extended version of LDM3D.
|
||||
|
||||
The abstract from the paper is:
|
||||
*Latent diffusion models have proven to be state-of-the-art in the creation and manipulation of visual outputs. However, as far as we know, the generation of depth maps jointly with RGB is still limited. We introduce LDM3D-VR, a suite of diffusion models targeting virtual reality development that includes LDM3D-pano and LDM3D-SR. These models enable the generation of panoramic RGBD based on textual prompts and the upscaling of low-resolution inputs to high-resolution RGBD, respectively. Our models are fine-tuned from existing pretrained models on datasets containing panoramic/high-resolution RGB images, depth maps and captions. Both models are evaluated in comparison to existing related methods*
|
||||
|
||||
Two checkpoints are available for use:
|
||||
- [ldm3d-pano](https://huggingface.co/Intel/ldm3d-pano). This checkpoint enables the generation of panoramic images and requires the StableDiffusionLDM3DPipeline pipeline to be used.
|
||||
- [ldm3d-sr](https://huggingface.co/Intel/ldm3d-sr). This checkpoint enables the upscaling of RGB and depth images. Can be used in cascade after the original LDM3D pipeline using the StableDiffusionUpscaleLDM3DPipeline from communauty pipeline.
|
||||
|
||||
|
||||
@@ -121,10 +121,16 @@ The table below summarizes the available Stable Diffusion pipelines, their suppo
|
||||
<td class="px-4 py-2 text-gray-700">
|
||||
<a href="./ldm3d_diffusion">StableDiffusionLDM3D</a>
|
||||
</td>
|
||||
<td class="px-4 py-2 text-gray-700">text-to-rgb, text-to-depth</td>
|
||||
<td class="px-4 py-2 text-gray-700">text-to-rgb, text-to-depth, text-to-pano</td>
|
||||
<td class="px-4 py-2"><a href="https://huggingface.co/spaces/r23/ldm3d-space"><img src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue"/></a>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="px-4 py-2 text-gray-700">
|
||||
<a href="./ldm3d_diffusion">StableDiffusionUpscaleLDM3D</a>
|
||||
</td>
|
||||
<td class="px-4 py-2 text-gray-700">ldm3d super-resolution</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
|
||||
@@ -0,0 +1,35 @@
|
||||
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# SDXL Turbo
|
||||
|
||||
Stable Diffusion XL (SDXL) Turbo was proposed in [Adversarial Diffusion Distillation](https://stability.ai/research/adversarial-diffusion-distillation) by Axel Sauer, Dominik Lorenz, Andreas Blattmann, and Robin Rombach.
|
||||
|
||||
The abstract from the paper is:
|
||||
|
||||
*We introduce Adversarial Diffusion Distillation (ADD), a novel training approach that efficiently samples large-scale foundational image diffusion models in just 1–4 steps while maintaining high image quality. We use score distillation to leverage large-scale off-the-shelf image diffusion models as a teacher signal in combination with an adversarial loss to ensure high image fidelity even in the low-step regime of one or two sampling steps. Our analyses show that our model clearly outperforms existing few-step methods (GANs,Latent Consistency Models) in a single step and reaches the performance of state-of-the-art diffusion models (SDXL) in only four steps. ADD is the first method to unlock single-step, real-time image synthesis with foundation models.*
|
||||
|
||||
## Tips
|
||||
|
||||
- SDXL Turbo uses the exact same architecture as [SDXL](./stable_diffusion_xl), which means it also has the same API. Please refer to the [SDXL](./stable_diffusion_xl) API reference for more details.
|
||||
- SDXL Turbo should disable guidance scale by setting `guidance_scale=0.0`
|
||||
- SDXL Turbo should use `timestep_spacing='trailing'` for the scheduler and use between 1 and 4 steps.
|
||||
- SDXL Turbo has been trained to generate images of size 512x512.
|
||||
- SDXL Turbo is open-access, but not open-source meaning that one might have to buy a model license in order to use it for commercial applications. Make sure to read the [official model card](https://huggingface.co/stabilityai/sdxl-turbo) to learn more.
|
||||
|
||||
<Tip>
|
||||
|
||||
To learn how to use SDXL Turbo for various tasks, how to optimize performance, and other usage examples, take a look at the [SDXL Turbo](../../../using-diffusers/sdxl_turbo) guide.
|
||||
|
||||
Check out the [Stability AI](https://huggingface.co/stabilityai) Hub organization for the official base and refiner model checkpoints!
|
||||
|
||||
</Tip>
|
||||
@@ -92,6 +92,19 @@ imageio.mimsave("video.mp4", result, fps=4)
|
||||
```
|
||||
|
||||
|
||||
- #### SDXL Support
|
||||
In order to use the SDXL model when generating a video from prompt, use the `TextToVideoZeroSDXLPipeline` pipeline:
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import TextToVideoZeroSDXLPipeline
|
||||
|
||||
model_id = "stabilityai/stable-diffusion-xl-base-1.0"
|
||||
pipe = TextToVideoZeroSDXLPipeline.from_pretrained(
|
||||
model_id, torch_dtype=torch.float16, variant="fp16", use_safetensors=True
|
||||
).to("cuda")
|
||||
```
|
||||
|
||||
### Text-To-Video with Pose Control
|
||||
To generate a video from prompt with additional pose control
|
||||
|
||||
@@ -141,7 +154,33 @@ To generate a video from prompt with additional pose control
|
||||
result = pipe(prompt=[prompt] * len(pose_images), image=pose_images, latents=latents).images
|
||||
imageio.mimsave("video.mp4", result, fps=4)
|
||||
```
|
||||
|
||||
- #### SDXL Support
|
||||
|
||||
Since our attention processor also works with SDXL, it can be utilized to generate a video from prompt using ControlNet models powered by SDXL:
|
||||
```python
|
||||
import torch
|
||||
from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel
|
||||
from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_zero import CrossFrameAttnProcessor
|
||||
|
||||
controlnet_model_id = 'thibaud/controlnet-openpose-sdxl-1.0'
|
||||
model_id = 'stabilityai/stable-diffusion-xl-base-1.0'
|
||||
|
||||
controlnet = ControlNetModel.from_pretrained(controlnet_model_id, torch_dtype=torch.float16)
|
||||
pipe = StableDiffusionControlNetPipeline.from_pretrained(
|
||||
model_id, controlnet=controlnet, torch_dtype=torch.float16
|
||||
).to('cuda')
|
||||
|
||||
# Set the attention processor
|
||||
pipe.unet.set_attn_processor(CrossFrameAttnProcessor(batch_size=2))
|
||||
pipe.controlnet.set_attn_processor(CrossFrameAttnProcessor(batch_size=2))
|
||||
|
||||
# fix latents for all frames
|
||||
latents = torch.randn((1, 4, 128, 128), device="cuda", dtype=torch.float16).repeat(len(pose_images), 1, 1, 1)
|
||||
|
||||
prompt = "Darth Vader dancing in a desert"
|
||||
result = pipe(prompt=[prompt] * len(pose_images), image=pose_images, latents=latents).images
|
||||
imageio.mimsave("video.mp4", result, fps=4)
|
||||
```
|
||||
|
||||
### Text-To-Video with Edge Control
|
||||
|
||||
@@ -253,5 +292,10 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers)
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## TextToVideoZeroSDXLPipeline
|
||||
[[autodoc]] TextToVideoZeroSDXLPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## TextToVideoPipelineOutput
|
||||
[[autodoc]] pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.TextToVideoPipelineOutput
|
||||
|
||||
@@ -24,7 +24,7 @@ The abstract from the paper is:
|
||||
|
||||
*Model-based reinforcement learning methods often use learning only for the purpose of estimating an approximate dynamics model, offloading the rest of the decision-making work to classical trajectory optimizers. While conceptually simple, this combination has a number of empirical shortcomings, suggesting that learned models may not be well-suited to standard trajectory optimization. In this paper, we consider what it would look like to fold as much of the trajectory optimization pipeline as possible into the modeling problem, such that sampling from the model and planning with it become nearly identical. The core of our technical approach lies in a diffusion probabilistic model that plans by iteratively denoising trajectories. We show how classifier-guided sampling and image inpainting can be reinterpreted as coherent planning strategies, explore the unusual and useful properties of diffusion-based planning methods, and demonstrate the effectiveness of our framework in control settings that emphasize long-horizon decision-making and test-time flexibility.*
|
||||
|
||||
You can find additional information about the model on the [project page](https://diffusion-planning.github.io/), the [original codebase](https://github.com/jannerm/diffuser), or try it out in a demo [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/reinforcement_learning_with_diffusers.ipynb).
|
||||
You can find additional information about the model on the [project page](https://diffusion-planning.github.io/), the [original codebase](https://github.com/jannerm/diffuser), or try it out in a demo [notebook](https://colab.research.google.com/drive/1rXm8CX4ZdN5qivjJ2lhwhkOmt_m0CvU0#scrollTo=6HXJvhyqcITc&uniqifier=1).
|
||||
|
||||
The script to run the model is available [here](https://github.com/huggingface/diffusers/tree/main/examples/reinforcement_learning).
|
||||
|
||||
|
||||
@@ -25,4 +25,4 @@ The abstract from the paper is:
|
||||
</Tip>
|
||||
|
||||
## ScoreSdeVpScheduler
|
||||
[[autodoc]] schedulers.scheduling_sde_vp.ScoreSdeVpScheduler
|
||||
[[autodoc]] schedulers.deprecated.scheduling_sde_vp.ScoreSdeVpScheduler
|
||||
|
||||
@@ -18,4 +18,4 @@ specific language governing permissions and limitations under the License.
|
||||
[[autodoc]] KarrasVeScheduler
|
||||
|
||||
## KarrasVeOutput
|
||||
[[autodoc]] schedulers.scheduling_karras_ve.KarrasVeOutput
|
||||
[[autodoc]] schedulers.deprecated.scheduling_karras_ve.KarrasVeOutput
|
||||
@@ -297,17 +297,37 @@ if you don't know yet what specific component you would like to add:
|
||||
- [Model or pipeline](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22New+pipeline%2Fmodel%22)
|
||||
- [Scheduler](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22New+scheduler%22)
|
||||
|
||||
Before adding any of the three components, it is strongly recommended that you give the [Philosophy guide](philosophy) a read to better understand the design of any of the three components. Please be aware that
|
||||
we cannot merge model, scheduler, or pipeline additions that strongly diverge from our design philosophy
|
||||
as it will lead to API inconsistencies. If you fundamentally disagree with a design choice, please
|
||||
open a [Feedback issue](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=&template=feedback.md&title=) instead so that it can be discussed whether a certain design
|
||||
pattern/design choice shall be changed everywhere in the library and whether we shall update our design philosophy. Consistency across the library is very important for us.
|
||||
Before adding any of the three components, it is strongly recommended that you give the [Philosophy guide](philosophy) a read to better understand the design of any of the three components. Please be aware that we cannot merge model, scheduler, or pipeline additions that strongly diverge from our design philosophy
|
||||
as it will lead to API inconsistencies. If you fundamentally disagree with a design choice, please open a [Feedback issue](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=&template=feedback.md&title=) instead so that it can be discussed whether a certain design pattern/design choice shall be changed everywhere in the library and whether we shall update our design philosophy. Consistency across the library is very important for us.
|
||||
|
||||
Please make sure to add links to the original codebase/paper to the PR and ideally also ping the
|
||||
original author directly on the PR so that they can follow the progress and potentially help with questions.
|
||||
Please make sure to add links to the original codebase/paper to the PR and ideally also ping the original author directly on the PR so that they can follow the progress and potentially help with questions.
|
||||
|
||||
If you are unsure or stuck in the PR, don't hesitate to leave a message to ask for a first review or help.
|
||||
|
||||
#### Copied from mechanism
|
||||
|
||||
A unique and important feature to understand when adding any pipeline, model or scheduler code is the `# Copied from` mechanism. You'll see this all over the Diffusers codebase, and the reason we use it is to keep the codebase easy to understand and maintain. Marking code with the `# Copied from` mechanism forces the marked code to be identical to the code it was copied from. This makes it easy to update and propagate changes across many files whenever you run `make fix-copies`.
|
||||
|
||||
For example, in the code example below, [`~diffusers.pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is the original code and `AltDiffusionPipelineOutput` uses the `# Copied from` mechanism to copy it. The only difference is changing the class prefix from `Stable` to `Alt`.
|
||||
|
||||
```py
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_output.StableDiffusionPipelineOutput with Stable->Alt
|
||||
class AltDiffusionPipelineOutput(BaseOutput):
|
||||
"""
|
||||
Output class for Alt Diffusion pipelines.
|
||||
|
||||
Args:
|
||||
images (`List[PIL.Image.Image]` or `np.ndarray`)
|
||||
List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width,
|
||||
num_channels)`.
|
||||
nsfw_content_detected (`List[bool]`)
|
||||
List indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content or
|
||||
`None` if safety checking could not be performed.
|
||||
"""
|
||||
```
|
||||
|
||||
To learn more, read this section of the [~Don't~ Repeat Yourself*](https://huggingface.co/blog/transformers-design-philosophy#4-machine-learning-models-are-static) blog post.
|
||||
|
||||
## How to write a good issue
|
||||
|
||||
**The better your issue is written, the higher the chances that it will be quickly resolved.**
|
||||
|
||||
@@ -0,0 +1,255 @@
|
||||
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Latent Consistency Distillation
|
||||
|
||||
[Latent Consistency Models (LCMs)](https://hf.co/papers/2310.04378) are able to generate high-quality images in just a few steps, representing a big leap forward because many pipelines require at least 25+ steps. LCMs are produced by applying the latent consistency distillation method to any Stable Diffusion model. This method works by applying *one-stage guided distillation* to the latent space, and incorporating a *skipping-step* method to consistently skip timesteps to accelerate the distillation process (refer to section 4.1, 4.2, and 4.3 of the paper for more details).
|
||||
|
||||
If you're training on a GPU with limited vRAM, try enabling `gradient_checkpointing`, `gradient_accumulation_steps`, and `mixed_precision` to reduce memory-usage and speedup training. You can reduce your memory-usage even more by enabling memory-efficient attention with [xFormers](../optimization/xformers) and [bitsandbytes'](https://github.com/TimDettmers/bitsandbytes) 8-bit optimizer.
|
||||
|
||||
This guide will explore the [train_lcm_distill_sd_wds.py](https://github.com/huggingface/diffusers/blob/main/examples/consistency_distillation/train_lcm_distill_sd_wds.py) script to help you become more familiar with it, and how you can adapt it for your own use-case.
|
||||
|
||||
Before running the script, make sure you install the library from source:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/huggingface/diffusers
|
||||
cd diffusers
|
||||
pip install .
|
||||
```
|
||||
|
||||
Then navigate to the example folder containing the training script and install the required dependencies for the script you're using:
|
||||
|
||||
```bash
|
||||
cd examples/consistency_distillation
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
<Tip>
|
||||
|
||||
🤗 Accelerate is a library for helping you train on multiple GPUs/TPUs or with mixed-precision. It'll automatically configure your training setup based on your hardware and environment. Take a look at the 🤗 Accelerate [Quick tour](https://huggingface.co/docs/accelerate/quicktour) to learn more.
|
||||
|
||||
</Tip>
|
||||
|
||||
Initialize an 🤗 Accelerate environment (try enabling `torch.compile` to significantly speedup training):
|
||||
|
||||
```bash
|
||||
accelerate config
|
||||
```
|
||||
|
||||
To setup a default 🤗 Accelerate environment without choosing any configurations:
|
||||
|
||||
```bash
|
||||
accelerate config default
|
||||
```
|
||||
|
||||
Or if your environment doesn't support an interactive shell, like a notebook, you can use:
|
||||
|
||||
```bash
|
||||
from accelerate.utils import write_basic_config
|
||||
|
||||
write_basic_config()
|
||||
```
|
||||
|
||||
Lastly, if you want to train a model on your own dataset, take a look at the [Create a dataset for training](create_dataset) guide to learn how to create a dataset that works with the training script.
|
||||
|
||||
## Script parameters
|
||||
|
||||
<Tip>
|
||||
|
||||
The following sections highlight parts of the training script that are important for understanding how to modify it, but it doesn't cover every aspect of the script in detail. If you're interested in learning more, feel free to read through the [script](https://github.com/huggingface/diffusers/blob/main/examples/consistency_distillation/train_lcm_distill_sd_wds.py) and let us know if you have any questions or concerns.
|
||||
|
||||
</Tip>
|
||||
|
||||
The training script provides many parameters to help you customize your training run. All of the parameters and their descriptions are found in the [`parse_args()`](https://github.com/huggingface/diffusers/blob/3b37488fa3280aed6a95de044d7a42ffdcb565ef/examples/consistency_distillation/train_lcm_distill_sd_wds.py#L419) function. This function provides default values for each parameter, such as the training batch size and learning rate, but you can also set your own values in the training command if you'd like.
|
||||
|
||||
For example, to speedup training with mixed precision using the fp16 format, add the `--mixed_precision` parameter to the training command:
|
||||
|
||||
```bash
|
||||
accelerate launch train_lcm_distill_sd_wds.py \
|
||||
--mixed_precision="fp16"
|
||||
```
|
||||
|
||||
Most of the parameters are identical to the parameters in the [Text-to-image](text2image#script-parameters) training guide, so you'll focus on the parameters that are relevant to latent consistency distillation in this guide.
|
||||
|
||||
- `--pretrained_teacher_model`: the path to a pretrained latent diffusion model to use as the teacher model
|
||||
- `--pretrained_vae_model_name_or_path`: path to a pretrained VAE; the SDXL VAE is known to suffer from numerical instability, so this parameter allows you to specify an alternative VAE (like this [VAE]((https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)) by madebyollin which works in fp16)
|
||||
- `--w_min` and `--w_max`: the minimum and maximum guidance scale values for guidance scale sampling
|
||||
- `--num_ddim_timesteps`: the number of timesteps for DDIM sampling
|
||||
- `--loss_type`: the type of loss (L2 or Huber) to calculate for latent consistency distillation; Huber loss is generally preferred because it's more robust to outliers
|
||||
- `--huber_c`: the Huber loss parameter
|
||||
|
||||
## Training script
|
||||
|
||||
The training script starts by creating a dataset class - [`Text2ImageDataset`](https://github.com/huggingface/diffusers/blob/3b37488fa3280aed6a95de044d7a42ffdcb565ef/examples/consistency_distillation/train_lcm_distill_sd_wds.py#L141) - for preprocessing the images and creating a training dataset.
|
||||
|
||||
```py
|
||||
def transform(example):
|
||||
image = example["image"]
|
||||
image = TF.resize(image, resolution, interpolation=transforms.InterpolationMode.BILINEAR)
|
||||
|
||||
c_top, c_left, _, _ = transforms.RandomCrop.get_params(image, output_size=(resolution, resolution))
|
||||
image = TF.crop(image, c_top, c_left, resolution, resolution)
|
||||
image = TF.to_tensor(image)
|
||||
image = TF.normalize(image, [0.5], [0.5])
|
||||
|
||||
example["image"] = image
|
||||
return example
|
||||
```
|
||||
|
||||
For improved performance on reading and writing large datasets stored in the cloud, this script uses the [WebDataset](https://github.com/webdataset/webdataset) format to create a preprocessing pipeline to apply transforms and create a dataset and dataloader for training. Images are processed and fed to the training loop without having to download the full dataset first.
|
||||
|
||||
```py
|
||||
processing_pipeline = [
|
||||
wds.decode("pil", handler=wds.ignore_and_continue),
|
||||
wds.rename(image="jpg;png;jpeg;webp", text="text;txt;caption", handler=wds.warn_and_continue),
|
||||
wds.map(filter_keys({"image", "text"})),
|
||||
wds.map(transform),
|
||||
wds.to_tuple("image", "text"),
|
||||
]
|
||||
```
|
||||
|
||||
In the [`main()`](https://github.com/huggingface/diffusers/blob/3b37488fa3280aed6a95de044d7a42ffdcb565ef/examples/consistency_distillation/train_lcm_distill_sd_wds.py#L768) function, all the necessary components like the noise scheduler, tokenizers, text encoders, and VAE are loaded. The teacher UNet is also loaded here and then you can create a student UNet from the teacher UNet. The student UNet is updated by the optimizer during training.
|
||||
|
||||
```py
|
||||
teacher_unet = UNet2DConditionModel.from_pretrained(
|
||||
args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision
|
||||
)
|
||||
|
||||
unet = UNet2DConditionModel(**teacher_unet.config)
|
||||
unet.load_state_dict(teacher_unet.state_dict(), strict=False)
|
||||
unet.train()
|
||||
```
|
||||
|
||||
Now you can create the [optimizer](https://github.com/huggingface/diffusers/blob/3b37488fa3280aed6a95de044d7a42ffdcb565ef/examples/consistency_distillation/train_lcm_distill_sd_wds.py#L979) to update the UNet parameters:
|
||||
|
||||
```py
|
||||
optimizer = optimizer_class(
|
||||
unet.parameters(),
|
||||
lr=args.learning_rate,
|
||||
betas=(args.adam_beta1, args.adam_beta2),
|
||||
weight_decay=args.adam_weight_decay,
|
||||
eps=args.adam_epsilon,
|
||||
)
|
||||
```
|
||||
|
||||
Create the [dataset](https://github.com/huggingface/diffusers/blob/3b37488fa3280aed6a95de044d7a42ffdcb565ef/examples/consistency_distillation/train_lcm_distill_sd_wds.py#L994):
|
||||
|
||||
```py
|
||||
dataset = Text2ImageDataset(
|
||||
train_shards_path_or_url=args.train_shards_path_or_url,
|
||||
num_train_examples=args.max_train_samples,
|
||||
per_gpu_batch_size=args.train_batch_size,
|
||||
global_batch_size=args.train_batch_size * accelerator.num_processes,
|
||||
num_workers=args.dataloader_num_workers,
|
||||
resolution=args.resolution,
|
||||
shuffle_buffer_size=1000,
|
||||
pin_memory=True,
|
||||
persistent_workers=True,
|
||||
)
|
||||
train_dataloader = dataset.train_dataloader
|
||||
```
|
||||
|
||||
Next, you're ready to setup the [training loop](https://github.com/huggingface/diffusers/blob/3b37488fa3280aed6a95de044d7a42ffdcb565ef/examples/consistency_distillation/train_lcm_distill_sd_wds.py#L1049) and implement the latent consistency distillation method (see Algorithm 1 in the paper for more details). This section of the script takes care of adding noise to the latents, sampling and creating a guidance scale embedding, and predicting the original image from the noise.
|
||||
|
||||
```py
|
||||
pred_x_0 = predicted_origin(
|
||||
noise_pred,
|
||||
start_timesteps,
|
||||
noisy_model_input,
|
||||
noise_scheduler.config.prediction_type,
|
||||
alpha_schedule,
|
||||
sigma_schedule,
|
||||
)
|
||||
|
||||
model_pred = c_skip_start * noisy_model_input + c_out_start * pred_x_0
|
||||
```
|
||||
|
||||
It gets the [teacher model predictions](https://github.com/huggingface/diffusers/blob/3b37488fa3280aed6a95de044d7a42ffdcb565ef/examples/consistency_distillation/train_lcm_distill_sd_wds.py#L1172) and the [LCM predictions](https://github.com/huggingface/diffusers/blob/3b37488fa3280aed6a95de044d7a42ffdcb565ef/examples/consistency_distillation/train_lcm_distill_sd_wds.py#L1209) next, calculates the loss, and then backpropagates it to the LCM.
|
||||
|
||||
```py
|
||||
if args.loss_type == "l2":
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||
elif args.loss_type == "huber":
|
||||
loss = torch.mean(
|
||||
torch.sqrt((model_pred.float() - target.float()) ** 2 + args.huber_c**2) - args.huber_c
|
||||
)
|
||||
```
|
||||
|
||||
If you want to learn more about how the training loop works, check out the [Understanding pipelines, models and schedulers tutorial](../using-diffusers/write_own_pipeline) which breaks down the basic pattern of the denoising process.
|
||||
|
||||
## Launch the script
|
||||
|
||||
Now you're ready to launch the training script and start distilling!
|
||||
|
||||
For this guide, you'll use the `--train_shards_path_or_url` to specify the path to the [Conceptual Captions 12M](https://github.com/google-research-datasets/conceptual-12m) dataset stored on the Hub [here](https://huggingface.co/datasets/laion/conceptual-captions-12m-webdataset). Set the `MODEL_DIR` environment variable to the name of the teacher model and `OUTPUT_DIR` to where you want to save the model.
|
||||
|
||||
```bash
|
||||
export MODEL_DIR="runwayml/stable-diffusion-v1-5"
|
||||
export OUTPUT_DIR="path/to/saved/model"
|
||||
|
||||
accelerate launch train_lcm_distill_sd_wds.py \
|
||||
--pretrained_teacher_model=$MODEL_DIR \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--mixed_precision=fp16 \
|
||||
--resolution=512 \
|
||||
--learning_rate=1e-6 --loss_type="huber" --ema_decay=0.95 --adam_weight_decay=0.0 \
|
||||
--max_train_steps=1000 \
|
||||
--max_train_samples=4000000 \
|
||||
--dataloader_num_workers=8 \
|
||||
--train_shards_path_or_url="pipe:curl -L -s https://huggingface.co/datasets/laion/conceptual-captions-12m-webdataset/resolve/main/data/{00000..01099}.tar?download=true" \
|
||||
--validation_steps=200 \
|
||||
--checkpointing_steps=200 --checkpoints_total_limit=10 \
|
||||
--train_batch_size=12 \
|
||||
--gradient_checkpointing --enable_xformers_memory_efficient_attention \
|
||||
--gradient_accumulation_steps=1 \
|
||||
--use_8bit_adam \
|
||||
--resume_from_checkpoint=latest \
|
||||
--report_to=wandb \
|
||||
--seed=453645634 \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
Once training is complete, you can use your new LCM for inference.
|
||||
|
||||
```py
|
||||
from diffusers import UNet2DConditionModel, DiffusionPipeline, LCMScheduler
|
||||
import torch
|
||||
|
||||
unet = UNet2DConditionModel.from_pretrained("your-username/your-model", torch_dtype=torch.float16, variant="fp16")
|
||||
pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", unet=unet, torch_dtype=torch.float16, variant="fp16")
|
||||
|
||||
pipeline.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
|
||||
pipeline.to("cuda")
|
||||
|
||||
prompt = "sushi rolls in the form of panda heads, sushi platter"
|
||||
|
||||
image = pipeline(prompt, num_inference_steps=4, guidance_scale=1.0).images[0]
|
||||
```
|
||||
|
||||
## LoRA
|
||||
|
||||
LoRA is a training technique for significantly reducing the number of trainable parameters. As a result, training is faster and it is easier to store the resulting weights because they are a lot smaller (~100MBs). Use the [train_lcm_distill_lora_sd_wds.py](https://github.com/huggingface/diffusers/blob/main/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py) or [train_lcm_distill_lora_sdxl.wds.py](https://github.com/huggingface/diffusers/blob/main/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py) script to train with LoRA.
|
||||
|
||||
The LoRA training script is discussed in more detail in the [LoRA training](lora) guide.
|
||||
|
||||
## Stable Diffusion XL
|
||||
|
||||
Stable Diffusion XL (SDXL) is a powerful text-to-image model that generates high-resolution images, and it adds a second text-encoder to its architecture. Use the [train_lcm_distill_sdxl_wds.py](https://github.com/huggingface/diffusers/blob/main/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py) script to train a SDXL model with LoRA.
|
||||
|
||||
The SDXL training script is discussed in more detail in the [SDXL training](sdxl) guide.
|
||||
|
||||
## Next steps
|
||||
|
||||
Congratulations on distilling a LCM model! To learn more about LCM, the following may be helpful:
|
||||
|
||||
- Learn how to use [LCMs for inference](../using-diffusers/lcm) for text-to-image, image-to-image, and with LoRA checkpoints.
|
||||
- Read the [SDXL in 4 steps with Latent Consistency LoRAs](https://huggingface.co/blog/lcm_lora) blog post to learn more about SDXL LCM-LoRA's for super fast inference, quality comparisons, benchmarks, and more.
|
||||
@@ -20,6 +20,8 @@ The Kandinsky models are a series of multilingual text-to-image generation model
|
||||
|
||||
[Kandinsky 2.2](../api/pipelines/kandinsky_v22) improves on the previous model by replacing the image encoder of the image prior model with a larger CLIP-ViT-G model to improve quality. The image prior model was also retrained on images with different resolutions and aspect ratios to generate higher-resolution images and different image sizes.
|
||||
|
||||
[Kandinsky 3](../api/pipelines/kandinsky3) simplifies the architecture and shifts away from the two-stage generation process involving the prior model and diffusion model. Instead, Kandinsky 3 uses [Flan-UL2](https://huggingface.co/google/flan-ul2) to encode text, a UNet with [BigGan-deep](https://hf.co/papers/1809.11096) blocks, and [Sber-MoVQGAN](https://github.com/ai-forever/MoVQGAN) to decode the latents into images. Text understanding and generated image quality are primarily achieved by using a larger text encoder and UNet.
|
||||
|
||||
This guide will show you how to use the Kandinsky models for text-to-image, image-to-image, inpainting, interpolation, and more.
|
||||
|
||||
Before you begin, make sure you have the following libraries installed:
|
||||
@@ -33,6 +35,10 @@ Before you begin, make sure you have the following libraries installed:
|
||||
|
||||
Kandinsky 2.1 and 2.2 usage is very similar! The only difference is Kandinsky 2.2 doesn't accept `prompt` as an input when decoding the latents. Instead, Kandinsky 2.2 only accepts `image_embeds` during decoding.
|
||||
|
||||
<br>
|
||||
|
||||
Kandinsky 3 has a more concise architecture and it doesn't require a prior model. This means it's usage is identical to other diffusion models like [Stable Diffusion XL](sdxl).
|
||||
|
||||
</Tip>
|
||||
|
||||
## Text-to-image
|
||||
@@ -91,6 +97,23 @@ image
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/kandinsky-text-to-image.png"/>
|
||||
</div>
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="Kandinsky 3">
|
||||
|
||||
Kandinsky 3 doesn't require a prior model so you can directly load the [`Kandinsky3Pipeline`] and pass a prompt to generate an image:
|
||||
|
||||
```py
|
||||
from diffusers import Kandinsky3Pipeline
|
||||
import torch
|
||||
|
||||
pipeline = Kandinsky3Pipeline.from_pretrained("kandinsky-community/kandinsky-3", variant="fp16", torch_dtype=torch.float16)
|
||||
pipeline.enable_model_cpu_offload()
|
||||
|
||||
prompt = "A alien cheeseburger creature eating itself, claymation, cinematic, moody lighting"
|
||||
image = pipeline(prompt).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
@@ -161,6 +184,20 @@ prior_pipeline = KandinskyPriorPipeline.from_pretrained("kandinsky-community/kan
|
||||
pipeline = KandinskyV22Img2ImgPipeline.from_pretrained("kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16, use_safetensors=True).to("cuda")
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="Kandinsky 3">
|
||||
|
||||
Kandinsky 3 doesn't require a prior model so you can directly load the image-to-image pipeline:
|
||||
|
||||
```py
|
||||
from diffusers import Kandinsky3Img2ImgPipeline
|
||||
from diffusers.utils import load_image
|
||||
import torch
|
||||
|
||||
pipeline = Kandinsky3Img2ImgPipeline.from_pretrained("kandinsky-community/kandinsky-3", variant="fp16", torch_dtype=torch.float16)
|
||||
pipeline.enable_model_cpu_offload()
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
@@ -218,6 +255,14 @@ make_image_grid([original_image.resize((512, 512)), image.resize((512, 512))], r
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/kandinsky-image-to-image.png"/>
|
||||
</div>
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="Kandinsky 3">
|
||||
|
||||
```py
|
||||
image = pipeline(prompt, negative_prompt=negative_prompt, image=image, strength=0.75, num_inference_steps=25).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
|
||||
@@ -307,3 +307,394 @@ prompt = "a house by william eggleston, sunrays, beautiful, sunlight, sunrays, b
|
||||
image = pipeline(prompt=prompt).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
## IP-Adapter
|
||||
|
||||
[IP-Adapter](https://ip-adapter.github.io/) is an effective and lightweight adapter that adds image prompting capabilities to a diffusion model. This adapter works by decoupling the cross-attention layers of the image and text features. All the other model components are frozen and only the embedded image features in the UNet are trained. As a result, IP-Adapter files are typically only ~100MBs.
|
||||
|
||||
IP-Adapter works with most of our pipelines, including Stable Diffusion, Stable Diffusion XL (SDXL), ControlNet, T2I-Adapter, AnimateDiff. And you can use any custom models finetuned from the same base models. It also works with LCM-Lora out of box.
|
||||
|
||||
|
||||
<Tip>
|
||||
|
||||
You can find official IP-Adapter checkpoints in [h94/IP-Adapter](https://huggingface.co/h94/IP-Adapter).
|
||||
|
||||
IP-Adapter was contributed by [okotaku](https://github.com/okotaku).
|
||||
|
||||
</Tip>
|
||||
|
||||
Let's first create a Stable Diffusion Pipeline.
|
||||
|
||||
```py
|
||||
from diffusers import AutoPipelineForText2Image
|
||||
import torch
|
||||
from diffusers.utils import load_image
|
||||
|
||||
|
||||
pipeline = AutoPipelineForText2Image.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to("cuda")
|
||||
```
|
||||
|
||||
Now load the [h94/IP-Adapter](https://huggingface.co/h94/IP-Adapter) weights with the [`~loaders.IPAdapterMixin.load_ip_adapter`] method.
|
||||
|
||||
```py
|
||||
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin")
|
||||
```
|
||||
|
||||
<Tip>
|
||||
IP-Adapter relies on an image encoder to generate the image features, if your IP-Adapter weights folder contains a "image_encoder" subfolder, the image encoder will be automatically loaded and registered to the pipeline. Otherwise you can so load a [`~transformers.CLIPVisionModelWithProjection`] model and pass it to a Stable Diffusion pipeline when you create it.
|
||||
|
||||
```py
|
||||
from diffusers import AutoPipelineForText2Image, CLIPVisionModelWithProjection
|
||||
import torch
|
||||
|
||||
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
|
||||
"h94/IP-Adapter",
|
||||
subfolder="models/image_encoder",
|
||||
torch_dtype=torch.float16,
|
||||
).to("cuda")
|
||||
|
||||
pipeline = AutoPipelineForText2Image.from_pretrained("runwayml/stable-diffusion-v1-5", image_encoder=image_encoder, torch_dtype=torch.float16).to("cuda")
|
||||
```
|
||||
</Tip>
|
||||
|
||||
IP-Adapter allows you to use both image and text to condition the image generation process. For example, let's use the bear image from the [Textual Inversion](#textual-inversion) section as the image prompt (`ip_adapter_image`) along with a text prompt to add "sunglasses". 😎
|
||||
|
||||
```py
|
||||
pipeline.set_ip_adapter_scale(0.6)
|
||||
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/load_neg_embed.png")
|
||||
generator = torch.Generator(device="cpu").manual_seed(33)
|
||||
images = pipeline(
|
||||
prompt='best quality, high quality, wearing sunglasses',
|
||||
ip_adapter_image=image,
|
||||
negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality",
|
||||
num_inference_steps=50,
|
||||
generator=generator,
|
||||
).images
|
||||
images[0]
|
||||
```
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip-bear.png" />
|
||||
</div>
|
||||
|
||||
<Tip>
|
||||
|
||||
You can use the [`~loaders.IPAdapterMixin.set_ip_adapter_scale`] method to adjust the text prompt and image prompt condition ratio. If you're only using the image prompt, you should set the scale to `1.0`. You can lower the scale to get more generation diversity, but it'll be less aligned with the prompt.
|
||||
`scale=0.5` can achieve good results in most cases when you use both text and image prompts.
|
||||
</Tip>
|
||||
|
||||
IP-Adapter also works great with Image-to-Image and Inpainting pipelines. See below examples of how you can use it with Image-to-Image and Inpaint.
|
||||
|
||||
<hfoptions id="tasks">
|
||||
<hfoption id="image-to-image">
|
||||
|
||||
```py
|
||||
from diffusers import AutoPipelineForImage2Image
|
||||
import torch
|
||||
from diffusers.utils import load_image
|
||||
|
||||
pipeline = AutoPipelineForImage2Image.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to("cuda")
|
||||
|
||||
image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/vermeer.jpg")
|
||||
ip_image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/river.png")
|
||||
|
||||
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin")
|
||||
generator = torch.Generator(device="cpu").manual_seed(33)
|
||||
images = pipeline(
|
||||
prompt='best quality, high quality',
|
||||
image = image,
|
||||
ip_adapter_image=ip_image,
|
||||
num_inference_steps=50,
|
||||
generator=generator,
|
||||
strength=0.6,
|
||||
).images
|
||||
images[0]
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="inpaint">
|
||||
|
||||
```py
|
||||
from diffusers import AutoPipelineForInpaint
|
||||
import torch
|
||||
from diffusers.utils import load_image
|
||||
|
||||
pipeline = AutoPipelineForInpaint.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float).to("cuda")
|
||||
|
||||
image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/inpaint_image.png")
|
||||
mask = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/mask.png")
|
||||
ip_image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/girl.png")
|
||||
|
||||
image = image.resize((512, 768))
|
||||
mask = mask.resize((512, 768))
|
||||
|
||||
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin")
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(33)
|
||||
images = pipeline(
|
||||
prompt='best quality, high quality',
|
||||
image = image,
|
||||
mask_image = mask,
|
||||
ip_adapter_image=ip_image,
|
||||
negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality",
|
||||
num_inference_steps=50,
|
||||
generator=generator,
|
||||
strength=0.5,
|
||||
).images
|
||||
images[0]
|
||||
```
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
|
||||
IP-Adapters can also be used with [SDXL](../api/pipelines/stable_diffusion/stable_diffusion_xl.md)
|
||||
|
||||
```python
|
||||
from diffusers import AutoPipelineForText2Image
|
||||
from diffusers.utils import load_image
|
||||
import torch
|
||||
|
||||
pipeline = AutoPipelineForText2Image.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
|
||||
image = load_image("https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/watercolor_painting.jpeg")
|
||||
|
||||
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin")
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(33)
|
||||
image = pipeline(
|
||||
prompt="best quality, high quality",
|
||||
ip_adapter_image=image,
|
||||
negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality",
|
||||
num_inference_steps=25,
|
||||
generator=generator,
|
||||
).images[0]
|
||||
image.save("sdxl_t2i.png")
|
||||
```
|
||||
|
||||
<div class="flex flex-row gap-4">
|
||||
<div class="flex-1">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/watercolor_painting.jpeg"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">input image</figcaption>
|
||||
</div>
|
||||
<div class="flex-1">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/sdxl_t2i.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">adapted image</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
You can use the IP-Adapter face model to apply specific faces to your images. It is an effective way to maintain consistent characters in your image generations.
|
||||
Weights are loaded with the same method used for the other IP-Adapters.
|
||||
|
||||
```python
|
||||
# Load ip-adapter-full-face_sd15.bin
|
||||
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-full-face_sd15.bin")
|
||||
```
|
||||
|
||||
<Tip>
|
||||
|
||||
It is recommended to use `DDIMScheduler` and `EulerDiscreteScheduler` for face model.
|
||||
|
||||
|
||||
</Tip>
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import StableDiffusionPipeline, DDIMScheduler
|
||||
from diffusers.utils import load_image
|
||||
|
||||
noise_scheduler = DDIMScheduler(
|
||||
num_train_timesteps=1000,
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
clip_sample=False,
|
||||
set_alpha_to_one=False,
|
||||
steps_offset=1
|
||||
)
|
||||
|
||||
pipeline = StableDiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5",
|
||||
torch_dtype=torch.float16,
|
||||
scheduler=noise_scheduler,
|
||||
).to("cuda")
|
||||
|
||||
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-full-face_sd15.bin")
|
||||
|
||||
pipeline.set_ip_adapter_scale(0.7)
|
||||
|
||||
image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ai_face2.png")
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(33)
|
||||
|
||||
image = pipeline(
|
||||
prompt="A photo of a girl wearing a black dress, holding red roses in hand, upper body, behind is the Eiffel Tower",
|
||||
ip_adapter_image=image,
|
||||
negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality",
|
||||
num_inference_steps=50, num_images_per_prompt=1, width=512, height=704,
|
||||
generator=generator,
|
||||
).images[0]
|
||||
```
|
||||
|
||||
<div class="flex flex-row gap-4">
|
||||
<div class="flex-1">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ai_face2.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">input image</figcaption>
|
||||
</div>
|
||||
<div class="flex-1">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ipadapter_full_face_output.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">output image</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
### LCM-Lora
|
||||
|
||||
You can use IP-Adapter with LCM-Lora to achieve "instant fine-tune" with custom images. Note that you need to load IP-Adapter weights before loading the LCM-Lora weights.
|
||||
|
||||
```py
|
||||
from diffusers import DiffusionPipeline, LCMScheduler
|
||||
import torch
|
||||
from diffusers.utils import load_image
|
||||
|
||||
model_id = "sd-dreambooth-library/herge-style"
|
||||
lcm_lora_id = "latent-consistency/lcm-lora-sdv1-5"
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
|
||||
|
||||
pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin")
|
||||
pipe.load_lora_weights(lcm_lora_id)
|
||||
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
prompt = "best quality, high quality"
|
||||
image = load_image("https://user-images.githubusercontent.com/24734142/266492875-2d50d223-8475-44f0-a7c6-08b51cb53572.png")
|
||||
images = pipe(
|
||||
prompt=prompt,
|
||||
ip_adapter_image=image,
|
||||
num_inference_steps=4,
|
||||
guidance_scale=1,
|
||||
).images[0]
|
||||
```
|
||||
|
||||
### Other pipelines
|
||||
|
||||
IP-Adapter is compatible with any pipeline that (1) uses a text prompt and (2) uses Stable Diffusion or Stable Diffusion XL checkpoint. To use IP-Adapter with a different pipeline, all you need to do is to run `load_ip_adapter()` method after you create the pipeline, and then pass your image to the pipeline as `ip_adapter_image`
|
||||
|
||||
<Tip>
|
||||
|
||||
🤗 Diffusers currently only supports using IP-Adapter with some of the most popular pipelines, feel free to open a [feature request](https://github.com/huggingface/diffusers/issues/new/choose) if you have a cool use-case and require integrating IP-adapters with a pipeline that does not support it yet!
|
||||
|
||||
</Tip>
|
||||
|
||||
You can find below examples on how to use IP-Adapter with ControlNet and AnimateDiff.
|
||||
|
||||
<hfoptions id="model">
|
||||
<hfoption id="ControlNet">
|
||||
|
||||
```
|
||||
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
|
||||
import torch
|
||||
from diffusers.utils import load_image
|
||||
|
||||
controlnet_model_path = "lllyasviel/control_v11f1p_sd15_depth"
|
||||
controlnet = ControlNetModel.from_pretrained(controlnet_model_path, torch_dtype=torch.float16)
|
||||
|
||||
pipeline = StableDiffusionControlNetPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16)
|
||||
pipeline.to("cuda")
|
||||
|
||||
image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/statue.png")
|
||||
depth_map = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/depth.png")
|
||||
|
||||
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin")
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(33)
|
||||
images = pipeline(
|
||||
prompt='best quality, high quality',
|
||||
image=depth_map,
|
||||
ip_adapter_image=image,
|
||||
negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality",
|
||||
num_inference_steps=50,
|
||||
generator=generator,
|
||||
).images
|
||||
images[0]
|
||||
```
|
||||
<div class="flex flex-row gap-4">
|
||||
<div class="flex-1">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/statue.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">input image</figcaption>
|
||||
</div>
|
||||
<div class="flex-1">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ipa-controlnet-out.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">adapted image</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="AnimateDiff">
|
||||
|
||||
```py
|
||||
# animate diff + ip adapter
|
||||
import torch
|
||||
from diffusers import MotionAdapter, AnimateDiffPipeline, DDIMScheduler
|
||||
from diffusers.utils import export_to_gif, load_image
|
||||
|
||||
# Load the motion adapter
|
||||
adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2", torch_dtype=torch.float16)
|
||||
# load SD 1.5 based finetuned model
|
||||
model_id = "Lykon/DreamShaper"
|
||||
pipe = AnimateDiffPipeline.from_pretrained(model_id, motion_adapter=adapter, torch_dtype=torch.float16)
|
||||
|
||||
# scheduler
|
||||
scheduler = DDIMScheduler(
|
||||
clip_sample=False,
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="linear",
|
||||
timestep_spacing="trailing",
|
||||
steps_offset=1
|
||||
)
|
||||
pipe.scheduler = scheduler
|
||||
|
||||
# enable memory savings
|
||||
pipe.enable_vae_slicing()
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
# load ip_adapter
|
||||
pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin")
|
||||
|
||||
# load motion adapters
|
||||
pipe.load_lora_weights("guoyww/animatediff-motion-lora-zoom-out", adapter_name="zoom-out")
|
||||
pipe.load_lora_weights("guoyww/animatediff-motion-lora-tilt-up", adapter_name="tilt-up")
|
||||
pipe.load_lora_weights("guoyww/animatediff-motion-lora-pan-left", adapter_name="pan-left")
|
||||
|
||||
seed = 42
|
||||
image = load_image("https://user-images.githubusercontent.com/24734142/266492875-2d50d223-8475-44f0-a7c6-08b51cb53572.png")
|
||||
images = [image] * 3
|
||||
prompts = ["best quality, high quality"] * 3
|
||||
negative_prompt = "bad quality, worst quality"
|
||||
adapter_weights = [[0.75, 0.0, 0.0], [0.0, 0.0, 0.75], [0.0, 0.75, 0.75]]
|
||||
|
||||
# generate
|
||||
output_frames = []
|
||||
for prompt, image, adapter_weight in zip(prompts, images, adapter_weights):
|
||||
pipe.set_adapters(["zoom-out", "tilt-up", "pan-left"], adapter_weights=adapter_weight)
|
||||
output = pipe(
|
||||
prompt= prompt,
|
||||
num_frames=16,
|
||||
guidance_scale=7.5,
|
||||
num_inference_steps=30,
|
||||
ip_adapter_image = image,
|
||||
generator=torch.Generator("cpu").manual_seed(seed),
|
||||
)
|
||||
frames = output.frames[0]
|
||||
output_frames.extend(frames)
|
||||
|
||||
export_to_gif(output_frames, "test_out_animation.gif")
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
|
||||
@@ -174,10 +174,4 @@ Set `private=True` in the [`~diffusers.utils.PushToHubMixin.push_to_hub`] functi
|
||||
controlnet.push_to_hub("my-controlnet-model-private", private=True)
|
||||
```
|
||||
|
||||
Private repositories are only visible to you, and other users won't be able to clone the repository and your repository won't appear in search results. Even if a user has the URL to your private repository, they'll receive a `404 - Sorry, we can't find the page you are looking for.`
|
||||
|
||||
To load a model, scheduler, or pipeline from private or gated repositories, set `use_auth_token=True`:
|
||||
|
||||
```py
|
||||
model = ControlNetModel.from_pretrained("your-namespace/my-controlnet-model-private", use_auth_token=True)
|
||||
```
|
||||
Private repositories are only visible to you, and other users won't be able to clone the repository and your repository won't appear in search results. Even if a user has the URL to your private repository, they'll receive a `404 - Sorry, we can't find the page you are looking for`. You must be [logged in](https://huggingface.co/docs/huggingface_hub/quick-start#login) to load a model from a private repository.
|
||||
@@ -0,0 +1,116 @@
|
||||
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Stable Diffusion XL Turbo
|
||||
|
||||
[[open-in-colab]]
|
||||
|
||||
SDXL Turbo is an adversarial time-distilled [Stable Diffusion XL](https://huggingface.co/papers/2307.01952) (SDXL) model capable
|
||||
of running inference in as little as 1 step.
|
||||
|
||||
This guide will show you how to use SDXL-Turbo for text-to-image and image-to-image.
|
||||
|
||||
Before you begin, make sure you have the following libraries installed:
|
||||
|
||||
```py
|
||||
# uncomment to install the necessary libraries in Colab
|
||||
#!pip install -q diffusers transformers accelerate omegaconf
|
||||
```
|
||||
|
||||
## Load model checkpoints
|
||||
|
||||
Model weights may be stored in separate subfolders on the Hub or locally, in which case, you should use the [`~StableDiffusionXLPipeline.from_pretrained`] method:
|
||||
|
||||
```py
|
||||
from diffusers import AutoPipelineForText2Image, AutoPipelineForImage2Image
|
||||
import torch
|
||||
|
||||
pipeline = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16")
|
||||
pipeline = pipeline.to("cuda")
|
||||
```
|
||||
|
||||
You can also use the [`~StableDiffusionXLPipeline.from_single_file`] method to load a model checkpoint stored in a single file format (`.ckpt` or `.safetensors`) from the Hub or locally:
|
||||
|
||||
```py
|
||||
from diffusers import StableDiffusionXLPipeline
|
||||
import torch
|
||||
|
||||
pipeline = StableDiffusionXLPipeline.from_single_file(
|
||||
"https://huggingface.co/stabilityai/sdxl-turbo/blob/main/sd_xl_turbo_1.0_fp16.safetensors", torch_dtype=torch.float16)
|
||||
pipeline = pipeline.to("cuda")
|
||||
```
|
||||
|
||||
## Text-to-image
|
||||
|
||||
For text-to-image, pass a text prompt. By default, SDXL Turbo generates a 512x512 image, and that resolution gives the best results. You can try setting the `height` and `width` parameters to 768x768 or 1024x1024, but you should expect quality degradations when doing so.
|
||||
|
||||
Make sure to set `guidance_scale` to 0.0 to disable, as the model was trained without it. A single inference step is enough to generate high quality images.
|
||||
Increasing the number of steps to 2, 3 or 4 should improve image quality.
|
||||
|
||||
```py
|
||||
from diffusers import AutoPipelineForText2Image
|
||||
import torch
|
||||
|
||||
pipeline_text2image = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16")
|
||||
pipeline_text2image = pipeline_text2image.to("cuda")
|
||||
|
||||
prompt = "A cinematic shot of a baby racoon wearing an intricate italian priest robe."
|
||||
|
||||
image = pipeline_text2image(prompt=prompt, guidance_scale=0.0, num_inference_steps=1).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/sdxl-turbo-text2img.png" alt="generated image of a racoon in a robe"/>
|
||||
</div>
|
||||
|
||||
## Image-to-image
|
||||
|
||||
For image-to-image generation, make sure that `num_inference_steps * strength` is larger or equal to 1.
|
||||
The image-to-image pipeline will run for `int(num_inference_steps * strength)` steps, e.g. `0.5 * 2.0 = 1` step in
|
||||
our example below.
|
||||
|
||||
```py
|
||||
from diffusers import AutoPipelineForImage2Image
|
||||
from diffusers.utils import load_image, make_image_grid
|
||||
|
||||
# use from_pipe to avoid consuming additional memory when loading a checkpoint
|
||||
pipeline = AutoPipelineForImage2Image.from_pipe(pipeline_text2image).to("cuda")
|
||||
|
||||
init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png")
|
||||
init_image = init_image.resize((512, 512))
|
||||
|
||||
prompt = "cat wizard, gandalf, lord of the rings, detailed, fantasy, cute, adorable, Pixar, Disney, 8k"
|
||||
|
||||
image = pipeline(prompt, image=init_image, strength=0.5, guidance_scale=0.0, num_inference_steps=2).images[0]
|
||||
make_image_grid([init_image, image], rows=1, cols=2)
|
||||
```
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/sdxl-turbo-img2img.png" alt="Image-to-image generation sample using SDXL Turbo"/>
|
||||
</div>
|
||||
|
||||
## Speed-up SDXL Turbo even more
|
||||
|
||||
- Compile the UNet if you are using PyTorch version 2 or better. The first inference run will be very slow, but subsequent ones will be much faster.
|
||||
|
||||
```py
|
||||
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
|
||||
```
|
||||
|
||||
- When using the default VAE, keep it in `float32` to avoid costly `dtype` conversions before and after each generation. You only need to do this one before your first generation:
|
||||
|
||||
```py
|
||||
pipe.upcast_vae()
|
||||
```
|
||||
|
||||
As an alternative, you can also use a [16-bit VAE](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix) created by community member [`@madebyollin`](https://huggingface.co/madebyollin) that does not need to be upcasted to `float32`.
|
||||
@@ -0,0 +1,134 @@
|
||||
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Stable Video Diffusion
|
||||
|
||||
[[open-in-colab]]
|
||||
|
||||
[Stable Video Diffusion](https://static1.squarespace.com/static/6213c340453c3f502425776e/t/655ce779b9d47d342a93c890/1700587395994/stable_video_diffusion.pdf) is a powerful image-to-video generation model that can generate high resolution (576x1024) 2-4 second videos conditioned on the input image.
|
||||
|
||||
This guide will show you how to use SVD to short generate videos from images.
|
||||
|
||||
Before you begin, make sure you have the following libraries installed:
|
||||
|
||||
```py
|
||||
!pip install -q -U diffusers transformers accelerate
|
||||
```
|
||||
|
||||
## Image to Video Generation
|
||||
|
||||
The are two variants of SVD. [SVD](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid)
|
||||
and [SVD-XT](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt). The svd checkpoint is trained to generate 14 frames and the svd-xt checkpoint is further
|
||||
finetuned to generate 25 frames.
|
||||
|
||||
We will use the `svd-xt` checkpoint for this guide.
|
||||
|
||||
```python
|
||||
import torch
|
||||
|
||||
from diffusers import StableVideoDiffusionPipeline
|
||||
from diffusers.utils import load_image, export_to_video
|
||||
|
||||
pipe = StableVideoDiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-video-diffusion-img2vid-xt", torch_dtype=torch.float16, variant="fp16"
|
||||
)
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
# Load the conditioning image
|
||||
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd/rocket.png?download=true")
|
||||
image = image.resize((1024, 576))
|
||||
|
||||
generator = torch.manual_seed(42)
|
||||
frames = pipe(image, decode_chunk_size=8, generator=generator).frames[0]
|
||||
|
||||
export_to_video(frames, "generated.mp4", fps=7)
|
||||
```
|
||||
|
||||
<video controls width="1024" height="576">
|
||||
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd/rocket_generated.webm" type="video/webm" />
|
||||
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd/rocket_generated.mp4" type="video/mp4" />
|
||||
</video>
|
||||
|
||||
<Tip>
|
||||
Since generating videos is more memory intensive we can use the `decode_chunk_size` argument to control how many frames are decoded at once. This will reduce the memory usage. It's recommended to tweak this value based on your GPU memory.
|
||||
Setting `decode_chunk_size=1` will decode one frame at a time and will use the least amount of memory but the video might have some flickering.
|
||||
|
||||
Additionally, we also use [model cpu offloading](../../optimization/memory#model-offloading) to reduce the memory usage.
|
||||
</Tip>
|
||||
|
||||
|
||||
### Torch.compile
|
||||
|
||||
You can achieve a 20-25% speed-up at the expense of slightly increased memory by compiling the UNet as follows:
|
||||
|
||||
```diff
|
||||
- pipe.enable_model_cpu_offload()
|
||||
+ pipe.to("cuda")
|
||||
+ pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
|
||||
```
|
||||
|
||||
### Low-memory
|
||||
|
||||
Video generation is very memory intensive as we have to essentially generate `num_frames` all at once. The mechanism is very comparable to text-to-image generation with a high batch size. To reduce the memory requirement you have multiple options. The following options trade inference speed against lower memory requirement:
|
||||
- enable model offloading: Each component of the pipeline is offloaded to CPU once it's not needed anymore.
|
||||
- enable feed-forward chunking: The feed-forward layer runs in a loop instead of running with a single huge feed-forward batch size
|
||||
- reduce `decode_chunk_size`: This means that the VAE decodes frames in chunks instead of decoding them all together. **Note**: In addition to leading to a small slowdown, this method also slightly leads to video quality deterioration
|
||||
|
||||
You can enable them as follows:
|
||||
|
||||
```diff
|
||||
-pipe.enable_model_cpu_offload()
|
||||
-frames = pipe(image, decode_chunk_size=8, generator=generator).frames[0]
|
||||
+pipe.enable_model_cpu_offload()
|
||||
+pipe.unet.enable_forward_chunking()
|
||||
+frames = pipe(image, decode_chunk_size=2, generator=generator, num_frames=25).frames[0]
|
||||
```
|
||||
|
||||
|
||||
Including all these tricks should lower the memory requirement to less than 8GB VRAM.
|
||||
|
||||
### Micro-conditioning
|
||||
|
||||
Along with conditioning image Stable Diffusion Video also allows providing micro-conditioning that allows more control over the generated video.
|
||||
It accepts the following arguments:
|
||||
|
||||
- `fps`: The frames per second of the generated video.
|
||||
- `motion_bucket_id`: The motion bucket id to use for the generated video. This can be used to control the motion of the generated video. Increasing the motion bucket id will increase the motion of the generated video.
|
||||
- `noise_aug_strength`: The amount of noise added to the conditioning image. The higher the values the less the video will resemble the conditioning image. Increasing this value will also increase the motion of the generated video.
|
||||
|
||||
Here is an example of using micro-conditioning to generate a video with more motion.
|
||||
|
||||
|
||||
```python
|
||||
import torch
|
||||
|
||||
from diffusers import StableVideoDiffusionPipeline
|
||||
from diffusers.utils import load_image, export_to_video
|
||||
|
||||
pipe = StableVideoDiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-video-diffusion-img2vid-xt", torch_dtype=torch.float16, variant="fp16"
|
||||
)
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
# Load the conditioning image
|
||||
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd/rocket.png?download=true")
|
||||
image = image.resize((1024, 576))
|
||||
|
||||
generator = torch.manual_seed(42)
|
||||
frames = pipe(image, decode_chunk_size=8, generator=generator, motion_bucket_id=180, noise_aug_strength=0.1).frames[0]
|
||||
export_to_video(frames, "generated.mp4", fps=7)
|
||||
```
|
||||
|
||||
<video width="1024" height="576" controls>
|
||||
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd/rocket_generated_motion.mp4" type="video/mp4">
|
||||
</video>
|
||||
|
||||
@@ -14,54 +14,41 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
[[open-in-colab]]
|
||||
|
||||
Unconditional image generation is a relatively straightforward task. The model only generates images - without any additional context like text or an image - resembling the training data it was trained on.
|
||||
Unconditional image generation generates images that look like a random sample from the training data the model was trained on because the denoising process is not guided by any additional context like text or image.
|
||||
|
||||
The [`DiffusionPipeline`] is the easiest way to use a pre-trained diffusion system for inference.
|
||||
To get started, use the [`DiffusionPipeline`] to load the [anton-l/ddpm-butterflies-128](https://huggingface.co/anton-l/ddpm-butterflies-128) checkpoint to generate images of butterflies. The [`DiffusionPipeline`] downloads and caches all the model components required to generate an image.
|
||||
|
||||
Start by creating an instance of [`DiffusionPipeline`] and specify which pipeline checkpoint you would like to download.
|
||||
You can use any of the 🧨 Diffusers [checkpoints](https://huggingface.co/models?library=diffusers&sort=downloads) from the Hub (the checkpoint you'll use generates images of butterflies).
|
||||
|
||||
<Tip>
|
||||
|
||||
💡 Want to train your own unconditional image generation model? Take a look at the training [guide](../training/unconditional_training) to learn how to generate your own images.
|
||||
|
||||
</Tip>
|
||||
|
||||
In this guide, you'll use [`DiffusionPipeline`] for unconditional image generation with [DDPM](https://arxiv.org/abs/2006.11239):
|
||||
|
||||
```python
|
||||
```py
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
generator = DiffusionPipeline.from_pretrained("anton-l/ddpm-butterflies-128", use_safetensors=True)
|
||||
```
|
||||
|
||||
The [`DiffusionPipeline`] downloads and caches all modeling, tokenization, and scheduling components.
|
||||
Because the model consists of roughly 1.4 billion parameters, we strongly recommend running it on a GPU.
|
||||
You can move the generator object to a GPU, just like you would in PyTorch:
|
||||
|
||||
```python
|
||||
generator.to("cuda")
|
||||
```
|
||||
|
||||
Now you can use the `generator` to generate an image:
|
||||
|
||||
```python
|
||||
generator = DiffusionPipeline.from_pretrained("anton-l/ddpm-butterflies-128").to("cuda")
|
||||
image = generator().images[0]
|
||||
image
|
||||
```
|
||||
|
||||
The output is by default wrapped into a [`PIL.Image`](https://pillow.readthedocs.io/en/stable/reference/Image.html?highlight=image#the-image-class) object.
|
||||
<Tip>
|
||||
|
||||
You can save the image by calling:
|
||||
Want to generate images of something else? Take a look at the training [guide](../training/unconditional_training) to learn how to train a model to generate your own images.
|
||||
|
||||
```python
|
||||
</Tip>
|
||||
|
||||
The output image is a [`PIL.Image`](https://pillow.readthedocs.io/en/stable/reference/Image.html?highlight=image#the-image-class) object that can be saved:
|
||||
|
||||
```py
|
||||
image.save("generated_image.png")
|
||||
```
|
||||
|
||||
Try out the Spaces below, and feel free to play around with the inference steps parameter to see how it affects the image quality!
|
||||
You can also try experimenting with the `num_inference_steps` parameter, which controls the number of denoising steps. More denoising steps typically produce higher quality images, but it'll take longer to generate. Feel free to play around with this parameter to see how it affects the image quality.
|
||||
|
||||
```py
|
||||
image = generator(num_inference_steps=100).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
Try out the Space below to generate an image of a butterfly!
|
||||
|
||||
<iframe
|
||||
src="https://stevhliu-ddpm-butterflies-128.hf.space"
|
||||
src="https://stevhliu-unconditional-image-generation.hf.space"
|
||||
frameborder="0"
|
||||
width="850"
|
||||
height="500"
|
||||
|
||||
@@ -96,3 +96,4 @@ specific language governing permissions and limitations under the License.
|
||||
| [versatile_diffusion](./api/pipelines/versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Dual Image and Text Guided Generation |
|
||||
| [vq_diffusion](./api/pipelines/vq_diffusion) | [Vector Quantized Diffusion Model for Text-to-Image Synthesis](https://arxiv.org/abs/2111.14822) | Text-to-Image Generation |
|
||||
| [stable_diffusion_ldm3d](./api/pipelines/stable_diffusion/ldm3d_diffusion) | [LDM3D: Latent Diffusion Model for 3D](https://arxiv.org/abs/2305.10853) | Text to Image and Depth Generation |
|
||||
| [stable_diffusion_upscaler_ldm3d](./api/pipelines/stable_diffusion/ldm3d_diffusion) | [LDM3D-VR: Latent Diffusion Model for 3D VR](https://arxiv.org/pdf/2311.03226) | Image and Depth Upscaling |
|
||||
|
||||
+4
-6
@@ -18,8 +18,7 @@ limitations under the License.
|
||||
Diffusers examples are a collection of scripts to demonstrate how to effectively use the `diffusers` library
|
||||
for a variety of use cases involving training or fine-tuning.
|
||||
|
||||
**Note**: If you are looking for **official** examples on how to use `diffusers` for inference,
|
||||
please have a look at [src/diffusers/pipelines](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines).
|
||||
**Note**: If you are looking for **official** examples on how to use `diffusers` for inference, please have a look at [src/diffusers/pipelines](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines).
|
||||
|
||||
Our examples aspire to be **self-contained**, **easy-to-tweak**, **beginner-friendly** and for **one-purpose-only**.
|
||||
More specifically, this means:
|
||||
@@ -27,11 +26,10 @@ More specifically, this means:
|
||||
- **Self-contained**: An example script shall only depend on "pip-install-able" Python packages that can be found in a `requirements.txt` file. Example scripts shall **not** depend on any local files. This means that one can simply download an example script, *e.g.* [train_unconditional.py](https://github.com/huggingface/diffusers/blob/main/examples/unconditional_image_generation/train_unconditional.py), install the required dependencies, *e.g.* [requirements.txt](https://github.com/huggingface/diffusers/blob/main/examples/unconditional_image_generation/requirements.txt) and execute the example script.
|
||||
- **Easy-to-tweak**: While we strive to present as many use cases as possible, the example scripts are just that - examples. It is expected that they won't work out-of-the box on your specific problem and that you will be required to change a few lines of code to adapt them to your needs. To help you with that, most of the examples fully expose the preprocessing of the data and the training loop to allow you to tweak and edit them as required.
|
||||
- **Beginner-friendly**: We do not aim for providing state-of-the-art training scripts for the newest models, but rather examples that can be used as a way to better understand diffusion models and how to use them with the `diffusers` library. We often purposefully leave out certain state-of-the-art methods if we consider them too complex for beginners.
|
||||
- **One-purpose-only**: Examples should show one task and one task only. Even if a task is from a modeling
|
||||
point of view very similar, *e.g.* image super-resolution and image modification tend to use the same model and training method, we want examples to showcase only one task to keep them as readable and easy-to-understand as possible.
|
||||
- **One-purpose-only**: Examples should show one task and one task only. Even if a task is from a modeling point of view very similar, *e.g.* image super-resolution and image modification tend to use the same model and training method, we want examples to showcase only one task to keep them as readable and easy-to-understand as possible.
|
||||
|
||||
We provide **official** examples that cover the most popular tasks of diffusion models.
|
||||
*Official* examples are **actively** maintained by the `diffusers` maintainers and we try to rigorously follow our example philosophy as defined above.
|
||||
*Official* examples are **actively** maintained by the `diffusers` maintainers and we try to rigorously follow our example philosophy as defined above.
|
||||
If you feel like another important example should exist, we are more than happy to welcome a [Feature Request](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=&template=feature_request.md&title=) or directly a [Pull Request](https://github.com/huggingface/diffusers/compare) from you!
|
||||
|
||||
Training examples show how to pretrain or fine-tune diffusion models for a variety of tasks. Currently we support:
|
||||
@@ -39,7 +37,7 @@ Training examples show how to pretrain or fine-tune diffusion models for a varie
|
||||
| Task | 🤗 Accelerate | 🤗 Datasets | Colab
|
||||
|---|---|:---:|:---:|
|
||||
| [**Unconditional Image Generation**](./unconditional_image_generation) | ✅ | ✅ | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb)
|
||||
| [**Text-to-Image fine-tuning**](./text_to_image) | ✅ | ✅ |
|
||||
| [**Text-to-Image fine-tuning**](./text_to_image) | ✅ | ✅ |
|
||||
| [**Textual Inversion**](./textual_inversion) | ✅ | - | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_textual_inversion_training.ipynb)
|
||||
| [**Dreambooth**](./dreambooth) | ✅ | - | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_dreambooth_training.ipynb)
|
||||
| [**ControlNet**](./controlnet) | ✅ | ✅ | -
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
+753
-70
File diff suppressed because it is too large
Load Diff
@@ -5,10 +5,11 @@ from typing import Dict, List, Union
|
||||
import safetensors.torch
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
|
||||
from diffusers import DiffusionPipeline, __version__
|
||||
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
|
||||
from diffusers.utils import CONFIG_NAME, DIFFUSERS_CACHE, ONNX_WEIGHTS_NAME, WEIGHTS_NAME
|
||||
from diffusers.utils import CONFIG_NAME, ONNX_WEIGHTS_NAME, WEIGHTS_NAME
|
||||
|
||||
|
||||
class CheckpointMergerPipeline(DiffusionPipeline):
|
||||
@@ -57,6 +58,7 @@ class CheckpointMergerPipeline(DiffusionPipeline):
|
||||
return (temp_dict, meta_keys)
|
||||
|
||||
@torch.no_grad()
|
||||
@validate_hf_hub_args
|
||||
def merge(self, pretrained_model_name_or_path_list: List[Union[str, os.PathLike]], **kwargs):
|
||||
"""
|
||||
Returns a new pipeline object of the class 'DiffusionPipeline' with the merged checkpoints(weights) of the models passed
|
||||
@@ -69,7 +71,7 @@ class CheckpointMergerPipeline(DiffusionPipeline):
|
||||
**kwargs:
|
||||
Supports all the default DiffusionPipeline.get_config_dict kwargs viz..
|
||||
|
||||
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map.
|
||||
cache_dir, resume_download, force_download, proxies, local_files_only, token, revision, torch_dtype, device_map.
|
||||
|
||||
alpha - The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
|
||||
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
|
||||
@@ -81,12 +83,12 @@ class CheckpointMergerPipeline(DiffusionPipeline):
|
||||
|
||||
"""
|
||||
# Default kwargs from DiffusionPipeline
|
||||
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||
token = kwargs.pop("token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
device_map = kwargs.pop("device_map", None)
|
||||
@@ -123,7 +125,7 @@ class CheckpointMergerPipeline(DiffusionPipeline):
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
token=token,
|
||||
revision=revision,
|
||||
)
|
||||
config_dicts.append(config_dict)
|
||||
@@ -159,7 +161,7 @@ class CheckpointMergerPipeline(DiffusionPipeline):
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
token=token,
|
||||
revision=revision,
|
||||
allow_patterns=allow_patterns,
|
||||
user_agent=user_agent,
|
||||
|
||||
Executable
+466
@@ -0,0 +1,466 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from math import pi
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from diffusers import DDPMScheduler, DiffusionPipeline, ImagePipelineOutput, UNet2DModel
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
|
||||
class DPSPipeline(DiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for Diffusion Posterior Sampling.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
||||
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
||||
|
||||
Parameters:
|
||||
unet ([`UNet2DModel`]):
|
||||
A `UNet2DModel` to denoise the encoded image latents.
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of
|
||||
[`DDPMScheduler`], or [`DDIMScheduler`].
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "unet"
|
||||
|
||||
def __init__(self, unet, scheduler):
|
||||
super().__init__()
|
||||
self.register_modules(unet=unet, scheduler=scheduler)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
measurement: torch.Tensor,
|
||||
operator: torch.nn.Module,
|
||||
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
|
||||
batch_size: int = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
num_inference_steps: int = 1000,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
zeta: float = 0.3,
|
||||
) -> Union[ImagePipelineOutput, Tuple]:
|
||||
r"""
|
||||
The call function to the pipeline for generation.
|
||||
|
||||
Args:
|
||||
measurement (`torch.Tensor`, *required*):
|
||||
A 'torch.Tensor', the corrupted image
|
||||
operator (`torch.nn.Module`, *required*):
|
||||
A 'torch.nn.Module', the operator generating the corrupted image
|
||||
loss_fn (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *required*):
|
||||
A 'Callable[[torch.Tensor, torch.Tensor], torch.Tensor]', the loss function used
|
||||
between the measurements, for most of the cases using RMSE is fine.
|
||||
batch_size (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
||||
generation deterministic.
|
||||
num_inference_steps (`int`, *optional*, defaults to 1000):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
>>> from diffusers import DDPMPipeline
|
||||
|
||||
>>> # load model and scheduler
|
||||
>>> pipe = DDPMPipeline.from_pretrained("google/ddpm-cat-256")
|
||||
|
||||
>>> # run pipeline in inference (sample random noise and denoise)
|
||||
>>> image = pipe().images[0]
|
||||
|
||||
>>> # save image
|
||||
>>> image.save("ddpm_generated_image.png")
|
||||
```
|
||||
|
||||
Returns:
|
||||
[`~pipelines.ImagePipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
|
||||
returned where the first element is a list with the generated images
|
||||
"""
|
||||
# Sample gaussian noise to begin loop
|
||||
if isinstance(self.unet.config.sample_size, int):
|
||||
image_shape = (
|
||||
batch_size,
|
||||
self.unet.config.in_channels,
|
||||
self.unet.config.sample_size,
|
||||
self.unet.config.sample_size,
|
||||
)
|
||||
else:
|
||||
image_shape = (batch_size, self.unet.config.in_channels, *self.unet.config.sample_size)
|
||||
|
||||
if self.device.type == "mps":
|
||||
# randn does not work reproducibly on mps
|
||||
image = randn_tensor(image_shape, generator=generator)
|
||||
image = image.to(self.device)
|
||||
else:
|
||||
image = randn_tensor(image_shape, generator=generator, device=self.device)
|
||||
|
||||
# set step values
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
for t in self.progress_bar(self.scheduler.timesteps):
|
||||
with torch.enable_grad():
|
||||
# 1. predict noise model_output
|
||||
image = image.requires_grad_()
|
||||
model_output = self.unet(image, t).sample
|
||||
|
||||
# 2. compute previous image x'_{t-1} and original prediction x0_{t}
|
||||
scheduler_out = self.scheduler.step(model_output, t, image, generator=generator)
|
||||
image_pred, origi_pred = scheduler_out.prev_sample, scheduler_out.pred_original_sample
|
||||
|
||||
# 3. compute y'_t = f(x0_{t})
|
||||
measurement_pred = operator(origi_pred)
|
||||
|
||||
# 4. compute loss = d(y, y'_t-1)
|
||||
loss = loss_fn(measurement, measurement_pred)
|
||||
loss.backward()
|
||||
|
||||
print("distance: {0:.4f}".format(loss.item()))
|
||||
|
||||
with torch.no_grad():
|
||||
image_pred = image_pred - zeta * image.grad
|
||||
image = image_pred.detach()
|
||||
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return ImagePipelineOutput(images=image)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import scipy
|
||||
from torch import nn
|
||||
from torchvision.utils import save_image
|
||||
|
||||
# defining the operators f(.) of y = f(x)
|
||||
# super-resolution operator
|
||||
class SuperResolutionOperator(nn.Module):
|
||||
def __init__(self, in_shape, scale_factor):
|
||||
super().__init__()
|
||||
|
||||
# Resizer local class, do not use outiside the SR operator class
|
||||
class Resizer(nn.Module):
|
||||
def __init__(self, in_shape, scale_factor=None, output_shape=None, kernel=None, antialiasing=True):
|
||||
super(Resizer, self).__init__()
|
||||
|
||||
# First standardize values and fill missing arguments (if needed) by deriving scale from output shape or vice versa
|
||||
scale_factor, output_shape = self.fix_scale_and_size(in_shape, output_shape, scale_factor)
|
||||
|
||||
# Choose interpolation method, each method has the matching kernel size
|
||||
def cubic(x):
|
||||
absx = np.abs(x)
|
||||
absx2 = absx**2
|
||||
absx3 = absx**3
|
||||
return (1.5 * absx3 - 2.5 * absx2 + 1) * (absx <= 1) + (
|
||||
-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2
|
||||
) * ((1 < absx) & (absx <= 2))
|
||||
|
||||
def lanczos2(x):
|
||||
return (
|
||||
(np.sin(pi * x) * np.sin(pi * x / 2) + np.finfo(np.float32).eps)
|
||||
/ ((pi**2 * x**2 / 2) + np.finfo(np.float32).eps)
|
||||
) * (abs(x) < 2)
|
||||
|
||||
def box(x):
|
||||
return ((-0.5 <= x) & (x < 0.5)) * 1.0
|
||||
|
||||
def lanczos3(x):
|
||||
return (
|
||||
(np.sin(pi * x) * np.sin(pi * x / 3) + np.finfo(np.float32).eps)
|
||||
/ ((pi**2 * x**2 / 3) + np.finfo(np.float32).eps)
|
||||
) * (abs(x) < 3)
|
||||
|
||||
def linear(x):
|
||||
return (x + 1) * ((-1 <= x) & (x < 0)) + (1 - x) * ((0 <= x) & (x <= 1))
|
||||
|
||||
method, kernel_width = {
|
||||
"cubic": (cubic, 4.0),
|
||||
"lanczos2": (lanczos2, 4.0),
|
||||
"lanczos3": (lanczos3, 6.0),
|
||||
"box": (box, 1.0),
|
||||
"linear": (linear, 2.0),
|
||||
None: (cubic, 4.0), # set default interpolation method as cubic
|
||||
}.get(kernel)
|
||||
|
||||
# Antialiasing is only used when downscaling
|
||||
antialiasing *= np.any(np.array(scale_factor) < 1)
|
||||
|
||||
# Sort indices of dimensions according to scale of each dimension. since we are going dim by dim this is efficient
|
||||
sorted_dims = np.argsort(np.array(scale_factor))
|
||||
self.sorted_dims = [int(dim) for dim in sorted_dims if scale_factor[dim] != 1]
|
||||
|
||||
# Iterate over dimensions to calculate local weights for resizing and resize each time in one direction
|
||||
field_of_view_list = []
|
||||
weights_list = []
|
||||
for dim in self.sorted_dims:
|
||||
# for each coordinate (along 1 dim), calculate which coordinates in the input image affect its result and the
|
||||
# weights that multiply the values there to get its result.
|
||||
weights, field_of_view = self.contributions(
|
||||
in_shape[dim], output_shape[dim], scale_factor[dim], method, kernel_width, antialiasing
|
||||
)
|
||||
|
||||
# convert to torch tensor
|
||||
weights = torch.tensor(weights.T, dtype=torch.float32)
|
||||
|
||||
# We add singleton dimensions to the weight matrix so we can multiply it with the big tensor we get for
|
||||
# tmp_im[field_of_view.T], (bsxfun style)
|
||||
weights_list.append(
|
||||
nn.Parameter(
|
||||
torch.reshape(weights, list(weights.shape) + (len(scale_factor) - 1) * [1]),
|
||||
requires_grad=False,
|
||||
)
|
||||
)
|
||||
field_of_view_list.append(
|
||||
nn.Parameter(
|
||||
torch.tensor(field_of_view.T.astype(np.int32), dtype=torch.long), requires_grad=False
|
||||
)
|
||||
)
|
||||
|
||||
self.field_of_view = nn.ParameterList(field_of_view_list)
|
||||
self.weights = nn.ParameterList(weights_list)
|
||||
|
||||
def forward(self, in_tensor):
|
||||
x = in_tensor
|
||||
|
||||
# Use the affecting position values and the set of weights to calculate the result of resizing along this 1 dim
|
||||
for dim, fov, w in zip(self.sorted_dims, self.field_of_view, self.weights):
|
||||
# To be able to act on each dim, we swap so that dim 0 is the wanted dim to resize
|
||||
x = torch.transpose(x, dim, 0)
|
||||
|
||||
# This is a bit of a complicated multiplication: x[field_of_view.T] is a tensor of order image_dims+1.
|
||||
# for each pixel in the output-image it matches the positions the influence it from the input image (along 1 dim
|
||||
# only, this is why it only adds 1 dim to 5the shape). We then multiply, for each pixel, its set of positions with
|
||||
# the matching set of weights. we do this by this big tensor element-wise multiplication (MATLAB bsxfun style:
|
||||
# matching dims are multiplied element-wise while singletons mean that the matching dim is all multiplied by the
|
||||
# same number
|
||||
x = torch.sum(x[fov] * w, dim=0)
|
||||
|
||||
# Finally we swap back the axes to the original order
|
||||
x = torch.transpose(x, dim, 0)
|
||||
|
||||
return x
|
||||
|
||||
def fix_scale_and_size(self, input_shape, output_shape, scale_factor):
|
||||
# First fixing the scale-factor (if given) to be standardized the function expects (a list of scale factors in the
|
||||
# same size as the number of input dimensions)
|
||||
if scale_factor is not None:
|
||||
# By default, if scale-factor is a scalar we assume 2d resizing and duplicate it.
|
||||
if np.isscalar(scale_factor) and len(input_shape) > 1:
|
||||
scale_factor = [scale_factor, scale_factor]
|
||||
|
||||
# We extend the size of scale-factor list to the size of the input by assigning 1 to all the unspecified scales
|
||||
scale_factor = list(scale_factor)
|
||||
scale_factor = [1] * (len(input_shape) - len(scale_factor)) + scale_factor
|
||||
|
||||
# Fixing output-shape (if given): extending it to the size of the input-shape, by assigning the original input-size
|
||||
# to all the unspecified dimensions
|
||||
if output_shape is not None:
|
||||
output_shape = list(input_shape[len(output_shape) :]) + list(np.uint(np.array(output_shape)))
|
||||
|
||||
# Dealing with the case of non-give scale-factor, calculating according to output-shape. note that this is
|
||||
# sub-optimal, because there can be different scales to the same output-shape.
|
||||
if scale_factor is None:
|
||||
scale_factor = 1.0 * np.array(output_shape) / np.array(input_shape)
|
||||
|
||||
# Dealing with missing output-shape. calculating according to scale-factor
|
||||
if output_shape is None:
|
||||
output_shape = np.uint(np.ceil(np.array(input_shape) * np.array(scale_factor)))
|
||||
|
||||
return scale_factor, output_shape
|
||||
|
||||
def contributions(self, in_length, out_length, scale, kernel, kernel_width, antialiasing):
|
||||
# This function calculates a set of 'filters' and a set of field_of_view that will later on be applied
|
||||
# such that each position from the field_of_view will be multiplied with a matching filter from the
|
||||
# 'weights' based on the interpolation method and the distance of the sub-pixel location from the pixel centers
|
||||
# around it. This is only done for one dimension of the image.
|
||||
|
||||
# When anti-aliasing is activated (default and only for downscaling) the receptive field is stretched to size of
|
||||
# 1/sf. this means filtering is more 'low-pass filter'.
|
||||
fixed_kernel = (lambda arg: scale * kernel(scale * arg)) if antialiasing else kernel
|
||||
kernel_width *= 1.0 / scale if antialiasing else 1.0
|
||||
|
||||
# These are the coordinates of the output image
|
||||
out_coordinates = np.arange(1, out_length + 1)
|
||||
|
||||
# since both scale-factor and output size can be provided simulatneously, perserving the center of the image requires shifting
|
||||
# the output coordinates. the deviation is because out_length doesn't necesary equal in_length*scale.
|
||||
# to keep the center we need to subtract half of this deivation so that we get equal margins for boths sides and center is preserved.
|
||||
shifted_out_coordinates = out_coordinates - (out_length - in_length * scale) / 2
|
||||
|
||||
# These are the matching positions of the output-coordinates on the input image coordinates.
|
||||
# Best explained by example: say we have 4 horizontal pixels for HR and we downscale by SF=2 and get 2 pixels:
|
||||
# [1,2,3,4] -> [1,2]. Remember each pixel number is the middle of the pixel.
|
||||
# The scaling is done between the distances and not pixel numbers (the right boundary of pixel 4 is transformed to
|
||||
# the right boundary of pixel 2. pixel 1 in the small image matches the boundary between pixels 1 and 2 in the big
|
||||
# one and not to pixel 2. This means the position is not just multiplication of the old pos by scale-factor).
|
||||
# So if we measure distance from the left border, middle of pixel 1 is at distance d=0.5, border between 1 and 2 is
|
||||
# at d=1, and so on (d = p - 0.5). we calculate (d_new = d_old / sf) which means:
|
||||
# (p_new-0.5 = (p_old-0.5) / sf) -> p_new = p_old/sf + 0.5 * (1-1/sf)
|
||||
match_coordinates = shifted_out_coordinates / scale + 0.5 * (1 - 1 / scale)
|
||||
|
||||
# This is the left boundary to start multiplying the filter from, it depends on the size of the filter
|
||||
left_boundary = np.floor(match_coordinates - kernel_width / 2)
|
||||
|
||||
# Kernel width needs to be enlarged because when covering has sub-pixel borders, it must 'see' the pixel centers
|
||||
# of the pixels it only covered a part from. So we add one pixel at each side to consider (weights can zeroize them)
|
||||
expanded_kernel_width = np.ceil(kernel_width) + 2
|
||||
|
||||
# Determine a set of field_of_view for each each output position, these are the pixels in the input image
|
||||
# that the pixel in the output image 'sees'. We get a matrix whos horizontal dim is the output pixels (big) and the
|
||||
# vertical dim is the pixels it 'sees' (kernel_size + 2)
|
||||
field_of_view = np.squeeze(
|
||||
np.int16(np.expand_dims(left_boundary, axis=1) + np.arange(expanded_kernel_width) - 1)
|
||||
)
|
||||
|
||||
# Assign weight to each pixel in the field of view. A matrix whos horizontal dim is the output pixels and the
|
||||
# vertical dim is a list of weights matching to the pixel in the field of view (that are specified in
|
||||
# 'field_of_view')
|
||||
weights = fixed_kernel(1.0 * np.expand_dims(match_coordinates, axis=1) - field_of_view - 1)
|
||||
|
||||
# Normalize weights to sum up to 1. be careful from dividing by 0
|
||||
sum_weights = np.sum(weights, axis=1)
|
||||
sum_weights[sum_weights == 0] = 1.0
|
||||
weights = 1.0 * weights / np.expand_dims(sum_weights, axis=1)
|
||||
|
||||
# We use this mirror structure as a trick for reflection padding at the boundaries
|
||||
mirror = np.uint(np.concatenate((np.arange(in_length), np.arange(in_length - 1, -1, step=-1))))
|
||||
field_of_view = mirror[np.mod(field_of_view, mirror.shape[0])]
|
||||
|
||||
# Get rid of weights and pixel positions that are of zero weight
|
||||
non_zero_out_pixels = np.nonzero(np.any(weights, axis=0))
|
||||
weights = np.squeeze(weights[:, non_zero_out_pixels])
|
||||
field_of_view = np.squeeze(field_of_view[:, non_zero_out_pixels])
|
||||
|
||||
# Final products are the relative positions and the matching weights, both are output_size X fixed_kernel_size
|
||||
return weights, field_of_view
|
||||
|
||||
self.down_sample = Resizer(in_shape, 1 / scale_factor)
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, data, **kwargs):
|
||||
return self.down_sample(data)
|
||||
|
||||
# Gaussian blurring operator
|
||||
class GaussialBlurOperator(nn.Module):
|
||||
def __init__(self, kernel_size, intensity):
|
||||
super().__init__()
|
||||
|
||||
class Blurkernel(nn.Module):
|
||||
def __init__(self, blur_type="gaussian", kernel_size=31, std=3.0):
|
||||
super().__init__()
|
||||
self.blur_type = blur_type
|
||||
self.kernel_size = kernel_size
|
||||
self.std = std
|
||||
self.seq = nn.Sequential(
|
||||
nn.ReflectionPad2d(self.kernel_size // 2),
|
||||
nn.Conv2d(3, 3, self.kernel_size, stride=1, padding=0, bias=False, groups=3),
|
||||
)
|
||||
self.weights_init()
|
||||
|
||||
def forward(self, x):
|
||||
return self.seq(x)
|
||||
|
||||
def weights_init(self):
|
||||
if self.blur_type == "gaussian":
|
||||
n = np.zeros((self.kernel_size, self.kernel_size))
|
||||
n[self.kernel_size // 2, self.kernel_size // 2] = 1
|
||||
k = scipy.ndimage.gaussian_filter(n, sigma=self.std)
|
||||
k = torch.from_numpy(k)
|
||||
self.k = k
|
||||
for name, f in self.named_parameters():
|
||||
f.data.copy_(k)
|
||||
|
||||
def update_weights(self, k):
|
||||
if not torch.is_tensor(k):
|
||||
k = torch.from_numpy(k)
|
||||
for name, f in self.named_parameters():
|
||||
f.data.copy_(k)
|
||||
|
||||
def get_kernel(self):
|
||||
return self.k
|
||||
|
||||
self.kernel_size = kernel_size
|
||||
self.conv = Blurkernel(blur_type="gaussian", kernel_size=kernel_size, std=intensity)
|
||||
self.kernel = self.conv.get_kernel()
|
||||
self.conv.update_weights(self.kernel.type(torch.float32))
|
||||
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, data, **kwargs):
|
||||
return self.conv(data)
|
||||
|
||||
def transpose(self, data, **kwargs):
|
||||
return data
|
||||
|
||||
def get_kernel(self):
|
||||
return self.kernel.view(1, 1, self.kernel_size, self.kernel_size)
|
||||
|
||||
# assuming the forward process y = f(x) is polluted by Gaussian noise, use l2 norm
|
||||
def RMSELoss(yhat, y):
|
||||
return torch.sqrt(torch.sum((yhat - y) ** 2))
|
||||
|
||||
# set up source image
|
||||
src = Image.open("sample.png")
|
||||
# read image into [1,3,H,W]
|
||||
src = torch.from_numpy(np.array(src, dtype=np.float32)).permute(2, 0, 1)[None]
|
||||
# normalize image to [-1,1]
|
||||
src = (src / 127.5) - 1.0
|
||||
src = src.to("cuda")
|
||||
|
||||
# set up operator and measurement
|
||||
# operator = SuperResolutionOperator(in_shape=src.shape, scale_factor=4).to("cuda")
|
||||
operator = GaussialBlurOperator(kernel_size=61, intensity=3.0).to("cuda")
|
||||
measurement = operator(src)
|
||||
|
||||
# set up scheduler
|
||||
scheduler = DDPMScheduler.from_pretrained("google/ddpm-celebahq-256")
|
||||
scheduler.set_timesteps(1000)
|
||||
|
||||
# set up model
|
||||
model = UNet2DModel.from_pretrained("google/ddpm-celebahq-256").to("cuda")
|
||||
|
||||
save_image((src + 1.0) / 2.0, "dps_src.png")
|
||||
save_image((measurement + 1.0) / 2.0, "dps_mea.png")
|
||||
|
||||
# finally, the pipeline
|
||||
dpspipe = DPSPipeline(model, scheduler)
|
||||
image = dpspipe(
|
||||
measurement=measurement,
|
||||
operator=operator,
|
||||
loss_fn=RMSELoss,
|
||||
zeta=1.0,
|
||||
).images[0]
|
||||
|
||||
image.save("dps_generated_image.png")
|
||||
@@ -16,6 +16,7 @@
|
||||
|
||||
import ast
|
||||
import gc
|
||||
import inspect
|
||||
import math
|
||||
import warnings
|
||||
from collections.abc import Iterable
|
||||
@@ -23,16 +24,29 @@ from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||
from packaging import version
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
||||
|
||||
from diffusers.configuration_utils import FrozenDict
|
||||
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from diffusers.loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers.models.attention import Attention, GatedSelfAttentionDense
|
||||
from diffusers.models.attention_processor import AttnProcessor2_0
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline
|
||||
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
||||
from diffusers.pipelines import DiffusionPipeline
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from diffusers.schedulers import KarrasDiffusionSchedulers
|
||||
from diffusers.utils import logging, replace_example_docstring
|
||||
from diffusers.utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
deprecate,
|
||||
logging,
|
||||
replace_example_docstring,
|
||||
scale_lora_layers,
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
@@ -44,6 +58,7 @@ EXAMPLE_DOC_STRING = """
|
||||
>>> pipe = DiffusionPipeline.from_pretrained(
|
||||
... "longlian/lmd_plus",
|
||||
... custom_pipeline="llm_grounded_diffusion",
|
||||
... custom_revision="main",
|
||||
... variant="fp16", torch_dtype=torch.float16
|
||||
... )
|
||||
>>> pipe.enable_model_cpu_offload()
|
||||
@@ -96,7 +111,12 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
# All keys in Stable Diffusion models: [('down', 0, 0, 0), ('down', 0, 1, 0), ('down', 1, 0, 0), ('down', 1, 1, 0), ('down', 2, 0, 0), ('down', 2, 1, 0), ('mid', 0, 0, 0), ('up', 1, 0, 0), ('up', 1, 1, 0), ('up', 1, 2, 0), ('up', 2, 0, 0), ('up', 2, 1, 0), ('up', 2, 2, 0), ('up', 3, 0, 0), ('up', 3, 1, 0), ('up', 3, 2, 0)]
|
||||
# Note that the first up block is `UpBlock2D` rather than `CrossAttnUpBlock2D` and does not have attention. The last index is always 0 in our case since we have one `BasicTransformerBlock` in each `Transformer2DModel`.
|
||||
DEFAULT_GUIDANCE_ATTN_KEYS = [("mid", 0, 0, 0), ("up", 1, 0, 0), ("up", 1, 1, 0), ("up", 1, 2, 0)]
|
||||
DEFAULT_GUIDANCE_ATTN_KEYS = [
|
||||
("mid", 0, 0, 0),
|
||||
("up", 1, 0, 0),
|
||||
("up", 1, 1, 0),
|
||||
("up", 1, 2, 0),
|
||||
]
|
||||
|
||||
|
||||
def convert_attn_keys(key):
|
||||
@@ -126,7 +146,15 @@ def scale_proportion(obj_box, H, W):
|
||||
|
||||
# Adapted from the parent class `AttnProcessor2_0`
|
||||
class AttnProcessorWithHook(AttnProcessor2_0):
|
||||
def __init__(self, attn_processor_key, hidden_size, cross_attention_dim, hook=None, fast_attn=True, enabled=True):
|
||||
def __init__(
|
||||
self,
|
||||
attn_processor_key,
|
||||
hidden_size,
|
||||
cross_attention_dim,
|
||||
hook=None,
|
||||
fast_attn=True,
|
||||
enabled=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.attn_processor_key = attn_processor_key
|
||||
self.hidden_size = hidden_size
|
||||
@@ -165,15 +193,16 @@ class AttnProcessorWithHook(AttnProcessor2_0):
|
||||
if attn.group_norm is not None:
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
query = attn.to_q(hidden_states, scale=scale)
|
||||
args = () if USE_PEFT_BACKEND else (scale,)
|
||||
query = attn.to_q(hidden_states, *args)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
key = attn.to_k(encoder_hidden_states, scale=scale)
|
||||
value = attn.to_v(encoder_hidden_states, scale=scale)
|
||||
key = attn.to_k(encoder_hidden_states, *args)
|
||||
value = attn.to_v(encoder_hidden_states, *args)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
@@ -186,7 +215,13 @@ class AttnProcessorWithHook(AttnProcessor2_0):
|
||||
|
||||
if self.hook is not None and self.enabled:
|
||||
# Call the hook with query, key, value, and attention maps
|
||||
self.hook(self.attn_processor_key, query_batch_dim, key_batch_dim, value_batch_dim, attention_probs)
|
||||
self.hook(
|
||||
self.attn_processor_key,
|
||||
query_batch_dim,
|
||||
key_batch_dim,
|
||||
value_batch_dim,
|
||||
attention_probs,
|
||||
)
|
||||
|
||||
if self.fast_attn:
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
@@ -202,7 +237,12 @@ class AttnProcessorWithHook(AttnProcessor2_0):
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=attention_mask,
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
@@ -211,7 +251,7 @@ class AttnProcessorWithHook(AttnProcessor2_0):
|
||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states, scale=scale)
|
||||
hidden_states = attn.to_out[0](hidden_states, *args)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
@@ -226,7 +266,9 @@ class AttnProcessorWithHook(AttnProcessor2_0):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LLMGroundedDiffusionPipeline(StableDiffusionPipeline):
|
||||
class LLMGroundedDiffusionPipeline(
|
||||
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin
|
||||
):
|
||||
r"""
|
||||
Pipeline for layout-grounded text-to-image generation using LLM-grounded Diffusion (LMD+): https://arxiv.org/pdf/2305.13655.pdf.
|
||||
|
||||
@@ -257,6 +299,11 @@ class LLMGroundedDiffusionPipeline(StableDiffusionPipeline):
|
||||
Whether a safety checker is needed for this pipeline.
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->unet->vae"
|
||||
_optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
|
||||
_exclude_from_cpu_offload = ["safety_checker"]
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
||||
|
||||
objects_text = "Objects: "
|
||||
bg_prompt_text = "Background prompt: "
|
||||
bg_prompt_text_no_trailing_space = bg_prompt_text.rstrip()
|
||||
@@ -272,12 +319,91 @@ class LLMGroundedDiffusionPipeline(StableDiffusionPipeline):
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPImageProcessor,
|
||||
image_encoder: CLIPVisionModelWithProjection = None,
|
||||
requires_safety_checker: bool = True,
|
||||
):
|
||||
super().__init__(
|
||||
vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker
|
||||
)
|
||||
# This is copied from StableDiffusionPipeline, with hook initizations for LMD+.
|
||||
super().__init__()
|
||||
|
||||
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
|
||||
deprecation_message = (
|
||||
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
|
||||
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
|
||||
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
|
||||
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
|
||||
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
|
||||
" file"
|
||||
)
|
||||
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
|
||||
new_config = dict(scheduler.config)
|
||||
new_config["steps_offset"] = 1
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
|
||||
deprecation_message = (
|
||||
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
|
||||
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
|
||||
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
|
||||
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
|
||||
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
|
||||
)
|
||||
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
|
||||
new_config = dict(scheduler.config)
|
||||
new_config["clip_sample"] = False
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if safety_checker is None and requires_safety_checker:
|
||||
logger.warning(
|
||||
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
||||
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
||||
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
||||
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
|
||||
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
||||
)
|
||||
|
||||
if safety_checker is not None and feature_extractor is None:
|
||||
raise ValueError(
|
||||
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
|
||||
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
||||
)
|
||||
|
||||
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
|
||||
version.parse(unet.config._diffusers_version).base_version
|
||||
) < version.parse("0.9.0.dev0")
|
||||
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
||||
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
|
||||
deprecation_message = (
|
||||
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
||||
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
|
||||
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
|
||||
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
|
||||
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
|
||||
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
|
||||
" in the config might lead to incorrect results in future versions. If you have downloaded this"
|
||||
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
|
||||
" the `unet/config.json` file"
|
||||
)
|
||||
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
|
||||
new_config = dict(unet.config)
|
||||
new_config["sample_size"] = 64
|
||||
unet._internal_dict = FrozenDict(new_config)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
image_encoder=image_encoder,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
# Initialize the attention hooks for LLM-grounded Diffusion
|
||||
self.register_attn_hooks(unet)
|
||||
self._saved_attn = None
|
||||
|
||||
@@ -464,7 +590,14 @@ class LLMGroundedDiffusionPipeline(StableDiffusionPipeline):
|
||||
|
||||
return token_map
|
||||
|
||||
def get_phrase_indices(self, prompt, phrases, token_map=None, add_suffix_if_not_found=False, verbose=False):
|
||||
def get_phrase_indices(
|
||||
self,
|
||||
prompt,
|
||||
phrases,
|
||||
token_map=None,
|
||||
add_suffix_if_not_found=False,
|
||||
verbose=False,
|
||||
):
|
||||
for obj in phrases:
|
||||
# Suffix the prompt with object name for attention guidance if object is not in the prompt, using "|" to separate the prompt and the suffix
|
||||
if obj not in prompt:
|
||||
@@ -485,7 +618,14 @@ class LLMGroundedDiffusionPipeline(StableDiffusionPipeline):
|
||||
phrase_token_map_str = " ".join(phrase_token_map)
|
||||
|
||||
if verbose:
|
||||
logger.info("Full str:", token_map_str, "Substr:", phrase_token_map_str, "Phrase:", phrases)
|
||||
logger.info(
|
||||
"Full str:",
|
||||
token_map_str,
|
||||
"Substr:",
|
||||
phrase_token_map_str,
|
||||
"Phrase:",
|
||||
phrases,
|
||||
)
|
||||
|
||||
# Count the number of token before substr
|
||||
# The substring comes with a trailing space that needs to be removed by minus one in the index.
|
||||
@@ -552,7 +692,15 @@ class LLMGroundedDiffusionPipeline(StableDiffusionPipeline):
|
||||
|
||||
return loss
|
||||
|
||||
def compute_ca_loss(self, saved_attn, bboxes, phrase_indices, guidance_attn_keys, verbose=False, **kwargs):
|
||||
def compute_ca_loss(
|
||||
self,
|
||||
saved_attn,
|
||||
bboxes,
|
||||
phrase_indices,
|
||||
guidance_attn_keys,
|
||||
verbose=False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
The `saved_attn` is supposed to be passed to `save_attn_to_dict` in `cross_attention_kwargs` prior to computing ths loss.
|
||||
`AttnProcessor` will put attention maps into the `save_attn_to_dict`.
|
||||
@@ -605,6 +753,7 @@ class LLMGroundedDiffusionPipeline(StableDiffusionPipeline):
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
ip_adapter_image: Optional[PipelineImageInput] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
@@ -662,6 +811,7 @@ class LLMGroundedDiffusionPipeline(StableDiffusionPipeline):
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
|
||||
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
|
||||
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
@@ -724,9 +874,10 @@ class LLMGroundedDiffusionPipeline(StableDiffusionPipeline):
|
||||
phrase_indices = []
|
||||
prompt_parsed = []
|
||||
for prompt_item in prompt:
|
||||
phrase_indices_parsed_item, prompt_parsed_item = self.get_phrase_indices(
|
||||
prompt_item, add_suffix_if_not_found=True
|
||||
)
|
||||
(
|
||||
phrase_indices_parsed_item,
|
||||
prompt_parsed_item,
|
||||
) = self.get_phrase_indices(prompt_item, add_suffix_if_not_found=True)
|
||||
phrase_indices.append(phrase_indices_parsed_item)
|
||||
prompt_parsed.append(prompt_parsed_item)
|
||||
prompt = prompt_parsed
|
||||
@@ -759,6 +910,11 @@ class LLMGroundedDiffusionPipeline(StableDiffusionPipeline):
|
||||
if do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
if ip_adapter_image is not None:
|
||||
image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt)
|
||||
if self.do_classifier_free_guidance:
|
||||
image_embeds = torch.cat([negative_image_embeds, image_embeds])
|
||||
|
||||
# 4. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
@@ -801,7 +957,10 @@ class LLMGroundedDiffusionPipeline(StableDiffusionPipeline):
|
||||
if n_objs:
|
||||
cond_boxes[:n_objs] = torch.tensor(boxes)
|
||||
text_embeddings = torch.zeros(
|
||||
max_objs, self.unet.config.cross_attention_dim, device=device, dtype=self.text_encoder.dtype
|
||||
max_objs,
|
||||
self.unet.config.cross_attention_dim,
|
||||
device=device,
|
||||
dtype=self.text_encoder.dtype,
|
||||
)
|
||||
if n_objs:
|
||||
text_embeddings[:n_objs] = _text_embeddings
|
||||
@@ -833,6 +992,9 @@ class LLMGroundedDiffusionPipeline(StableDiffusionPipeline):
|
||||
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 6.1 Add image embeds for IP-Adapter
|
||||
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
|
||||
|
||||
loss_attn = torch.tensor(10000.0)
|
||||
|
||||
# 7. Denoising loop
|
||||
@@ -869,6 +1031,7 @@ class LLMGroundedDiffusionPipeline(StableDiffusionPipeline):
|
||||
t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
).sample
|
||||
|
||||
# perform guidance
|
||||
@@ -1013,3 +1176,438 @@ class LLMGroundedDiffusionPipeline(StableDiffusionPipeline):
|
||||
self.enable_attn_hook(enabled=False)
|
||||
|
||||
return latents, loss
|
||||
|
||||
# Below are methods copied from StableDiffusionPipeline
|
||||
# The design choice of not inheriting from StableDiffusionPipeline is discussed here: https://github.com/huggingface/diffusers/pull/5993#issuecomment-1834258517
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
|
||||
def enable_vae_slicing(self):
|
||||
r"""
|
||||
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
||||
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
||||
"""
|
||||
self.vae.enable_slicing()
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
|
||||
def disable_vae_slicing(self):
|
||||
r"""
|
||||
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
self.vae.disable_slicing()
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
|
||||
def enable_vae_tiling(self):
|
||||
r"""
|
||||
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
||||
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
||||
processing larger images.
|
||||
"""
|
||||
self.vae.enable_tiling()
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
|
||||
def disable_vae_tiling(self):
|
||||
r"""
|
||||
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
self.vae.disable_tiling()
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
|
||||
def _encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt=None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
lora_scale: Optional[float] = None,
|
||||
**kwargs,
|
||||
):
|
||||
deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
|
||||
deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
|
||||
|
||||
prompt_embeds_tuple = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
lora_scale=lora_scale,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# concatenate for backwards comp
|
||||
prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt=None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
lora_scale: Optional[float] = None,
|
||||
clip_skip: Optional[int] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
device: (`torch.device`):
|
||||
torch device
|
||||
num_images_per_prompt (`int`):
|
||||
number of images that should be generated per prompt
|
||||
do_classifier_free_guidance (`bool`):
|
||||
whether to use classifier free guidance or not
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
lora_scale (`float`, *optional*):
|
||||
A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
||||
clip_skip (`int`, *optional*):
|
||||
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
||||
the output of the pre-final layer will be used for computing the prompt embeddings.
|
||||
"""
|
||||
# set lora scale so that monkey patched LoRA
|
||||
# function of text encoder can correctly access it
|
||||
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
if not USE_PEFT_BACKEND:
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
else:
|
||||
scale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
# textual inversion: procecss multi-vector tokens if necessary
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
||||
text_input_ids, untruncated_ids
|
||||
):
|
||||
removed_text = self.tokenizer.batch_decode(
|
||||
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
||||
)
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
||||
attention_mask = text_inputs.attention_mask.to(device)
|
||||
else:
|
||||
attention_mask = None
|
||||
|
||||
if clip_skip is None:
|
||||
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
|
||||
prompt_embeds = prompt_embeds[0]
|
||||
else:
|
||||
prompt_embeds = self.text_encoder(
|
||||
text_input_ids.to(device),
|
||||
attention_mask=attention_mask,
|
||||
output_hidden_states=True,
|
||||
)
|
||||
# Access the `hidden_states` first, that contains a tuple of
|
||||
# all the hidden states from the encoder layers. Then index into
|
||||
# the tuple to access the hidden states from the desired layer.
|
||||
prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
|
||||
# We also need to apply the final LayerNorm here to not mess with the
|
||||
# representations. The `last_hidden_states` that we typically use for
|
||||
# obtaining the final prompt representations passes through the LayerNorm
|
||||
# layer.
|
||||
prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
|
||||
|
||||
if self.text_encoder is not None:
|
||||
prompt_embeds_dtype = self.text_encoder.dtype
|
||||
elif self.unet is not None:
|
||||
prompt_embeds_dtype = self.unet.dtype
|
||||
else:
|
||||
prompt_embeds_dtype = prompt_embeds.dtype
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
||||
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
uncond_tokens: List[str]
|
||||
if negative_prompt is None:
|
||||
uncond_tokens = [""] * batch_size
|
||||
elif prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif isinstance(negative_prompt, str):
|
||||
uncond_tokens = [negative_prompt]
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
else:
|
||||
uncond_tokens = negative_prompt
|
||||
|
||||
# textual inversion: procecss multi-vector tokens if necessary
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
|
||||
|
||||
max_length = prompt_embeds.shape[1]
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_tokens,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
||||
attention_mask = uncond_input.attention_mask.to(device)
|
||||
else:
|
||||
attention_mask = None
|
||||
|
||||
negative_prompt_embeds = self.text_encoder(
|
||||
uncond_input.input_ids.to(device),
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
negative_prompt_embeds = negative_prompt_embeds[0]
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = negative_prompt_embeds.shape[1]
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
|
||||
def encode_image(self, image, device, num_images_per_prompt):
|
||||
dtype = next(self.image_encoder.parameters()).dtype
|
||||
|
||||
if not isinstance(image, torch.Tensor):
|
||||
image = self.feature_extractor(image, return_tensors="pt").pixel_values
|
||||
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
image_embeds = self.image_encoder(image).image_embeds
|
||||
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
|
||||
uncond_image_embeds = torch.zeros_like(image_embeds)
|
||||
return image_embeds, uncond_image_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
|
||||
def run_safety_checker(self, image, device, dtype):
|
||||
if self.safety_checker is None:
|
||||
has_nsfw_concept = None
|
||||
else:
|
||||
if torch.is_tensor(image):
|
||||
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
|
||||
else:
|
||||
feature_extractor_input = self.image_processor.numpy_to_pil(image)
|
||||
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
|
||||
image, has_nsfw_concept = self.safety_checker(
|
||||
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
|
||||
)
|
||||
return image, has_nsfw_concept
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
|
||||
def decode_latents(self, latents):
|
||||
deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
|
||||
deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
|
||||
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
image = self.vae.decode(latents, return_dict=False)[0]
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
return image
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
|
||||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
# check if the scheduler accepts generator
|
||||
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
if accepts_generator:
|
||||
extra_step_kwargs["generator"] = generator
|
||||
return extra_step_kwargs
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
device,
|
||||
generator,
|
||||
latents=None,
|
||||
):
|
||||
shape = (
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
height // self.vae_scale_factor,
|
||||
width // self.vae_scale_factor,
|
||||
)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
latents = latents.to(device)
|
||||
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_freeu
|
||||
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
|
||||
r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497.
|
||||
|
||||
The suffixes after the scaling factors represent the stages where they are being applied.
|
||||
|
||||
Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values
|
||||
that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
|
||||
|
||||
Args:
|
||||
s1 (`float`):
|
||||
Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
|
||||
mitigate "oversmoothing effect" in the enhanced denoising process.
|
||||
s2 (`float`):
|
||||
Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
|
||||
mitigate "oversmoothing effect" in the enhanced denoising process.
|
||||
b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
|
||||
b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
|
||||
"""
|
||||
if not hasattr(self, "unet"):
|
||||
raise ValueError("The pipeline must have `unet` for using FreeU.")
|
||||
self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
|
||||
def disable_freeu(self):
|
||||
"""Disables the FreeU mechanism if enabled."""
|
||||
self.unet.disable_freeu()
|
||||
|
||||
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
|
||||
def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
|
||||
"""
|
||||
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
|
||||
|
||||
Args:
|
||||
timesteps (`torch.Tensor`):
|
||||
generate embedding vectors at these timesteps
|
||||
embedding_dim (`int`, *optional*, defaults to 512):
|
||||
dimension of the embeddings to generate
|
||||
dtype:
|
||||
data type of the generated embeddings
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
|
||||
"""
|
||||
assert len(w.shape) == 1
|
||||
w = w * 1000.0
|
||||
|
||||
half_dim = embedding_dim // 2
|
||||
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
|
||||
emb = w.to(dtype)[:, None] * emb[None, :]
|
||||
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
||||
if embedding_dim % 2 == 1: # zero pad
|
||||
emb = torch.nn.functional.pad(emb, (0, 1))
|
||||
assert emb.shape == (w.shape[0], embedding_dim)
|
||||
return emb
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.guidance_scale
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.guidance_rescale
|
||||
@property
|
||||
def guidance_rescale(self):
|
||||
return self._guidance_rescale
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.clip_skip
|
||||
@property
|
||||
def clip_skip(self):
|
||||
return self._clip_skip
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.do_classifier_free_guidance
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.cross_attention_kwargs
|
||||
@property
|
||||
def cross_attention_kwargs(self):
|
||||
return self._cross_attention_kwargs
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.num_timesteps
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@@ -250,6 +250,7 @@ def get_weighted_text_embeddings_sdxl(
|
||||
neg_prompt: str = "",
|
||||
neg_prompt_2: str = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
device: Optional[torch.device] = None,
|
||||
):
|
||||
"""
|
||||
This function can process long prompt with weights, no length limitation
|
||||
@@ -262,10 +263,13 @@ def get_weighted_text_embeddings_sdxl(
|
||||
neg_prompt (str)
|
||||
neg_prompt_2 (str)
|
||||
num_images_per_prompt (int)
|
||||
device (torch.device)
|
||||
Returns:
|
||||
prompt_embeds (torch.Tensor)
|
||||
neg_prompt_embeds (torch.Tensor)
|
||||
"""
|
||||
device = device or pipe._execution_device
|
||||
|
||||
if prompt_2:
|
||||
prompt = f"{prompt} {prompt_2}"
|
||||
|
||||
@@ -330,17 +334,17 @@ def get_weighted_text_embeddings_sdxl(
|
||||
# get prompt embeddings one by one is not working.
|
||||
for i in range(len(prompt_token_groups)):
|
||||
# get positive prompt embeddings with weights
|
||||
token_tensor = torch.tensor([prompt_token_groups[i]], dtype=torch.long, device=pipe.device)
|
||||
weight_tensor = torch.tensor(prompt_weight_groups[i], dtype=torch.float16, device=pipe.device)
|
||||
token_tensor = torch.tensor([prompt_token_groups[i]], dtype=torch.long, device=device)
|
||||
weight_tensor = torch.tensor(prompt_weight_groups[i], dtype=torch.float16, device=device)
|
||||
|
||||
token_tensor_2 = torch.tensor([prompt_token_groups_2[i]], dtype=torch.long, device=pipe.device)
|
||||
token_tensor_2 = torch.tensor([prompt_token_groups_2[i]], dtype=torch.long, device=device)
|
||||
|
||||
# use first text encoder
|
||||
prompt_embeds_1 = pipe.text_encoder(token_tensor.to(pipe.device), output_hidden_states=True)
|
||||
prompt_embeds_1 = pipe.text_encoder(token_tensor.to(device), output_hidden_states=True)
|
||||
prompt_embeds_1_hidden_states = prompt_embeds_1.hidden_states[-2]
|
||||
|
||||
# use second text encoder
|
||||
prompt_embeds_2 = pipe.text_encoder_2(token_tensor_2.to(pipe.device), output_hidden_states=True)
|
||||
prompt_embeds_2 = pipe.text_encoder_2(token_tensor_2.to(device), output_hidden_states=True)
|
||||
prompt_embeds_2_hidden_states = prompt_embeds_2.hidden_states[-2]
|
||||
pooled_prompt_embeds = prompt_embeds_2[0]
|
||||
|
||||
@@ -357,16 +361,16 @@ def get_weighted_text_embeddings_sdxl(
|
||||
embeds.append(token_embedding)
|
||||
|
||||
# get negative prompt embeddings with weights
|
||||
neg_token_tensor = torch.tensor([neg_prompt_token_groups[i]], dtype=torch.long, device=pipe.device)
|
||||
neg_token_tensor_2 = torch.tensor([neg_prompt_token_groups_2[i]], dtype=torch.long, device=pipe.device)
|
||||
neg_weight_tensor = torch.tensor(neg_prompt_weight_groups[i], dtype=torch.float16, device=pipe.device)
|
||||
neg_token_tensor = torch.tensor([neg_prompt_token_groups[i]], dtype=torch.long, device=device)
|
||||
neg_token_tensor_2 = torch.tensor([neg_prompt_token_groups_2[i]], dtype=torch.long, device=device)
|
||||
neg_weight_tensor = torch.tensor(neg_prompt_weight_groups[i], dtype=torch.float16, device=device)
|
||||
|
||||
# use first text encoder
|
||||
neg_prompt_embeds_1 = pipe.text_encoder(neg_token_tensor.to(pipe.device), output_hidden_states=True)
|
||||
neg_prompt_embeds_1 = pipe.text_encoder(neg_token_tensor.to(device), output_hidden_states=True)
|
||||
neg_prompt_embeds_1_hidden_states = neg_prompt_embeds_1.hidden_states[-2]
|
||||
|
||||
# use second text encoder
|
||||
neg_prompt_embeds_2 = pipe.text_encoder_2(neg_token_tensor_2.to(pipe.device), output_hidden_states=True)
|
||||
neg_prompt_embeds_2 = pipe.text_encoder_2(neg_token_tensor_2.to(device), output_hidden_states=True)
|
||||
neg_prompt_embeds_2_hidden_states = neg_prompt_embeds_2.hidden_states[-2]
|
||||
negative_pooled_prompt_embeds = neg_prompt_embeds_2[0]
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,772 @@
|
||||
# Copyright 2023 The Intel Labs Team Authors and the HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers.image_processor import PipelineDepthInput, PipelineImageInput, VaeImageProcessorLDM3D
|
||||
from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_ldm3d import LDM3DPipelineOutput
|
||||
from diffusers.schedulers import DDPMScheduler, KarrasDiffusionSchedulers
|
||||
from diffusers.utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
deprecate,
|
||||
logging,
|
||||
scale_lora_layers,
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```python
|
||||
>>> from diffusers import StableDiffusionUpscaleLDM3DPipeline
|
||||
>>> from PIL import Image
|
||||
>>> from io import BytesIO
|
||||
>>> import requests
|
||||
|
||||
>>> pipe = StableDiffusionUpscaleLDM3DPipeline.from_pretrained("Intel/ldm3d-sr")
|
||||
>>> pipe = pipe.to("cuda")
|
||||
>>> rgb_path = "https://huggingface.co/Intel/ldm3d-sr/resolve/main/lemons_ldm3d_rgb.jpg"
|
||||
>>> depth_path = "https://huggingface.co/Intel/ldm3d-sr/resolve/main/lemons_ldm3d_depth.png"
|
||||
>>> low_res_rgb = Image.open(BytesIO(requests.get(rgb_path).content)).convert("RGB")
|
||||
>>> low_res_depth = Image.open(BytesIO(requests.get(depth_path).content)).convert("L")
|
||||
>>> output = pipe(
|
||||
... prompt="high quality high resolution uhd 4k image",
|
||||
... rgb=low_res_rgb,
|
||||
... depth=low_res_depth,
|
||||
... num_inference_steps=50,
|
||||
... target_res=[1024, 1024],
|
||||
... )
|
||||
>>> rgb_image, depth_image = output.rgb, output.depth
|
||||
>>> rgb_image[0].save("hr_ldm3d_rgb.jpg")
|
||||
>>> depth_image[0].save("hr_ldm3d_depth.png")
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
class StableDiffusionUpscaleLDM3DPipeline(
|
||||
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
|
||||
):
|
||||
r"""
|
||||
Pipeline for text-to-image and 3D generation using LDM3D.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
||||
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
||||
|
||||
The pipeline also inherits the following loading methods:
|
||||
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
||||
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
||||
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
||||
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`~transformers.CLIPTextModel`]):
|
||||
Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
|
||||
tokenizer ([`~transformers.CLIPTokenizer`]):
|
||||
A `CLIPTokenizer` to tokenize text.
|
||||
unet ([`UNet2DConditionModel`]):
|
||||
A `UNet2DConditionModel` to denoise the encoded image latents.
|
||||
low_res_scheduler ([`SchedulerMixin`]):
|
||||
A scheduler used to add initial noise to the low resolution conditioning image. It must be an instance of
|
||||
[`DDPMScheduler`].
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
safety_checker ([`StableDiffusionSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||||
Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
|
||||
about a model's potential harms.
|
||||
feature_extractor ([`~transformers.CLIPImageProcessor`]):
|
||||
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
|
||||
"""
|
||||
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: CLIPTextModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
low_res_scheduler: DDPMScheduler,
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPImageProcessor,
|
||||
requires_safety_checker: bool = True,
|
||||
watermarker: Optional[Any] = None,
|
||||
max_noise_level: int = 350,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if safety_checker is None and requires_safety_checker:
|
||||
logger.warning(
|
||||
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
||||
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
||||
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
||||
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
|
||||
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
||||
)
|
||||
|
||||
if safety_checker is not None and feature_extractor is None:
|
||||
raise ValueError(
|
||||
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
|
||||
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
||||
)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
low_res_scheduler=low_res_scheduler,
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
watermarker=watermarker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.image_processor = VaeImageProcessorLDM3D(vae_scale_factor=self.vae_scale_factor, resample="bilinear")
|
||||
# self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
self.register_to_config(max_noise_level=max_noise_level)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_ldm3d.StableDiffusionLDM3DPipeline._encode_prompt
|
||||
def _encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt=None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
lora_scale: Optional[float] = None,
|
||||
**kwargs,
|
||||
):
|
||||
deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
|
||||
deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
|
||||
|
||||
prompt_embeds_tuple = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
lora_scale=lora_scale,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# concatenate for backwards comp
|
||||
prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_ldm3d.StableDiffusionLDM3DPipeline.encode_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt=None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
lora_scale: Optional[float] = None,
|
||||
clip_skip: Optional[int] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
device: (`torch.device`):
|
||||
torch device
|
||||
num_images_per_prompt (`int`):
|
||||
number of images that should be generated per prompt
|
||||
do_classifier_free_guidance (`bool`):
|
||||
whether to use classifier free guidance or not
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
lora_scale (`float`, *optional*):
|
||||
A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
||||
clip_skip (`int`, *optional*):
|
||||
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
||||
the output of the pre-final layer will be used for computing the prompt embeddings.
|
||||
"""
|
||||
# set lora scale so that monkey patched LoRA
|
||||
# function of text encoder can correctly access it
|
||||
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
if not USE_PEFT_BACKEND:
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
else:
|
||||
scale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
# textual inversion: procecss multi-vector tokens if necessary
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
||||
text_input_ids, untruncated_ids
|
||||
):
|
||||
removed_text = self.tokenizer.batch_decode(
|
||||
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
||||
)
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
||||
attention_mask = text_inputs.attention_mask.to(device)
|
||||
else:
|
||||
attention_mask = None
|
||||
|
||||
if clip_skip is None:
|
||||
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
|
||||
prompt_embeds = prompt_embeds[0]
|
||||
else:
|
||||
prompt_embeds = self.text_encoder(
|
||||
text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
|
||||
)
|
||||
# Access the `hidden_states` first, that contains a tuple of
|
||||
# all the hidden states from the encoder layers. Then index into
|
||||
# the tuple to access the hidden states from the desired layer.
|
||||
prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
|
||||
# We also need to apply the final LayerNorm here to not mess with the
|
||||
# representations. The `last_hidden_states` that we typically use for
|
||||
# obtaining the final prompt representations passes through the LayerNorm
|
||||
# layer.
|
||||
prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
|
||||
|
||||
if self.text_encoder is not None:
|
||||
prompt_embeds_dtype = self.text_encoder.dtype
|
||||
elif self.unet is not None:
|
||||
prompt_embeds_dtype = self.unet.dtype
|
||||
else:
|
||||
prompt_embeds_dtype = prompt_embeds.dtype
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
||||
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
uncond_tokens: List[str]
|
||||
if negative_prompt is None:
|
||||
uncond_tokens = [""] * batch_size
|
||||
elif prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif isinstance(negative_prompt, str):
|
||||
uncond_tokens = [negative_prompt]
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
else:
|
||||
uncond_tokens = negative_prompt
|
||||
|
||||
# textual inversion: procecss multi-vector tokens if necessary
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
|
||||
|
||||
max_length = prompt_embeds.shape[1]
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_tokens,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
||||
attention_mask = uncond_input.attention_mask.to(device)
|
||||
else:
|
||||
attention_mask = None
|
||||
|
||||
negative_prompt_embeds = self.text_encoder(
|
||||
uncond_input.input_ids.to(device),
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
negative_prompt_embeds = negative_prompt_embeds[0]
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = negative_prompt_embeds.shape[1]
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
def run_safety_checker(self, image, device, dtype):
|
||||
if self.safety_checker is None:
|
||||
has_nsfw_concept = None
|
||||
else:
|
||||
if torch.is_tensor(image):
|
||||
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
|
||||
else:
|
||||
feature_extractor_input = self.image_processor.numpy_to_pil(image)
|
||||
rgb_feature_extractor_input = feature_extractor_input[0]
|
||||
safety_checker_input = self.feature_extractor(rgb_feature_extractor_input, return_tensors="pt").to(device)
|
||||
image, has_nsfw_concept = self.safety_checker(
|
||||
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
|
||||
)
|
||||
return image, has_nsfw_concept
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
|
||||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
# check if the scheduler accepts generator
|
||||
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
if accepts_generator:
|
||||
extra_step_kwargs["generator"] = generator
|
||||
return extra_step_kwargs
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
image,
|
||||
noise_level,
|
||||
callback_steps,
|
||||
negative_prompt=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
target_res=None,
|
||||
):
|
||||
if (callback_steps is None) or (
|
||||
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
||||
f" {type(callback_steps)}."
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
||||
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
||||
raise ValueError(
|
||||
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
|
||||
if (
|
||||
not isinstance(image, torch.Tensor)
|
||||
and not isinstance(image, PIL.Image.Image)
|
||||
and not isinstance(image, np.ndarray)
|
||||
and not isinstance(image, list)
|
||||
):
|
||||
raise ValueError(
|
||||
f"`image` has to be of type `torch.Tensor`, `np.ndarray`, `PIL.Image.Image` or `list` but is {type(image)}"
|
||||
)
|
||||
|
||||
# verify batch size of prompt and image are same if image is a list or tensor or numpy array
|
||||
if isinstance(image, list) or isinstance(image, torch.Tensor) or isinstance(image, np.ndarray):
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if isinstance(image, list):
|
||||
image_batch_size = len(image)
|
||||
else:
|
||||
image_batch_size = image.shape[0]
|
||||
if batch_size != image_batch_size:
|
||||
raise ValueError(
|
||||
f"`prompt` has batch size {batch_size} and `image` has batch size {image_batch_size}."
|
||||
" Please make sure that passed `prompt` matches the batch size of `image`."
|
||||
)
|
||||
|
||||
# check noise level
|
||||
if noise_level > self.config.max_noise_level:
|
||||
raise ValueError(f"`noise_level` has to be <= {self.config.max_noise_level} but is {noise_level}")
|
||||
|
||||
if (callback_steps is None) or (
|
||||
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
||||
f" {type(callback_steps)}."
|
||||
)
|
||||
|
||||
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
||||
shape = (batch_size, num_channels_latents, height, width)
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
if latents.shape != shape:
|
||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
||||
latents = latents.to(device)
|
||||
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
# def upcast_vae(self):
|
||||
# dtype = self.vae.dtype
|
||||
# self.vae.to(dtype=torch.float32)
|
||||
# use_torch_2_0_or_xformers = isinstance(
|
||||
# self.vae.decoder.mid_block.attentions[0].processor,
|
||||
# (
|
||||
# AttnProcessor2_0,
|
||||
# XFormersAttnProcessor,
|
||||
# LoRAXFormersAttnProcessor,
|
||||
# LoRAAttnProcessor2_0,
|
||||
# ),
|
||||
# )
|
||||
# # if xformers or torch_2_0 is used attention block does not need
|
||||
# # to be in float32 which can save lots of memory
|
||||
# if use_torch_2_0_or_xformers:
|
||||
# self.vae.post_quant_conv.to(dtype)
|
||||
# self.vae.decoder.conv_in.to(dtype)
|
||||
# self.vae.decoder.mid_block.to(dtype)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
rgb: PipelineImageInput = None,
|
||||
depth: PipelineDepthInput = None,
|
||||
num_inference_steps: int = 75,
|
||||
guidance_scale: float = 9.0,
|
||||
noise_level: int = 20,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: int = 1,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
target_res: Optional[List[int]] = [1024, 1024],
|
||||
):
|
||||
r"""
|
||||
The call function to the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
|
||||
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
|
||||
`Image` or tensor representing an image batch to be upscaled.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, *optional*, defaults to 5.0):
|
||||
A higher guidance scale value encourages the model to generate images closely linked to the text
|
||||
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
|
||||
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
eta (`float`, *optional*, defaults to 0.0):
|
||||
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
|
||||
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
||||
generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor is generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
||||
provided, text embeddings are generated from the `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
|
||||
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
||||
plain tuple.
|
||||
callback (`Callable`, *optional*):
|
||||
A function that calls every `callback_steps` steps during inference. The function is called with the
|
||||
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
||||
callback_steps (`int`, *optional*, defaults to 1):
|
||||
The frequency at which the `callback` function is called. If not specified, the callback is called at
|
||||
every step.
|
||||
cross_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
|
||||
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
|
||||
otherwise a `tuple` is returned where the first element is a list with the generated images and the
|
||||
second element is a list of `bool`s indicating whether the corresponding generated image contains
|
||||
"not-safe-for-work" (nsfw) content.
|
||||
"""
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
rgb,
|
||||
noise_level,
|
||||
callback_steps,
|
||||
negative_prompt,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
)
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
# 3. Encode input prompt
|
||||
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
)
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
if do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
# 4. Preprocess image
|
||||
rgb, depth = self.image_processor.preprocess(rgb, depth, target_res=target_res)
|
||||
rgb = rgb.to(dtype=prompt_embeds.dtype, device=device)
|
||||
depth = depth.to(dtype=prompt_embeds.dtype, device=device)
|
||||
|
||||
# 5. set timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 6. Encode low resolutiom image to latent space
|
||||
image = torch.cat([rgb, depth], axis=1)
|
||||
latent_space_image = self.vae.encode(image).latent_dist.sample(generator)
|
||||
latent_space_image *= self.vae.scaling_factor
|
||||
noise_level = torch.tensor([noise_level], dtype=torch.long, device=device)
|
||||
# noise_rgb = randn_tensor(rgb.shape, generator=generator, device=device, dtype=prompt_embeds.dtype)
|
||||
# rgb = self.low_res_scheduler.add_noise(rgb, noise_rgb, noise_level)
|
||||
# noise_depth = randn_tensor(depth.shape, generator=generator, device=device, dtype=prompt_embeds.dtype)
|
||||
# depth = self.low_res_scheduler.add_noise(depth, noise_depth, noise_level)
|
||||
|
||||
batch_multiplier = 2 if do_classifier_free_guidance else 1
|
||||
latent_space_image = torch.cat([latent_space_image] * batch_multiplier * num_images_per_prompt)
|
||||
noise_level = torch.cat([noise_level] * latent_space_image.shape[0])
|
||||
|
||||
# 7. Prepare latent variables
|
||||
height, width = latent_space_image.shape[2:]
|
||||
num_channels_latents = self.vae.config.latent_channels
|
||||
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# 8. Check that sizes of image and latents match
|
||||
num_channels_image = latent_space_image.shape[1]
|
||||
if num_channels_latents + num_channels_image != self.unet.config.in_channels:
|
||||
raise ValueError(
|
||||
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
|
||||
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
|
||||
f" `num_channels_image`: {num_channels_image} "
|
||||
f" = {num_channels_latents+num_channels_image}. Please verify the config of"
|
||||
" `pipeline.unet` or your `image` input."
|
||||
)
|
||||
|
||||
# 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 10. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
|
||||
# concat latents, mask, masked_image_latents in the channel dimension
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
latent_model_input = torch.cat([latent_model_input, latent_space_image], dim=1)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(
|
||||
latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
class_labels=noise_level,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
if not output_type == "latent":
|
||||
# make sure the VAE is in float32 mode, as it overflows in float16
|
||||
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
||||
|
||||
if needs_upcasting:
|
||||
self.upcast_vae()
|
||||
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
||||
|
||||
image = self.vae.decode(latents / self.vae.scaling_factor, return_dict=False)[0]
|
||||
|
||||
# cast back to fp16 if needed
|
||||
if needs_upcasting:
|
||||
self.vae.to(dtype=torch.float16)
|
||||
|
||||
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
||||
|
||||
else:
|
||||
image = latents
|
||||
has_nsfw_concept = None
|
||||
|
||||
if has_nsfw_concept is None:
|
||||
do_denormalize = [True] * image.shape[0]
|
||||
else:
|
||||
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
|
||||
|
||||
rgb, depth = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
|
||||
|
||||
# 11. Apply watermark
|
||||
if output_type == "pil" and self.watermarker is not None:
|
||||
rgb = self.watermarker.apply_watermark(rgb)
|
||||
|
||||
# Offload last model to CPU
|
||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||
self.final_offload_hook.offload()
|
||||
|
||||
if not return_dict:
|
||||
return ((rgb, depth), has_nsfw_concept)
|
||||
|
||||
return LDM3DPipelineOutput(rgb=rgb, depth=depth, nsfw_content_detected=has_nsfw_concept)
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,589 @@
|
||||
import math
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
import torchvision.transforms.functional as FF
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from diffusers import StableDiffusionPipeline
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from diffusers.schedulers import KarrasDiffusionSchedulers
|
||||
from diffusers.utils import USE_PEFT_BACKEND
|
||||
|
||||
|
||||
try:
|
||||
from compel import Compel
|
||||
except ImportError:
|
||||
Compel = None
|
||||
|
||||
KCOMM = "ADDCOMM"
|
||||
KBRK = "BREAK"
|
||||
|
||||
|
||||
class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
|
||||
r"""
|
||||
Args for Regional Prompting Pipeline:
|
||||
rp_args:dict
|
||||
Required
|
||||
rp_args["mode"]: cols, rows, prompt, prompt-ex
|
||||
for cols, rows mode
|
||||
rp_args["div"]: ex) 1;1;1(Divide into 3 regions)
|
||||
for prompt, prompt-ex mode
|
||||
rp_args["th"]: ex) 0.5,0.5,0.6 (threshold for prompt mode)
|
||||
|
||||
Optional
|
||||
rp_args["save_mask"]: True/False (save masks in prompt mode)
|
||||
|
||||
Pipeline for text-to-image generation using Stable Diffusion.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`CLIPTextModel`]):
|
||||
Frozen text-encoder. Stable Diffusion uses the text portion of
|
||||
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
||||
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
||||
tokenizer (`CLIPTokenizer`):
|
||||
Tokenizer of class
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
safety_checker ([`StableDiffusionSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||||
Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
|
||||
feature_extractor ([`CLIPImageProcessor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: CLIPTextModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
requires_safety_checker: bool = True,
|
||||
):
|
||||
super().__init__(
|
||||
vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker
|
||||
)
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt: str,
|
||||
height: int = 512,
|
||||
width: int = 512,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 7.5,
|
||||
negative_prompt: str = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
rp_args: Dict[str, str] = None,
|
||||
):
|
||||
active = KBRK in prompt[0] if type(prompt) == list else KBRK in prompt # noqa: E721
|
||||
if negative_prompt is None:
|
||||
negative_prompt = "" if type(prompt) == str else [""] * len(prompt) # noqa: E721
|
||||
|
||||
device = self._execution_device
|
||||
regions = 0
|
||||
|
||||
self.power = int(rp_args["power"]) if "power" in rp_args else 1
|
||||
|
||||
prompts = prompt if type(prompt) == list else [prompt] # noqa: E721
|
||||
n_prompts = negative_prompt if type(negative_prompt) == list else [negative_prompt] # noqa: E721
|
||||
self.batch = batch = num_images_per_prompt * len(prompts)
|
||||
all_prompts_cn, all_prompts_p = promptsmaker(prompts, num_images_per_prompt)
|
||||
all_n_prompts_cn, _ = promptsmaker(n_prompts, num_images_per_prompt)
|
||||
|
||||
cn = len(all_prompts_cn) == len(all_n_prompts_cn)
|
||||
|
||||
if Compel:
|
||||
compel = Compel(tokenizer=self.tokenizer, text_encoder=self.text_encoder)
|
||||
|
||||
def getcompelembs(prps):
|
||||
embl = []
|
||||
for prp in prps:
|
||||
embl.append(compel.build_conditioning_tensor(prp))
|
||||
return torch.cat(embl)
|
||||
|
||||
conds = getcompelembs(all_prompts_cn)
|
||||
unconds = getcompelembs(all_n_prompts_cn) if cn else getcompelembs(n_prompts)
|
||||
embs = getcompelembs(prompts)
|
||||
n_embs = getcompelembs(n_prompts)
|
||||
prompt = negative_prompt = None
|
||||
else:
|
||||
conds = self.encode_prompt(prompts, device, 1, True)[0]
|
||||
unconds = (
|
||||
self.encode_prompt(n_prompts, device, 1, True)[0]
|
||||
if cn
|
||||
else self.encode_prompt(all_n_prompts_cn, device, 1, True)[0]
|
||||
)
|
||||
embs = n_embs = None
|
||||
|
||||
if not active:
|
||||
pcallback = None
|
||||
mode = None
|
||||
else:
|
||||
if any(x in rp_args["mode"].upper() for x in ["COL", "ROW"]):
|
||||
mode = "COL" if "COL" in rp_args["mode"].upper() else "ROW"
|
||||
ocells, icells, regions = make_cells(rp_args["div"])
|
||||
|
||||
elif "PRO" in rp_args["mode"].upper():
|
||||
regions = len(all_prompts_p[0])
|
||||
mode = "PROMPT"
|
||||
reset_attnmaps(self)
|
||||
self.ex = "EX" in rp_args["mode"].upper()
|
||||
self.target_tokens = target_tokens = tokendealer(self, all_prompts_p)
|
||||
thresholds = [float(x) for x in rp_args["th"].split(",")]
|
||||
|
||||
orig_hw = (height, width)
|
||||
revers = True
|
||||
|
||||
def pcallback(s_self, step: int, timestep: int, latents: torch.FloatTensor, selfs=None):
|
||||
if "PRO" in mode: # in Prompt mode, make masks from sum of attension maps
|
||||
self.step = step
|
||||
|
||||
if len(self.attnmaps_sizes) > 3:
|
||||
self.history[step] = self.attnmaps.copy()
|
||||
for hw in self.attnmaps_sizes:
|
||||
allmasks = []
|
||||
basemasks = [None] * batch
|
||||
for tt, th in zip(target_tokens, thresholds):
|
||||
for b in range(batch):
|
||||
key = f"{tt}-{b}"
|
||||
_, mask, _ = makepmask(self, self.attnmaps[key], hw[0], hw[1], th, step)
|
||||
mask = mask.unsqueeze(0).unsqueeze(-1)
|
||||
if self.ex:
|
||||
allmasks[b::batch] = [x - mask for x in allmasks[b::batch]]
|
||||
allmasks[b::batch] = [torch.where(x > 0, 1, 0) for x in allmasks[b::batch]]
|
||||
allmasks.append(mask)
|
||||
basemasks[b] = mask if basemasks[b] is None else basemasks[b] + mask
|
||||
basemasks = [1 - mask for mask in basemasks]
|
||||
basemasks = [torch.where(x > 0, 1, 0) for x in basemasks]
|
||||
allmasks = basemasks + allmasks
|
||||
|
||||
self.attnmasks[hw] = torch.cat(allmasks)
|
||||
self.maskready = True
|
||||
return latents
|
||||
|
||||
def hook_forward(module):
|
||||
# diffusers==0.23.2
|
||||
def forward(
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
scale: float = 1.0,
|
||||
) -> torch.Tensor:
|
||||
attn = module
|
||||
xshape = hidden_states.shape
|
||||
self.hw = (h, w) = split_dims(xshape[1], *orig_hw)
|
||||
|
||||
if revers:
|
||||
nx, px = hidden_states.chunk(2)
|
||||
else:
|
||||
px, nx = hidden_states.chunk(2)
|
||||
|
||||
if cn:
|
||||
hidden_states = torch.cat([px for i in range(regions)] + [nx for i in range(regions)], 0)
|
||||
encoder_hidden_states = torch.cat([conds] + [unconds])
|
||||
else:
|
||||
hidden_states = torch.cat([px for i in range(regions)] + [nx], 0)
|
||||
encoder_hidden_states = torch.cat([conds] + [unconds])
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
args = () if USE_PEFT_BACKEND else (scale,)
|
||||
|
||||
if attn.spatial_norm is not None:
|
||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||
|
||||
if attn.group_norm is not None:
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
args = () if USE_PEFT_BACKEND else (scale,)
|
||||
query = attn.to_q(hidden_states, *args)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
key = attn.to_k(encoder_hidden_states, *args)
|
||||
value = attn.to_v(encoder_hidden_states, *args)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
hidden_states = scaled_dot_product_attention(
|
||||
self,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=attention_mask,
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
getattn="PRO" in mode,
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states, *args)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
if attn.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states = hidden_states / attn.rescale_output_factor
|
||||
|
||||
#### Regional Prompting Col/Row mode
|
||||
if any(x in mode for x in ["COL", "ROW"]):
|
||||
reshaped = hidden_states.reshape(hidden_states.size()[0], h, w, hidden_states.size()[2])
|
||||
center = reshaped.shape[0] // 2
|
||||
px = reshaped[0:center] if cn else reshaped[0:-batch]
|
||||
nx = reshaped[center:] if cn else reshaped[-batch:]
|
||||
outs = [px, nx] if cn else [px]
|
||||
for out in outs:
|
||||
c = 0
|
||||
for i, ocell in enumerate(ocells):
|
||||
for icell in icells[i]:
|
||||
if "ROW" in mode:
|
||||
out[
|
||||
0:batch,
|
||||
int(h * ocell[0]) : int(h * ocell[1]),
|
||||
int(w * icell[0]) : int(w * icell[1]),
|
||||
:,
|
||||
] = out[
|
||||
c * batch : (c + 1) * batch,
|
||||
int(h * ocell[0]) : int(h * ocell[1]),
|
||||
int(w * icell[0]) : int(w * icell[1]),
|
||||
:,
|
||||
]
|
||||
else:
|
||||
out[
|
||||
0:batch,
|
||||
int(h * icell[0]) : int(h * icell[1]),
|
||||
int(w * ocell[0]) : int(w * ocell[1]),
|
||||
:,
|
||||
] = out[
|
||||
c * batch : (c + 1) * batch,
|
||||
int(h * icell[0]) : int(h * icell[1]),
|
||||
int(w * ocell[0]) : int(w * ocell[1]),
|
||||
:,
|
||||
]
|
||||
c += 1
|
||||
px, nx = (px[0:batch], nx[0:batch]) if cn else (px[0:batch], nx)
|
||||
hidden_states = torch.cat([nx, px], 0) if revers else torch.cat([px, nx], 0)
|
||||
hidden_states = hidden_states.reshape(xshape)
|
||||
|
||||
#### Regional Prompting Prompt mode
|
||||
elif "PRO" in mode:
|
||||
center = reshaped.shape[0] // 2
|
||||
px = reshaped[0:center] if cn else reshaped[0:-batch]
|
||||
nx = reshaped[center:] if cn else reshaped[-batch:]
|
||||
|
||||
if (h, w) in self.attnmasks and self.maskready:
|
||||
|
||||
def mask(input):
|
||||
out = torch.multiply(input, self.attnmasks[(h, w)])
|
||||
for b in range(batch):
|
||||
for r in range(1, regions):
|
||||
out[b] = out[b] + out[r * batch + b]
|
||||
return out
|
||||
|
||||
px, nx = (mask(px), mask(nx)) if cn else (mask(px), nx)
|
||||
px, nx = (px[0:batch], nx[0:batch]) if cn else (px[0:batch], nx)
|
||||
hidden_states = torch.cat([nx, px], 0) if revers else torch.cat([px, nx], 0)
|
||||
return hidden_states
|
||||
|
||||
return forward
|
||||
|
||||
def hook_forwards(root_module: torch.nn.Module):
|
||||
for name, module in root_module.named_modules():
|
||||
if "attn2" in name and module.__class__.__name__ == "Attention":
|
||||
module.forward = hook_forward(module)
|
||||
|
||||
hook_forwards(self.unet)
|
||||
|
||||
output = StableDiffusionPipeline(**self.components)(
|
||||
prompt=prompt,
|
||||
prompt_embeds=embs,
|
||||
negative_prompt=negative_prompt,
|
||||
negative_prompt_embeds=n_embs,
|
||||
height=height,
|
||||
width=width,
|
||||
num_inference_steps=num_inference_steps,
|
||||
guidance_scale=guidance_scale,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
eta=eta,
|
||||
generator=generator,
|
||||
latents=latents,
|
||||
output_type=output_type,
|
||||
return_dict=return_dict,
|
||||
callback_on_step_end=pcallback,
|
||||
)
|
||||
|
||||
if "save_mask" in rp_args:
|
||||
save_mask = rp_args["save_mask"]
|
||||
else:
|
||||
save_mask = False
|
||||
|
||||
if mode == "PROMPT" and save_mask:
|
||||
saveattnmaps(self, output, height, width, thresholds, num_inference_steps // 2, regions)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
### Make prompt list for each regions
|
||||
def promptsmaker(prompts, batch):
|
||||
out_p = []
|
||||
plen = len(prompts)
|
||||
for prompt in prompts:
|
||||
add = ""
|
||||
if KCOMM in prompt:
|
||||
add, prompt = prompt.split(KCOMM)
|
||||
add = add + " "
|
||||
prompts = prompt.split(KBRK)
|
||||
out_p.append([add + p for p in prompts])
|
||||
out = [None] * batch * len(out_p[0]) * len(out_p)
|
||||
for p, prs in enumerate(out_p): # inputs prompts
|
||||
for r, pr in enumerate(prs): # prompts for regions
|
||||
start = (p + r * plen) * batch
|
||||
out[start : start + batch] = [pr] * batch # P1R1B1,P1R1B2...,P1R2B1,P1R2B2...,P2R1B1...
|
||||
return out, out_p
|
||||
|
||||
|
||||
### make regions from ratios
|
||||
### ";" makes outercells, "," makes inner cells
|
||||
def make_cells(ratios):
|
||||
if ";" not in ratios and "," in ratios:
|
||||
ratios = ratios.replace(",", ";")
|
||||
ratios = ratios.split(";")
|
||||
ratios = [inratios.split(",") for inratios in ratios]
|
||||
|
||||
icells = []
|
||||
ocells = []
|
||||
|
||||
def startend(cells, array):
|
||||
current_start = 0
|
||||
array = [float(x) for x in array]
|
||||
for value in array:
|
||||
end = current_start + (value / sum(array))
|
||||
cells.append([current_start, end])
|
||||
current_start = end
|
||||
|
||||
startend(ocells, [r[0] for r in ratios])
|
||||
|
||||
for inratios in ratios:
|
||||
if 2 > len(inratios):
|
||||
icells.append([[0, 1]])
|
||||
else:
|
||||
add = []
|
||||
startend(add, inratios[1:])
|
||||
icells.append(add)
|
||||
|
||||
return ocells, icells, sum(len(cell) for cell in icells)
|
||||
|
||||
|
||||
def make_emblist(self, prompts):
|
||||
with torch.no_grad():
|
||||
tokens = self.tokenizer(
|
||||
prompts, max_length=self.tokenizer.model_max_length, padding=True, truncation=True, return_tensors="pt"
|
||||
).input_ids.to(self.device)
|
||||
embs = self.text_encoder(tokens, output_hidden_states=True).last_hidden_state.to(self.device, dtype=self.dtype)
|
||||
return embs
|
||||
|
||||
|
||||
def split_dims(xs, height, width):
|
||||
xs = xs
|
||||
|
||||
def repeat_div(x, y):
|
||||
while y > 0:
|
||||
x = math.ceil(x / 2)
|
||||
y = y - 1
|
||||
return x
|
||||
|
||||
scale = math.ceil(math.log2(math.sqrt(height * width / xs)))
|
||||
dsh = repeat_div(height, scale)
|
||||
dsw = repeat_div(width, scale)
|
||||
return dsh, dsw
|
||||
|
||||
|
||||
##### for prompt mode
|
||||
def get_attn_maps(self, attn):
|
||||
height, width = self.hw
|
||||
target_tokens = self.target_tokens
|
||||
if (height, width) not in self.attnmaps_sizes:
|
||||
self.attnmaps_sizes.append((height, width))
|
||||
|
||||
for b in range(self.batch):
|
||||
for t in target_tokens:
|
||||
power = self.power
|
||||
add = attn[b, :, :, t[0] : t[0] + len(t)] ** (power) * (self.attnmaps_sizes.index((height, width)) + 1)
|
||||
add = torch.sum(add, dim=2)
|
||||
key = f"{t}-{b}"
|
||||
if key not in self.attnmaps:
|
||||
self.attnmaps[key] = add
|
||||
else:
|
||||
if self.attnmaps[key].shape[1] != add.shape[1]:
|
||||
add = add.view(8, height, width)
|
||||
add = FF.resize(add, self.attnmaps_sizes[0], antialias=None)
|
||||
add = add.reshape_as(self.attnmaps[key])
|
||||
|
||||
self.attnmaps[key] = self.attnmaps[key] + add
|
||||
|
||||
|
||||
def reset_attnmaps(self): # init parameters in every batch
|
||||
self.step = 0
|
||||
self.attnmaps = {} # maked from attention maps
|
||||
self.attnmaps_sizes = [] # height,width set of u-net blocks
|
||||
self.attnmasks = {} # maked from attnmaps for regions
|
||||
self.maskready = False
|
||||
self.history = {}
|
||||
|
||||
|
||||
def saveattnmaps(self, output, h, w, th, step, regions):
|
||||
masks = []
|
||||
for i, mask in enumerate(self.history[step].values()):
|
||||
img, _, mask = makepmask(self, mask, h, w, th[i % len(th)], step)
|
||||
if self.ex:
|
||||
masks = [x - mask for x in masks]
|
||||
masks.append(mask)
|
||||
if len(masks) == regions - 1:
|
||||
output.images.extend([FF.to_pil_image(mask) for mask in masks])
|
||||
masks = []
|
||||
else:
|
||||
output.images.append(img)
|
||||
|
||||
|
||||
def makepmask(
|
||||
self, mask, h, w, th, step
|
||||
): # make masks from attention cache return [for preview, for attention, for Latent]
|
||||
th = th - step * 0.005
|
||||
if 0.05 >= th:
|
||||
th = 0.05
|
||||
mask = torch.mean(mask, dim=0)
|
||||
mask = mask / mask.max().item()
|
||||
mask = torch.where(mask > th, 1, 0)
|
||||
mask = mask.float()
|
||||
mask = mask.view(1, *self.attnmaps_sizes[0])
|
||||
img = FF.to_pil_image(mask)
|
||||
img = img.resize((w, h))
|
||||
mask = FF.resize(mask, (h, w), interpolation=FF.InterpolationMode.NEAREST, antialias=None)
|
||||
lmask = mask
|
||||
mask = mask.reshape(h * w)
|
||||
mask = torch.where(mask > 0.1, 1, 0)
|
||||
return img, mask, lmask
|
||||
|
||||
|
||||
def tokendealer(self, all_prompts):
|
||||
for prompts in all_prompts:
|
||||
targets = [p.split(",")[-1] for p in prompts[1:]]
|
||||
tt = []
|
||||
|
||||
for target in targets:
|
||||
ptokens = (
|
||||
self.tokenizer(
|
||||
prompts,
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
).input_ids
|
||||
)[0]
|
||||
ttokens = (
|
||||
self.tokenizer(
|
||||
target,
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
).input_ids
|
||||
)[0]
|
||||
|
||||
tlist = []
|
||||
|
||||
for t in range(ttokens.shape[0] - 2):
|
||||
for p in range(ptokens.shape[0]):
|
||||
if ttokens[t + 1] == ptokens[p]:
|
||||
tlist.append(p)
|
||||
if tlist != []:
|
||||
tt.append(tlist)
|
||||
|
||||
return tt
|
||||
|
||||
|
||||
def scaled_dot_product_attention(
|
||||
self, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None, getattn=False
|
||||
) -> torch.Tensor:
|
||||
# Efficient implementation equivalent to the following:
|
||||
L, S = query.size(-2), key.size(-2)
|
||||
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
|
||||
attn_bias = torch.zeros(L, S, dtype=query.dtype, device=self.device)
|
||||
if is_causal:
|
||||
assert attn_mask is None
|
||||
temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
|
||||
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
|
||||
attn_bias.to(query.dtype)
|
||||
|
||||
if attn_mask is not None:
|
||||
if attn_mask.dtype == torch.bool:
|
||||
attn_mask.masked_fill_(attn_mask.logical_not(), float("-inf"))
|
||||
else:
|
||||
attn_bias += attn_mask
|
||||
attn_weight = query @ key.transpose(-2, -1) * scale_factor
|
||||
attn_weight += attn_bias
|
||||
attn_weight = torch.softmax(attn_weight, dim=-1)
|
||||
if getattn:
|
||||
get_attn_maps(self, attn_weight)
|
||||
attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
|
||||
return attn_weight @ value
|
||||
@@ -0,0 +1,594 @@
|
||||
import math
|
||||
import tempfile
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
from torchvision import transforms
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from diffusers import AutoencoderKL, DiffusionPipeline, DPMSolverMultistepScheduler, UNet2DConditionModel
|
||||
from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin
|
||||
from diffusers.models.attention_processor import (
|
||||
AttnAddedKVProcessor,
|
||||
AttnAddedKVProcessor2_0,
|
||||
LoRAAttnAddedKVProcessor,
|
||||
LoRAAttnProcessor,
|
||||
LoRAAttnProcessor2_0,
|
||||
SlicedAttnAddedKVProcessor,
|
||||
)
|
||||
from diffusers.optimization import get_scheduler
|
||||
|
||||
|
||||
class SdeDragPipeline(DiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for image drag-and-drop editing using stochastic differential equations: https://arxiv.org/abs/2311.01410.
|
||||
Please refer to the [official repository](https://github.com/ML-GSAI/SDE-Drag) for more information.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`CLIPTextModel`]):
|
||||
Frozen text-encoder. Stable Diffusion uses the text portion of
|
||||
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
||||
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
||||
tokenizer (`CLIPTokenizer`):
|
||||
Tokenizer of class
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Please use
|
||||
[`DDIMScheduler`].
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: CLIPTextModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: DPMSolverMultistepScheduler,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt: str,
|
||||
image: PIL.Image.Image,
|
||||
mask_image: PIL.Image.Image,
|
||||
source_points: List[List[int]],
|
||||
target_points: List[List[int]],
|
||||
t0: Optional[float] = 0.6,
|
||||
steps: Optional[int] = 200,
|
||||
step_size: Optional[int] = 2,
|
||||
image_scale: Optional[float] = 0.3,
|
||||
adapt_radius: Optional[int] = 5,
|
||||
min_lora_scale: Optional[float] = 0.5,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for image editing.
|
||||
Args:
|
||||
prompt (`str`, *required*):
|
||||
The prompt to guide the image editing.
|
||||
image (`PIL.Image.Image`, *required*):
|
||||
Which will be edited, parts of the image will be masked out with `mask_image` and edited
|
||||
according to `prompt`.
|
||||
mask_image (`PIL.Image.Image`, *required*):
|
||||
To mask `image`. White pixels in the mask will be edited, while black pixels will be preserved.
|
||||
source_points (`List[List[int]]`, *required*):
|
||||
Used to mark the starting positions of drag editing in the image, with each pixel represented as a
|
||||
`List[int]` of length 2.
|
||||
target_points (`List[List[int]]`, *required*):
|
||||
Used to mark the target positions of drag editing in the image, with each pixel represented as a
|
||||
`List[int]` of length 2.
|
||||
t0 (`float`, *optional*, defaults to 0.6):
|
||||
The time parameter. Higher t0 improves the fidelity while lowering the faithfulness of the edited images
|
||||
and vice versa.
|
||||
steps (`int`, *optional*, defaults to 200):
|
||||
The number of sampling iterations.
|
||||
step_size (`int`, *optional*, defaults to 2):
|
||||
The drag diatance of each drag step.
|
||||
image_scale (`float`, *optional*, defaults to 0.3):
|
||||
To avoid duplicating the content, use image_scale to perturbs the source.
|
||||
adapt_radius (`int`, *optional*, defaults to 5):
|
||||
The size of the region for copy and paste operations during each step of the drag process.
|
||||
min_lora_scale (`float`, *optional*, defaults to 0.5):
|
||||
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
||||
min_lora_scale specifies the minimum LoRA scale during the image drag-editing process.
|
||||
generator ('torch.Generator', *optional*, defaults to None):
|
||||
To make generation deterministic(https://pytorch.org/docs/stable/generated/torch.Generator.html).
|
||||
Examples:
|
||||
```py
|
||||
>>> import PIL
|
||||
>>> import torch
|
||||
>>> from diffusers import DDIMScheduler, DiffusionPipeline
|
||||
|
||||
>>> # Load the pipeline
|
||||
>>> model_path = "runwayml/stable-diffusion-v1-5"
|
||||
>>> scheduler = DDIMScheduler.from_pretrained(model_path, subfolder="scheduler")
|
||||
>>> pipe = DiffusionPipeline.from_pretrained(model_path, scheduler=scheduler, custom_pipeline="sde_drag")
|
||||
>>> pipe.to('cuda')
|
||||
|
||||
>>> # To save GPU memory, torch.float16 can be used, but it may compromise image quality.
|
||||
>>> # If not training LoRA, please avoid using torch.float16
|
||||
>>> # pipe.to(torch.float16)
|
||||
|
||||
>>> # Provide prompt, image, mask image, and the starting and target points for drag editing.
|
||||
>>> prompt = "prompt of the image"
|
||||
>>> image = PIL.Image.open('/path/to/image')
|
||||
>>> mask_image = PIL.Image.open('/path/to/mask_image')
|
||||
>>> source_points = [[123, 456]]
|
||||
>>> target_points = [[234, 567]]
|
||||
|
||||
>>> # train_lora is optional, and in most cases, using train_lora can better preserve consistency with the original image.
|
||||
>>> pipe.train_lora(prompt, image)
|
||||
|
||||
>>> output = pipe(prompt, image, mask_image, source_points, target_points)
|
||||
>>> output_image = PIL.Image.fromarray(output)
|
||||
>>> output_image.save("./output.png")
|
||||
```
|
||||
"""
|
||||
|
||||
self.scheduler.set_timesteps(steps)
|
||||
|
||||
noise_scale = (1 - image_scale**2) ** (0.5)
|
||||
|
||||
text_embeddings = self._get_text_embed(prompt)
|
||||
uncond_embeddings = self._get_text_embed([""])
|
||||
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
||||
|
||||
latent = self._get_img_latent(image)
|
||||
|
||||
mask = mask_image.resize((latent.shape[3], latent.shape[2]))
|
||||
mask = torch.tensor(np.array(mask))
|
||||
mask = mask.unsqueeze(0).expand_as(latent).to(self.device)
|
||||
|
||||
source_points = torch.tensor(source_points).div(torch.tensor([8]), rounding_mode="trunc")
|
||||
target_points = torch.tensor(target_points).div(torch.tensor([8]), rounding_mode="trunc")
|
||||
|
||||
distance = target_points - source_points
|
||||
distance_norm_max = torch.norm(distance.float(), dim=1, keepdim=True).max()
|
||||
|
||||
if distance_norm_max <= step_size:
|
||||
drag_num = 1
|
||||
else:
|
||||
drag_num = distance_norm_max.div(torch.tensor([step_size]), rounding_mode="trunc")
|
||||
if (distance_norm_max / drag_num - step_size).abs() > (
|
||||
distance_norm_max / (drag_num + 1) - step_size
|
||||
).abs():
|
||||
drag_num += 1
|
||||
|
||||
latents = []
|
||||
for i in tqdm(range(int(drag_num)), desc="SDE Drag"):
|
||||
source_new = source_points + (i / drag_num * distance).to(torch.int)
|
||||
target_new = source_points + ((i + 1) / drag_num * distance).to(torch.int)
|
||||
|
||||
latent, noises, hook_latents, lora_scales, cfg_scales = self._forward(
|
||||
latent, steps, t0, min_lora_scale, text_embeddings, generator
|
||||
)
|
||||
latent = self._copy_and_paste(
|
||||
latent,
|
||||
source_new,
|
||||
target_new,
|
||||
adapt_radius,
|
||||
latent.shape[2] - 1,
|
||||
latent.shape[3] - 1,
|
||||
image_scale,
|
||||
noise_scale,
|
||||
generator,
|
||||
)
|
||||
latent = self._backward(
|
||||
latent, mask, steps, t0, noises, hook_latents, lora_scales, cfg_scales, text_embeddings, generator
|
||||
)
|
||||
|
||||
latents.append(latent)
|
||||
|
||||
result_image = 1 / 0.18215 * latents[-1]
|
||||
|
||||
with torch.no_grad():
|
||||
result_image = self.vae.decode(result_image).sample
|
||||
|
||||
result_image = (result_image / 2 + 0.5).clamp(0, 1)
|
||||
result_image = result_image.cpu().permute(0, 2, 3, 1).numpy()[0]
|
||||
result_image = (result_image * 255).astype(np.uint8)
|
||||
|
||||
return result_image
|
||||
|
||||
def train_lora(self, prompt, image, lora_step=100, lora_rank=16, generator=None):
|
||||
accelerator = Accelerator(gradient_accumulation_steps=1, mixed_precision="fp16")
|
||||
|
||||
self.vae.requires_grad_(False)
|
||||
self.text_encoder.requires_grad_(False)
|
||||
self.unet.requires_grad_(False)
|
||||
|
||||
unet_lora_attn_procs = {}
|
||||
for name, attn_processor in self.unet.attn_processors.items():
|
||||
cross_attention_dim = None if name.endswith("attn1.processor") else self.unet.config.cross_attention_dim
|
||||
if name.startswith("mid_block"):
|
||||
hidden_size = self.unet.config.block_out_channels[-1]
|
||||
elif name.startswith("up_blocks"):
|
||||
block_id = int(name[len("up_blocks.")])
|
||||
hidden_size = list(reversed(self.unet.config.block_out_channels))[block_id]
|
||||
elif name.startswith("down_blocks"):
|
||||
block_id = int(name[len("down_blocks.")])
|
||||
hidden_size = self.unet.config.block_out_channels[block_id]
|
||||
else:
|
||||
raise NotImplementedError("name must start with up_blocks, mid_blocks, or down_blocks")
|
||||
|
||||
if isinstance(attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)):
|
||||
lora_attn_processor_class = LoRAAttnAddedKVProcessor
|
||||
else:
|
||||
lora_attn_processor_class = (
|
||||
LoRAAttnProcessor2_0
|
||||
if hasattr(torch.nn.functional, "scaled_dot_product_attention")
|
||||
else LoRAAttnProcessor
|
||||
)
|
||||
unet_lora_attn_procs[name] = lora_attn_processor_class(
|
||||
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=lora_rank
|
||||
)
|
||||
|
||||
self.unet.set_attn_processor(unet_lora_attn_procs)
|
||||
unet_lora_layers = AttnProcsLayers(self.unet.attn_processors)
|
||||
params_to_optimize = unet_lora_layers.parameters()
|
||||
|
||||
optimizer = torch.optim.AdamW(
|
||||
params_to_optimize,
|
||||
lr=2e-4,
|
||||
betas=(0.9, 0.999),
|
||||
weight_decay=1e-2,
|
||||
eps=1e-08,
|
||||
)
|
||||
|
||||
lr_scheduler = get_scheduler(
|
||||
"constant",
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=0,
|
||||
num_training_steps=lora_step,
|
||||
num_cycles=1,
|
||||
power=1.0,
|
||||
)
|
||||
|
||||
unet_lora_layers = accelerator.prepare_model(unet_lora_layers)
|
||||
optimizer = accelerator.prepare_optimizer(optimizer)
|
||||
lr_scheduler = accelerator.prepare_scheduler(lr_scheduler)
|
||||
|
||||
with torch.no_grad():
|
||||
text_inputs = self._tokenize_prompt(prompt, tokenizer_max_length=None)
|
||||
text_embedding = self._encode_prompt(
|
||||
text_inputs.input_ids, text_inputs.attention_mask, text_encoder_use_attention_mask=False
|
||||
)
|
||||
|
||||
image_transforms = transforms.Compose(
|
||||
[
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5]),
|
||||
]
|
||||
)
|
||||
|
||||
image = image_transforms(image).to(self.device, dtype=self.vae.dtype)
|
||||
image = image.unsqueeze(dim=0)
|
||||
latents_dist = self.vae.encode(image).latent_dist
|
||||
|
||||
for _ in tqdm(range(lora_step), desc="Train LoRA"):
|
||||
self.unet.train()
|
||||
model_input = latents_dist.sample() * self.vae.config.scaling_factor
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn(
|
||||
model_input.size(),
|
||||
dtype=model_input.dtype,
|
||||
layout=model_input.layout,
|
||||
device=model_input.device,
|
||||
generator=generator,
|
||||
)
|
||||
bsz, channels, height, width = model_input.shape
|
||||
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.randint(
|
||||
0, self.scheduler.config.num_train_timesteps, (bsz,), device=model_input.device, generator=generator
|
||||
)
|
||||
timesteps = timesteps.long()
|
||||
|
||||
# Add noise to the model input according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
noisy_model_input = self.scheduler.add_noise(model_input, noise, timesteps)
|
||||
|
||||
# Predict the noise residual
|
||||
model_pred = self.unet(noisy_model_input, timesteps, text_embedding).sample
|
||||
|
||||
# Get the target for loss depending on the prediction type
|
||||
if self.scheduler.config.prediction_type == "epsilon":
|
||||
target = noise
|
||||
elif self.scheduler.config.prediction_type == "v_prediction":
|
||||
target = self.scheduler.get_velocity(model_input, noise, timesteps)
|
||||
else:
|
||||
raise ValueError(f"Unknown prediction type {self.scheduler.config.prediction_type}")
|
||||
|
||||
loss = torch.nn.functional.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
with tempfile.TemporaryDirectory() as save_lora_dir:
|
||||
LoraLoaderMixin.save_lora_weights(
|
||||
save_directory=save_lora_dir,
|
||||
unet_lora_layers=unet_lora_layers,
|
||||
text_encoder_lora_layers=None,
|
||||
)
|
||||
|
||||
self.unet.load_attn_procs(save_lora_dir)
|
||||
|
||||
def _tokenize_prompt(self, prompt, tokenizer_max_length=None):
|
||||
if tokenizer_max_length is not None:
|
||||
max_length = tokenizer_max_length
|
||||
else:
|
||||
max_length = self.tokenizer.model_max_length
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
truncation=True,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
return text_inputs
|
||||
|
||||
def _encode_prompt(self, input_ids, attention_mask, text_encoder_use_attention_mask=False):
|
||||
text_input_ids = input_ids.to(self.device)
|
||||
|
||||
if text_encoder_use_attention_mask:
|
||||
attention_mask = attention_mask.to(self.device)
|
||||
else:
|
||||
attention_mask = None
|
||||
|
||||
prompt_embeds = self.text_encoder(
|
||||
text_input_ids,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
prompt_embeds = prompt_embeds[0]
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
@torch.no_grad()
|
||||
def _get_text_embed(self, prompt):
|
||||
text_input = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
|
||||
return text_embeddings
|
||||
|
||||
def _copy_and_paste(
|
||||
self, latent, source_new, target_new, adapt_radius, max_height, max_width, image_scale, noise_scale, generator
|
||||
):
|
||||
def adaption_r(source, target, adapt_radius, max_height, max_width):
|
||||
r_x_lower = min(adapt_radius, source[0], target[0])
|
||||
r_x_upper = min(adapt_radius, max_width - source[0], max_width - target[0])
|
||||
r_y_lower = min(adapt_radius, source[1], target[1])
|
||||
r_y_upper = min(adapt_radius, max_height - source[1], max_height - target[1])
|
||||
return r_x_lower, r_x_upper, r_y_lower, r_y_upper
|
||||
|
||||
for source_, target_ in zip(source_new, target_new):
|
||||
r_x_lower, r_x_upper, r_y_lower, r_y_upper = adaption_r(
|
||||
source_, target_, adapt_radius, max_height, max_width
|
||||
)
|
||||
|
||||
source_feature = latent[
|
||||
:, :, source_[1] - r_y_lower : source_[1] + r_y_upper, source_[0] - r_x_lower : source_[0] + r_x_upper
|
||||
].clone()
|
||||
|
||||
latent[
|
||||
:, :, source_[1] - r_y_lower : source_[1] + r_y_upper, source_[0] - r_x_lower : source_[0] + r_x_upper
|
||||
] = image_scale * source_feature + noise_scale * torch.randn(
|
||||
latent.shape[0],
|
||||
4,
|
||||
r_y_lower + r_y_upper,
|
||||
r_x_lower + r_x_upper,
|
||||
device=self.device,
|
||||
generator=generator,
|
||||
)
|
||||
|
||||
latent[
|
||||
:, :, target_[1] - r_y_lower : target_[1] + r_y_upper, target_[0] - r_x_lower : target_[0] + r_x_upper
|
||||
] = source_feature * 1.1
|
||||
return latent
|
||||
|
||||
@torch.no_grad()
|
||||
def _get_img_latent(self, image, height=None, weight=None):
|
||||
data = image.convert("RGB")
|
||||
if height is not None:
|
||||
data = data.resize((weight, height))
|
||||
transform = transforms.ToTensor()
|
||||
data = transform(data).unsqueeze(0)
|
||||
data = (data * 2.0) - 1.0
|
||||
data = data.to(self.device, dtype=self.vae.dtype)
|
||||
latent = self.vae.encode(data).latent_dist.sample()
|
||||
latent = 0.18215 * latent
|
||||
return latent
|
||||
|
||||
@torch.no_grad()
|
||||
def _get_eps(self, latent, timestep, guidance_scale, text_embeddings, lora_scale=None):
|
||||
latent_model_input = torch.cat([latent] * 2) if guidance_scale > 1.0 else latent
|
||||
text_embeddings = text_embeddings if guidance_scale > 1.0 else text_embeddings.chunk(2)[1]
|
||||
|
||||
cross_attention_kwargs = None if lora_scale is None else {"scale": lora_scale}
|
||||
|
||||
with torch.no_grad():
|
||||
noise_pred = self.unet(
|
||||
latent_model_input,
|
||||
timestep,
|
||||
encoder_hidden_states=text_embeddings,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
).sample
|
||||
|
||||
if guidance_scale > 1.0:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
elif guidance_scale == 1.0:
|
||||
noise_pred_text = noise_pred
|
||||
noise_pred_uncond = 0.0
|
||||
else:
|
||||
raise NotImplementedError(guidance_scale)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
return noise_pred
|
||||
|
||||
def _forward_sde(
|
||||
self, timestep, sample, guidance_scale, text_embeddings, steps, eta=1.0, lora_scale=None, generator=None
|
||||
):
|
||||
num_train_timesteps = len(self.scheduler)
|
||||
alphas_cumprod = self.scheduler.alphas_cumprod
|
||||
initial_alpha_cumprod = torch.tensor(1.0)
|
||||
|
||||
prev_timestep = timestep + num_train_timesteps // steps
|
||||
|
||||
alpha_prod_t = alphas_cumprod[timestep] if timestep >= 0 else initial_alpha_cumprod
|
||||
alpha_prod_t_prev = alphas_cumprod[prev_timestep]
|
||||
|
||||
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
||||
|
||||
x_prev = (alpha_prod_t_prev / alpha_prod_t) ** (0.5) * sample + (1 - alpha_prod_t_prev / alpha_prod_t) ** (
|
||||
0.5
|
||||
) * torch.randn(
|
||||
sample.size(), dtype=sample.dtype, layout=sample.layout, device=self.device, generator=generator
|
||||
)
|
||||
eps = self._get_eps(x_prev, prev_timestep, guidance_scale, text_embeddings, lora_scale)
|
||||
|
||||
sigma_t_prev = (
|
||||
eta
|
||||
* (1 - alpha_prod_t) ** (0.5)
|
||||
* (1 - alpha_prod_t_prev / (1 - alpha_prod_t_prev) * (1 - alpha_prod_t) / alpha_prod_t) ** (0.5)
|
||||
)
|
||||
|
||||
pred_original_sample = (x_prev - beta_prod_t_prev ** (0.5) * eps) / alpha_prod_t_prev ** (0.5)
|
||||
pred_sample_direction_coeff = (1 - alpha_prod_t - sigma_t_prev**2) ** (0.5)
|
||||
|
||||
noise = (
|
||||
sample - alpha_prod_t ** (0.5) * pred_original_sample - pred_sample_direction_coeff * eps
|
||||
) / sigma_t_prev
|
||||
|
||||
return x_prev, noise
|
||||
|
||||
def _sample(
|
||||
self,
|
||||
timestep,
|
||||
sample,
|
||||
guidance_scale,
|
||||
text_embeddings,
|
||||
steps,
|
||||
sde=False,
|
||||
noise=None,
|
||||
eta=1.0,
|
||||
lora_scale=None,
|
||||
generator=None,
|
||||
):
|
||||
num_train_timesteps = len(self.scheduler)
|
||||
alphas_cumprod = self.scheduler.alphas_cumprod
|
||||
final_alpha_cumprod = torch.tensor(1.0)
|
||||
|
||||
eps = self._get_eps(sample, timestep, guidance_scale, text_embeddings, lora_scale)
|
||||
|
||||
prev_timestep = timestep - num_train_timesteps // steps
|
||||
|
||||
alpha_prod_t = alphas_cumprod[timestep]
|
||||
alpha_prod_t_prev = alphas_cumprod[prev_timestep] if prev_timestep >= 0 else final_alpha_cumprod
|
||||
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
|
||||
sigma_t = (
|
||||
eta
|
||||
* ((1 - alpha_prod_t_prev) / (1 - alpha_prod_t)) ** (0.5)
|
||||
* (1 - alpha_prod_t / alpha_prod_t_prev) ** (0.5)
|
||||
if sde
|
||||
else 0
|
||||
)
|
||||
|
||||
pred_original_sample = (sample - beta_prod_t ** (0.5) * eps) / alpha_prod_t ** (0.5)
|
||||
pred_sample_direction_coeff = (1 - alpha_prod_t_prev - sigma_t**2) ** (0.5)
|
||||
|
||||
noise = (
|
||||
torch.randn(
|
||||
sample.size(), dtype=sample.dtype, layout=sample.layout, device=self.device, generator=generator
|
||||
)
|
||||
if noise is None
|
||||
else noise
|
||||
)
|
||||
latent = (
|
||||
alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction_coeff * eps + sigma_t * noise
|
||||
)
|
||||
|
||||
return latent
|
||||
|
||||
def _forward(self, latent, steps, t0, lora_scale_min, text_embeddings, generator):
|
||||
def scale_schedule(begin, end, n, length, type="linear"):
|
||||
if type == "constant":
|
||||
return end
|
||||
elif type == "linear":
|
||||
return begin + (end - begin) * n / length
|
||||
elif type == "cos":
|
||||
factor = (1 - math.cos(n * math.pi / length)) / 2
|
||||
return (1 - factor) * begin + factor * end
|
||||
else:
|
||||
raise NotImplementedError(type)
|
||||
|
||||
noises = []
|
||||
latents = []
|
||||
lora_scales = []
|
||||
cfg_scales = []
|
||||
latents.append(latent)
|
||||
t0 = int(t0 * steps)
|
||||
t_begin = steps - t0
|
||||
|
||||
length = len(self.scheduler.timesteps[t_begin - 1 : -1]) - 1
|
||||
index = 1
|
||||
for t in self.scheduler.timesteps[t_begin:].flip(dims=[0]):
|
||||
lora_scale = scale_schedule(1, lora_scale_min, index, length, type="cos")
|
||||
cfg_scale = scale_schedule(1, 3.0, index, length, type="linear")
|
||||
latent, noise = self._forward_sde(
|
||||
t, latent, cfg_scale, text_embeddings, steps, lora_scale=lora_scale, generator=generator
|
||||
)
|
||||
|
||||
noises.append(noise)
|
||||
latents.append(latent)
|
||||
lora_scales.append(lora_scale)
|
||||
cfg_scales.append(cfg_scale)
|
||||
index += 1
|
||||
return latent, noises, latents, lora_scales, cfg_scales
|
||||
|
||||
def _backward(
|
||||
self, latent, mask, steps, t0, noises, hook_latents, lora_scales, cfg_scales, text_embeddings, generator
|
||||
):
|
||||
t0 = int(t0 * steps)
|
||||
t_begin = steps - t0
|
||||
|
||||
hook_latent = hook_latents.pop()
|
||||
latent = torch.where(mask > 128, latent, hook_latent)
|
||||
for t in self.scheduler.timesteps[t_begin - 1 : -1]:
|
||||
latent = self._sample(
|
||||
t,
|
||||
latent,
|
||||
cfg_scales.pop(),
|
||||
text_embeddings,
|
||||
steps,
|
||||
sde=True,
|
||||
noise=noises.pop(),
|
||||
lora_scale=lora_scales.pop(),
|
||||
generator=generator,
|
||||
)
|
||||
hook_latent = hook_latents.pop()
|
||||
latent = torch.where(mask > 128, latent, hook_latent)
|
||||
return latent
|
||||
@@ -21,7 +21,7 @@ from packaging import version
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from diffusers.configuration_utils import FrozenDict
|
||||
from diffusers.loaders import TextualInversionLoaderMixin
|
||||
from diffusers.loaders import LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
@@ -62,7 +62,7 @@ EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
class StableDiffusionIPEXPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
|
||||
class StableDiffusionIPEXPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using Stable Diffusion on IPEX.
|
||||
|
||||
|
||||
@@ -28,6 +28,7 @@ import PIL.Image
|
||||
import tensorrt as trt
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
from onnx import shape_inference
|
||||
from polygraphy import cuda
|
||||
from polygraphy.backend.common import bytes_from_path
|
||||
@@ -41,7 +42,7 @@ from polygraphy.backend.trt import (
|
||||
save_engine,
|
||||
)
|
||||
from polygraphy.backend.trt import util as trt_util
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
||||
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers.pipelines.stable_diffusion import (
|
||||
@@ -50,7 +51,7 @@ from diffusers.pipelines.stable_diffusion import (
|
||||
StableDiffusionSafetyChecker,
|
||||
)
|
||||
from diffusers.schedulers import DDIMScheduler
|
||||
from diffusers.utils import DIFFUSERS_CACHE, logging
|
||||
from diffusers.utils import logging
|
||||
|
||||
|
||||
"""
|
||||
@@ -709,6 +710,7 @@ class TensorRTStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
|
||||
scheduler: DDIMScheduler,
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
image_encoder: CLIPVisionModelWithProjection = None,
|
||||
requires_safety_checker: bool = True,
|
||||
stages=["clip", "unet", "vae", "vae_encoder"],
|
||||
image_height: int = 512,
|
||||
@@ -724,7 +726,15 @@ class TensorRTStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
|
||||
timing_cache: str = "timing_cache",
|
||||
):
|
||||
super().__init__(
|
||||
vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker
|
||||
vae,
|
||||
text_encoder,
|
||||
tokenizer,
|
||||
unet,
|
||||
scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
image_encoder=image_encoder,
|
||||
requires_safety_checker=requires_safety_checker,
|
||||
)
|
||||
|
||||
self.vae.forward = self.vae.decode
|
||||
@@ -769,12 +779,13 @@ class TensorRTStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
|
||||
self.models["vae_encoder"] = make_VAEEncoder(self.vae, **models_args)
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
def set_cached_folder(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
|
||||
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||
token = kwargs.pop("token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
|
||||
cls.cached_folder = (
|
||||
@@ -786,7 +797,7 @@ class TensorRTStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
token=token,
|
||||
revision=revision,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -28,6 +28,7 @@ import PIL.Image
|
||||
import tensorrt as trt
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
from onnx import shape_inference
|
||||
from polygraphy import cuda
|
||||
from polygraphy.backend.common import bytes_from_path
|
||||
@@ -41,7 +42,7 @@ from polygraphy.backend.trt import (
|
||||
save_engine,
|
||||
)
|
||||
from polygraphy.backend.trt import util as trt_util
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
||||
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers.pipelines.stable_diffusion import (
|
||||
@@ -51,7 +52,7 @@ from diffusers.pipelines.stable_diffusion import (
|
||||
)
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint import prepare_mask_and_masked_image
|
||||
from diffusers.schedulers import DDIMScheduler
|
||||
from diffusers.utils import DIFFUSERS_CACHE, logging
|
||||
from diffusers.utils import logging
|
||||
|
||||
|
||||
"""
|
||||
@@ -710,6 +711,7 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline):
|
||||
scheduler: DDIMScheduler,
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
image_encoder: CLIPVisionModelWithProjection = None,
|
||||
requires_safety_checker: bool = True,
|
||||
stages=["clip", "unet", "vae", "vae_encoder"],
|
||||
image_height: int = 512,
|
||||
@@ -725,7 +727,15 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline):
|
||||
timing_cache: str = "timing_cache",
|
||||
):
|
||||
super().__init__(
|
||||
vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker
|
||||
vae,
|
||||
text_encoder,
|
||||
tokenizer,
|
||||
unet,
|
||||
scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
image_encoder=image_encoder,
|
||||
requires_safety_checker=requires_safety_checker,
|
||||
)
|
||||
|
||||
self.vae.forward = self.vae.decode
|
||||
@@ -770,12 +780,13 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline):
|
||||
self.models["vae_encoder"] = make_VAEEncoder(self.vae, **models_args)
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
def set_cached_folder(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
|
||||
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||
token = kwargs.pop("token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
|
||||
cls.cached_folder = (
|
||||
@@ -787,7 +798,7 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline):
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
token=token,
|
||||
revision=revision,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -27,6 +27,7 @@ import onnx_graphsurgeon as gs
|
||||
import tensorrt as trt
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
from onnx import shape_inference
|
||||
from polygraphy import cuda
|
||||
from polygraphy.backend.common import bytes_from_path
|
||||
@@ -40,7 +41,7 @@ from polygraphy.backend.trt import (
|
||||
save_engine,
|
||||
)
|
||||
from polygraphy.backend.trt import util as trt_util
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
||||
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers.pipelines.stable_diffusion import (
|
||||
@@ -49,7 +50,7 @@ from diffusers.pipelines.stable_diffusion import (
|
||||
StableDiffusionSafetyChecker,
|
||||
)
|
||||
from diffusers.schedulers import DDIMScheduler
|
||||
from diffusers.utils import DIFFUSERS_CACHE, logging
|
||||
from diffusers.utils import logging
|
||||
|
||||
|
||||
"""
|
||||
@@ -624,6 +625,7 @@ class TensorRTStableDiffusionPipeline(StableDiffusionPipeline):
|
||||
scheduler: DDIMScheduler,
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
image_encoder: CLIPVisionModelWithProjection = None,
|
||||
requires_safety_checker: bool = True,
|
||||
stages=["clip", "unet", "vae"],
|
||||
image_height: int = 768,
|
||||
@@ -639,7 +641,15 @@ class TensorRTStableDiffusionPipeline(StableDiffusionPipeline):
|
||||
timing_cache: str = "timing_cache",
|
||||
):
|
||||
super().__init__(
|
||||
vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker
|
||||
vae,
|
||||
text_encoder,
|
||||
tokenizer,
|
||||
unet,
|
||||
scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
image_encoder=image_encoder,
|
||||
requires_safety_checker=requires_safety_checker,
|
||||
)
|
||||
|
||||
self.vae.forward = self.vae.decode
|
||||
@@ -682,12 +692,13 @@ class TensorRTStableDiffusionPipeline(StableDiffusionPipeline):
|
||||
self.models["vae"] = make_VAE(self.vae, **models_args)
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
def set_cached_folder(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
|
||||
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||
token = kwargs.pop("token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
|
||||
cls.cached_folder = (
|
||||
@@ -699,7 +710,7 @@ class TensorRTStableDiffusionPipeline(StableDiffusionPipeline):
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
token=token,
|
||||
revision=revision,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# Latent Consistency Distillation Example:
|
||||
|
||||
[Latent Consistency Models (LCMs)](https://arxiv.org/abs/2310.04378) is method to distill latent diffusion model to enable swift inference with minimal steps. This example demonstrates how to use the latent consistency distillation to distill stable-diffusion-v1.5 for less timestep inference.
|
||||
[Latent Consistency Models (LCMs)](https://arxiv.org/abs/2310.04378) is a method to distill a latent diffusion model to enable swift inference with minimal steps. This example demonstrates how to use latent consistency distillation to distill stable-diffusion-v1.5 for inference with few timesteps.
|
||||
|
||||
## Full model distillation
|
||||
|
||||
@@ -24,7 +24,7 @@ Then cd in the example folder and run
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
|
||||
And initialize an [🤗 Accelerate](https://github.com/huggingface/accelerate/) environment with:
|
||||
|
||||
```bash
|
||||
accelerate config
|
||||
@@ -46,12 +46,16 @@ write_basic_config()
|
||||
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
|
||||
|
||||
|
||||
#### Example with LAION-A6+ dataset
|
||||
#### Example
|
||||
|
||||
The following uses the [Conceptual Captions 12M (CC12M) dataset](https://github.com/google-research-datasets/conceptual-12m) as an example, and for illustrative purposes only. For best results you may consider large and high-quality text-image datasets such as [LAION](https://laion.ai/blog/laion-400-open-dataset/). You may also need to search the hyperparameter space according to the dataset you use.
|
||||
|
||||
```bash
|
||||
runwayml/stable-diffusion-v1-5
|
||||
PROGRAM="train_lcm_distill_sd_wds.py \
|
||||
--pretrained_teacher_model=$MODEL_DIR \
|
||||
export MODEL_NAME="runwayml/stable-diffusion-v1-5"
|
||||
export OUTPUT_DIR="path/to/saved/model"
|
||||
|
||||
accelerate launch train_lcm_distill_sd_wds.py \
|
||||
--pretrained_teacher_model=$MODEL_NAME \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--mixed_precision=fp16 \
|
||||
--resolution=512 \
|
||||
@@ -59,7 +63,7 @@ PROGRAM="train_lcm_distill_sd_wds.py \
|
||||
--max_train_steps=1000 \
|
||||
--max_train_samples=4000000 \
|
||||
--dataloader_num_workers=8 \
|
||||
--train_shards_path_or_url='pipe:aws s3 cp s3://muse-datasets/laion-aesthetic6plus-min512-data/{00000..01210}.tar -' \
|
||||
--train_shards_path_or_url="pipe:curl -L -s https://huggingface.co/datasets/laion/conceptual-captions-12m-webdataset/resolve/main/data/{00000..01099}.tar?download=true" \
|
||||
--validation_steps=200 \
|
||||
--checkpointing_steps=200 --checkpoints_total_limit=10 \
|
||||
--train_batch_size=12 \
|
||||
@@ -69,19 +73,23 @@ PROGRAM="train_lcm_distill_sd_wds.py \
|
||||
--resume_from_checkpoint=latest \
|
||||
--report_to=wandb \
|
||||
--seed=453645634 \
|
||||
--push_to_hub \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
## LCM-LoRA
|
||||
|
||||
Instead of fine-tuning the full model, we can also just train a LoRA that can be injected into any SDXL model.
|
||||
|
||||
### Example with LAION-A6+ dataset
|
||||
|
||||
### Example
|
||||
|
||||
The following uses the [Conceptual Captions 12M (CC12M) dataset](https://github.com/google-research-datasets/conceptual-12m) as an example. For best results you may consider large and high-quality text-image datasets such as [LAION](https://laion.ai/blog/laion-400-open-dataset/).
|
||||
|
||||
```bash
|
||||
runwayml/stable-diffusion-v1-5
|
||||
PROGRAM="train_lcm_distill_lora_sd_wds.py \
|
||||
--pretrained_teacher_model=$MODEL_DIR \
|
||||
export MODEL_NAME="runwayml/stable-diffusion-v1-5"
|
||||
export OUTPUT_DIR="path/to/saved/model"
|
||||
|
||||
accelerate launch train_lcm_distill_lora_sd_wds.py \
|
||||
--pretrained_teacher_model=$MODEL_NAME \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--mixed_precision=fp16 \
|
||||
--resolution=512 \
|
||||
@@ -90,7 +98,7 @@ PROGRAM="train_lcm_distill_lora_sd_wds.py \
|
||||
--max_train_steps=1000 \
|
||||
--max_train_samples=4000000 \
|
||||
--dataloader_num_workers=8 \
|
||||
--train_shards_path_or_url='pipe:aws s3 cp s3://muse-datasets/laion-aesthetic6plus-min512-data/{00000..01210}.tar -' \
|
||||
--train_shards_path_or_url="pipe:curl -L -s https://huggingface.co/datasets/laion/conceptual-captions-12m-webdataset/resolve/main/data/{00000..01099}.tar?download=true" \
|
||||
--validation_steps=200 \
|
||||
--checkpointing_steps=200 --checkpoints_total_limit=10 \
|
||||
--train_batch_size=12 \
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# Latent Consistency Distillation Example:
|
||||
|
||||
[Latent Consistency Models (LCMs)](https://arxiv.org/abs/2310.04378) is method to distill latent diffusion model to enable swift inference with minimal steps. This example demonstrates how to use the latent consistency distillation to distill SDXL for less timestep inference.
|
||||
[Latent Consistency Models (LCMs)](https://arxiv.org/abs/2310.04378) is a method to distill a latent diffusion model to enable swift inference with minimal steps. This example demonstrates how to use latent consistency distillation to distill SDXL for inference with few timesteps.
|
||||
|
||||
## Full model distillation
|
||||
|
||||
@@ -24,7 +24,7 @@ Then cd in the example folder and run
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
|
||||
And initialize an [🤗 Accelerate](https://github.com/huggingface/accelerate/) environment with:
|
||||
|
||||
```bash
|
||||
accelerate config
|
||||
@@ -46,12 +46,16 @@ write_basic_config()
|
||||
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
|
||||
|
||||
|
||||
#### Example with LAION-A6+ dataset
|
||||
#### Example
|
||||
|
||||
The following uses the [Conceptual Captions 12M (CC12M) dataset](https://github.com/google-research-datasets/conceptual-12m) as an example, and for illustrative purposes only. For best results you may consider large and high-quality text-image datasets such as [LAION](https://laion.ai/blog/laion-400-open-dataset/). You may also need to search the hyperparameter space according to the dataset you use.
|
||||
|
||||
```bash
|
||||
export MODEL_DIR="stabilityai/stable-diffusion-xl-base-1.0"
|
||||
PROGRAM="train_lcm_distill_sdxl_wds.py \
|
||||
--pretrained_teacher_model=$MODEL_DIR \
|
||||
export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
|
||||
export OUTPUT_DIR="path/to/saved/model"
|
||||
|
||||
accelerate launch train_lcm_distill_sdxl_wds.py \
|
||||
--pretrained_teacher_model=$MODEL_NAME \
|
||||
--pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--mixed_precision=fp16 \
|
||||
@@ -60,7 +64,7 @@ PROGRAM="train_lcm_distill_sdxl_wds.py \
|
||||
--max_train_steps=1000 \
|
||||
--max_train_samples=4000000 \
|
||||
--dataloader_num_workers=8 \
|
||||
--train_shards_path_or_url='pipe:aws s3 cp s3://muse-datasets/laion-aesthetic6plus-min512-data/{00000..01210}.tar -' \
|
||||
--train_shards_path_or_url="pipe:curl -L -s https://huggingface.co/datasets/laion/conceptual-captions-12m-webdataset/resolve/main/data/{00000..01099}.tar?download=true" \
|
||||
--validation_steps=200 \
|
||||
--checkpointing_steps=200 --checkpoints_total_limit=10 \
|
||||
--train_batch_size=12 \
|
||||
@@ -77,11 +81,15 @@ PROGRAM="train_lcm_distill_sdxl_wds.py \
|
||||
|
||||
Instead of fine-tuning the full model, we can also just train a LoRA that can be injected into any SDXL model.
|
||||
|
||||
### Example with LAION-A6+ dataset
|
||||
|
||||
### Example
|
||||
|
||||
The following uses the [Conceptual Captions 12M (CC12M) dataset](https://github.com/google-research-datasets/conceptual-12m) as an example. For best results you may consider large and high-quality text-image datasets such as [LAION](https://laion.ai/blog/laion-400-open-dataset/).
|
||||
|
||||
```bash
|
||||
export MODEL_DIR="stabilityai/stable-diffusion-xl-base-1.0"
|
||||
PROGRAM="train_lcm_distill_lora_sdxl_wds.py \
|
||||
export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
|
||||
export OUTPUT_DIR="path/to/saved/model"
|
||||
|
||||
accelerate launch train_lcm_distill_lora_sdxl_wds.py \
|
||||
--pretrained_teacher_model=$MODEL_DIR \
|
||||
--pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
@@ -92,7 +100,7 @@ PROGRAM="train_lcm_distill_lora_sdxl_wds.py \
|
||||
--max_train_steps=1000 \
|
||||
--max_train_samples=4000000 \
|
||||
--dataloader_num_workers=8 \
|
||||
--train_shards_path_or_url='pipe:aws s3 cp s3://muse-datasets/laion-aesthetic6plus-min512-data/{00000..01210}.tar -' \
|
||||
--train_shards_path_or_url="pipe:curl -L -s https://huggingface.co/datasets/laion/conceptual-captions-12m-webdataset/resolve/main/data/{00000..01099}.tar?download=true" \
|
||||
--validation_steps=200 \
|
||||
--checkpointing_steps=200 --checkpoints_total_limit=10 \
|
||||
--train_batch_size=12 \
|
||||
|
||||
@@ -71,7 +71,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.18.0.dev0")
|
||||
check_min_version("0.25.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -423,7 +423,7 @@ def import_model_class_from_model_name_or_path(
|
||||
pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
|
||||
):
|
||||
text_encoder_config = PretrainedConfig.from_pretrained(
|
||||
pretrained_model_name_or_path, subfolder=subfolder, revision=revision, use_auth_token=True
|
||||
pretrained_model_name_or_path, subfolder=subfolder, revision=revision
|
||||
)
|
||||
model_class = text_encoder_config.architectures[0]
|
||||
|
||||
@@ -1123,7 +1123,7 @@ def main(args):
|
||||
for epoch in range(first_epoch, args.num_train_epochs):
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
with accelerator.accumulate(unet):
|
||||
image, text, _, _ = batch
|
||||
image, text = batch
|
||||
|
||||
image = image.to(accelerator.device, non_blocking=True)
|
||||
encoded_text = compute_embeddings_fn(text)
|
||||
|
||||
@@ -68,11 +68,16 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
MAX_SEQ_LENGTH = 77
|
||||
|
||||
# Adjust for your dataset
|
||||
WDS_JSON_WIDTH = "width" # original_width for LAION
|
||||
WDS_JSON_HEIGHT = "height" # original_height for LAION
|
||||
MIN_SIZE = 700 # ~960 for LAION, ideal: 1024 if the dataset contains large images
|
||||
|
||||
if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.18.0.dev0")
|
||||
check_min_version("0.25.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -146,10 +151,10 @@ class WebdatasetFilter:
|
||||
try:
|
||||
if "json" in x:
|
||||
x_json = json.loads(x["json"])
|
||||
filter_size = (x_json.get("original_width", 0.0) or 0.0) >= self.min_size and x_json.get(
|
||||
"original_height", 0
|
||||
filter_size = (x_json.get(WDS_JSON_WIDTH, 0.0) or 0.0) >= self.min_size and x_json.get(
|
||||
WDS_JSON_HEIGHT, 0
|
||||
) >= self.min_size
|
||||
filter_watermark = (x_json.get("pwatermark", 1.0) or 1.0) <= self.max_pwatermark
|
||||
filter_watermark = (x_json.get("pwatermark", 0.0) or 0.0) <= self.max_pwatermark
|
||||
return filter_size and filter_watermark
|
||||
else:
|
||||
return False
|
||||
@@ -180,7 +185,7 @@ class Text2ImageDataset:
|
||||
if use_fix_crop_and_size:
|
||||
return (resolution, resolution)
|
||||
else:
|
||||
return (int(json.get("original_width", 0.0)), int(json.get("original_height", 0.0)))
|
||||
return (int(json.get(WDS_JSON_WIDTH, 0.0)), int(json.get(WDS_JSON_HEIGHT, 0.0)))
|
||||
|
||||
def transform(example):
|
||||
# resize image
|
||||
@@ -212,7 +217,7 @@ class Text2ImageDataset:
|
||||
pipeline = [
|
||||
wds.ResampledShards(train_shards_path_or_url),
|
||||
tarfile_to_samples_nothrow,
|
||||
wds.select(WebdatasetFilter(min_size=960)),
|
||||
wds.select(WebdatasetFilter(min_size=MIN_SIZE)),
|
||||
wds.shuffle(shuffle_buffer_size),
|
||||
*processing_pipeline,
|
||||
wds.batched(per_gpu_batch_size, partial=False, collation_fn=default_collate),
|
||||
@@ -392,7 +397,7 @@ def import_model_class_from_model_name_or_path(
|
||||
pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
|
||||
):
|
||||
text_encoder_config = PretrainedConfig.from_pretrained(
|
||||
pretrained_model_name_or_path, subfolder=subfolder, revision=revision, use_auth_token=True
|
||||
pretrained_model_name_or_path, subfolder=subfolder, revision=revision
|
||||
)
|
||||
model_class = text_encoder_config.architectures[0]
|
||||
|
||||
|
||||
@@ -70,7 +70,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.18.0.dev0")
|
||||
check_min_version("0.25.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -400,7 +400,7 @@ def import_model_class_from_model_name_or_path(
|
||||
pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
|
||||
):
|
||||
text_encoder_config = PretrainedConfig.from_pretrained(
|
||||
pretrained_model_name_or_path, subfolder=subfolder, revision=revision, use_auth_token=True
|
||||
pretrained_model_name_or_path, subfolder=subfolder, revision=revision
|
||||
)
|
||||
model_class = text_encoder_config.architectures[0]
|
||||
|
||||
@@ -657,6 +657,15 @@ def parse_args():
|
||||
default=0.001,
|
||||
help="The huber loss parameter. Only used if `--loss_type=huber`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--unet_time_cond_proj_dim",
|
||||
type=int,
|
||||
default=256,
|
||||
help=(
|
||||
"The dimension of the guidance scale embedding in the U-Net, which will be used if the teacher U-Net"
|
||||
" does not have `time_cond_proj_dim` set."
|
||||
),
|
||||
)
|
||||
# ----Exponential Moving Average (EMA)----
|
||||
parser.add_argument(
|
||||
"--ema_decay",
|
||||
@@ -1097,7 +1106,7 @@ def main(args):
|
||||
for epoch in range(first_epoch, args.num_train_epochs):
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
with accelerator.accumulate(unet):
|
||||
image, text, _, _ = batch
|
||||
image, text = batch
|
||||
|
||||
image = image.to(accelerator.device, non_blocking=True)
|
||||
encoded_text = compute_embeddings_fn(text)
|
||||
@@ -1138,7 +1147,7 @@ def main(args):
|
||||
|
||||
# 20.4.6. Sample a random guidance scale w from U[w_min, w_max] and embed it
|
||||
w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min
|
||||
w_embedding = guidance_scale_embedding(w, embedding_dim=args.unet_time_cond_proj_dim)
|
||||
w_embedding = guidance_scale_embedding(w, embedding_dim=unet.config.time_cond_proj_dim)
|
||||
w = w.reshape(bsz, 1, 1, 1)
|
||||
# Move to U-Net device and dtype
|
||||
w = w.to(device=latents.device, dtype=latents.dtype)
|
||||
|
||||
@@ -67,11 +67,16 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
MAX_SEQ_LENGTH = 77
|
||||
|
||||
# Adjust for your dataset
|
||||
WDS_JSON_WIDTH = "width" # original_width for LAION
|
||||
WDS_JSON_HEIGHT = "height" # original_height for LAION
|
||||
MIN_SIZE = 700 # ~960 for LAION, ideal: 1024 if the dataset contains large images
|
||||
|
||||
if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.18.0.dev0")
|
||||
check_min_version("0.25.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -128,10 +133,10 @@ class WebdatasetFilter:
|
||||
try:
|
||||
if "json" in x:
|
||||
x_json = json.loads(x["json"])
|
||||
filter_size = (x_json.get("original_width", 0.0) or 0.0) >= self.min_size and x_json.get(
|
||||
"original_height", 0
|
||||
filter_size = (x_json.get(WDS_JSON_WIDTH, 0.0) or 0.0) >= self.min_size and x_json.get(
|
||||
WDS_JSON_HEIGHT, 0
|
||||
) >= self.min_size
|
||||
filter_watermark = (x_json.get("pwatermark", 1.0) or 1.0) <= self.max_pwatermark
|
||||
filter_watermark = (x_json.get("pwatermark", 0.0) or 0.0) <= self.max_pwatermark
|
||||
return filter_size and filter_watermark
|
||||
else:
|
||||
return False
|
||||
@@ -162,7 +167,7 @@ class Text2ImageDataset:
|
||||
if use_fix_crop_and_size:
|
||||
return (resolution, resolution)
|
||||
else:
|
||||
return (int(json.get("original_width", 0.0)), int(json.get("original_height", 0.0)))
|
||||
return (int(json.get(WDS_JSON_WIDTH, 0.0)), int(json.get(WDS_JSON_HEIGHT, 0.0)))
|
||||
|
||||
def transform(example):
|
||||
# resize image
|
||||
@@ -194,7 +199,7 @@ class Text2ImageDataset:
|
||||
pipeline = [
|
||||
wds.ResampledShards(train_shards_path_or_url),
|
||||
tarfile_to_samples_nothrow,
|
||||
wds.select(WebdatasetFilter(min_size=960)),
|
||||
wds.select(WebdatasetFilter(min_size=MIN_SIZE)),
|
||||
wds.shuffle(shuffle_buffer_size),
|
||||
*processing_pipeline,
|
||||
wds.batched(per_gpu_batch_size, partial=False, collation_fn=default_collate),
|
||||
@@ -414,7 +419,7 @@ def import_model_class_from_model_name_or_path(
|
||||
pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
|
||||
):
|
||||
text_encoder_config = PretrainedConfig.from_pretrained(
|
||||
pretrained_model_name_or_path, subfolder=subfolder, revision=revision, use_auth_token=True
|
||||
pretrained_model_name_or_path, subfolder=subfolder, revision=revision
|
||||
)
|
||||
model_class = text_encoder_config.architectures[0]
|
||||
|
||||
@@ -677,6 +682,15 @@ def parse_args():
|
||||
default=0.001,
|
||||
help="The huber loss parameter. Only used if `--loss_type=huber`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--unet_time_cond_proj_dim",
|
||||
type=int,
|
||||
default=256,
|
||||
help=(
|
||||
"The dimension of the guidance scale embedding in the U-Net, which will be used if the teacher U-Net"
|
||||
" does not have `time_cond_proj_dim` set."
|
||||
),
|
||||
)
|
||||
# ----Exponential Moving Average (EMA)----
|
||||
parser.add_argument(
|
||||
"--ema_decay",
|
||||
@@ -1233,6 +1247,7 @@ def main(args):
|
||||
|
||||
# 20.4.6. Sample a random guidance scale w from U[w_min, w_max] and embed it
|
||||
w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min
|
||||
w_embedding = guidance_scale_embedding(w, embedding_dim=unet.config.time_cond_proj_dim)
|
||||
w = w.reshape(bsz, 1, 1, 1)
|
||||
w = w.to(device=latents.device, dtype=latents.dtype)
|
||||
|
||||
@@ -1243,7 +1258,7 @@ def main(args):
|
||||
noise_pred = unet(
|
||||
noisy_model_input,
|
||||
start_timesteps,
|
||||
timestep_cond=None,
|
||||
timestep_cond=w_embedding,
|
||||
encoder_hidden_states=prompt_embeds.float(),
|
||||
added_cond_kwargs=encoded_text,
|
||||
).sample
|
||||
@@ -1308,7 +1323,7 @@ def main(args):
|
||||
target_noise_pred = target_unet(
|
||||
x_prev.float(),
|
||||
timesteps,
|
||||
timestep_cond=None,
|
||||
timestep_cond=w_embedding,
|
||||
encoder_hidden_states=prompt_embeds.float(),
|
||||
added_cond_kwargs=encoded_text,
|
||||
).sample
|
||||
|
||||
@@ -0,0 +1,120 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
|
||||
sys.path.append("..")
|
||||
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
logger = logging.getLogger()
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
|
||||
class ControlNet(ExamplesTestsAccelerate):
|
||||
def test_controlnet_checkpointing_checkpoints_total_limit(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
examples/controlnet/train_controlnet.py
|
||||
--pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
|
||||
--dataset_name=hf-internal-testing/fill10
|
||||
--output_dir={tmpdir}
|
||||
--resolution=64
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--max_train_steps=6
|
||||
--checkpoints_total_limit=2
|
||||
--checkpointing_steps=2
|
||||
--controlnet_model_name_or_path=hf-internal-testing/tiny-controlnet
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-4", "checkpoint-6"},
|
||||
)
|
||||
|
||||
def test_controlnet_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
examples/controlnet/train_controlnet.py
|
||||
--pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
|
||||
--dataset_name=hf-internal-testing/fill10
|
||||
--output_dir={tmpdir}
|
||||
--resolution=64
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--controlnet_model_name_or_path=hf-internal-testing/tiny-controlnet
|
||||
--max_train_steps=9
|
||||
--checkpointing_steps=2
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"},
|
||||
)
|
||||
|
||||
resume_run_args = f"""
|
||||
examples/controlnet/train_controlnet.py
|
||||
--pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
|
||||
--dataset_name=hf-internal-testing/fill10
|
||||
--output_dir={tmpdir}
|
||||
--resolution=64
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--controlnet_model_name_or_path=hf-internal-testing/tiny-controlnet
|
||||
--max_train_steps=11
|
||||
--checkpointing_steps=2
|
||||
--resume_from_checkpoint=checkpoint-8
|
||||
--checkpoints_total_limit=3
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + resume_run_args)
|
||||
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-8", "checkpoint-10", "checkpoint-12"},
|
||||
)
|
||||
|
||||
|
||||
class ControlNetSDXL(ExamplesTestsAccelerate):
|
||||
def test_controlnet_sdxl(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
examples/controlnet/train_controlnet_sdxl.py
|
||||
--pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-xl-pipe
|
||||
--dataset_name=hf-internal-testing/fill10
|
||||
--output_dir={tmpdir}
|
||||
--resolution=64
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--controlnet_model_name_or_path=hf-internal-testing/tiny-controlnet-sdxl
|
||||
--max_train_steps=9
|
||||
--checkpointing_steps=2
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "diffusion_pytorch_model.safetensors")))
|
||||
@@ -56,7 +56,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.24.0.dev0")
|
||||
check_min_version("0.25.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -86,6 +86,7 @@ def log_validation(vae, text_encoder, tokenizer, unet, controlnet, args, acceler
|
||||
controlnet=controlnet,
|
||||
safety_checker=None,
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)
|
||||
@@ -249,10 +250,13 @@ def parse_args(input_args=None):
|
||||
type=str,
|
||||
default=None,
|
||||
required=False,
|
||||
help=(
|
||||
"Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be"
|
||||
" float32 precision."
|
||||
),
|
||||
help="Revision of pretrained model identifier from huggingface.co/models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--variant",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer_name",
|
||||
@@ -767,11 +771,13 @@ def main(args):
|
||||
# Load scheduler and models
|
||||
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
||||
text_encoder = text_encoder_cls.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
|
||||
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
|
||||
)
|
||||
vae = AutoencoderKL.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
|
||||
)
|
||||
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
|
||||
unet = UNet2DConditionModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
|
||||
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
|
||||
)
|
||||
|
||||
if args.controlnet_model_name_or_path:
|
||||
|
||||
@@ -59,7 +59,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.24.0.dev0")
|
||||
check_min_version("0.25.0.dev0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -58,7 +58,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.24.0.dev0")
|
||||
check_min_version("0.25.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -74,6 +74,7 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step)
|
||||
unet=unet,
|
||||
controlnet=controlnet,
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)
|
||||
@@ -243,15 +244,18 @@ def parse_args(input_args=None):
|
||||
help="Path to pretrained controlnet model or model identifier from huggingface.co/models."
|
||||
" If not specified controlnet weights are initialized from unet.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--variant",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--revision",
|
||||
type=str,
|
||||
default=None,
|
||||
required=False,
|
||||
help=(
|
||||
"Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be"
|
||||
" float32 precision."
|
||||
),
|
||||
help="Revision of pretrained model identifier from huggingface.co/models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer_name",
|
||||
@@ -793,10 +797,16 @@ def main(args):
|
||||
|
||||
# Load the tokenizers
|
||||
tokenizer_one = AutoTokenizer.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, use_fast=False
|
||||
args.pretrained_model_name_or_path,
|
||||
subfolder="tokenizer",
|
||||
revision=args.revision,
|
||||
use_fast=False,
|
||||
)
|
||||
tokenizer_two = AutoTokenizer.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.revision, use_fast=False
|
||||
args.pretrained_model_name_or_path,
|
||||
subfolder="tokenizer_2",
|
||||
revision=args.revision,
|
||||
use_fast=False,
|
||||
)
|
||||
|
||||
# import correct text encoder classes
|
||||
@@ -810,10 +820,10 @@ def main(args):
|
||||
# Load scheduler and models
|
||||
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
||||
text_encoder_one = text_encoder_cls_one.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
|
||||
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
|
||||
)
|
||||
text_encoder_two = text_encoder_cls_two.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision
|
||||
args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
|
||||
)
|
||||
vae_path = (
|
||||
args.pretrained_model_name_or_path
|
||||
@@ -824,9 +834,10 @@ def main(args):
|
||||
vae_path,
|
||||
subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
)
|
||||
unet = UNet2DConditionModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
|
||||
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
|
||||
)
|
||||
|
||||
if args.controlnet_model_name_or_path:
|
||||
|
||||
@@ -0,0 +1,130 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
|
||||
sys.path.append("..")
|
||||
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
logger = logging.getLogger()
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
|
||||
class CustomDiffusion(ExamplesTestsAccelerate):
|
||||
def test_custom_diffusion(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
examples/custom_diffusion/train_custom_diffusion.py
|
||||
--pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
|
||||
--instance_data_dir docs/source/en/imgs
|
||||
--instance_prompt <new1>
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--learning_rate 1.0e-05
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--modifier_token <new1>
|
||||
--no_safe_serialization
|
||||
--output_dir {tmpdir}
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_custom_diffusion_weights.bin")))
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "<new1>.bin")))
|
||||
|
||||
def test_custom_diffusion_checkpointing_checkpoints_total_limit(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
examples/custom_diffusion/train_custom_diffusion.py
|
||||
--pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
|
||||
--instance_data_dir=docs/source/en/imgs
|
||||
--output_dir={tmpdir}
|
||||
--instance_prompt=<new1>
|
||||
--resolution=64
|
||||
--train_batch_size=1
|
||||
--modifier_token=<new1>
|
||||
--dataloader_num_workers=0
|
||||
--max_train_steps=6
|
||||
--checkpoints_total_limit=2
|
||||
--checkpointing_steps=2
|
||||
--no_safe_serialization
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-4", "checkpoint-6"},
|
||||
)
|
||||
|
||||
def test_custom_diffusion_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
examples/custom_diffusion/train_custom_diffusion.py
|
||||
--pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
|
||||
--instance_data_dir=docs/source/en/imgs
|
||||
--output_dir={tmpdir}
|
||||
--instance_prompt=<new1>
|
||||
--resolution=64
|
||||
--train_batch_size=1
|
||||
--modifier_token=<new1>
|
||||
--dataloader_num_workers=0
|
||||
--max_train_steps=9
|
||||
--checkpointing_steps=2
|
||||
--no_safe_serialization
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"},
|
||||
)
|
||||
|
||||
resume_run_args = f"""
|
||||
examples/custom_diffusion/train_custom_diffusion.py
|
||||
--pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
|
||||
--instance_data_dir=docs/source/en/imgs
|
||||
--output_dir={tmpdir}
|
||||
--instance_prompt=<new1>
|
||||
--resolution=64
|
||||
--train_batch_size=1
|
||||
--modifier_token=<new1>
|
||||
--dataloader_num_workers=0
|
||||
--max_train_steps=11
|
||||
--checkpointing_steps=2
|
||||
--resume_from_checkpoint=checkpoint-8
|
||||
--checkpoints_total_limit=3
|
||||
--no_safe_serialization
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + resume_run_args)
|
||||
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-6", "checkpoint-8", "checkpoint-10"},
|
||||
)
|
||||
@@ -62,7 +62,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.24.0.dev0")
|
||||
check_min_version("0.25.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -332,6 +332,12 @@ def parse_args(input_args=None):
|
||||
required=False,
|
||||
help="Revision of pretrained model identifier from huggingface.co/models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--variant",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer_name",
|
||||
type=str,
|
||||
@@ -740,6 +746,7 @@ def main(args):
|
||||
torch_dtype=torch_dtype,
|
||||
safety_checker=None,
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
)
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
@@ -801,11 +808,13 @@ def main(args):
|
||||
# Load scheduler and models
|
||||
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
||||
text_encoder = text_encoder_cls.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
|
||||
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
|
||||
)
|
||||
vae = AutoencoderKL.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
|
||||
)
|
||||
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
|
||||
unet = UNet2DConditionModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
|
||||
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
|
||||
)
|
||||
|
||||
# Adding a modifier token which is optimized ####
|
||||
@@ -1229,6 +1238,7 @@ def main(args):
|
||||
text_encoder=accelerator.unwrap_model(text_encoder),
|
||||
tokenizer=tokenizer,
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
|
||||
@@ -1278,7 +1288,7 @@ def main(args):
|
||||
# Final inference
|
||||
# Load previous pipeline
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path, revision=args.revision, torch_dtype=weight_dtype
|
||||
args.pretrained_model_name_or_path, revision=args.revision, variant=args.variant, torch_dtype=weight_dtype
|
||||
)
|
||||
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
|
||||
pipeline = pipeline.to(accelerator.device)
|
||||
|
||||
@@ -44,6 +44,7 @@ write_basic_config()
|
||||
```
|
||||
|
||||
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
|
||||
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
|
||||
|
||||
### Dog toy example
|
||||
|
||||
|
||||
@@ -47,6 +47,7 @@ write_basic_config()
|
||||
```
|
||||
|
||||
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
|
||||
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
|
||||
|
||||
### Dog toy example
|
||||
|
||||
|
||||
@@ -4,3 +4,4 @@ transformers>=4.25.1
|
||||
ftfy
|
||||
tensorboard
|
||||
Jinja2
|
||||
peft==0.7.0
|
||||
@@ -4,3 +4,4 @@ transformers>=4.25.1
|
||||
ftfy
|
||||
tensorboard
|
||||
Jinja2
|
||||
peft==0.7.0
|
||||
@@ -0,0 +1,230 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
from diffusers import DiffusionPipeline, UNet2DConditionModel
|
||||
|
||||
|
||||
sys.path.append("..")
|
||||
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
logger = logging.getLogger()
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
|
||||
class DreamBooth(ExamplesTestsAccelerate):
|
||||
def test_dreambooth(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
examples/dreambooth/train_dreambooth.py
|
||||
--pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
|
||||
--instance_data_dir docs/source/en/imgs
|
||||
--instance_prompt photo
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.safetensors")))
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))
|
||||
|
||||
def test_dreambooth_if(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
examples/dreambooth/train_dreambooth.py
|
||||
--pretrained_model_name_or_path hf-internal-testing/tiny-if-pipe
|
||||
--instance_data_dir docs/source/en/imgs
|
||||
--instance_prompt photo
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
--pre_compute_text_embeddings
|
||||
--tokenizer_max_length=77
|
||||
--text_encoder_use_attention_mask
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.safetensors")))
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))
|
||||
|
||||
def test_dreambooth_checkpointing(self):
|
||||
instance_prompt = "photo"
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe"
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Run training script with checkpointing
|
||||
# max_train_steps == 5, checkpointing_steps == 2
|
||||
# Should create checkpoints at steps 2, 4
|
||||
|
||||
initial_run_args = f"""
|
||||
examples/dreambooth/train_dreambooth.py
|
||||
--pretrained_model_name_or_path {pretrained_model_name_or_path}
|
||||
--instance_data_dir docs/source/en/imgs
|
||||
--instance_prompt {instance_prompt}
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 5
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
--checkpointing_steps=2
|
||||
--seed=0
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + initial_run_args)
|
||||
|
||||
# check can run the original fully trained output pipeline
|
||||
pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
|
||||
pipe(instance_prompt, num_inference_steps=2)
|
||||
|
||||
# check checkpoint directories exist
|
||||
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-2")))
|
||||
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-4")))
|
||||
|
||||
# check can run an intermediate checkpoint
|
||||
unet = UNet2DConditionModel.from_pretrained(tmpdir, subfolder="checkpoint-2/unet")
|
||||
pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, unet=unet, safety_checker=None)
|
||||
pipe(instance_prompt, num_inference_steps=2)
|
||||
|
||||
# Remove checkpoint 2 so that we can check only later checkpoints exist after resuming
|
||||
shutil.rmtree(os.path.join(tmpdir, "checkpoint-2"))
|
||||
|
||||
# Run training script for 7 total steps resuming from checkpoint 4
|
||||
|
||||
resume_run_args = f"""
|
||||
examples/dreambooth/train_dreambooth.py
|
||||
--pretrained_model_name_or_path {pretrained_model_name_or_path}
|
||||
--instance_data_dir docs/source/en/imgs
|
||||
--instance_prompt {instance_prompt}
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 7
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
--checkpointing_steps=2
|
||||
--resume_from_checkpoint=checkpoint-4
|
||||
--seed=0
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + resume_run_args)
|
||||
|
||||
# check can run new fully trained pipeline
|
||||
pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
|
||||
pipe(instance_prompt, num_inference_steps=2)
|
||||
|
||||
# check old checkpoints do not exist
|
||||
self.assertFalse(os.path.isdir(os.path.join(tmpdir, "checkpoint-2")))
|
||||
|
||||
# check new checkpoints exist
|
||||
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-4")))
|
||||
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-6")))
|
||||
|
||||
def test_dreambooth_checkpointing_checkpoints_total_limit(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
examples/dreambooth/train_dreambooth.py
|
||||
--pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
|
||||
--instance_data_dir=docs/source/en/imgs
|
||||
--output_dir={tmpdir}
|
||||
--instance_prompt=prompt
|
||||
--resolution=64
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--max_train_steps=6
|
||||
--checkpoints_total_limit=2
|
||||
--checkpointing_steps=2
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-4", "checkpoint-6"},
|
||||
)
|
||||
|
||||
def test_dreambooth_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
examples/dreambooth/train_dreambooth.py
|
||||
--pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
|
||||
--instance_data_dir=docs/source/en/imgs
|
||||
--output_dir={tmpdir}
|
||||
--instance_prompt=prompt
|
||||
--resolution=64
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--max_train_steps=9
|
||||
--checkpointing_steps=2
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"},
|
||||
)
|
||||
|
||||
resume_run_args = f"""
|
||||
examples/dreambooth/train_dreambooth.py
|
||||
--pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
|
||||
--instance_data_dir=docs/source/en/imgs
|
||||
--output_dir={tmpdir}
|
||||
--instance_prompt=prompt
|
||||
--resolution=64
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--max_train_steps=11
|
||||
--checkpointing_steps=2
|
||||
--resume_from_checkpoint=checkpoint-8
|
||||
--checkpoints_total_limit=3
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + resume_run_args)
|
||||
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-6", "checkpoint-8", "checkpoint-10"},
|
||||
)
|
||||
@@ -0,0 +1,388 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
import safetensors
|
||||
|
||||
|
||||
sys.path.append("..")
|
||||
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
|
||||
|
||||
from diffusers import DiffusionPipeline # noqa: E402
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
logger = logging.getLogger()
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
|
||||
class DreamBoothLoRA(ExamplesTestsAccelerate):
|
||||
def test_dreambooth_lora(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
examples/dreambooth/train_dreambooth_lora.py
|
||||
--pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
|
||||
--instance_data_dir docs/source/en/imgs
|
||||
--instance_prompt photo
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
|
||||
|
||||
# make sure the state_dict has the correct naming in the parameters.
|
||||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
|
||||
is_lora = all("lora" in k for k in lora_state_dict.keys())
|
||||
self.assertTrue(is_lora)
|
||||
|
||||
# when not training the text encoder, all the parameters in the state dict should start
|
||||
# with `"unet"` in their names.
|
||||
starts_with_unet = all(key.startswith("unet") for key in lora_state_dict.keys())
|
||||
self.assertTrue(starts_with_unet)
|
||||
|
||||
def test_dreambooth_lora_with_text_encoder(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
examples/dreambooth/train_dreambooth_lora.py
|
||||
--pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
|
||||
--instance_data_dir docs/source/en/imgs
|
||||
--instance_prompt photo
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--train_text_encoder
|
||||
--output_dir {tmpdir}
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
|
||||
|
||||
# check `text_encoder` is present at all.
|
||||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
|
||||
keys = lora_state_dict.keys()
|
||||
is_text_encoder_present = any(k.startswith("text_encoder") for k in keys)
|
||||
self.assertTrue(is_text_encoder_present)
|
||||
|
||||
# the names of the keys of the state dict should either start with `unet`
|
||||
# or `text_encoder`.
|
||||
is_correct_naming = all(k.startswith("unet") or k.startswith("text_encoder") for k in keys)
|
||||
self.assertTrue(is_correct_naming)
|
||||
|
||||
def test_dreambooth_lora_checkpointing_checkpoints_total_limit(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
examples/dreambooth/train_dreambooth_lora.py
|
||||
--pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
|
||||
--instance_data_dir=docs/source/en/imgs
|
||||
--output_dir={tmpdir}
|
||||
--instance_prompt=prompt
|
||||
--resolution=64
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--max_train_steps=6
|
||||
--checkpoints_total_limit=2
|
||||
--checkpointing_steps=2
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-4", "checkpoint-6"},
|
||||
)
|
||||
|
||||
def test_dreambooth_lora_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
examples/dreambooth/train_dreambooth_lora.py
|
||||
--pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
|
||||
--instance_data_dir=docs/source/en/imgs
|
||||
--output_dir={tmpdir}
|
||||
--instance_prompt=prompt
|
||||
--resolution=64
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--max_train_steps=9
|
||||
--checkpointing_steps=2
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"},
|
||||
)
|
||||
|
||||
resume_run_args = f"""
|
||||
examples/dreambooth/train_dreambooth_lora.py
|
||||
--pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
|
||||
--instance_data_dir=docs/source/en/imgs
|
||||
--output_dir={tmpdir}
|
||||
--instance_prompt=prompt
|
||||
--resolution=64
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--max_train_steps=11
|
||||
--checkpointing_steps=2
|
||||
--resume_from_checkpoint=checkpoint-8
|
||||
--checkpoints_total_limit=3
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + resume_run_args)
|
||||
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-6", "checkpoint-8", "checkpoint-10"},
|
||||
)
|
||||
|
||||
def test_dreambooth_lora_if_model(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
examples/dreambooth/train_dreambooth_lora.py
|
||||
--pretrained_model_name_or_path hf-internal-testing/tiny-if-pipe
|
||||
--instance_data_dir docs/source/en/imgs
|
||||
--instance_prompt photo
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
--pre_compute_text_embeddings
|
||||
--tokenizer_max_length=77
|
||||
--text_encoder_use_attention_mask
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
|
||||
|
||||
# make sure the state_dict has the correct naming in the parameters.
|
||||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
|
||||
is_lora = all("lora" in k for k in lora_state_dict.keys())
|
||||
self.assertTrue(is_lora)
|
||||
|
||||
# when not training the text encoder, all the parameters in the state dict should start
|
||||
# with `"unet"` in their names.
|
||||
starts_with_unet = all(key.startswith("unet") for key in lora_state_dict.keys())
|
||||
self.assertTrue(starts_with_unet)
|
||||
|
||||
|
||||
class DreamBoothLoRASDXL(ExamplesTestsAccelerate):
|
||||
def test_dreambooth_lora_sdxl(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
examples/dreambooth/train_dreambooth_lora_sdxl.py
|
||||
--pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe
|
||||
--instance_data_dir docs/source/en/imgs
|
||||
--instance_prompt photo
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
|
||||
|
||||
# make sure the state_dict has the correct naming in the parameters.
|
||||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
|
||||
is_lora = all("lora" in k for k in lora_state_dict.keys())
|
||||
self.assertTrue(is_lora)
|
||||
|
||||
# when not training the text encoder, all the parameters in the state dict should start
|
||||
# with `"unet"` in their names.
|
||||
starts_with_unet = all(key.startswith("unet") for key in lora_state_dict.keys())
|
||||
self.assertTrue(starts_with_unet)
|
||||
|
||||
def test_dreambooth_lora_sdxl_with_text_encoder(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
examples/dreambooth/train_dreambooth_lora_sdxl.py
|
||||
--pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe
|
||||
--instance_data_dir docs/source/en/imgs
|
||||
--instance_prompt photo
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
--train_text_encoder
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
|
||||
|
||||
# make sure the state_dict has the correct naming in the parameters.
|
||||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
|
||||
is_lora = all("lora" in k for k in lora_state_dict.keys())
|
||||
self.assertTrue(is_lora)
|
||||
|
||||
# when not training the text encoder, all the parameters in the state dict should start
|
||||
# with `"unet"` or `"text_encoder"` or `"text_encoder_2"` in their names.
|
||||
keys = lora_state_dict.keys()
|
||||
starts_with_unet = all(
|
||||
k.startswith("unet") or k.startswith("text_encoder") or k.startswith("text_encoder_2") for k in keys
|
||||
)
|
||||
self.assertTrue(starts_with_unet)
|
||||
|
||||
def test_dreambooth_lora_sdxl_custom_captions(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
examples/dreambooth/train_dreambooth_lora_sdxl.py
|
||||
--pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe
|
||||
--dataset_name hf-internal-testing/dummy_image_text_data
|
||||
--caption_column text
|
||||
--instance_prompt photo
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
|
||||
def test_dreambooth_lora_sdxl_text_encoder_custom_captions(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
examples/dreambooth/train_dreambooth_lora_sdxl.py
|
||||
--pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe
|
||||
--dataset_name hf-internal-testing/dummy_image_text_data
|
||||
--caption_column text
|
||||
--instance_prompt photo
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
--train_text_encoder
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
|
||||
def test_dreambooth_lora_sdxl_checkpointing_checkpoints_total_limit(self):
|
||||
pipeline_path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe"
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
examples/dreambooth/train_dreambooth_lora_sdxl.py
|
||||
--pretrained_model_name_or_path {pipeline_path}
|
||||
--instance_data_dir docs/source/en/imgs
|
||||
--instance_prompt photo
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 7
|
||||
--checkpointing_steps=2
|
||||
--checkpoints_total_limit=2
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(pipeline_path)
|
||||
pipe.load_lora_weights(tmpdir)
|
||||
pipe("a prompt", num_inference_steps=2)
|
||||
|
||||
# check checkpoint directories exist
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
# checkpoint-2 should have been deleted
|
||||
{"checkpoint-4", "checkpoint-6"},
|
||||
)
|
||||
|
||||
def test_dreambooth_lora_sdxl_text_encoder_checkpointing_checkpoints_total_limit(self):
|
||||
pipeline_path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe"
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
examples/dreambooth/train_dreambooth_lora_sdxl.py
|
||||
--pretrained_model_name_or_path {pipeline_path}
|
||||
--instance_data_dir docs/source/en/imgs
|
||||
--instance_prompt photo
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 7
|
||||
--checkpointing_steps=2
|
||||
--checkpoints_total_limit=2
|
||||
--train_text_encoder
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(pipeline_path)
|
||||
pipe.load_lora_weights(tmpdir)
|
||||
pipe("a prompt", num_inference_steps=2)
|
||||
|
||||
# check checkpoint directories exist
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
# checkpoint-2 should have been deleted
|
||||
{"checkpoint-4", "checkpoint-6"},
|
||||
)
|
||||
@@ -61,7 +61,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.24.0.dev0")
|
||||
check_min_version("0.25.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -139,6 +139,7 @@ def log_validation(
|
||||
text_encoder=text_encoder,
|
||||
unet=accelerator.unwrap_model(unet),
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
torch_dtype=weight_dtype,
|
||||
**pipeline_args,
|
||||
)
|
||||
@@ -239,10 +240,13 @@ def parse_args(input_args=None):
|
||||
type=str,
|
||||
default=None,
|
||||
required=False,
|
||||
help=(
|
||||
"Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be"
|
||||
" float32 precision."
|
||||
),
|
||||
help="Revision of pretrained model identifier from huggingface.co/models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--variant",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer_name",
|
||||
@@ -296,7 +300,7 @@ def parse_args(input_args=None):
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default="text-inversion-model",
|
||||
default="dreambooth-model",
|
||||
help="The output directory where the model predictions and checkpoints will be written.",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
||||
@@ -859,6 +863,7 @@ def main(args):
|
||||
torch_dtype=torch_dtype,
|
||||
safety_checker=None,
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
)
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
@@ -912,18 +917,18 @@ def main(args):
|
||||
# Load scheduler and models
|
||||
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
||||
text_encoder = text_encoder_cls.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
|
||||
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
|
||||
)
|
||||
|
||||
if model_has_vae(args):
|
||||
vae = AutoencoderKL.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision
|
||||
args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
|
||||
)
|
||||
else:
|
||||
vae = None
|
||||
|
||||
unet = UNet2DConditionModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
|
||||
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
|
||||
)
|
||||
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
@@ -1379,6 +1384,7 @@ def main(args):
|
||||
args.pretrained_model_name_or_path,
|
||||
unet=accelerator.unwrap_model(unet),
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
**pipeline_args,
|
||||
)
|
||||
|
||||
|
||||
@@ -35,7 +35,7 @@ from diffusers.utils import check_min_version
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.24.0.dev0")
|
||||
check_min_version("0.25.0.dev0")
|
||||
|
||||
# Cache compiled models across invocations of this script.
|
||||
cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache"))
|
||||
@@ -460,7 +460,10 @@ def main():
|
||||
|
||||
# Load models and create wrapper for stable diffusion
|
||||
text_encoder = FlaxCLIPTextModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="text_encoder", dtype=weight_dtype, revision=args.revision
|
||||
args.pretrained_model_name_or_path,
|
||||
subfolder="text_encoder",
|
||||
dtype=weight_dtype,
|
||||
revision=args.revision,
|
||||
)
|
||||
vae, vae_params = FlaxAutoencoderKL.from_pretrained(
|
||||
vae_arg,
|
||||
@@ -468,7 +471,10 @@ def main():
|
||||
**vae_kwargs,
|
||||
)
|
||||
unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="unet", dtype=weight_dtype, revision=args.revision
|
||||
args.pretrained_model_name_or_path,
|
||||
subfolder="unet",
|
||||
dtype=weight_dtype,
|
||||
revision=args.revision,
|
||||
)
|
||||
|
||||
# Optimization
|
||||
|
||||
@@ -16,7 +16,6 @@
|
||||
import argparse
|
||||
import copy
|
||||
import gc
|
||||
import itertools
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
@@ -35,6 +34,8 @@ from accelerate.utils import ProjectConfiguration, set_seed
|
||||
from huggingface_hub import create_repo, upload_folder
|
||||
from huggingface_hub.utils import insecure_hashlib
|
||||
from packaging import version
|
||||
from peft import LoraConfig
|
||||
from peft.utils import get_peft_model_state_dict
|
||||
from PIL import Image
|
||||
from PIL.ImageOps import exif_transpose
|
||||
from torch.utils.data import Dataset
|
||||
@@ -52,24 +53,50 @@ from diffusers import (
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.loaders import LoraLoaderMixin
|
||||
from diffusers.models.attention_processor import (
|
||||
AttnAddedKVProcessor,
|
||||
AttnAddedKVProcessor2_0,
|
||||
SlicedAttnAddedKVProcessor,
|
||||
)
|
||||
from diffusers.models.lora import LoRALinearLayer, text_encoder_lora_state_dict
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import unet_lora_state_dict
|
||||
from diffusers.utils import check_min_version, is_wandb_available
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.24.0.dev0")
|
||||
check_min_version("0.25.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
# TODO: This function should be removed once training scripts are rewritten in PEFT
|
||||
def text_encoder_lora_state_dict(text_encoder):
|
||||
state_dict = {}
|
||||
|
||||
def text_encoder_attn_modules(text_encoder):
|
||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection
|
||||
|
||||
attn_modules = []
|
||||
|
||||
if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
|
||||
for i, layer in enumerate(text_encoder.text_model.encoder.layers):
|
||||
name = f"text_model.encoder.layers.{i}.self_attn"
|
||||
mod = layer.self_attn
|
||||
attn_modules.append((name, mod))
|
||||
|
||||
return attn_modules
|
||||
|
||||
for name, module in text_encoder_attn_modules(text_encoder):
|
||||
for k, v in module.q_proj.lora_linear_layer.state_dict().items():
|
||||
state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v
|
||||
|
||||
for k, v in module.k_proj.lora_linear_layer.state_dict().items():
|
||||
state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v
|
||||
|
||||
for k, v in module.v_proj.lora_linear_layer.state_dict().items():
|
||||
state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v
|
||||
|
||||
for k, v in module.out_proj.lora_linear_layer.state_dict().items():
|
||||
state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def save_model_card(
|
||||
repo_id: str,
|
||||
images=None,
|
||||
@@ -150,6 +177,12 @@ def parse_args(input_args=None):
|
||||
required=False,
|
||||
help="Revision of pretrained model identifier from huggingface.co/models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--variant",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer_name",
|
||||
type=str,
|
||||
@@ -717,6 +750,7 @@ def main(args):
|
||||
torch_dtype=torch_dtype,
|
||||
safety_checker=None,
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
)
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
@@ -770,11 +804,11 @@ def main(args):
|
||||
# Load scheduler and models
|
||||
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
||||
text_encoder = text_encoder_cls.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
|
||||
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
|
||||
)
|
||||
try:
|
||||
vae = AutoencoderKL.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision
|
||||
args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
|
||||
)
|
||||
except OSError:
|
||||
# IF does not have a VAE so let's just set it to None
|
||||
@@ -782,7 +816,7 @@ def main(args):
|
||||
vae = None
|
||||
|
||||
unet = UNet2DConditionModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
|
||||
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
|
||||
)
|
||||
|
||||
# We only train the additional adapter LoRA layers
|
||||
@@ -824,79 +858,19 @@ def main(args):
|
||||
text_encoder.gradient_checkpointing_enable()
|
||||
|
||||
# now we will add new LoRA weights to the attention layers
|
||||
# It's important to realize here how many attention weights will be added and of which sizes
|
||||
# The sizes of the attention layers consist only of two different variables:
|
||||
# 1) - the "hidden_size", which is increased according to `unet.config.block_out_channels`.
|
||||
# 2) - the "cross attention size", which is set to `unet.config.cross_attention_dim`.
|
||||
unet_lora_config = LoraConfig(
|
||||
r=args.rank,
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=["to_k", "to_q", "to_v", "to_out.0", "add_k_proj", "add_v_proj"],
|
||||
)
|
||||
unet.add_adapter(unet_lora_config)
|
||||
|
||||
# Let's first see how many attention processors we will have to set.
|
||||
# For Stable Diffusion, it should be equal to:
|
||||
# - down blocks (2x attention layers) * (2x transformer layers) * (3x down blocks) = 12
|
||||
# - mid blocks (2x attention layers) * (1x transformer layers) * (1x mid blocks) = 2
|
||||
# - up blocks (2x attention layers) * (3x transformer layers) * (3x up blocks) = 18
|
||||
# => 32 layers
|
||||
|
||||
# Set correct lora layers
|
||||
unet_lora_parameters = []
|
||||
for attn_processor_name, attn_processor in unet.attn_processors.items():
|
||||
# Parse the attention module.
|
||||
attn_module = unet
|
||||
for n in attn_processor_name.split(".")[:-1]:
|
||||
attn_module = getattr(attn_module, n)
|
||||
|
||||
# Set the `lora_layer` attribute of the attention-related matrices.
|
||||
attn_module.to_q.set_lora_layer(
|
||||
LoRALinearLayer(
|
||||
in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=args.rank
|
||||
)
|
||||
)
|
||||
attn_module.to_k.set_lora_layer(
|
||||
LoRALinearLayer(
|
||||
in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=args.rank
|
||||
)
|
||||
)
|
||||
attn_module.to_v.set_lora_layer(
|
||||
LoRALinearLayer(
|
||||
in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=args.rank
|
||||
)
|
||||
)
|
||||
attn_module.to_out[0].set_lora_layer(
|
||||
LoRALinearLayer(
|
||||
in_features=attn_module.to_out[0].in_features,
|
||||
out_features=attn_module.to_out[0].out_features,
|
||||
rank=args.rank,
|
||||
)
|
||||
)
|
||||
|
||||
# Accumulate the LoRA params to optimize.
|
||||
unet_lora_parameters.extend(attn_module.to_q.lora_layer.parameters())
|
||||
unet_lora_parameters.extend(attn_module.to_k.lora_layer.parameters())
|
||||
unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters())
|
||||
unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters())
|
||||
|
||||
if isinstance(attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)):
|
||||
attn_module.add_k_proj.set_lora_layer(
|
||||
LoRALinearLayer(
|
||||
in_features=attn_module.add_k_proj.in_features,
|
||||
out_features=attn_module.add_k_proj.out_features,
|
||||
rank=args.rank,
|
||||
)
|
||||
)
|
||||
attn_module.add_v_proj.set_lora_layer(
|
||||
LoRALinearLayer(
|
||||
in_features=attn_module.add_v_proj.in_features,
|
||||
out_features=attn_module.add_v_proj.out_features,
|
||||
rank=args.rank,
|
||||
)
|
||||
)
|
||||
unet_lora_parameters.extend(attn_module.add_k_proj.lora_layer.parameters())
|
||||
unet_lora_parameters.extend(attn_module.add_v_proj.lora_layer.parameters())
|
||||
|
||||
# The text encoder comes from 🤗 transformers, so we cannot directly modify it.
|
||||
# So, instead, we monkey-patch the forward calls of its attention-blocks.
|
||||
# The text encoder comes from 🤗 transformers, we will also attach adapters to it.
|
||||
if args.train_text_encoder:
|
||||
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
|
||||
text_lora_parameters = LoraLoaderMixin._modify_text_encoder(text_encoder, dtype=torch.float32, rank=args.rank)
|
||||
text_lora_config = LoraConfig(
|
||||
r=args.rank, init_lora_weights="gaussian", target_modules=["q_proj", "k_proj", "v_proj", "out_proj"]
|
||||
)
|
||||
text_encoder.add_adapter(text_lora_config)
|
||||
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
def save_model_hook(models, weights, output_dir):
|
||||
@@ -908,9 +882,9 @@ def main(args):
|
||||
|
||||
for model in models:
|
||||
if isinstance(model, type(accelerator.unwrap_model(unet))):
|
||||
unet_lora_layers_to_save = unet_lora_state_dict(model)
|
||||
unet_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||
elif isinstance(model, type(accelerator.unwrap_model(text_encoder))):
|
||||
text_encoder_lora_layers_to_save = text_encoder_lora_state_dict(model)
|
||||
text_encoder_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
@@ -970,11 +944,10 @@ def main(args):
|
||||
optimizer_class = torch.optim.AdamW
|
||||
|
||||
# Optimizer creation
|
||||
params_to_optimize = (
|
||||
itertools.chain(unet_lora_parameters, text_lora_parameters)
|
||||
if args.train_text_encoder
|
||||
else unet_lora_parameters
|
||||
)
|
||||
params_to_optimize = list(filter(lambda p: p.requires_grad, unet.parameters()))
|
||||
if args.train_text_encoder:
|
||||
params_to_optimize = params_to_optimize + list(filter(lambda p: p.requires_grad, text_encoder.parameters()))
|
||||
|
||||
optimizer = optimizer_class(
|
||||
params_to_optimize,
|
||||
lr=args.learning_rate,
|
||||
@@ -1217,12 +1190,7 @@ def main(args):
|
||||
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients:
|
||||
params_to_clip = (
|
||||
itertools.chain(unet_lora_parameters, text_lora_parameters)
|
||||
if args.train_text_encoder
|
||||
else unet_lora_parameters
|
||||
)
|
||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||
accelerator.clip_grad_norm_(params_to_optimize, args.max_grad_norm)
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
@@ -1277,6 +1245,7 @@ def main(args):
|
||||
unet=accelerator.unwrap_model(unet),
|
||||
text_encoder=None if args.pre_compute_text_embeddings else accelerator.unwrap_model(text_encoder),
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
|
||||
@@ -1344,25 +1313,25 @@ def main(args):
|
||||
if accelerator.is_main_process:
|
||||
unet = accelerator.unwrap_model(unet)
|
||||
unet = unet.to(torch.float32)
|
||||
unet_lora_layers = unet_lora_state_dict(unet)
|
||||
|
||||
if text_encoder is not None and args.train_text_encoder:
|
||||
unet_lora_state_dict = get_peft_model_state_dict(unet)
|
||||
|
||||
if args.train_text_encoder:
|
||||
text_encoder = accelerator.unwrap_model(text_encoder)
|
||||
text_encoder = text_encoder.to(torch.float32)
|
||||
text_encoder_lora_layers = text_encoder_lora_state_dict(text_encoder)
|
||||
text_encoder_state_dict = get_peft_model_state_dict(text_encoder)
|
||||
else:
|
||||
text_encoder_lora_layers = None
|
||||
text_encoder_state_dict = None
|
||||
|
||||
LoraLoaderMixin.save_lora_weights(
|
||||
save_directory=args.output_dir,
|
||||
unet_lora_layers=unet_lora_layers,
|
||||
text_encoder_lora_layers=text_encoder_lora_layers,
|
||||
unet_lora_layers=unet_lora_state_dict,
|
||||
text_encoder_lora_layers=text_encoder_state_dict,
|
||||
)
|
||||
|
||||
# Final inference
|
||||
# Load previous pipeline
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path, revision=args.revision, torch_dtype=weight_dtype
|
||||
args.pretrained_model_name_or_path, revision=args.revision, variant=args.variant, torch_dtype=weight_dtype
|
||||
)
|
||||
|
||||
# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
|
||||
|
||||
@@ -34,6 +34,8 @@ from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration
|
||||
from huggingface_hub import create_repo, upload_folder
|
||||
from huggingface_hub.utils import insecure_hashlib
|
||||
from packaging import version
|
||||
from peft import LoraConfig
|
||||
from peft.utils import get_peft_model_state_dict
|
||||
from PIL import Image
|
||||
from PIL.ImageOps import exif_transpose
|
||||
from torch.utils.data import Dataset
|
||||
@@ -50,50 +52,112 @@ from diffusers import (
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.loaders import LoraLoaderMixin
|
||||
from diffusers.models.lora import LoRALinearLayer, text_encoder_lora_state_dict
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import unet_lora_state_dict
|
||||
from diffusers.training_utils import compute_snr
|
||||
from diffusers.utils import check_min_version, is_wandb_available
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.24.0.dev0")
|
||||
check_min_version("0.25.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
# TODO: This function should be removed once training scripts are rewritten in PEFT
|
||||
def text_encoder_lora_state_dict(text_encoder):
|
||||
state_dict = {}
|
||||
|
||||
def text_encoder_attn_modules(text_encoder):
|
||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection
|
||||
|
||||
attn_modules = []
|
||||
|
||||
if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
|
||||
for i, layer in enumerate(text_encoder.text_model.encoder.layers):
|
||||
name = f"text_model.encoder.layers.{i}.self_attn"
|
||||
mod = layer.self_attn
|
||||
attn_modules.append((name, mod))
|
||||
|
||||
return attn_modules
|
||||
|
||||
for name, module in text_encoder_attn_modules(text_encoder):
|
||||
for k, v in module.q_proj.lora_linear_layer.state_dict().items():
|
||||
state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v
|
||||
|
||||
for k, v in module.k_proj.lora_linear_layer.state_dict().items():
|
||||
state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v
|
||||
|
||||
for k, v in module.v_proj.lora_linear_layer.state_dict().items():
|
||||
state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v
|
||||
|
||||
for k, v in module.out_proj.lora_linear_layer.state_dict().items():
|
||||
state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def save_model_card(
|
||||
repo_id: str, images=None, base_model=str, train_text_encoder=False, prompt=str, repo_folder=None, vae_path=None
|
||||
repo_id: str,
|
||||
images=None,
|
||||
base_model=str,
|
||||
train_text_encoder=False,
|
||||
instance_prompt=str,
|
||||
validation_prompt=str,
|
||||
repo_folder=None,
|
||||
vae_path=None,
|
||||
):
|
||||
img_str = ""
|
||||
img_str = "widget:\n" if images else ""
|
||||
for i, image in enumerate(images):
|
||||
image.save(os.path.join(repo_folder, f"image_{i}.png"))
|
||||
img_str += f"\n"
|
||||
img_str += f"""
|
||||
- text: '{validation_prompt if validation_prompt else ' ' }'
|
||||
output:
|
||||
url:
|
||||
"image_{i}.png"
|
||||
"""
|
||||
|
||||
yaml = f"""
|
||||
---
|
||||
license: openrail++
|
||||
base_model: {base_model}
|
||||
instance_prompt: {prompt}
|
||||
tags:
|
||||
- stable-diffusion-xl
|
||||
- stable-diffusion-xl-diffusers
|
||||
- text-to-image
|
||||
- diffusers
|
||||
- lora
|
||||
inference: true
|
||||
- template:sd-lora
|
||||
{img_str}
|
||||
base_model: {base_model}
|
||||
instance_prompt: {instance_prompt}
|
||||
license: openrail++
|
||||
---
|
||||
"""
|
||||
model_card = f"""
|
||||
# LoRA DreamBooth - {repo_id}
|
||||
|
||||
These are LoRA adaption weights for {base_model}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/). You can find some example images in the following. \n
|
||||
{img_str}
|
||||
model_card = f"""
|
||||
# SDXL LoRA DreamBooth - {repo_id}
|
||||
|
||||
<Gallery />
|
||||
|
||||
## Model description
|
||||
|
||||
These are {repo_id} LoRA adaption weights for {base_model}.
|
||||
|
||||
The weights were trained using [DreamBooth](https://dreambooth.github.io/).
|
||||
|
||||
LoRA for the text encoder was enabled: {train_text_encoder}.
|
||||
|
||||
Special VAE used for training: {vae_path}.
|
||||
|
||||
## Trigger words
|
||||
|
||||
You should use {instance_prompt} to trigger the image generation.
|
||||
|
||||
## Download model
|
||||
|
||||
Weights for this model are available in Safetensors format.
|
||||
|
||||
[Download]({repo_id}/tree/main) them in the Files & versions tab.
|
||||
|
||||
"""
|
||||
with open(os.path.join(repo_folder, "README.md"), "w") as f:
|
||||
f.write(yaml + model_card)
|
||||
@@ -141,13 +205,59 @@ def parse_args(input_args=None):
|
||||
required=False,
|
||||
help="Revision of pretrained model identifier from huggingface.co/models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--variant",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset_name",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private,"
|
||||
" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
|
||||
" or to a folder containing files that 🤗 Datasets can understand."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset_config_name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The config of the Dataset, leave as None if there's only one config.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--instance_data_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="A folder containing the training data of instance images.",
|
||||
help=("A folder containing the training data. "),
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--cache_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The directory where the downloaded models and datasets will be stored.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--image_column",
|
||||
type=str,
|
||||
default="image",
|
||||
help="The column of the dataset containing the target image. By "
|
||||
"default, the standard Image Dataset maps out 'file_name' "
|
||||
"to 'image'.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--caption_column",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The column of the dataset containing the instance prompt for each image",
|
||||
)
|
||||
|
||||
parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.")
|
||||
|
||||
parser.add_argument(
|
||||
"--class_data_dir",
|
||||
type=str,
|
||||
@@ -160,7 +270,7 @@ def parse_args(input_args=None):
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="The prompt with identifier specifying the instance",
|
||||
help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--class_prompt",
|
||||
@@ -299,9 +409,16 @@ def parse_args(input_args=None):
|
||||
parser.add_argument(
|
||||
"--learning_rate",
|
||||
type=float,
|
||||
default=5e-4,
|
||||
default=1e-4,
|
||||
help="Initial learning rate (after the potential warmup period) to use.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--text_encoder_lr",
|
||||
type=float,
|
||||
default=5e-6,
|
||||
help="Text encoder learning rate to use.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--scale_lr",
|
||||
action="store_true",
|
||||
@@ -317,6 +434,14 @@ def parse_args(input_args=None):
|
||||
' "constant", "constant_with_warmup"]'
|
||||
),
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--snr_gamma",
|
||||
type=float,
|
||||
default=None,
|
||||
help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
|
||||
"More details here: https://arxiv.org/abs/2303.09556.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
|
||||
)
|
||||
@@ -335,13 +460,59 @@ def parse_args(input_args=None):
|
||||
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
|
||||
),
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
|
||||
"--optimizer",
|
||||
type=str,
|
||||
default="AdamW",
|
||||
help=('The optimizer type to use. Choose between ["AdamW", "prodigy"]'),
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use_8bit_adam",
|
||||
action="store_true",
|
||||
help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam and Prodigy optimizers."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prodigy_beta3",
|
||||
type=float,
|
||||
default=None,
|
||||
help="coefficients for computing the Prodidy stepsize using running averages. If set to None, "
|
||||
"uses the value of square root of beta2. Ignored if optimizer is adamW",
|
||||
)
|
||||
parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay")
|
||||
parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params")
|
||||
parser.add_argument(
|
||||
"--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--adam_epsilon",
|
||||
type=float,
|
||||
default=1e-08,
|
||||
help="Epsilon value for the Adam optimizer and Prodigy optimizers.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--prodigy_use_bias_correction",
|
||||
type=bool,
|
||||
default=True,
|
||||
help="Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prodigy_safeguard_warmup",
|
||||
type=bool,
|
||||
default=True,
|
||||
help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. "
|
||||
"Ignored if optimizer is adamW",
|
||||
)
|
||||
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
|
||||
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
|
||||
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
|
||||
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
|
||||
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
||||
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
|
||||
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
|
||||
@@ -414,6 +585,12 @@ def parse_args(input_args=None):
|
||||
else:
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.dataset_name is None and args.instance_data_dir is None:
|
||||
raise ValueError("Specify either `--dataset_name` or `--instance_data_dir`")
|
||||
|
||||
if args.dataset_name is not None and args.instance_data_dir is not None:
|
||||
raise ValueError("Specify only one of `--dataset_name` or `--instance_data_dir`")
|
||||
|
||||
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
||||
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
||||
args.local_rank = env_local_rank
|
||||
@@ -442,20 +619,84 @@ class DreamBoothDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
instance_data_root,
|
||||
instance_prompt,
|
||||
class_prompt,
|
||||
class_data_root=None,
|
||||
class_num=None,
|
||||
size=1024,
|
||||
repeats=1,
|
||||
center_crop=False,
|
||||
):
|
||||
self.size = size
|
||||
self.center_crop = center_crop
|
||||
|
||||
self.instance_data_root = Path(instance_data_root)
|
||||
if not self.instance_data_root.exists():
|
||||
raise ValueError("Instance images root doesn't exists.")
|
||||
self.instance_prompt = instance_prompt
|
||||
self.custom_instance_prompts = None
|
||||
self.class_prompt = class_prompt
|
||||
|
||||
self.instance_images_path = list(Path(instance_data_root).iterdir())
|
||||
self.num_instance_images = len(self.instance_images_path)
|
||||
# if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory,
|
||||
# we load the training data using load_dataset
|
||||
if args.dataset_name is not None:
|
||||
try:
|
||||
from datasets import load_dataset
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"You are trying to load your data using the datasets library. If you wish to train using custom "
|
||||
"captions please install the datasets library: `pip install datasets`. If you wish to load a "
|
||||
"local folder containing images only, specify --instance_data_dir instead."
|
||||
)
|
||||
# Downloading and loading a dataset from the hub.
|
||||
# See more about loading custom images at
|
||||
# https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
|
||||
dataset = load_dataset(
|
||||
args.dataset_name,
|
||||
args.dataset_config_name,
|
||||
cache_dir=args.cache_dir,
|
||||
)
|
||||
# Preprocessing the datasets.
|
||||
column_names = dataset["train"].column_names
|
||||
|
||||
# 6. Get the column names for input/target.
|
||||
if args.image_column is None:
|
||||
image_column = column_names[0]
|
||||
logger.info(f"image column defaulting to {image_column}")
|
||||
else:
|
||||
image_column = args.image_column
|
||||
if image_column not in column_names:
|
||||
raise ValueError(
|
||||
f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
|
||||
)
|
||||
instance_images = dataset["train"][image_column]
|
||||
|
||||
if args.caption_column is None:
|
||||
logger.info(
|
||||
"No caption column provided, defaulting to instance_prompt for all images. If your dataset "
|
||||
"contains captions/prompts for the images, make sure to specify the "
|
||||
"column as --caption_column"
|
||||
)
|
||||
self.custom_instance_prompts = None
|
||||
else:
|
||||
if args.caption_column not in column_names:
|
||||
raise ValueError(
|
||||
f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
|
||||
)
|
||||
custom_instance_prompts = dataset["train"][args.caption_column]
|
||||
# create final list of captions according to --repeats
|
||||
self.custom_instance_prompts = []
|
||||
for caption in custom_instance_prompts:
|
||||
self.custom_instance_prompts.extend(itertools.repeat(caption, repeats))
|
||||
else:
|
||||
self.instance_data_root = Path(instance_data_root)
|
||||
if not self.instance_data_root.exists():
|
||||
raise ValueError("Instance images root doesn't exists.")
|
||||
|
||||
instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())]
|
||||
self.custom_instance_prompts = None
|
||||
|
||||
self.instance_images = []
|
||||
for img in instance_images:
|
||||
self.instance_images.extend(itertools.repeat(img, repeats))
|
||||
self.num_instance_images = len(self.instance_images)
|
||||
self._length = self.num_instance_images
|
||||
|
||||
if class_data_root is not None:
|
||||
@@ -484,13 +725,23 @@ class DreamBoothDataset(Dataset):
|
||||
|
||||
def __getitem__(self, index):
|
||||
example = {}
|
||||
instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
|
||||
instance_image = self.instance_images[index % self.num_instance_images]
|
||||
instance_image = exif_transpose(instance_image)
|
||||
|
||||
if not instance_image.mode == "RGB":
|
||||
instance_image = instance_image.convert("RGB")
|
||||
example["instance_images"] = self.image_transforms(instance_image)
|
||||
|
||||
if self.custom_instance_prompts:
|
||||
caption = self.custom_instance_prompts[index % self.num_instance_images]
|
||||
if caption:
|
||||
example["instance_prompt"] = caption
|
||||
else:
|
||||
example["instance_prompt"] = self.instance_prompt
|
||||
|
||||
else: # costum prompts were provided, but length does not match size of image dataset
|
||||
example["instance_prompt"] = self.instance_prompt
|
||||
|
||||
if self.class_data_root:
|
||||
class_image = Image.open(self.class_images_path[index % self.num_class_images])
|
||||
class_image = exif_transpose(class_image)
|
||||
@@ -498,22 +749,25 @@ class DreamBoothDataset(Dataset):
|
||||
if not class_image.mode == "RGB":
|
||||
class_image = class_image.convert("RGB")
|
||||
example["class_images"] = self.image_transforms(class_image)
|
||||
example["class_prompt"] = self.class_prompt
|
||||
|
||||
return example
|
||||
|
||||
|
||||
def collate_fn(examples, with_prior_preservation=False):
|
||||
pixel_values = [example["instance_images"] for example in examples]
|
||||
prompts = [example["instance_prompt"] for example in examples]
|
||||
|
||||
# Concat class and instance examples for prior preservation.
|
||||
# We do this to avoid doing two forward passes.
|
||||
if with_prior_preservation:
|
||||
pixel_values += [example["class_images"] for example in examples]
|
||||
prompts += [example["class_prompt"] for example in examples]
|
||||
|
||||
pixel_values = torch.stack(pixel_values)
|
||||
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
|
||||
|
||||
batch = {"pixel_values": pixel_values}
|
||||
batch = {"pixel_values": pixel_values, "prompts": prompts}
|
||||
return batch
|
||||
|
||||
|
||||
@@ -630,6 +884,7 @@ def main(args):
|
||||
args.pretrained_model_name_or_path,
|
||||
torch_dtype=torch_dtype,
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
)
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
@@ -668,10 +923,16 @@ def main(args):
|
||||
|
||||
# Load the tokenizers
|
||||
tokenizer_one = AutoTokenizer.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, use_fast=False
|
||||
args.pretrained_model_name_or_path,
|
||||
subfolder="tokenizer",
|
||||
revision=args.revision,
|
||||
use_fast=False,
|
||||
)
|
||||
tokenizer_two = AutoTokenizer.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.revision, use_fast=False
|
||||
args.pretrained_model_name_or_path,
|
||||
subfolder="tokenizer_2",
|
||||
revision=args.revision,
|
||||
use_fast=False,
|
||||
)
|
||||
|
||||
# import correct text encoder classes
|
||||
@@ -685,10 +946,10 @@ def main(args):
|
||||
# Load scheduler and models
|
||||
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
||||
text_encoder_one = text_encoder_cls_one.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
|
||||
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
|
||||
)
|
||||
text_encoder_two = text_encoder_cls_two.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision
|
||||
args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
|
||||
)
|
||||
vae_path = (
|
||||
args.pretrained_model_name_or_path
|
||||
@@ -696,10 +957,13 @@ def main(args):
|
||||
else args.pretrained_vae_model_name_or_path
|
||||
)
|
||||
vae = AutoencoderKL.from_pretrained(
|
||||
vae_path, subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, revision=args.revision
|
||||
vae_path,
|
||||
subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
)
|
||||
unet = UNet2DConditionModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
|
||||
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
|
||||
)
|
||||
|
||||
# We only train the additional adapter LoRA layers
|
||||
@@ -732,7 +996,8 @@ def main(args):
|
||||
xformers_version = version.parse(xformers.__version__)
|
||||
if xformers_version == version.parse("0.0.16"):
|
||||
logger.warn(
|
||||
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
||||
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, "
|
||||
"please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
||||
)
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
else:
|
||||
@@ -745,54 +1010,19 @@ def main(args):
|
||||
text_encoder_two.gradient_checkpointing_enable()
|
||||
|
||||
# now we will add new LoRA weights to the attention layers
|
||||
# Set correct lora layers
|
||||
unet_lora_parameters = []
|
||||
for attn_processor_name, attn_processor in unet.attn_processors.items():
|
||||
# Parse the attention module.
|
||||
attn_module = unet
|
||||
for n in attn_processor_name.split(".")[:-1]:
|
||||
attn_module = getattr(attn_module, n)
|
||||
|
||||
# Set the `lora_layer` attribute of the attention-related matrices.
|
||||
attn_module.to_q.set_lora_layer(
|
||||
LoRALinearLayer(
|
||||
in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=args.rank
|
||||
)
|
||||
)
|
||||
attn_module.to_k.set_lora_layer(
|
||||
LoRALinearLayer(
|
||||
in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=args.rank
|
||||
)
|
||||
)
|
||||
attn_module.to_v.set_lora_layer(
|
||||
LoRALinearLayer(
|
||||
in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=args.rank
|
||||
)
|
||||
)
|
||||
attn_module.to_out[0].set_lora_layer(
|
||||
LoRALinearLayer(
|
||||
in_features=attn_module.to_out[0].in_features,
|
||||
out_features=attn_module.to_out[0].out_features,
|
||||
rank=args.rank,
|
||||
)
|
||||
)
|
||||
|
||||
# Accumulate the LoRA params to optimize.
|
||||
unet_lora_parameters.extend(attn_module.to_q.lora_layer.parameters())
|
||||
unet_lora_parameters.extend(attn_module.to_k.lora_layer.parameters())
|
||||
unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters())
|
||||
unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters())
|
||||
unet_lora_config = LoraConfig(
|
||||
r=args.rank, init_lora_weights="gaussian", target_modules=["to_k", "to_q", "to_v", "to_out.0"]
|
||||
)
|
||||
unet.add_adapter(unet_lora_config)
|
||||
|
||||
# The text encoder comes from 🤗 transformers, so we cannot directly modify it.
|
||||
# So, instead, we monkey-patch the forward calls of its attention-blocks.
|
||||
if args.train_text_encoder:
|
||||
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
|
||||
text_lora_parameters_one = LoraLoaderMixin._modify_text_encoder(
|
||||
text_encoder_one, dtype=torch.float32, rank=args.rank
|
||||
)
|
||||
text_lora_parameters_two = LoraLoaderMixin._modify_text_encoder(
|
||||
text_encoder_two, dtype=torch.float32, rank=args.rank
|
||||
text_lora_config = LoraConfig(
|
||||
r=args.rank, init_lora_weights="gaussian", target_modules=["q_proj", "k_proj", "v_proj", "out_proj"]
|
||||
)
|
||||
text_encoder_one.add_adapter(text_lora_config)
|
||||
text_encoder_two.add_adapter(text_lora_config)
|
||||
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
def save_model_hook(models, weights, output_dir):
|
||||
@@ -805,11 +1035,11 @@ def main(args):
|
||||
|
||||
for model in models:
|
||||
if isinstance(model, type(accelerator.unwrap_model(unet))):
|
||||
unet_lora_layers_to_save = unet_lora_state_dict(model)
|
||||
unet_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
|
||||
text_encoder_one_lora_layers_to_save = text_encoder_lora_state_dict(model)
|
||||
text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
|
||||
text_encoder_two_lora_layers_to_save = text_encoder_lora_state_dict(model)
|
||||
text_encoder_two_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
@@ -866,35 +1096,109 @@ def main(args):
|
||||
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
|
||||
)
|
||||
|
||||
# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
|
||||
if args.use_8bit_adam:
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
|
||||
)
|
||||
unet_lora_parameters = list(filter(lambda p: p.requires_grad, unet.parameters()))
|
||||
|
||||
optimizer_class = bnb.optim.AdamW8bit
|
||||
if args.train_text_encoder:
|
||||
text_lora_parameters_one = list(filter(lambda p: p.requires_grad, text_encoder_one.parameters()))
|
||||
text_lora_parameters_two = list(filter(lambda p: p.requires_grad, text_encoder_two.parameters()))
|
||||
|
||||
# Optimization parameters
|
||||
unet_lora_parameters_with_lr = {"params": unet_lora_parameters, "lr": args.learning_rate}
|
||||
if args.train_text_encoder:
|
||||
# different learning rate for text encoder and unet
|
||||
text_lora_parameters_one_with_lr = {
|
||||
"params": text_lora_parameters_one,
|
||||
"weight_decay": args.adam_weight_decay_text_encoder,
|
||||
"lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,
|
||||
}
|
||||
text_lora_parameters_two_with_lr = {
|
||||
"params": text_lora_parameters_two,
|
||||
"weight_decay": args.adam_weight_decay_text_encoder,
|
||||
"lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,
|
||||
}
|
||||
params_to_optimize = [
|
||||
unet_lora_parameters_with_lr,
|
||||
text_lora_parameters_one_with_lr,
|
||||
text_lora_parameters_two_with_lr,
|
||||
]
|
||||
else:
|
||||
optimizer_class = torch.optim.AdamW
|
||||
params_to_optimize = [unet_lora_parameters_with_lr]
|
||||
|
||||
# Optimizer creation
|
||||
params_to_optimize = (
|
||||
itertools.chain(unet_lora_parameters, text_lora_parameters_one, text_lora_parameters_two)
|
||||
if args.train_text_encoder
|
||||
else unet_lora_parameters
|
||||
if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"):
|
||||
logger.warn(
|
||||
f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]."
|
||||
"Defaulting to adamW"
|
||||
)
|
||||
args.optimizer = "adamw"
|
||||
|
||||
if args.use_8bit_adam and not args.optimizer.lower() == "adamw":
|
||||
logger.warn(
|
||||
f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was "
|
||||
f"set to {args.optimizer.lower()}"
|
||||
)
|
||||
|
||||
if args.optimizer.lower() == "adamw":
|
||||
if args.use_8bit_adam:
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
|
||||
)
|
||||
|
||||
optimizer_class = bnb.optim.AdamW8bit
|
||||
else:
|
||||
optimizer_class = torch.optim.AdamW
|
||||
|
||||
optimizer = optimizer_class(
|
||||
params_to_optimize,
|
||||
betas=(args.adam_beta1, args.adam_beta2),
|
||||
weight_decay=args.adam_weight_decay,
|
||||
eps=args.adam_epsilon,
|
||||
)
|
||||
|
||||
if args.optimizer.lower() == "prodigy":
|
||||
try:
|
||||
import prodigyopt
|
||||
except ImportError:
|
||||
raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`")
|
||||
|
||||
optimizer_class = prodigyopt.Prodigy
|
||||
|
||||
optimizer = optimizer_class(
|
||||
params_to_optimize,
|
||||
lr=args.learning_rate,
|
||||
betas=(args.adam_beta1, args.adam_beta2),
|
||||
weight_decay=args.adam_weight_decay,
|
||||
eps=args.adam_epsilon,
|
||||
decouple=args.prodigy_decouple,
|
||||
use_bias_correction=args.prodigy_use_bias_correction,
|
||||
safeguard_warmup=args.prodigy_safeguard_warmup,
|
||||
)
|
||||
|
||||
# Dataset and DataLoaders creation:
|
||||
train_dataset = DreamBoothDataset(
|
||||
instance_data_root=args.instance_data_dir,
|
||||
instance_prompt=args.instance_prompt,
|
||||
class_prompt=args.class_prompt,
|
||||
class_data_root=args.class_data_dir if args.with_prior_preservation else None,
|
||||
class_num=args.num_class_images,
|
||||
size=args.resolution,
|
||||
repeats=args.repeats,
|
||||
center_crop=args.center_crop,
|
||||
)
|
||||
optimizer = optimizer_class(
|
||||
params_to_optimize,
|
||||
lr=args.learning_rate,
|
||||
betas=(args.adam_beta1, args.adam_beta2),
|
||||
weight_decay=args.adam_weight_decay,
|
||||
eps=args.adam_epsilon,
|
||||
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset,
|
||||
batch_size=args.train_batch_size,
|
||||
shuffle=True,
|
||||
collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
|
||||
num_workers=args.dataloader_num_workers,
|
||||
)
|
||||
|
||||
# Computes additional embeddings/ids required by the SDXL UNet.
|
||||
# regular text emebddings (when `train_text_encoder` is not True)
|
||||
# regular text embeddings (when `train_text_encoder` is not True)
|
||||
# pooled text embeddings
|
||||
# time ids
|
||||
|
||||
@@ -921,7 +1225,11 @@ def main(args):
|
||||
|
||||
# Handle instance prompt.
|
||||
instance_time_ids = compute_time_ids()
|
||||
if not args.train_text_encoder:
|
||||
|
||||
# If no type of tuning is done on the text_encoder and custom instance prompts are NOT
|
||||
# provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid
|
||||
# the redundant encoding.
|
||||
if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
|
||||
instance_prompt_hidden_states, instance_pooled_prompt_embeds = compute_text_embeddings(
|
||||
args.instance_prompt, text_encoders, tokenizers
|
||||
)
|
||||
@@ -934,49 +1242,36 @@ def main(args):
|
||||
args.class_prompt, text_encoders, tokenizers
|
||||
)
|
||||
|
||||
# Clear the memory here.
|
||||
if not args.train_text_encoder:
|
||||
# Clear the memory here
|
||||
if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
|
||||
del tokenizers, text_encoders
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Pack the statically computed variables appropriately. This is so that we don't
|
||||
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
|
||||
# pack the statically computed variables appropriately here. This is so that we don't
|
||||
# have to pass them to the dataloader.
|
||||
add_time_ids = instance_time_ids
|
||||
if args.with_prior_preservation:
|
||||
add_time_ids = torch.cat([add_time_ids, class_time_ids], dim=0)
|
||||
|
||||
if not args.train_text_encoder:
|
||||
prompt_embeds = instance_prompt_hidden_states
|
||||
unet_add_text_embeds = instance_pooled_prompt_embeds
|
||||
if args.with_prior_preservation:
|
||||
prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0)
|
||||
unet_add_text_embeds = torch.cat([unet_add_text_embeds, class_pooled_prompt_embeds], dim=0)
|
||||
else:
|
||||
tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt)
|
||||
tokens_two = tokenize_prompt(tokenizer_two, args.instance_prompt)
|
||||
if args.with_prior_preservation:
|
||||
class_tokens_one = tokenize_prompt(tokenizer_one, args.class_prompt)
|
||||
class_tokens_two = tokenize_prompt(tokenizer_two, args.class_prompt)
|
||||
tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0)
|
||||
tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)
|
||||
|
||||
# Dataset and DataLoaders creation:
|
||||
train_dataset = DreamBoothDataset(
|
||||
instance_data_root=args.instance_data_dir,
|
||||
class_data_root=args.class_data_dir if args.with_prior_preservation else None,
|
||||
class_num=args.num_class_images,
|
||||
size=args.resolution,
|
||||
center_crop=args.center_crop,
|
||||
)
|
||||
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset,
|
||||
batch_size=args.train_batch_size,
|
||||
shuffle=True,
|
||||
collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
|
||||
num_workers=args.dataloader_num_workers,
|
||||
)
|
||||
if not train_dataset.custom_instance_prompts:
|
||||
if not args.train_text_encoder:
|
||||
prompt_embeds = instance_prompt_hidden_states
|
||||
unet_add_text_embeds = instance_pooled_prompt_embeds
|
||||
if args.with_prior_preservation:
|
||||
prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0)
|
||||
unet_add_text_embeds = torch.cat([unet_add_text_embeds, class_pooled_prompt_embeds], dim=0)
|
||||
# if we're optmizing the text encoder (both if instance prompt is used for all images or custom prompts) we need to tokenize and encode the
|
||||
# batch prompts on all training steps
|
||||
else:
|
||||
tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt)
|
||||
tokens_two = tokenize_prompt(tokenizer_two, args.instance_prompt)
|
||||
if args.with_prior_preservation:
|
||||
class_tokens_one = tokenize_prompt(tokenizer_one, args.class_prompt)
|
||||
class_tokens_two = tokenize_prompt(tokenizer_two, args.class_prompt)
|
||||
tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0)
|
||||
tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)
|
||||
|
||||
# Scheduler and math around the number of training steps.
|
||||
overrode_max_train_steps = False
|
||||
@@ -1079,6 +1374,17 @@ def main(args):
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
with accelerator.accumulate(unet):
|
||||
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
|
||||
prompts = batch["prompts"]
|
||||
|
||||
# encode batch prompts when custom prompts are provided for each image -
|
||||
if train_dataset.custom_instance_prompts:
|
||||
if not args.train_text_encoder:
|
||||
prompt_embeds, unet_add_text_embeds = compute_text_embeddings(
|
||||
prompts, text_encoders, tokenizers
|
||||
)
|
||||
else:
|
||||
tokens_one = tokenize_prompt(tokenizer_one, prompts)
|
||||
tokens_two = tokenize_prompt(tokenizer_two, prompts)
|
||||
|
||||
# Convert images to latent space
|
||||
model_input = vae.encode(pixel_values).latent_dist.sample()
|
||||
@@ -1099,16 +1405,21 @@ def main(args):
|
||||
# (this is the forward diffusion process)
|
||||
noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
|
||||
|
||||
# Calculate the elements to repeat depending on the use of prior-preservation.
|
||||
elems_to_repeat = bsz // 2 if args.with_prior_preservation else bsz
|
||||
# Calculate the elements to repeat depending on the use of prior-preservation and custom captions.
|
||||
if not train_dataset.custom_instance_prompts:
|
||||
elems_to_repeat_text_embeds = bsz // 2 if args.with_prior_preservation else bsz
|
||||
elems_to_repeat_time_ids = bsz // 2 if args.with_prior_preservation else bsz
|
||||
else:
|
||||
elems_to_repeat_text_embeds = 1
|
||||
elems_to_repeat_time_ids = bsz // 2 if args.with_prior_preservation else bsz
|
||||
|
||||
# Predict the noise residual
|
||||
if not args.train_text_encoder:
|
||||
unet_added_conditions = {
|
||||
"time_ids": add_time_ids.repeat(elems_to_repeat, 1),
|
||||
"text_embeds": unet_add_text_embeds.repeat(elems_to_repeat, 1),
|
||||
"time_ids": add_time_ids.repeat(elems_to_repeat_time_ids, 1),
|
||||
"text_embeds": unet_add_text_embeds.repeat(elems_to_repeat_text_embeds, 1),
|
||||
}
|
||||
prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat, 1, 1)
|
||||
prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1)
|
||||
model_pred = unet(
|
||||
noisy_model_input,
|
||||
timesteps,
|
||||
@@ -1116,15 +1427,17 @@ def main(args):
|
||||
added_cond_kwargs=unet_added_conditions,
|
||||
).sample
|
||||
else:
|
||||
unet_added_conditions = {"time_ids": add_time_ids.repeat(elems_to_repeat, 1)}
|
||||
unet_added_conditions = {"time_ids": add_time_ids.repeat(elems_to_repeat_time_ids, 1)}
|
||||
prompt_embeds, pooled_prompt_embeds = encode_prompt(
|
||||
text_encoders=[text_encoder_one, text_encoder_two],
|
||||
tokenizers=None,
|
||||
prompt=None,
|
||||
text_input_ids_list=[tokens_one, tokens_two],
|
||||
)
|
||||
unet_added_conditions.update({"text_embeds": pooled_prompt_embeds.repeat(elems_to_repeat, 1)})
|
||||
prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat, 1, 1)
|
||||
unet_added_conditions.update(
|
||||
{"text_embeds": pooled_prompt_embeds.repeat(elems_to_repeat_text_embeds, 1)}
|
||||
)
|
||||
prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1)
|
||||
model_pred = unet(
|
||||
noisy_model_input, timesteps, prompt_embeds_input, added_cond_kwargs=unet_added_conditions
|
||||
).sample
|
||||
@@ -1142,16 +1455,34 @@ def main(args):
|
||||
model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
|
||||
target, target_prior = torch.chunk(target, 2, dim=0)
|
||||
|
||||
# Compute instance loss
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||
|
||||
# Compute prior loss
|
||||
prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
|
||||
|
||||
if args.snr_gamma is None:
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||
else:
|
||||
# Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
|
||||
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
|
||||
# This is discussed in Section 4.2 of the same paper.
|
||||
snr = compute_snr(noise_scheduler, timesteps)
|
||||
base_weight = (
|
||||
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
|
||||
)
|
||||
|
||||
if noise_scheduler.config.prediction_type == "v_prediction":
|
||||
# Velocity objective needs to be floored to an SNR weight of one.
|
||||
mse_loss_weights = base_weight + 1
|
||||
else:
|
||||
# Epsilon and sample both use the same loss weights.
|
||||
mse_loss_weights = base_weight
|
||||
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
|
||||
loss = loss.mean()
|
||||
|
||||
if args.with_prior_preservation:
|
||||
# Add the prior loss to the instance loss.
|
||||
loss = loss + args.prior_loss_weight * prior_loss
|
||||
else:
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients:
|
||||
@@ -1212,10 +1543,16 @@ def main(args):
|
||||
# create pipeline
|
||||
if not args.train_text_encoder:
|
||||
text_encoder_one = text_encoder_cls_one.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
|
||||
args.pretrained_model_name_or_path,
|
||||
subfolder="text_encoder",
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
)
|
||||
text_encoder_two = text_encoder_cls_two.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision
|
||||
args.pretrained_model_name_or_path,
|
||||
subfolder="text_encoder_2",
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
)
|
||||
pipeline = StableDiffusionXLPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
@@ -1224,6 +1561,7 @@ def main(args):
|
||||
text_encoder_2=accelerator.unwrap_model(text_encoder_two),
|
||||
unet=accelerator.unwrap_model(unet),
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
|
||||
@@ -1277,13 +1615,13 @@ def main(args):
|
||||
if accelerator.is_main_process:
|
||||
unet = accelerator.unwrap_model(unet)
|
||||
unet = unet.to(torch.float32)
|
||||
unet_lora_layers = unet_lora_state_dict(unet)
|
||||
unet_lora_layers = get_peft_model_state_dict(unet)
|
||||
|
||||
if args.train_text_encoder:
|
||||
text_encoder_one = accelerator.unwrap_model(text_encoder_one)
|
||||
text_encoder_lora_layers = text_encoder_lora_state_dict(text_encoder_one.to(torch.float32))
|
||||
text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one.to(torch.float32))
|
||||
text_encoder_two = accelerator.unwrap_model(text_encoder_two)
|
||||
text_encoder_2_lora_layers = text_encoder_lora_state_dict(text_encoder_two.to(torch.float32))
|
||||
text_encoder_2_lora_layers = get_peft_model_state_dict(text_encoder_two.to(torch.float32))
|
||||
else:
|
||||
text_encoder_lora_layers = None
|
||||
text_encoder_2_lora_layers = None
|
||||
@@ -1301,10 +1639,15 @@ def main(args):
|
||||
vae_path,
|
||||
subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
pipeline = StableDiffusionXLPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path, vae=vae, revision=args.revision, torch_dtype=weight_dtype
|
||||
args.pretrained_model_name_or_path,
|
||||
vae=vae,
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
|
||||
# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
|
||||
@@ -1353,7 +1696,8 @@ def main(args):
|
||||
images=images,
|
||||
base_model=args.pretrained_model_name_or_path,
|
||||
train_text_encoder=args.train_text_encoder,
|
||||
prompt=args.instance_prompt,
|
||||
instance_prompt=args.instance_prompt,
|
||||
validation_prompt=args.validation_prompt,
|
||||
repo_folder=args.output_dir,
|
||||
vae_path=args.pretrained_vae_model_name_or_path,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,101 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
|
||||
sys.path.append("..")
|
||||
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
logger = logging.getLogger()
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
|
||||
class InstructPix2Pix(ExamplesTestsAccelerate):
|
||||
def test_instruct_pix2pix_checkpointing_checkpoints_total_limit(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
examples/instruct_pix2pix/train_instruct_pix2pix.py
|
||||
--pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
|
||||
--dataset_name=hf-internal-testing/instructpix2pix-10-samples
|
||||
--resolution=64
|
||||
--random_flip
|
||||
--train_batch_size=1
|
||||
--max_train_steps=7
|
||||
--checkpointing_steps=2
|
||||
--checkpoints_total_limit=2
|
||||
--output_dir {tmpdir}
|
||||
--seed=0
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-4", "checkpoint-6"},
|
||||
)
|
||||
|
||||
def test_instruct_pix2pix_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
examples/instruct_pix2pix/train_instruct_pix2pix.py
|
||||
--pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
|
||||
--dataset_name=hf-internal-testing/instructpix2pix-10-samples
|
||||
--resolution=64
|
||||
--random_flip
|
||||
--train_batch_size=1
|
||||
--max_train_steps=9
|
||||
--checkpointing_steps=2
|
||||
--output_dir {tmpdir}
|
||||
--seed=0
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
|
||||
# check checkpoint directories exist
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"},
|
||||
)
|
||||
|
||||
resume_run_args = f"""
|
||||
examples/instruct_pix2pix/train_instruct_pix2pix.py
|
||||
--pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
|
||||
--dataset_name=hf-internal-testing/instructpix2pix-10-samples
|
||||
--resolution=64
|
||||
--random_flip
|
||||
--train_batch_size=1
|
||||
--max_train_steps=11
|
||||
--checkpointing_steps=2
|
||||
--output_dir {tmpdir}
|
||||
--seed=0
|
||||
--resume_from_checkpoint=checkpoint-8
|
||||
--checkpoints_total_limit=3
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + resume_run_args)
|
||||
|
||||
# check checkpoint directories exist
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-6", "checkpoint-8", "checkpoint-10"},
|
||||
)
|
||||
@@ -52,7 +52,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.24.0.dev0")
|
||||
check_min_version("0.25.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
@@ -78,6 +78,12 @@ def parse_args():
|
||||
required=False,
|
||||
help="Revision of pretrained model identifier from huggingface.co/models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--variant",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset_name",
|
||||
type=str,
|
||||
@@ -435,9 +441,11 @@ def main():
|
||||
args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
|
||||
)
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
|
||||
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
|
||||
)
|
||||
vae = AutoencoderKL.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
|
||||
)
|
||||
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
|
||||
unet = UNet2DConditionModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision
|
||||
)
|
||||
@@ -915,6 +923,7 @@ def main():
|
||||
text_encoder=accelerator.unwrap_model(text_encoder),
|
||||
vae=accelerator.unwrap_model(vae),
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
pipeline = pipeline.to(accelerator.device)
|
||||
@@ -966,6 +975,7 @@ def main():
|
||||
vae=accelerator.unwrap_model(vae),
|
||||
unet=unet,
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
)
|
||||
pipeline.save_pretrained(args.output_dir)
|
||||
|
||||
|
||||
@@ -55,7 +55,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.24.0.dev0")
|
||||
check_min_version("0.25.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
@@ -118,6 +118,12 @@ def parse_args():
|
||||
required=False,
|
||||
help="Revision of pretrained model identifier from huggingface.co/models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--variant",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset_name",
|
||||
type=str,
|
||||
@@ -484,9 +490,10 @@ def main():
|
||||
vae_path,
|
||||
subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
)
|
||||
unet = UNet2DConditionModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
|
||||
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
|
||||
)
|
||||
|
||||
# InstructPix2Pix uses an additional image for conditioning. To accommodate that,
|
||||
@@ -695,10 +702,16 @@ def main():
|
||||
|
||||
# Load scheduler, tokenizer and models.
|
||||
tokenizer_1 = AutoTokenizer.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, use_fast=False
|
||||
args.pretrained_model_name_or_path,
|
||||
subfolder="tokenizer",
|
||||
revision=args.revision,
|
||||
use_fast=False,
|
||||
)
|
||||
tokenizer_2 = AutoTokenizer.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.revision, use_fast=False
|
||||
args.pretrained_model_name_or_path,
|
||||
subfolder="tokenizer_2",
|
||||
revision=args.revision,
|
||||
use_fast=False,
|
||||
)
|
||||
text_encoder_cls_1 = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
|
||||
text_encoder_cls_2 = import_model_class_from_model_name_or_path(
|
||||
@@ -708,10 +721,10 @@ def main():
|
||||
# Load scheduler and models
|
||||
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
||||
text_encoder_1 = text_encoder_cls_1.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
|
||||
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
|
||||
)
|
||||
text_encoder_2 = text_encoder_cls_2.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision
|
||||
args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
|
||||
)
|
||||
|
||||
# We ALWAYS pre-compute the additional condition embeddings needed for SDXL
|
||||
@@ -1109,6 +1122,7 @@ def main():
|
||||
tokenizer_2=tokenizer_2,
|
||||
vae=vae,
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
pipeline = pipeline.to(accelerator.device)
|
||||
@@ -1176,6 +1190,7 @@ def main():
|
||||
vae=vae,
|
||||
unet=unet,
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
)
|
||||
pipeline.save_pretrained(args.output_dir)
|
||||
|
||||
|
||||
@@ -52,7 +52,7 @@ if is_wandb_available():
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.24.0.dev0")
|
||||
check_min_version("0.25.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -46,7 +46,7 @@ from diffusers.utils import check_min_version, is_wandb_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.24.0.dev0")
|
||||
check_min_version("0.25.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -46,7 +46,7 @@ from diffusers.utils import check_min_version, is_wandb_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.24.0.dev0")
|
||||
check_min_version("0.25.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -51,7 +51,7 @@ if is_wandb_available():
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.24.0.dev0")
|
||||
check_min_version("0.25.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# Research projects
|
||||
|
||||
This folder contains various research projects using 🧨 Diffusers.
|
||||
They are not really maintained by the core maintainers of this library and often require a specific version of Diffusers that is indicated in the requirements file of each folder.
|
||||
This folder contains various research projects using 🧨 Diffusers.
|
||||
They are not really maintained by the core maintainers of this library and often require a specific version of Diffusers that is indicated in the requirements file of each folder.
|
||||
Updating them to the most recent version of the library will require some work.
|
||||
|
||||
To use any of them, just run the command
|
||||
|
||||
@@ -420,7 +420,7 @@ def import_model_class_from_model_name_or_path(
|
||||
pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
|
||||
):
|
||||
text_encoder_config = PretrainedConfig.from_pretrained(
|
||||
pretrained_model_name_or_path, subfolder=subfolder, revision=revision, use_auth_token=True
|
||||
pretrained_model_name_or_path, subfolder=subfolder, revision=revision
|
||||
)
|
||||
model_class = text_encoder_config.architectures[0]
|
||||
|
||||
@@ -975,7 +975,7 @@ def main(args):
|
||||
revision=args.revision,
|
||||
)
|
||||
unet = UNet2DConditionModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, use_auth_token=True
|
||||
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
|
||||
)
|
||||
|
||||
if args.controlnet_model_name_or_path:
|
||||
|
||||
+5
-5
@@ -1,6 +1,6 @@
|
||||
## [Deprecated] Multi Token Textual Inversion
|
||||
|
||||
**IMPORTART: This research project is deprecated. Multi Token Textual Inversion is now supported natively in [the officail textual inversion example](https://github.com/huggingface/diffusers/tree/main/examples/textual_inversion#running-locally-with-pytorch).**
|
||||
**IMPORTART: This research project is deprecated. Multi Token Textual Inversion is now supported natively in [the official textual inversion example](https://github.com/huggingface/diffusers/tree/main/examples/textual_inversion#running-locally-with-pytorch).**
|
||||
|
||||
The author of this project is [Isamu Isozaki](https://github.com/isamu-isozaki) - please make sure to tag the author for issue and PRs as well as @patrickvonplaten.
|
||||
|
||||
@@ -17,9 +17,9 @@ Feel free to add these options to your training! In practice num_vec_per_token a
|
||||
[Textual inversion](https://arxiv.org/abs/2208.01618) is a method to personalize text2image models like stable diffusion on your own images using just 3-5 examples.
|
||||
The `textual_inversion.py` script shows how to implement the training procedure and adapt it for stable diffusion.
|
||||
|
||||
## Running on Colab
|
||||
## Running on Colab
|
||||
|
||||
Colab for training
|
||||
Colab for training
|
||||
[](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_textual_inversion_training.ipynb)
|
||||
|
||||
Colab for inference
|
||||
@@ -53,7 +53,7 @@ accelerate config
|
||||
|
||||
### Cat toy example
|
||||
|
||||
You need to accept the model license before downloading or using the weights. In this example we'll use model version `v1-5`, so you'll need to visit [its card](https://huggingface.co/runwayml/stable-diffusion-v1-5), read the license and tick the checkbox if you agree.
|
||||
You need to accept the model license before downloading or using the weights. In this example we'll use model version `v1-5`, so you'll need to visit [its card](https://huggingface.co/runwayml/stable-diffusion-v1-5), read the license and tick the checkbox if you agree.
|
||||
|
||||
You have to be a registered user in 🤗 Hugging Face Hub, and you'll also need to use an access token for the code to work. For more information on access tokens, please refer to [this section of the documentation](https://huggingface.co/docs/hub/security-tokens).
|
||||
|
||||
@@ -63,7 +63,7 @@ Run the following command to authenticate your token
|
||||
huggingface-cli login
|
||||
```
|
||||
|
||||
If you have already cloned the repo, then you won't need to go through these steps.
|
||||
If you have already cloned the repo, then you won't need to go through these steps.
|
||||
|
||||
<br>
|
||||
|
||||
@@ -2,4 +2,4 @@
|
||||
|
||||
**This research project is not actively maintained by the diffusers team. For any questions or comments, please contact Prathik Rao (prathikr), Sunghoon Choi (hanbitmyths), Ashwini Khade (askhade), or Peng Wang (pengwa) on github with any questions.**
|
||||
|
||||
This aims to provide diffusers examples with ONNXRuntime optimizations for training/fine-tuning unconditional image generation, text to image, and textual inversion. Please see individual directories for more details on how to run each task using ONNXRuntime.
|
||||
This aims to provide diffusers examples with ONNXRuntime optimizations for training/fine-tuning unconditional image generation, text to image, and textual inversion. Please see individual directories for more details on how to run each task using ONNXRuntime.
|
||||
|
||||
@@ -34,7 +34,7 @@ accelerate config
|
||||
|
||||
### Pokemon example
|
||||
|
||||
You need to accept the model license before downloading or using the weights. In this example we'll use model version `v1-4`, so you'll need to visit [its card](https://huggingface.co/CompVis/stable-diffusion-v1-4), read the license and tick the checkbox if you agree.
|
||||
You need to accept the model license before downloading or using the weights. In this example we'll use model version `v1-4`, so you'll need to visit [its card](https://huggingface.co/CompVis/stable-diffusion-v1-4), read the license and tick the checkbox if you agree.
|
||||
|
||||
You have to be a registered user in 🤗 Hugging Face Hub, and you'll also need to use an access token for the code to work. For more information on access tokens, please refer to [this section of the documentation](https://huggingface.co/docs/hub/security-tokens).
|
||||
|
||||
@@ -68,7 +68,7 @@ accelerate launch --mixed_precision="fp16" train_text_to_image.py \
|
||||
--learning_rate=1e-05 \
|
||||
--max_grad_norm=1 \
|
||||
--lr_scheduler="constant" --lr_warmup_steps=0 \
|
||||
--output_dir="sd-pokemon-model"
|
||||
--output_dir="sd-pokemon-model"
|
||||
```
|
||||
|
||||
Please contact Prathik Rao (prathikr), Sunghoon Choi (hanbitmyths), Ashwini Khade (askhade), or Peng Wang (pengwa) on github with any questions.
|
||||
@@ -3,9 +3,9 @@
|
||||
[Textual inversion](https://arxiv.org/abs/2208.01618) is a method to personalize text2image models like stable diffusion on your own images using just 3-5 examples.
|
||||
The `textual_inversion.py` script shows how to implement the training procedure and adapt it for stable diffusion.
|
||||
|
||||
## Running on Colab
|
||||
## Running on Colab
|
||||
|
||||
Colab for training
|
||||
Colab for training
|
||||
[](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_textual_inversion_training.ipynb)
|
||||
|
||||
Colab for inference
|
||||
@@ -39,7 +39,7 @@ accelerate config
|
||||
|
||||
### Cat toy example
|
||||
|
||||
You need to accept the model license before downloading or using the weights. In this example we'll use model version `v1-5`, so you'll need to visit [its card](https://huggingface.co/runwayml/stable-diffusion-v1-5), read the license and tick the checkbox if you agree.
|
||||
You need to accept the model license before downloading or using the weights. In this example we'll use model version `v1-5`, so you'll need to visit [its card](https://huggingface.co/runwayml/stable-diffusion-v1-5), read the license and tick the checkbox if you agree.
|
||||
|
||||
You have to be a registered user in 🤗 Hugging Face Hub, and you'll also need to use an access token for the code to work. For more information on access tokens, please refer to [this section of the documentation](https://huggingface.co/docs/hub/security-tokens).
|
||||
|
||||
@@ -49,7 +49,7 @@ Run the following command to authenticate your token
|
||||
huggingface-cli login
|
||||
```
|
||||
|
||||
If you have already cloned the repo, then you won't need to go through these steps.
|
||||
If you have already cloned the repo, then you won't need to go through these steps.
|
||||
|
||||
<br>
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user