Compare commits
22 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| cb69798b3d | |||
| 0229976ab5 | |||
| 8f1b207ffd | |||
| ccdd96ca52 | |||
| 4c723d8ec3 | |||
| bec2d8eaea | |||
| a0a51eb098 | |||
| a5a0ccf86a | |||
| dd07b19e27 | |||
| 57636ad4f4 | |||
| cefc2cf82d | |||
| b3e56e71fb | |||
| 5b5fa49a89 | |||
| decfa3c9e1 | |||
| 48305755bf | |||
| 7853bfbed7 | |||
| 23ebbb4bc8 | |||
| 1b456bd5d5 | |||
| af769881d3 | |||
| 99308efb55 | |||
| 5015ce4fc7 | |||
| 5ed984cc47 |
@@ -7,7 +7,7 @@ on:
|
||||
|
||||
env:
|
||||
DIFFUSERS_IS_CI: yes
|
||||
HF_HUB_ENABLE_HF_TRANSFER: 1
|
||||
HF_XET_HIGH_PERFORMANCE: 1
|
||||
HF_HOME: /mnt/cache
|
||||
OMP_NUM_THREADS: 8
|
||||
MKL_NUM_THREADS: 8
|
||||
|
||||
@@ -42,18 +42,39 @@ jobs:
|
||||
CHANGED_FILES: ${{ steps.file_changes.outputs.all }}
|
||||
run: |
|
||||
echo "$CHANGED_FILES"
|
||||
for FILE in $CHANGED_FILES; do
|
||||
ALLOWED_IMAGES=(
|
||||
diffusers-pytorch-cpu
|
||||
diffusers-pytorch-cuda
|
||||
diffusers-pytorch-xformers-cuda
|
||||
diffusers-pytorch-minimum-cuda
|
||||
diffusers-doc-builder
|
||||
)
|
||||
|
||||
declare -A IMAGES_TO_BUILD=()
|
||||
|
||||
for FILE in $CHANGED_FILES; do
|
||||
# skip anything that isn't still on disk
|
||||
if [[ ! -f "$FILE" ]]; then
|
||||
if [[ ! -e "$FILE" ]]; then
|
||||
echo "Skipping removed file $FILE"
|
||||
continue
|
||||
fi
|
||||
if [[ "$FILE" == docker/*Dockerfile ]]; then
|
||||
DOCKER_PATH="${FILE%/Dockerfile}"
|
||||
DOCKER_TAG=$(basename "$DOCKER_PATH")
|
||||
echo "Building Docker image for $DOCKER_TAG"
|
||||
docker build -t "$DOCKER_TAG" "$DOCKER_PATH"
|
||||
fi
|
||||
|
||||
for IMAGE in "${ALLOWED_IMAGES[@]}"; do
|
||||
if [[ "$FILE" == docker/${IMAGE}/* ]]; then
|
||||
IMAGES_TO_BUILD["$IMAGE"]=1
|
||||
fi
|
||||
done
|
||||
done
|
||||
|
||||
if [[ ${#IMAGES_TO_BUILD[@]} -eq 0 ]]; then
|
||||
echo "No relevant Docker changes detected."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
for IMAGE in "${!IMAGES_TO_BUILD[@]}"; do
|
||||
DOCKER_PATH="docker/${IMAGE}"
|
||||
echo "Building Docker image for $IMAGE"
|
||||
docker build -t "$IMAGE" "$DOCKER_PATH"
|
||||
done
|
||||
if: steps.file_changes.outputs.all != ''
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ on:
|
||||
|
||||
env:
|
||||
DIFFUSERS_IS_CI: yes
|
||||
HF_HUB_ENABLE_HF_TRANSFER: 1
|
||||
HF_XET_HIGH_PERFORMANCE: 1
|
||||
OMP_NUM_THREADS: 8
|
||||
MKL_NUM_THREADS: 8
|
||||
PYTEST_TIMEOUT: 600
|
||||
|
||||
@@ -26,7 +26,7 @@ concurrency:
|
||||
|
||||
env:
|
||||
DIFFUSERS_IS_CI: yes
|
||||
HF_HUB_ENABLE_HF_TRANSFER: 1
|
||||
HF_XET_HIGH_PERFORMANCE: 1
|
||||
OMP_NUM_THREADS: 4
|
||||
MKL_NUM_THREADS: 4
|
||||
PYTEST_TIMEOUT: 60
|
||||
|
||||
@@ -22,7 +22,7 @@ concurrency:
|
||||
|
||||
env:
|
||||
DIFFUSERS_IS_CI: yes
|
||||
HF_HUB_ENABLE_HF_TRANSFER: 1
|
||||
HF_XET_HIGH_PERFORMANCE: 1
|
||||
OMP_NUM_THREADS: 4
|
||||
MKL_NUM_THREADS: 4
|
||||
PYTEST_TIMEOUT: 60
|
||||
|
||||
@@ -24,7 +24,7 @@ env:
|
||||
DIFFUSERS_IS_CI: yes
|
||||
OMP_NUM_THREADS: 8
|
||||
MKL_NUM_THREADS: 8
|
||||
HF_HUB_ENABLE_HF_TRANSFER: 1
|
||||
HF_XET_HIGH_PERFORMANCE: 1
|
||||
PYTEST_TIMEOUT: 600
|
||||
PIPELINE_USAGE_CUTOFF: 1000000000 # set high cutoff so that only always-test pipelines run
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ env:
|
||||
DIFFUSERS_IS_CI: yes
|
||||
OMP_NUM_THREADS: 8
|
||||
MKL_NUM_THREADS: 8
|
||||
HF_HUB_ENABLE_HF_TRANSFER: 1
|
||||
HF_XET_HIGH_PERFORMANCE: 1
|
||||
PYTEST_TIMEOUT: 600
|
||||
PIPELINE_USAGE_CUTOFF: 50000
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ env:
|
||||
HF_HOME: /mnt/cache
|
||||
OMP_NUM_THREADS: 8
|
||||
MKL_NUM_THREADS: 8
|
||||
HF_HUB_ENABLE_HF_TRANSFER: 1
|
||||
HF_XET_HIGH_PERFORMANCE: 1
|
||||
PYTEST_TIMEOUT: 600
|
||||
RUN_SLOW: no
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ env:
|
||||
HF_HOME: /mnt/cache
|
||||
OMP_NUM_THREADS: 8
|
||||
MKL_NUM_THREADS: 8
|
||||
HF_HUB_ENABLE_HF_TRANSFER: 1
|
||||
HF_XET_HIGH_PERFORMANCE: 1
|
||||
PYTEST_TIMEOUT: 600
|
||||
RUN_SLOW: no
|
||||
|
||||
|
||||
@@ -171,7 +171,7 @@ Also, say 👋 in our public Discord channel <a href="https://discord.gg/G7tWnz9
|
||||
<tr style="border-top: 2px solid black">
|
||||
<td>Text-guided Image Inpainting</td>
|
||||
<td><a href="https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/inpaint">Stable Diffusion Inpainting</a></td>
|
||||
<td><a href="https://huggingface.co/runwayml/stable-diffusion-inpainting"> runwayml/stable-diffusion-inpainting </a></td>
|
||||
<td><a href="https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-inpainting"> stable-diffusion-v1-5/stable-diffusion-inpainting </a></td>
|
||||
</tr>
|
||||
<tr style="border-top: 2px solid black">
|
||||
<td>Image Variation</td>
|
||||
|
||||
@@ -33,7 +33,7 @@ RUN uv pip install --no-cache-dir "git+https://github.com/huggingface/diffusers.
|
||||
RUN uv pip install --no-cache-dir \
|
||||
accelerate \
|
||||
numpy==1.26.4 \
|
||||
hf_transfer \
|
||||
hf_xet \
|
||||
setuptools==69.5.1 \
|
||||
bitsandbytes \
|
||||
torchao \
|
||||
|
||||
@@ -44,6 +44,6 @@ RUN python3 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
|
||||
scipy \
|
||||
tensorboard \
|
||||
transformers \
|
||||
hf_transfer
|
||||
hf_xet
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
@@ -38,13 +38,12 @@ RUN python3.10 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
|
||||
datasets \
|
||||
hf-doc-builder \
|
||||
huggingface-hub \
|
||||
hf_transfer \
|
||||
hf_xet \
|
||||
Jinja2 \
|
||||
librosa \
|
||||
numpy==1.26.4 \
|
||||
scipy \
|
||||
tensorboard \
|
||||
transformers \
|
||||
hf_transfer
|
||||
transformers
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
@@ -31,7 +31,7 @@ RUN uv pip install --no-cache-dir "git+https://github.com/huggingface/diffusers.
|
||||
RUN uv pip install --no-cache-dir \
|
||||
accelerate \
|
||||
numpy==1.26.4 \
|
||||
hf_transfer
|
||||
hf_xet
|
||||
|
||||
RUN apt-get clean && rm -rf /var/lib/apt/lists/* && apt-get autoremove && apt-get autoclean
|
||||
|
||||
|
||||
@@ -44,6 +44,6 @@ RUN uv pip install --no-cache-dir \
|
||||
accelerate \
|
||||
numpy==1.26.4 \
|
||||
pytorch-lightning \
|
||||
hf_transfer
|
||||
hf_xet
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
|
||||
@@ -47,6 +47,6 @@ RUN uv pip install --no-cache-dir \
|
||||
accelerate \
|
||||
numpy==1.26.4 \
|
||||
pytorch-lightning \
|
||||
hf_transfer
|
||||
hf_xet
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
|
||||
@@ -44,7 +44,7 @@ RUN uv pip install --no-cache-dir \
|
||||
accelerate \
|
||||
numpy==1.26.4 \
|
||||
pytorch-lightning \
|
||||
hf_transfer \
|
||||
hf_xet \
|
||||
xformers
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
|
||||
+260
-263
@@ -1,5 +1,4 @@
|
||||
- title: Get started
|
||||
sections:
|
||||
- sections:
|
||||
- local: index
|
||||
title: Diffusers
|
||||
- local: installation
|
||||
@@ -8,9 +7,8 @@
|
||||
title: Quickstart
|
||||
- local: stable_diffusion
|
||||
title: Basic performance
|
||||
|
||||
- title: Pipelines
|
||||
isExpanded: false
|
||||
title: Get started
|
||||
- isExpanded: false
|
||||
sections:
|
||||
- local: using-diffusers/loading
|
||||
title: DiffusionPipeline
|
||||
@@ -28,9 +26,8 @@
|
||||
title: Model formats
|
||||
- local: using-diffusers/push_to_hub
|
||||
title: Sharing pipelines and models
|
||||
|
||||
- title: Adapters
|
||||
isExpanded: false
|
||||
title: Pipelines
|
||||
- isExpanded: false
|
||||
sections:
|
||||
- local: tutorials/using_peft_for_inference
|
||||
title: LoRA
|
||||
@@ -44,9 +41,8 @@
|
||||
title: DreamBooth
|
||||
- local: using-diffusers/textual_inversion_inference
|
||||
title: Textual inversion
|
||||
|
||||
- title: Inference
|
||||
isExpanded: false
|
||||
title: Adapters
|
||||
- isExpanded: false
|
||||
sections:
|
||||
- local: using-diffusers/weighted_prompts
|
||||
title: Prompting
|
||||
@@ -56,9 +52,8 @@
|
||||
title: Batch inference
|
||||
- local: training/distributed_inference
|
||||
title: Distributed inference
|
||||
|
||||
- title: Inference optimization
|
||||
isExpanded: false
|
||||
title: Inference
|
||||
- isExpanded: false
|
||||
sections:
|
||||
- local: optimization/fp16
|
||||
title: Accelerate inference
|
||||
@@ -70,8 +65,7 @@
|
||||
title: Reduce memory usage
|
||||
- local: optimization/speed-memory-optims
|
||||
title: Compiling and offloading quantized models
|
||||
- title: Community optimizations
|
||||
sections:
|
||||
- sections:
|
||||
- local: optimization/pruna
|
||||
title: Pruna
|
||||
- local: optimization/xformers
|
||||
@@ -90,9 +84,9 @@
|
||||
title: ParaAttention
|
||||
- local: using-diffusers/image_quality
|
||||
title: FreeU
|
||||
|
||||
- title: Hybrid Inference
|
||||
isExpanded: false
|
||||
title: Community optimizations
|
||||
title: Inference optimization
|
||||
- isExpanded: false
|
||||
sections:
|
||||
- local: hybrid_inference/overview
|
||||
title: Overview
|
||||
@@ -102,9 +96,8 @@
|
||||
title: VAE Encode
|
||||
- local: hybrid_inference/api_reference
|
||||
title: API Reference
|
||||
|
||||
- title: Modular Diffusers
|
||||
isExpanded: false
|
||||
title: Hybrid Inference
|
||||
- isExpanded: false
|
||||
sections:
|
||||
- local: modular_diffusers/overview
|
||||
title: Overview
|
||||
@@ -126,9 +119,8 @@
|
||||
title: ComponentsManager
|
||||
- local: modular_diffusers/guiders
|
||||
title: Guiders
|
||||
|
||||
- title: Training
|
||||
isExpanded: false
|
||||
title: Modular Diffusers
|
||||
- isExpanded: false
|
||||
sections:
|
||||
- local: training/overview
|
||||
title: Overview
|
||||
@@ -138,8 +130,7 @@
|
||||
title: Adapt a model to a new task
|
||||
- local: tutorials/basic_training
|
||||
title: Train a diffusion model
|
||||
- title: Models
|
||||
sections:
|
||||
- sections:
|
||||
- local: training/unconditional_training
|
||||
title: Unconditional image generation
|
||||
- local: training/text2image
|
||||
@@ -158,8 +149,8 @@
|
||||
title: InstructPix2Pix
|
||||
- local: training/cogvideox
|
||||
title: CogVideoX
|
||||
- title: Methods
|
||||
sections:
|
||||
title: Models
|
||||
- sections:
|
||||
- local: training/text_inversion
|
||||
title: Textual Inversion
|
||||
- local: training/dreambooth
|
||||
@@ -172,9 +163,9 @@
|
||||
title: Latent Consistency Distillation
|
||||
- local: training/ddpo
|
||||
title: Reinforcement learning training with DDPO
|
||||
|
||||
- title: Quantization
|
||||
isExpanded: false
|
||||
title: Methods
|
||||
title: Training
|
||||
- isExpanded: false
|
||||
sections:
|
||||
- local: quantization/overview
|
||||
title: Getting started
|
||||
@@ -188,9 +179,8 @@
|
||||
title: quanto
|
||||
- local: quantization/modelopt
|
||||
title: NVIDIA ModelOpt
|
||||
|
||||
- title: Model accelerators and hardware
|
||||
isExpanded: false
|
||||
title: Quantization
|
||||
- isExpanded: false
|
||||
sections:
|
||||
- local: optimization/onnx
|
||||
title: ONNX
|
||||
@@ -204,9 +194,8 @@
|
||||
title: Intel Gaudi
|
||||
- local: optimization/neuron
|
||||
title: AWS Neuron
|
||||
|
||||
- title: Specific pipeline examples
|
||||
isExpanded: false
|
||||
title: Model accelerators and hardware
|
||||
- isExpanded: false
|
||||
sections:
|
||||
- local: using-diffusers/consisid
|
||||
title: ConsisID
|
||||
@@ -232,12 +221,10 @@
|
||||
title: Stable Video Diffusion
|
||||
- local: using-diffusers/marigold_usage
|
||||
title: Marigold Computer Vision
|
||||
|
||||
- title: Resources
|
||||
isExpanded: false
|
||||
title: Specific pipeline examples
|
||||
- isExpanded: false
|
||||
sections:
|
||||
- title: Task recipes
|
||||
sections:
|
||||
- sections:
|
||||
- local: using-diffusers/unconditional_image_generation
|
||||
title: Unconditional image generation
|
||||
- local: using-diffusers/conditional_image_generation
|
||||
@@ -252,6 +239,7 @@
|
||||
title: Video generation
|
||||
- local: using-diffusers/depth2img
|
||||
title: Depth-to-image
|
||||
title: Task recipes
|
||||
- local: using-diffusers/write_own_pipeline
|
||||
title: Understanding pipelines, models and schedulers
|
||||
- local: community_projects
|
||||
@@ -266,12 +254,10 @@
|
||||
title: Diffusers' Ethical Guidelines
|
||||
- local: conceptual/evaluation
|
||||
title: Evaluating Diffusion Models
|
||||
|
||||
- title: API
|
||||
isExpanded: false
|
||||
title: Resources
|
||||
- isExpanded: false
|
||||
sections:
|
||||
- title: Main Classes
|
||||
sections:
|
||||
- sections:
|
||||
- local: api/configuration
|
||||
title: Configuration
|
||||
- local: api/logging
|
||||
@@ -282,8 +268,8 @@
|
||||
title: Quantization
|
||||
- local: api/parallel
|
||||
title: Parallel inference
|
||||
- title: Modular
|
||||
sections:
|
||||
title: Main Classes
|
||||
- sections:
|
||||
- local: api/modular_diffusers/pipeline
|
||||
title: Pipeline
|
||||
- local: api/modular_diffusers/pipeline_blocks
|
||||
@@ -294,8 +280,8 @@
|
||||
title: Components and configs
|
||||
- local: api/modular_diffusers/guiders
|
||||
title: Guiders
|
||||
- title: Loaders
|
||||
sections:
|
||||
title: Modular
|
||||
- sections:
|
||||
- local: api/loaders/ip_adapter
|
||||
title: IP-Adapter
|
||||
- local: api/loaders/lora
|
||||
@@ -310,14 +296,13 @@
|
||||
title: SD3Transformer2D
|
||||
- local: api/loaders/peft
|
||||
title: PEFT
|
||||
- title: Models
|
||||
sections:
|
||||
title: Loaders
|
||||
- sections:
|
||||
- local: api/models/overview
|
||||
title: Overview
|
||||
- local: api/models/auto_model
|
||||
title: AutoModel
|
||||
- title: ControlNets
|
||||
sections:
|
||||
- sections:
|
||||
- local: api/models/controlnet
|
||||
title: ControlNetModel
|
||||
- local: api/models/controlnet_union
|
||||
@@ -332,8 +317,8 @@
|
||||
title: SD3ControlNetModel
|
||||
- local: api/models/controlnet_sparsectrl
|
||||
title: SparseControlNetModel
|
||||
- title: Transformers
|
||||
sections:
|
||||
title: ControlNets
|
||||
- sections:
|
||||
- local: api/models/allegro_transformer3d
|
||||
title: AllegroTransformer3DModel
|
||||
- local: api/models/aura_flow_transformer2d
|
||||
@@ -396,8 +381,8 @@
|
||||
title: TransformerTemporalModel
|
||||
- local: api/models/wan_transformer_3d
|
||||
title: WanTransformer3DModel
|
||||
- title: UNets
|
||||
sections:
|
||||
title: Transformers
|
||||
- sections:
|
||||
- local: api/models/stable_cascade_unet
|
||||
title: StableCascadeUNet
|
||||
- local: api/models/unet
|
||||
@@ -412,8 +397,8 @@
|
||||
title: UNetMotionModel
|
||||
- local: api/models/uvit2d
|
||||
title: UViT2DModel
|
||||
- title: VAEs
|
||||
sections:
|
||||
title: UNets
|
||||
- sections:
|
||||
- local: api/models/asymmetricautoencoderkl
|
||||
title: AsymmetricAutoencoderKL
|
||||
- local: api/models/autoencoder_dc
|
||||
@@ -446,210 +431,220 @@
|
||||
title: Tiny AutoEncoder
|
||||
- local: api/models/vq
|
||||
title: VQModel
|
||||
- title: Pipelines
|
||||
sections:
|
||||
title: VAEs
|
||||
title: Models
|
||||
- sections:
|
||||
- local: api/pipelines/overview
|
||||
title: Overview
|
||||
- local: api/pipelines/allegro
|
||||
title: Allegro
|
||||
- local: api/pipelines/amused
|
||||
title: aMUSEd
|
||||
- local: api/pipelines/animatediff
|
||||
title: AnimateDiff
|
||||
- local: api/pipelines/attend_and_excite
|
||||
title: Attend-and-Excite
|
||||
- local: api/pipelines/audioldm
|
||||
title: AudioLDM
|
||||
- local: api/pipelines/audioldm2
|
||||
title: AudioLDM 2
|
||||
- local: api/pipelines/aura_flow
|
||||
title: AuraFlow
|
||||
- sections:
|
||||
- local: api/pipelines/audioldm
|
||||
title: AudioLDM
|
||||
- local: api/pipelines/audioldm2
|
||||
title: AudioLDM 2
|
||||
- local: api/pipelines/dance_diffusion
|
||||
title: Dance Diffusion
|
||||
- local: api/pipelines/musicldm
|
||||
title: MusicLDM
|
||||
- local: api/pipelines/stable_audio
|
||||
title: Stable Audio
|
||||
title: Audio
|
||||
- local: api/pipelines/auto_pipeline
|
||||
title: AutoPipeline
|
||||
- local: api/pipelines/blip_diffusion
|
||||
title: BLIP-Diffusion
|
||||
- local: api/pipelines/bria_3_2
|
||||
title: Bria 3.2
|
||||
- local: api/pipelines/chroma
|
||||
title: Chroma
|
||||
- local: api/pipelines/cogvideox
|
||||
title: CogVideoX
|
||||
- local: api/pipelines/cogview3
|
||||
title: CogView3
|
||||
- local: api/pipelines/cogview4
|
||||
title: CogView4
|
||||
- local: api/pipelines/consisid
|
||||
title: ConsisID
|
||||
- local: api/pipelines/consistency_models
|
||||
title: Consistency Models
|
||||
- local: api/pipelines/controlnet
|
||||
title: ControlNet
|
||||
- local: api/pipelines/controlnet_flux
|
||||
title: ControlNet with Flux.1
|
||||
- local: api/pipelines/controlnet_hunyuandit
|
||||
title: ControlNet with Hunyuan-DiT
|
||||
- local: api/pipelines/controlnet_sd3
|
||||
title: ControlNet with Stable Diffusion 3
|
||||
- local: api/pipelines/controlnet_sdxl
|
||||
title: ControlNet with Stable Diffusion XL
|
||||
- local: api/pipelines/controlnet_sana
|
||||
title: ControlNet-Sana
|
||||
- local: api/pipelines/controlnetxs
|
||||
title: ControlNet-XS
|
||||
- local: api/pipelines/controlnetxs_sdxl
|
||||
title: ControlNet-XS with Stable Diffusion XL
|
||||
- local: api/pipelines/controlnet_union
|
||||
title: ControlNetUnion
|
||||
- local: api/pipelines/cosmos
|
||||
title: Cosmos
|
||||
- local: api/pipelines/dance_diffusion
|
||||
title: Dance Diffusion
|
||||
- local: api/pipelines/ddim
|
||||
title: DDIM
|
||||
- local: api/pipelines/ddpm
|
||||
title: DDPM
|
||||
- local: api/pipelines/deepfloyd_if
|
||||
title: DeepFloyd IF
|
||||
- local: api/pipelines/diffedit
|
||||
title: DiffEdit
|
||||
- local: api/pipelines/dit
|
||||
title: DiT
|
||||
- local: api/pipelines/easyanimate
|
||||
title: EasyAnimate
|
||||
- local: api/pipelines/flux
|
||||
title: Flux
|
||||
- local: api/pipelines/control_flux_inpaint
|
||||
title: FluxControlInpaint
|
||||
- local: api/pipelines/framepack
|
||||
title: Framepack
|
||||
- local: api/pipelines/hidream
|
||||
title: HiDream-I1
|
||||
- local: api/pipelines/hunyuandit
|
||||
title: Hunyuan-DiT
|
||||
- local: api/pipelines/hunyuan_video
|
||||
title: HunyuanVideo
|
||||
- local: api/pipelines/i2vgenxl
|
||||
title: I2VGen-XL
|
||||
- local: api/pipelines/pix2pix
|
||||
title: InstructPix2Pix
|
||||
- local: api/pipelines/kandinsky
|
||||
title: Kandinsky 2.1
|
||||
- local: api/pipelines/kandinsky_v22
|
||||
title: Kandinsky 2.2
|
||||
- local: api/pipelines/kandinsky3
|
||||
title: Kandinsky 3
|
||||
- local: api/pipelines/kolors
|
||||
title: Kolors
|
||||
- local: api/pipelines/latent_consistency_models
|
||||
title: Latent Consistency Models
|
||||
- local: api/pipelines/latent_diffusion
|
||||
title: Latent Diffusion
|
||||
- local: api/pipelines/latte
|
||||
title: Latte
|
||||
- local: api/pipelines/ledits_pp
|
||||
title: LEDITS++
|
||||
- local: api/pipelines/ltx_video
|
||||
title: LTXVideo
|
||||
- local: api/pipelines/lumina2
|
||||
title: Lumina 2.0
|
||||
- local: api/pipelines/lumina
|
||||
title: Lumina-T2X
|
||||
- local: api/pipelines/marigold
|
||||
title: Marigold
|
||||
- local: api/pipelines/mochi
|
||||
title: Mochi
|
||||
- local: api/pipelines/panorama
|
||||
title: MultiDiffusion
|
||||
- local: api/pipelines/musicldm
|
||||
title: MusicLDM
|
||||
- local: api/pipelines/omnigen
|
||||
title: OmniGen
|
||||
- local: api/pipelines/pag
|
||||
title: PAG
|
||||
- local: api/pipelines/paint_by_example
|
||||
title: Paint by Example
|
||||
- local: api/pipelines/pia
|
||||
title: Personalized Image Animator (PIA)
|
||||
- local: api/pipelines/pixart
|
||||
title: PixArt-α
|
||||
- local: api/pipelines/pixart_sigma
|
||||
title: PixArt-Σ
|
||||
- local: api/pipelines/qwenimage
|
||||
title: QwenImage
|
||||
- local: api/pipelines/sana
|
||||
title: Sana
|
||||
- local: api/pipelines/sana_sprint
|
||||
title: Sana Sprint
|
||||
- local: api/pipelines/self_attention_guidance
|
||||
title: Self-Attention Guidance
|
||||
- local: api/pipelines/semantic_stable_diffusion
|
||||
title: Semantic Guidance
|
||||
- local: api/pipelines/shap_e
|
||||
title: Shap-E
|
||||
- local: api/pipelines/skyreels_v2
|
||||
title: SkyReels-V2
|
||||
- local: api/pipelines/stable_audio
|
||||
title: Stable Audio
|
||||
- local: api/pipelines/stable_cascade
|
||||
title: Stable Cascade
|
||||
- title: Stable Diffusion
|
||||
sections:
|
||||
- local: api/pipelines/stable_diffusion/overview
|
||||
title: Overview
|
||||
- local: api/pipelines/stable_diffusion/depth2img
|
||||
title: Depth-to-image
|
||||
- local: api/pipelines/stable_diffusion/gligen
|
||||
title: GLIGEN (Grounded Language-to-Image Generation)
|
||||
- local: api/pipelines/stable_diffusion/image_variation
|
||||
title: Image variation
|
||||
- local: api/pipelines/stable_diffusion/img2img
|
||||
title: Image-to-image
|
||||
- sections:
|
||||
- local: api/pipelines/amused
|
||||
title: aMUSEd
|
||||
- local: api/pipelines/animatediff
|
||||
title: AnimateDiff
|
||||
- local: api/pipelines/attend_and_excite
|
||||
title: Attend-and-Excite
|
||||
- local: api/pipelines/aura_flow
|
||||
title: AuraFlow
|
||||
- local: api/pipelines/blip_diffusion
|
||||
title: BLIP-Diffusion
|
||||
- local: api/pipelines/bria_3_2
|
||||
title: Bria 3.2
|
||||
- local: api/pipelines/chroma
|
||||
title: Chroma
|
||||
- local: api/pipelines/cogview3
|
||||
title: CogView3
|
||||
- local: api/pipelines/cogview4
|
||||
title: CogView4
|
||||
- local: api/pipelines/consistency_models
|
||||
title: Consistency Models
|
||||
- local: api/pipelines/controlnet
|
||||
title: ControlNet
|
||||
- local: api/pipelines/controlnet_flux
|
||||
title: ControlNet with Flux.1
|
||||
- local: api/pipelines/controlnet_hunyuandit
|
||||
title: ControlNet with Hunyuan-DiT
|
||||
- local: api/pipelines/controlnet_sd3
|
||||
title: ControlNet with Stable Diffusion 3
|
||||
- local: api/pipelines/controlnet_sdxl
|
||||
title: ControlNet with Stable Diffusion XL
|
||||
- local: api/pipelines/controlnet_sana
|
||||
title: ControlNet-Sana
|
||||
- local: api/pipelines/controlnetxs
|
||||
title: ControlNet-XS
|
||||
- local: api/pipelines/controlnetxs_sdxl
|
||||
title: ControlNet-XS with Stable Diffusion XL
|
||||
- local: api/pipelines/controlnet_union
|
||||
title: ControlNetUnion
|
||||
- local: api/pipelines/cosmos
|
||||
title: Cosmos
|
||||
- local: api/pipelines/ddim
|
||||
title: DDIM
|
||||
- local: api/pipelines/ddpm
|
||||
title: DDPM
|
||||
- local: api/pipelines/deepfloyd_if
|
||||
title: DeepFloyd IF
|
||||
- local: api/pipelines/diffedit
|
||||
title: DiffEdit
|
||||
- local: api/pipelines/dit
|
||||
title: DiT
|
||||
- local: api/pipelines/easyanimate
|
||||
title: EasyAnimate
|
||||
- local: api/pipelines/flux
|
||||
title: Flux
|
||||
- local: api/pipelines/control_flux_inpaint
|
||||
title: FluxControlInpaint
|
||||
- local: api/pipelines/hidream
|
||||
title: HiDream-I1
|
||||
- local: api/pipelines/hunyuandit
|
||||
title: Hunyuan-DiT
|
||||
- local: api/pipelines/pix2pix
|
||||
title: InstructPix2Pix
|
||||
- local: api/pipelines/kandinsky
|
||||
title: Kandinsky 2.1
|
||||
- local: api/pipelines/kandinsky_v22
|
||||
title: Kandinsky 2.2
|
||||
- local: api/pipelines/kandinsky3
|
||||
title: Kandinsky 3
|
||||
- local: api/pipelines/kolors
|
||||
title: Kolors
|
||||
- local: api/pipelines/latent_consistency_models
|
||||
title: Latent Consistency Models
|
||||
- local: api/pipelines/latent_diffusion
|
||||
title: Latent Diffusion
|
||||
- local: api/pipelines/ledits_pp
|
||||
title: LEDITS++
|
||||
- local: api/pipelines/lumina2
|
||||
title: Lumina 2.0
|
||||
- local: api/pipelines/lumina
|
||||
title: Lumina-T2X
|
||||
- local: api/pipelines/marigold
|
||||
title: Marigold
|
||||
- local: api/pipelines/panorama
|
||||
title: MultiDiffusion
|
||||
- local: api/pipelines/omnigen
|
||||
title: OmniGen
|
||||
- local: api/pipelines/pag
|
||||
title: PAG
|
||||
- local: api/pipelines/paint_by_example
|
||||
title: Paint by Example
|
||||
- local: api/pipelines/pixart
|
||||
title: PixArt-α
|
||||
- local: api/pipelines/pixart_sigma
|
||||
title: PixArt-Σ
|
||||
- local: api/pipelines/prx
|
||||
title: PRX
|
||||
- local: api/pipelines/qwenimage
|
||||
title: QwenImage
|
||||
- local: api/pipelines/sana
|
||||
title: Sana
|
||||
- local: api/pipelines/sana_sprint
|
||||
title: Sana Sprint
|
||||
- local: api/pipelines/self_attention_guidance
|
||||
title: Self-Attention Guidance
|
||||
- local: api/pipelines/semantic_stable_diffusion
|
||||
title: Semantic Guidance
|
||||
- local: api/pipelines/shap_e
|
||||
title: Shap-E
|
||||
- local: api/pipelines/stable_cascade
|
||||
title: Stable Cascade
|
||||
- sections:
|
||||
- local: api/pipelines/stable_diffusion/overview
|
||||
title: Overview
|
||||
- local: api/pipelines/stable_diffusion/depth2img
|
||||
title: Depth-to-image
|
||||
- local: api/pipelines/stable_diffusion/gligen
|
||||
title: GLIGEN (Grounded Language-to-Image Generation)
|
||||
- local: api/pipelines/stable_diffusion/image_variation
|
||||
title: Image variation
|
||||
- local: api/pipelines/stable_diffusion/img2img
|
||||
title: Image-to-image
|
||||
- local: api/pipelines/stable_diffusion/inpaint
|
||||
title: Inpainting
|
||||
- local: api/pipelines/stable_diffusion/k_diffusion
|
||||
title: K-Diffusion
|
||||
- local: api/pipelines/stable_diffusion/latent_upscale
|
||||
title: Latent upscaler
|
||||
- local: api/pipelines/stable_diffusion/ldm3d_diffusion
|
||||
title: LDM3D Text-to-(RGB, Depth), Text-to-(RGB-pano, Depth-pano), LDM3D
|
||||
Upscaler
|
||||
- local: api/pipelines/stable_diffusion/stable_diffusion_safe
|
||||
title: Safe Stable Diffusion
|
||||
- local: api/pipelines/stable_diffusion/sdxl_turbo
|
||||
title: SDXL Turbo
|
||||
- local: api/pipelines/stable_diffusion/stable_diffusion_2
|
||||
title: Stable Diffusion 2
|
||||
- local: api/pipelines/stable_diffusion/stable_diffusion_3
|
||||
title: Stable Diffusion 3
|
||||
- local: api/pipelines/stable_diffusion/stable_diffusion_xl
|
||||
title: Stable Diffusion XL
|
||||
- local: api/pipelines/stable_diffusion/upscale
|
||||
title: Super-resolution
|
||||
- local: api/pipelines/stable_diffusion/adapter
|
||||
title: T2I-Adapter
|
||||
- local: api/pipelines/stable_diffusion/text2img
|
||||
title: Text-to-image
|
||||
title: Stable Diffusion
|
||||
- local: api/pipelines/stable_unclip
|
||||
title: Stable unCLIP
|
||||
- local: api/pipelines/unclip
|
||||
title: unCLIP
|
||||
- local: api/pipelines/unidiffuser
|
||||
title: UniDiffuser
|
||||
- local: api/pipelines/value_guided_sampling
|
||||
title: Value-guided sampling
|
||||
- local: api/pipelines/visualcloze
|
||||
title: VisualCloze
|
||||
- local: api/pipelines/wuerstchen
|
||||
title: Wuerstchen
|
||||
title: Image
|
||||
- sections:
|
||||
- local: api/pipelines/allegro
|
||||
title: Allegro
|
||||
- local: api/pipelines/cogvideox
|
||||
title: CogVideoX
|
||||
- local: api/pipelines/consisid
|
||||
title: ConsisID
|
||||
- local: api/pipelines/framepack
|
||||
title: Framepack
|
||||
- local: api/pipelines/hunyuan_video
|
||||
title: HunyuanVideo
|
||||
- local: api/pipelines/i2vgenxl
|
||||
title: I2VGen-XL
|
||||
- local: api/pipelines/latte
|
||||
title: Latte
|
||||
- local: api/pipelines/ltx_video
|
||||
title: LTXVideo
|
||||
- local: api/pipelines/mochi
|
||||
title: Mochi
|
||||
- local: api/pipelines/pia
|
||||
title: Personalized Image Animator (PIA)
|
||||
- local: api/pipelines/skyreels_v2
|
||||
title: SkyReels-V2
|
||||
- local: api/pipelines/stable_diffusion/svd
|
||||
title: Image-to-video
|
||||
- local: api/pipelines/stable_diffusion/inpaint
|
||||
title: Inpainting
|
||||
- local: api/pipelines/stable_diffusion/k_diffusion
|
||||
title: K-Diffusion
|
||||
- local: api/pipelines/stable_diffusion/latent_upscale
|
||||
title: Latent upscaler
|
||||
- local: api/pipelines/stable_diffusion/ldm3d_diffusion
|
||||
title: LDM3D Text-to-(RGB, Depth), Text-to-(RGB-pano, Depth-pano), LDM3D Upscaler
|
||||
- local: api/pipelines/stable_diffusion/stable_diffusion_safe
|
||||
title: Safe Stable Diffusion
|
||||
- local: api/pipelines/stable_diffusion/sdxl_turbo
|
||||
title: SDXL Turbo
|
||||
- local: api/pipelines/stable_diffusion/stable_diffusion_2
|
||||
title: Stable Diffusion 2
|
||||
- local: api/pipelines/stable_diffusion/stable_diffusion_3
|
||||
title: Stable Diffusion 3
|
||||
- local: api/pipelines/stable_diffusion/stable_diffusion_xl
|
||||
title: Stable Diffusion XL
|
||||
- local: api/pipelines/stable_diffusion/upscale
|
||||
title: Super-resolution
|
||||
- local: api/pipelines/stable_diffusion/adapter
|
||||
title: T2I-Adapter
|
||||
- local: api/pipelines/stable_diffusion/text2img
|
||||
title: Text-to-image
|
||||
- local: api/pipelines/stable_unclip
|
||||
title: Stable unCLIP
|
||||
- local: api/pipelines/text_to_video
|
||||
title: Text-to-video
|
||||
- local: api/pipelines/text_to_video_zero
|
||||
title: Text2Video-Zero
|
||||
- local: api/pipelines/unclip
|
||||
title: unCLIP
|
||||
- local: api/pipelines/unidiffuser
|
||||
title: UniDiffuser
|
||||
- local: api/pipelines/value_guided_sampling
|
||||
title: Value-guided sampling
|
||||
- local: api/pipelines/visualcloze
|
||||
title: VisualCloze
|
||||
- local: api/pipelines/wan
|
||||
title: Wan
|
||||
- local: api/pipelines/wuerstchen
|
||||
title: Wuerstchen
|
||||
- title: Schedulers
|
||||
sections:
|
||||
title: Stable Video Diffusion
|
||||
- local: api/pipelines/text_to_video
|
||||
title: Text-to-video
|
||||
- local: api/pipelines/text_to_video_zero
|
||||
title: Text2Video-Zero
|
||||
- local: api/pipelines/wan
|
||||
title: Wan
|
||||
title: Video
|
||||
title: Pipelines
|
||||
- sections:
|
||||
- local: api/schedulers/overview
|
||||
title: Overview
|
||||
- local: api/schedulers/cm_stochastic_iterative
|
||||
@@ -718,8 +713,8 @@
|
||||
title: UniPCMultistepScheduler
|
||||
- local: api/schedulers/vq_diffusion
|
||||
title: VQDiffusionScheduler
|
||||
- title: Internal classes
|
||||
sections:
|
||||
title: Schedulers
|
||||
- sections:
|
||||
- local: api/internal_classes_overview
|
||||
title: Overview
|
||||
- local: api/attnprocessor
|
||||
@@ -736,3 +731,5 @@
|
||||
title: VAE Image Processor
|
||||
- local: api/video_processor
|
||||
title: Video Processor
|
||||
title: Internal classes
|
||||
title: API
|
||||
|
||||
@@ -107,6 +107,9 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi
|
||||
|
||||
[[autodoc]] loaders.lora_pipeline.QwenImageLoraLoaderMixin
|
||||
|
||||
## KandinskyLoraLoaderMixin
|
||||
[[autodoc]] loaders.lora_pipeline.KandinskyLoraLoaderMixin
|
||||
|
||||
## LoraBaseMixin
|
||||
|
||||
[[autodoc]] loaders.lora_base.LoraBaseMixin
|
||||
@@ -39,7 +39,7 @@ mask_url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images
|
||||
original_image = load_image(img_url).resize((512, 512))
|
||||
mask_image = load_image(mask_url).resize((512, 512))
|
||||
|
||||
pipe = StableDiffusionInpaintPipeline.from_pretrained("runwayml/stable-diffusion-inpainting")
|
||||
pipe = StableDiffusionInpaintPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-inpainting")
|
||||
pipe.vae = AsymmetricAutoencoderKL.from_pretrained("cross-attention/asymmetric-autoencoder-kl-x-1-5")
|
||||
pipe.to("cuda")
|
||||
|
||||
|
||||
@@ -0,0 +1,131 @@
|
||||
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License. -->
|
||||
|
||||
# PRX
|
||||
|
||||
|
||||
PRX generates high-quality images from text using a simplified MMDIT architecture where text tokens don't update through transformer blocks. It employs flow matching with discrete scheduling for efficient sampling and uses Google's T5Gemma-2B-2B-UL2 model for multi-language text encoding. The ~1.3B parameter transformer delivers fast inference without sacrificing quality. You can choose between Flux VAE (8x compression, 16 latent channels) for balanced quality and speed or DC-AE (32x compression, 32 latent channels) for latent compression and faster processing.
|
||||
|
||||
## Available models
|
||||
|
||||
PRX offers multiple variants with different VAE configurations, each optimized for specific resolutions. Base models excel with detailed prompts, capturing complex compositions and subtle details. Fine-tuned models trained on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) improve aesthetic quality, especially with simpler prompts.
|
||||
|
||||
|
||||
| Model | Resolution | Fine-tuned | Distilled | Description | Suggested prompts | Suggested parameters | Recommended dtype |
|
||||
|:-----:|:-----------------:|:----------:|:----------:|:----------:|:----------:|:----------:|:----------:|
|
||||
| [`Photoroom/prx-256-t2i`](https://huggingface.co/Photoroom/prx-256-t2i)| 256 | No | No | Base model pre-trained at 256 with Flux VAE|Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
|
||||
| [`Photoroom/prx-256-t2i-sft`](https://huggingface.co/Photoroom/prx-256-t2i-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with Flux VAE | Can handle less detailed prompts|28 steps, cfg=5.0| `torch.bfloat16` |
|
||||
| [`Photoroom/prx-512-t2i`](https://huggingface.co/Photoroom/prx-512-t2i)| 512 | No | No | Base model pre-trained at 512 with Flux VAE |Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
|
||||
| [`Photoroom/prx-512-t2i-sft`](https://huggingface.co/Photoroom/prx-512-t2i-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with Flux VAE | Can handle less detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
|
||||
| [`Photoroom/prx-512-t2i-sft-distilled`](https://huggingface.co/Photoroom/prx-512-t2i-sft-distilled)| 512 | Yes | Yes | 8-step distilled model from [`Photoroom/prx-512-t2i-sft`](https://huggingface.co/Photoroom/prx-512-t2i-sft) | Can handle less detailed prompts in natural language|8 steps, cfg=1.0| `torch.bfloat16` |
|
||||
| [`Photoroom/prx-512-t2i-dc-ae`](https://huggingface.co/Photoroom/prx-512-t2i-dc-ae)| 512 | No | No | Base model pre-trained at 512 with [Deep Compression Autoencoder (DC-AE)](https://hanlab.mit.edu/projects/dc-ae)|Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
|
||||
| [`Photoroom/prx-512-t2i-dc-ae-sft`](https://huggingface.co/Photoroom/prx-512-t2i-dc-ae-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with [Deep Compression Autoencoder (DC-AE)](https://hanlab.mit.edu/projects/dc-ae) | Can handle less detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
|
||||
| [`Photoroom/prx-512-t2i-dc-ae-sft-distilled`](https://huggingface.co/Photoroom/prx-512-t2i-dc-ae-sft-distilled)| 512 | Yes | Yes | 8-step distilled model from [`Photoroom/prx-512-t2i-dc-ae-sft-distilled`](https://huggingface.co/Photoroom/prx-512-t2i-dc-ae-sft-distilled) | Can handle less detailed prompts in natural language|8 steps, cfg=1.0| `torch.bfloat16` |s
|
||||
|
||||
Refer to [this](https://huggingface.co/collections/Photoroom/prx-models-68e66254c202ebfab99ad38e) collection for more information.
|
||||
|
||||
## Loading the pipeline
|
||||
|
||||
Load the pipeline with [`~DiffusionPipeline.from_pretrained`].
|
||||
|
||||
```py
|
||||
from diffusers.pipelines.prx import PRXPipeline
|
||||
|
||||
# Load pipeline - VAE and text encoder will be loaded from HuggingFace
|
||||
pipe = PRXPipeline.from_pretrained("Photoroom/prx-512-t2i-sft", torch_dtype=torch.bfloat16)
|
||||
pipe.to("cuda")
|
||||
|
||||
prompt = "A front-facing portrait of a lion the golden savanna at sunset."
|
||||
image = pipe(prompt, num_inference_steps=28, guidance_scale=5.0).images[0]
|
||||
image.save("prx_output.png")
|
||||
```
|
||||
|
||||
### Manual Component Loading
|
||||
|
||||
Load components individually to customize the pipeline for instance to use quantized models.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers.pipelines.prx import PRXPipeline
|
||||
from diffusers.models import AutoencoderKL, AutoencoderDC
|
||||
from diffusers.models.transformers.transformer_prx import PRXTransformer2DModel
|
||||
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from transformers import T5GemmaModel, GemmaTokenizerFast
|
||||
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
|
||||
from transformers import BitsAndBytesConfig as BitsAndBytesConfig
|
||||
|
||||
quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
|
||||
# Load transformer
|
||||
transformer = PRXTransformer2DModel.from_pretrained(
|
||||
"checkpoints/prx-512-t2i-sft",
|
||||
subfolder="transformer",
|
||||
quantization_config=quant_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
# Load scheduler
|
||||
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
|
||||
"checkpoints/prx-512-t2i-sft", subfolder="scheduler"
|
||||
)
|
||||
|
||||
# Load T5Gemma text encoder
|
||||
t5gemma_model = T5GemmaModel.from_pretrained("google/t5gemma-2b-2b-ul2",
|
||||
quantization_config=quant_config,
|
||||
torch_dtype=torch.bfloat16)
|
||||
text_encoder = t5gemma_model.encoder.to(dtype=torch.bfloat16)
|
||||
tokenizer = GemmaTokenizerFast.from_pretrained("google/t5gemma-2b-2b-ul2")
|
||||
tokenizer.model_max_length = 256
|
||||
|
||||
# Load VAE - choose either Flux VAE or DC-AE
|
||||
# Flux VAE
|
||||
vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev",
|
||||
subfolder="vae",
|
||||
quantization_config=quant_config,
|
||||
torch_dtype=torch.bfloat16)
|
||||
|
||||
pipe = PRXPipeline(
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
vae=vae
|
||||
)
|
||||
pipe.to("cuda")
|
||||
```
|
||||
|
||||
|
||||
## Memory Optimization
|
||||
|
||||
For memory-constrained environments:
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers.pipelines.prx import PRXPipeline
|
||||
|
||||
pipe = PRXPipeline.from_pretrained("Photoroom/prx-512-t2i-sft", torch_dtype=torch.bfloat16)
|
||||
pipe.enable_model_cpu_offload() # Offload components to CPU when not in use
|
||||
|
||||
# Or use sequential CPU offload for even lower memory
|
||||
pipe.enable_sequential_cpu_offload()
|
||||
```
|
||||
|
||||
## PRXPipeline
|
||||
|
||||
[[autodoc]] PRXPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## PRXPipelineOutput
|
||||
|
||||
[[autodoc]] pipelines.prx.pipeline_output.PRXPipelineOutput
|
||||
@@ -21,7 +21,7 @@ The Stable Diffusion model can also infer depth based on an image using [MiDaS](
|
||||
> [!TIP]
|
||||
> Make sure to check out the Stable Diffusion [Tips](overview#tips) section to learn how to explore the tradeoff between scheduler speed and quality, and how to reuse pipeline components efficiently!
|
||||
>
|
||||
> If you're interested in using one of the official checkpoints for a task, explore the [CompVis](https://huggingface.co/CompVis), [Runway](https://huggingface.co/runwayml), and [Stability AI](https://huggingface.co/stabilityai) Hub organizations!
|
||||
> If you're interested in using one of the official checkpoints for a task, explore the [CompVis](https://huggingface.co/CompVis) and [Stability AI](https://huggingface.co/stabilityai) Hub organizations!
|
||||
|
||||
## StableDiffusionDepth2ImgPipeline
|
||||
|
||||
|
||||
@@ -21,14 +21,14 @@ The Stable Diffusion model can also be applied to inpainting which lets you edit
|
||||
## Tips
|
||||
|
||||
It is recommended to use this pipeline with checkpoints that have been specifically fine-tuned for inpainting, such
|
||||
as [runwayml/stable-diffusion-inpainting](https://huggingface.co/runwayml/stable-diffusion-inpainting). Default
|
||||
as [stable-diffusion-v1-5/stable-diffusion-inpainting](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-inpainting). Default
|
||||
text-to-image Stable Diffusion checkpoints, such as
|
||||
[stable-diffusion-v1-5/stable-diffusion-v1-5](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) are also compatible but they might be less performant.
|
||||
|
||||
> [!TIP]
|
||||
> Make sure to check out the Stable Diffusion [Tips](overview#tips) section to learn how to explore the tradeoff between scheduler speed and quality, and how to reuse pipeline components efficiently!
|
||||
>
|
||||
> If you're interested in using one of the official checkpoints for a task, explore the [CompVis](https://huggingface.co/CompVis), [Runway](https://huggingface.co/runwayml), and [Stability AI](https://huggingface.co/stabilityai) Hub organizations!
|
||||
> If you're interested in using one of the official checkpoints for a task, explore the [CompVis](https://huggingface.co/CompVis) and [Stability AI](https://huggingface.co/stabilityai) Hub organizations!
|
||||
|
||||
## StableDiffusionInpaintPipeline
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ The Stable Diffusion latent upscaler model was created by [Katherine Crowson](ht
|
||||
> [!TIP]
|
||||
> Make sure to check out the Stable Diffusion [Tips](overview#tips) section to learn how to explore the tradeoff between scheduler speed and quality, and how to reuse pipeline components efficiently!
|
||||
>
|
||||
> If you're interested in using one of the official checkpoints for a task, explore the [CompVis](https://huggingface.co/CompVis), [Runway](https://huggingface.co/runwayml), and [Stability AI](https://huggingface.co/stabilityai) Hub organizations!
|
||||
> If you're interested in using one of the official checkpoints for a task, explore the [CompVis](https://huggingface.co/CompVis) and [Stability AI](https://huggingface.co/stabilityai) Hub organizations!
|
||||
|
||||
## StableDiffusionLatentUpscalePipeline
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ Stable Diffusion is trained on 512x512 images from a subset of the LAION-5B data
|
||||
|
||||
For more details about how Stable Diffusion works and how it differs from the base latent diffusion model, take a look at the Stability AI [announcement](https://stability.ai/blog/stable-diffusion-announcement) and our own [blog post](https://huggingface.co/blog/stable_diffusion#how-does-stable-diffusion-work) for more technical details.
|
||||
|
||||
You can find the original codebase for Stable Diffusion v1.0 at [CompVis/stable-diffusion](https://github.com/CompVis/stable-diffusion) and Stable Diffusion v2.0 at [Stability-AI/stablediffusion](https://github.com/Stability-AI/stablediffusion) as well as their original scripts for various tasks. Additional official checkpoints for the different Stable Diffusion versions and tasks can be found on the [CompVis](https://huggingface.co/CompVis), [Runway](https://huggingface.co/runwayml), and [Stability AI](https://huggingface.co/stabilityai) Hub organizations. Explore these organizations to find the best checkpoint for your use-case!
|
||||
You can find the original codebase for Stable Diffusion v1.0 at [CompVis/stable-diffusion](https://github.com/CompVis/stable-diffusion) and Stable Diffusion v2.0 at [Stability-AI/stablediffusion](https://github.com/Stability-AI/stablediffusion) as well as their original scripts for various tasks. Additional official checkpoints for the different Stable Diffusion versions and tasks can be found on the [CompVis](https://huggingface.co/CompVis) and [Stability AI](https://huggingface.co/stabilityai) Hub organizations. Explore these organizations to find the best checkpoint for your use-case!
|
||||
|
||||
The table below summarizes the available Stable Diffusion pipelines, their supported tasks, and an interactive demo:
|
||||
|
||||
@@ -64,7 +64,7 @@ The table below summarizes the available Stable Diffusion pipelines, their suppo
|
||||
<a href="./inpaint">StableDiffusionInpaint</a>
|
||||
</td>
|
||||
<td class="px-4 py-2 text-gray-700">inpainting</td>
|
||||
<td class="px-4 py-2"><a href="https://huggingface.co/spaces/runwayml/stable-diffusion-inpainting"><img src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue"/></a>
|
||||
<td class="px-4 py-2"><a href="https://huggingface.co/spaces/stable-diffusion-v1-5/stable-diffusion-inpainting"><img src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue"/></a>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
|
||||
@@ -36,7 +36,7 @@ Here are some examples for how to use Stable Diffusion 2 for each task:
|
||||
> [!TIP]
|
||||
> Make sure to check out the Stable Diffusion [Tips](overview#tips) section to learn how to explore the tradeoff between scheduler speed and quality, and how to reuse pipeline components efficiently!
|
||||
>
|
||||
> If you're interested in using one of the official checkpoints for a task, explore the [CompVis](https://huggingface.co/CompVis), [Runway](https://huggingface.co/runwayml), and [Stability AI](https://huggingface.co/stabilityai) Hub organizations!
|
||||
> If you're interested in using one of the official checkpoints for a task, explore the [CompVis](https://huggingface.co/CompVis) and [Stability AI](https://huggingface.co/stabilityai) Hub organizations!
|
||||
|
||||
## Text-to-image
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@ The abstract from the paper is:
|
||||
> [!TIP]
|
||||
> Make sure to check out the Stable Diffusion [Tips](overview#tips) section to learn how to explore the tradeoff between scheduler speed and quality, and how to reuse pipeline components efficiently!
|
||||
>
|
||||
> If you're interested in using one of the official checkpoints for a task, explore the [CompVis](https://huggingface.co/CompVis), [Runway](https://huggingface.co/runwayml), and [Stability AI](https://huggingface.co/stabilityai) Hub organizations!
|
||||
> If you're interested in using one of the official checkpoints for a task, explore the [CompVis](https://huggingface.co/CompVis) and [Stability AI](https://huggingface.co/stabilityai) Hub organizations!
|
||||
|
||||
## StableDiffusionPipeline
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ The Stable Diffusion upscaler diffusion model was created by the researchers and
|
||||
> [!TIP]
|
||||
> Make sure to check out the Stable Diffusion [Tips](overview#tips) section to learn how to explore the tradeoff between scheduler speed and quality, and how to reuse pipeline components efficiently!
|
||||
>
|
||||
> If you're interested in using one of the official checkpoints for a task, explore the [CompVis](https://huggingface.co/CompVis), [Runway](https://huggingface.co/runwayml), and [Stability AI](https://huggingface.co/stabilityai) Hub organizations!
|
||||
> If you're interested in using one of the official checkpoints for a task, explore the [CompVis](https://huggingface.co/CompVis) and [Stability AI](https://huggingface.co/stabilityai) Hub organizations!
|
||||
|
||||
## StableDiffusionUpscalePipeline
|
||||
|
||||
|
||||
@@ -16,12 +16,12 @@ pipeline.unet.config["in_channels"]
|
||||
4
|
||||
```
|
||||
|
||||
Inpainting requires 9 channels in the input sample. You can check this value in a pretrained inpainting model like [`runwayml/stable-diffusion-inpainting`](https://huggingface.co/runwayml/stable-diffusion-inpainting):
|
||||
Inpainting requires 9 channels in the input sample. You can check this value in a pretrained inpainting model like [`stable-diffusion-v1-5/stable-diffusion-inpainting`](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-inpainting):
|
||||
|
||||
```py
|
||||
from diffusers import StableDiffusionPipeline
|
||||
|
||||
pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-inpainting", use_safetensors=True)
|
||||
pipeline = StableDiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-inpainting", use_safetensors=True)
|
||||
pipeline.unet.config["in_channels"]
|
||||
9
|
||||
```
|
||||
|
||||
@@ -215,7 +215,7 @@ from diffusers import AutoPipelineForInpainting, LCMScheduler
|
||||
from diffusers.utils import load_image, make_image_grid
|
||||
|
||||
pipe = AutoPipelineForInpainting.from_pretrained(
|
||||
"runwayml/stable-diffusion-inpainting",
|
||||
"stable-diffusion-v1-5/stable-diffusion-inpainting",
|
||||
torch_dtype=torch.float16,
|
||||
variant="fp16",
|
||||
).to("cuda")
|
||||
|
||||
@@ -112,7 +112,7 @@ blurred_mask
|
||||
|
||||
## Popular models
|
||||
|
||||
[Stable Diffusion Inpainting](https://huggingface.co/runwayml/stable-diffusion-inpainting), [Stable Diffusion XL (SDXL) Inpainting](https://huggingface.co/diffusers/stable-diffusion-xl-1.0-inpainting-0.1), and [Kandinsky 2.2 Inpainting](https://huggingface.co/kandinsky-community/kandinsky-2-2-decoder-inpaint) are among the most popular models for inpainting. SDXL typically produces higher resolution images than Stable Diffusion v1.5, and Kandinsky 2.2 is also capable of generating high-quality images.
|
||||
[Stable Diffusion Inpainting](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-inpainting), [Stable Diffusion XL (SDXL) Inpainting](https://huggingface.co/diffusers/stable-diffusion-xl-1.0-inpainting-0.1), and [Kandinsky 2.2 Inpainting](https://huggingface.co/kandinsky-community/kandinsky-2-2-decoder-inpaint) are among the most popular models for inpainting. SDXL typically produces higher resolution images than Stable Diffusion v1.5, and Kandinsky 2.2 is also capable of generating high-quality images.
|
||||
|
||||
### Stable Diffusion Inpainting
|
||||
|
||||
@@ -124,7 +124,7 @@ from diffusers import AutoPipelineForInpainting
|
||||
from diffusers.utils import load_image, make_image_grid
|
||||
|
||||
pipeline = AutoPipelineForInpainting.from_pretrained(
|
||||
"runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16, variant="fp16"
|
||||
"stable-diffusion-v1-5/stable-diffusion-inpainting", torch_dtype=torch.float16, variant="fp16"
|
||||
)
|
||||
pipeline.enable_model_cpu_offload()
|
||||
# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
|
||||
@@ -244,7 +244,7 @@ make_image_grid([init_image, image], rows=1, cols=2)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="runwayml/stable-diffusion-inpainting">
|
||||
<hfoption id="stable-diffusion-v1-5/stable-diffusion-inpainting">
|
||||
|
||||
```py
|
||||
import torch
|
||||
@@ -252,7 +252,7 @@ from diffusers import AutoPipelineForInpainting
|
||||
from diffusers.utils import load_image, make_image_grid
|
||||
|
||||
pipeline = AutoPipelineForInpainting.from_pretrained(
|
||||
"runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16, variant="fp16"
|
||||
"stable-diffusion-v1-5/stable-diffusion-inpainting", torch_dtype=torch.float16, variant="fp16"
|
||||
)
|
||||
pipeline.enable_model_cpu_offload()
|
||||
# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
|
||||
@@ -278,7 +278,7 @@ make_image_grid([init_image, image], rows=1, cols=2)
|
||||
</div>
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint-specific.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">runwayml/stable-diffusion-inpainting</figcaption>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">stable-diffusion-v1-5/stable-diffusion-inpainting</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -308,7 +308,7 @@ make_image_grid([init_image, image], rows=1, cols=2)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="runwayml/stable-diffusion-inpaint">
|
||||
<hfoption id="stable-diffusion-v1-5/stable-diffusion-inpaint">
|
||||
|
||||
```py
|
||||
import torch
|
||||
@@ -316,7 +316,7 @@ from diffusers import AutoPipelineForInpainting
|
||||
from diffusers.utils import load_image, make_image_grid
|
||||
|
||||
pipeline = AutoPipelineForInpainting.from_pretrained(
|
||||
"runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16, variant="fp16"
|
||||
"stable-diffusion-v1-5/stable-diffusion-inpainting", torch_dtype=torch.float16, variant="fp16"
|
||||
)
|
||||
pipeline.enable_model_cpu_offload()
|
||||
# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
|
||||
@@ -340,7 +340,7 @@ make_image_grid([init_image, image], rows=1, cols=2)
|
||||
</div>
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/specific-inpaint-basic.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">runwayml/stable-diffusion-inpainting</figcaption>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">stable-diffusion-v1-5/stable-diffusion-inpainting</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -358,7 +358,7 @@ from diffusers.utils import load_image, make_image_grid
|
||||
|
||||
device = "cuda"
|
||||
pipeline = AutoPipelineForInpainting.from_pretrained(
|
||||
"runwayml/stable-diffusion-inpainting",
|
||||
"stable-diffusion-v1-5/stable-diffusion-inpainting",
|
||||
torch_dtype=torch.float16,
|
||||
variant="fp16"
|
||||
)
|
||||
@@ -396,7 +396,7 @@ from diffusers import AutoPipelineForInpainting
|
||||
from diffusers.utils import load_image, make_image_grid
|
||||
|
||||
pipeline = AutoPipelineForInpainting.from_pretrained(
|
||||
"runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16, variant="fp16"
|
||||
"stable-diffusion-v1-5/stable-diffusion-inpainting", torch_dtype=torch.float16, variant="fp16"
|
||||
)
|
||||
pipeline.enable_model_cpu_offload()
|
||||
# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
|
||||
@@ -441,7 +441,7 @@ from diffusers import AutoPipelineForInpainting
|
||||
from diffusers.utils import load_image, make_image_grid
|
||||
|
||||
pipeline = AutoPipelineForInpainting.from_pretrained(
|
||||
"runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16, variant="fp16"
|
||||
"stable-diffusion-v1-5/stable-diffusion-inpainting", torch_dtype=torch.float16, variant="fp16"
|
||||
)
|
||||
pipeline.enable_model_cpu_offload()
|
||||
# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
|
||||
@@ -481,7 +481,7 @@ from diffusers import AutoPipelineForInpainting
|
||||
from diffusers.utils import load_image, make_image_grid
|
||||
|
||||
pipeline = AutoPipelineForInpainting.from_pretrained(
|
||||
"runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16, variant="fp16"
|
||||
"stable-diffusion-v1-5/stable-diffusion-inpainting", torch_dtype=torch.float16, variant="fp16"
|
||||
)
|
||||
pipeline.enable_model_cpu_offload()
|
||||
# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
|
||||
@@ -606,7 +606,7 @@ from diffusers import AutoPipelineForInpainting, AutoPipelineForImage2Image
|
||||
from diffusers.utils import load_image, make_image_grid
|
||||
|
||||
pipeline = AutoPipelineForInpainting.from_pretrained(
|
||||
"runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16, variant="fp16"
|
||||
"stable-diffusion-v1-5/stable-diffusion-inpainting", torch_dtype=torch.float16, variant="fp16"
|
||||
)
|
||||
pipeline.enable_model_cpu_offload()
|
||||
# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
|
||||
@@ -683,7 +683,7 @@ from diffusers import AutoPipelineForInpainting
|
||||
from diffusers.utils import make_image_grid
|
||||
|
||||
pipeline = AutoPipelineForInpainting.from_pretrained(
|
||||
"runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16,
|
||||
"stable-diffusion-v1-5/stable-diffusion-inpainting", torch_dtype=torch.float16,
|
||||
)
|
||||
pipeline.enable_model_cpu_offload()
|
||||
# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
|
||||
@@ -714,7 +714,7 @@ controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_inpai
|
||||
|
||||
# pass ControlNet to the pipeline
|
||||
pipeline = StableDiffusionControlNetInpaintPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-inpainting", controlnet=controlnet, torch_dtype=torch.float16, variant="fp16"
|
||||
"stable-diffusion-v1-5/stable-diffusion-inpainting", controlnet=controlnet, torch_dtype=torch.float16, variant="fp16"
|
||||
)
|
||||
pipeline.enable_model_cpu_offload()
|
||||
# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
|
||||
|
||||
@@ -173,7 +173,7 @@ mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data
|
||||
init_image = download_image(img_url).resize((512, 512))
|
||||
mask_image = download_image(mask_url).resize((512, 512))
|
||||
|
||||
path = "runwayml/stable-diffusion-inpainting"
|
||||
path = "stable-diffusion-v1-5/stable-diffusion-inpainting"
|
||||
|
||||
run_compile = True # Set True / False
|
||||
|
||||
|
||||
@@ -28,12 +28,12 @@ pipeline.unet.config["in_channels"]
|
||||
4
|
||||
```
|
||||
|
||||
인페인팅은 입력 샘플에 9개의 채널이 필요합니다. [`runwayml/stable-diffusion-inpainting`](https://huggingface.co/runwayml/stable-diffusion-inpainting)와 같은 사전학습된 인페인팅 모델에서 이 값을 확인할 수 있습니다:
|
||||
인페인팅은 입력 샘플에 9개의 채널이 필요합니다. [`stable-diffusion-v1-5/stable-diffusion-inpainting`](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-inpainting)와 같은 사전학습된 인페인팅 모델에서 이 값을 확인할 수 있습니다:
|
||||
|
||||
```py
|
||||
from diffusers import StableDiffusionPipeline
|
||||
|
||||
pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-inpainting")
|
||||
pipeline = StableDiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-inpainting")
|
||||
pipeline.unet.config["in_channels"]
|
||||
9
|
||||
```
|
||||
|
||||
@@ -14,7 +14,7 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
[[open-in-colab]]
|
||||
|
||||
[`StableDiffusionInpaintPipeline`]은 마스크와 텍스트 프롬프트를 제공하여 이미지의 특정 부분을 편집할 수 있도록 합니다. 이 기능은 인페인팅 작업을 위해 특별히 훈련된 [`runwayml/stable-diffusion-inpainting`](https://huggingface.co/runwayml/stable-diffusion-inpainting)과 같은 Stable Diffusion 버전을 사용합니다.
|
||||
[`StableDiffusionInpaintPipeline`]은 마스크와 텍스트 프롬프트를 제공하여 이미지의 특정 부분을 편집할 수 있도록 합니다. 이 기능은 인페인팅 작업을 위해 특별히 훈련된 [`stable-diffusion-v1-5/stable-diffusion-inpainting`](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-inpainting)과 같은 Stable Diffusion 버전을 사용합니다.
|
||||
|
||||
먼저 [`StableDiffusionInpaintPipeline`] 인스턴스를 불러옵니다:
|
||||
|
||||
@@ -27,7 +27,7 @@ from io import BytesIO
|
||||
from diffusers import StableDiffusionInpaintPipeline
|
||||
|
||||
pipeline = StableDiffusionInpaintPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-inpainting",
|
||||
"stable-diffusion-v1-5/stable-diffusion-inpainting",
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
pipeline = pipeline.to("cuda")
|
||||
@@ -61,12 +61,3 @@ image = pipe(prompt=prompt, image=init_image, mask_image=mask_image).images[0]
|
||||
|
||||
> [!WARNING]
|
||||
> 이전의 실험적인 인페인팅 구현에서는 품질이 낮은 다른 프로세스를 사용했습니다. 이전 버전과의 호환성을 보장하기 위해 새 모델이 포함되지 않은 사전학습된 파이프라인을 불러오면 이전 인페인팅 방법이 계속 적용됩니다.
|
||||
|
||||
아래 Space에서 이미지 인페인팅을 직접 해보세요!
|
||||
|
||||
<iframe
|
||||
src="https://runwayml-stable-diffusion-inpainting.hf.space"
|
||||
frameborder="0"
|
||||
width="850"
|
||||
height="500"
|
||||
></iframe>
|
||||
|
||||
@@ -16,12 +16,12 @@ pipeline.unet.config["in_channels"]
|
||||
4
|
||||
```
|
||||
|
||||
而图像修复任务需要输入样本具有9个通道。您可以在 [`runwayml/stable-diffusion-inpainting`](https://huggingface.co/runwayml/stable-diffusion-inpainting) 这样的预训练修复模型中验证此参数:
|
||||
而图像修复任务需要输入样本具有9个通道。您可以在 [`stable-diffusion-v1-5/stable-diffusion-inpainting`](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-inpainting) 这样的预训练修复模型中验证此参数:
|
||||
|
||||
```python
|
||||
from diffusers import StableDiffusionPipeline
|
||||
|
||||
pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-inpainting", use_safetensors=True)
|
||||
pipeline = StableDiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-inpainting", use_safetensors=True)
|
||||
pipeline.unet.config["in_channels"]
|
||||
9
|
||||
```
|
||||
|
||||
@@ -1328,7 +1328,7 @@ model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined"
|
||||
|
||||
# Load Stable Diffusion Inpainting Pipeline with custom pipeline
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-inpainting",
|
||||
"stable-diffusion-v1-5/stable-diffusion-inpainting",
|
||||
custom_pipeline="text_inpainting",
|
||||
segmentation_model=model,
|
||||
segmentation_processor=processor
|
||||
|
||||
@@ -126,7 +126,7 @@ EXAMPLE_DOC_STRING = """
|
||||
... "lllyasviel/control_v11p_sd15_inpaint", torch_dtype=torch.float16
|
||||
... )
|
||||
>>> pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
|
||||
... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
|
||||
... "stable-diffusion-v1-5/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
|
||||
... )
|
||||
|
||||
>>> pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
|
||||
@@ -347,7 +347,7 @@ class AdaptiveMaskInpaintPipeline(
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
safety_checker ([`StableDiffusionSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||||
Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
|
||||
Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for more details
|
||||
about a model's potential harms.
|
||||
feature_extractor ([`~transformers.CLIPImageProcessor`]):
|
||||
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
|
||||
@@ -429,8 +429,8 @@ class AdaptiveMaskInpaintPipeline(
|
||||
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
||||
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
|
||||
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
|
||||
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
|
||||
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
|
||||
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5"
|
||||
" \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
|
||||
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
|
||||
" in the config might lead to incorrect results in future versions. If you have downloaded this"
|
||||
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
|
||||
@@ -970,7 +970,7 @@ class AdaptiveMaskInpaintPipeline(
|
||||
>>> default_mask_image = download_image(mask_url).resize((512, 512))
|
||||
|
||||
>>> pipe = AdaptiveMaskInpaintPipeline.from_pretrained(
|
||||
... "runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16
|
||||
... "stable-diffusion-v1-5/stable-diffusion-inpainting", torch_dtype=torch.float16
|
||||
... )
|
||||
>>> pipe = pipe.to("cuda")
|
||||
|
||||
@@ -1095,7 +1095,7 @@ class AdaptiveMaskInpaintPipeline(
|
||||
|
||||
# 8. Check that sizes of mask, masked image and latents match
|
||||
if num_channels_unet == 9:
|
||||
# default case for runwayml/stable-diffusion-inpainting
|
||||
# default case for stable-diffusion-v1-5/stable-diffusion-inpainting
|
||||
num_channels_mask = mask.shape[1]
|
||||
num_channels_masked_image = masked_image_latents.shape[1]
|
||||
if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:
|
||||
|
||||
@@ -62,7 +62,7 @@ class ComposableStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin)
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
safety_checker ([`StableDiffusionSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||||
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
|
||||
Please, refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for details.
|
||||
feature_extractor ([`CLIPImageProcessor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
@@ -145,8 +145,8 @@ class ComposableStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin)
|
||||
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
||||
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
|
||||
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
|
||||
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
|
||||
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
|
||||
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5"
|
||||
" \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
|
||||
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
|
||||
" in the config might lead to incorrect results in future versions. If you have downloaded this"
|
||||
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
|
||||
|
||||
@@ -1276,7 +1276,7 @@ class FrescoV2VPipeline(StableDiffusionControlNetImg2ImgPipeline):
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
safety_checker ([`StableDiffusionSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||||
Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
|
||||
Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for more details
|
||||
about a model's potential harms.
|
||||
feature_extractor ([`~transformers.CLIPImageProcessor`]):
|
||||
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
|
||||
|
||||
@@ -678,7 +678,7 @@ class StableDiffusionHDPainterPipeline(StableDiffusionInpaintPipeline):
|
||||
|
||||
# 8. Check that sizes of mask, masked image and latents match
|
||||
if num_channels_unet == 9:
|
||||
# default case for runwayml/stable-diffusion-inpainting
|
||||
# default case for stable-diffusion-v1-5/stable-diffusion-inpainting
|
||||
num_channels_mask = mask.shape[1]
|
||||
num_channels_masked_image = masked_image_latents.shape[1]
|
||||
if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:
|
||||
|
||||
@@ -78,7 +78,7 @@ class ImageToImageInpaintingPipeline(DiffusionPipeline):
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
safety_checker ([`StableDiffusionSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||||
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
|
||||
Please, refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for details.
|
||||
feature_extractor ([`CLIPImageProcessor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
|
||||
@@ -86,7 +86,7 @@ class InstaFlowPipeline(
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
safety_checker ([`StableDiffusionSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||||
Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
|
||||
Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for more details
|
||||
about a model's potential harms.
|
||||
feature_extractor ([`~transformers.CLIPImageProcessor`]):
|
||||
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
|
||||
@@ -165,8 +165,8 @@ class InstaFlowPipeline(
|
||||
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
||||
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
|
||||
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
|
||||
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
|
||||
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
|
||||
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5"
|
||||
" \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
|
||||
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
|
||||
" in the config might lead to incorrect results in future versions. If you have downloaded this"
|
||||
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
|
||||
|
||||
@@ -166,7 +166,7 @@ class IPAdapterFaceIDStableDiffusionPipeline(
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
safety_checker ([`StableDiffusionSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||||
Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
|
||||
Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for more details
|
||||
about a model's potential harms.
|
||||
feature_extractor ([`~transformers.CLIPImageProcessor`]):
|
||||
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
|
||||
@@ -247,8 +247,8 @@ class IPAdapterFaceIDStableDiffusionPipeline(
|
||||
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
||||
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
|
||||
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
|
||||
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
|
||||
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
|
||||
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5"
|
||||
" \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
|
||||
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
|
||||
" in the config might lead to incorrect results in future versions. If you have downloaded this"
|
||||
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
|
||||
|
||||
@@ -414,7 +414,7 @@ class StableDiffusionHighResFixPipeline(StableDiffusionPipeline):
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
safety_checker ([`StableDiffusionSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||||
Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
|
||||
Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for more details
|
||||
about a model's potential harms.
|
||||
feature_extractor ([`~transformers.CLIPImageProcessor`]):
|
||||
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
|
||||
|
||||
@@ -222,7 +222,7 @@ class LatentConsistencyModelWalkPipeline(
|
||||
supports [`LCMScheduler`].
|
||||
safety_checker ([`StableDiffusionSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||||
Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
|
||||
Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for more details
|
||||
about a model's potential harms.
|
||||
feature_extractor ([`~transformers.CLIPImageProcessor`]):
|
||||
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
|
||||
|
||||
@@ -302,7 +302,7 @@ class LLMGroundedDiffusionPipeline(
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
safety_checker ([`StableDiffusionSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||||
Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
|
||||
Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for more details
|
||||
about a model's potential harms.
|
||||
feature_extractor ([`~transformers.CLIPImageProcessor`]):
|
||||
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
|
||||
@@ -392,8 +392,8 @@ class LLMGroundedDiffusionPipeline(
|
||||
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
||||
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
|
||||
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
|
||||
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
|
||||
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
|
||||
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5"
|
||||
" \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
|
||||
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
|
||||
" in the config might lead to incorrect results in future versions. If you have downloaded this"
|
||||
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
|
||||
|
||||
@@ -552,8 +552,8 @@ class StableDiffusionLongPromptWeightingPipeline(
|
||||
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
||||
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
|
||||
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
|
||||
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
|
||||
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
|
||||
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5"
|
||||
" \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
|
||||
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
|
||||
" in the config might lead to incorrect results in future versions. If you have downloaded this"
|
||||
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
|
||||
|
||||
@@ -1765,7 +1765,7 @@ class SDXLLongPromptWeightingPipeline(
|
||||
|
||||
# Check that sizes of mask, masked image and latents match
|
||||
if num_channels_unet == 9:
|
||||
# default case for runwayml/stable-diffusion-inpainting
|
||||
# default case for stable-diffusion-v1-5/stable-diffusion-inpainting
|
||||
num_channels_mask = mask.shape[1]
|
||||
num_channels_masked_image = masked_image_latents.shape[1]
|
||||
if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet:
|
||||
|
||||
@@ -3729,8 +3729,8 @@ class MatryoshkaPipeline(
|
||||
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
||||
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
|
||||
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
|
||||
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
|
||||
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
|
||||
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5"
|
||||
" \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
|
||||
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
|
||||
" in the config might lead to incorrect results in future versions. If you have downloaded this"
|
||||
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
|
||||
|
||||
@@ -78,7 +78,7 @@ class MultilingualStableDiffusion(DiffusionPipeline, StableDiffusionMixin):
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
safety_checker ([`StableDiffusionSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||||
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
|
||||
Please, refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for details.
|
||||
feature_extractor ([`CLIPImageProcessor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
|
||||
@@ -1607,7 +1607,7 @@ class KolorsControlNetInpaintPipeline(
|
||||
|
||||
# 9. Check that sizes of mask, masked image and latents match
|
||||
if num_channels_unet == 9:
|
||||
# default case for runwayml/stable-diffusion-inpainting
|
||||
# default case for stable-diffusion-v1-5/stable-diffusion-inpainting
|
||||
num_channels_mask = mask.shape[1]
|
||||
num_channels_masked_image = masked_image_latents.shape[1]
|
||||
if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:
|
||||
|
||||
@@ -135,7 +135,7 @@ class FabricPipeline(DiffusionPipeline):
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
safety_checker ([`StableDiffusionSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||||
Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
|
||||
Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for more details
|
||||
about a model's potential harms.
|
||||
"""
|
||||
|
||||
@@ -163,8 +163,8 @@ class FabricPipeline(DiffusionPipeline):
|
||||
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
||||
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
|
||||
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
|
||||
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
|
||||
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
|
||||
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5"
|
||||
" \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
|
||||
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
|
||||
" in the config might lead to incorrect results in future versions. If you have downloaded this"
|
||||
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
|
||||
|
||||
@@ -1487,7 +1487,7 @@ class KolorsInpaintPipeline(
|
||||
|
||||
# 8. Check that sizes of mask, masked image and latents match
|
||||
if num_channels_unet == 9:
|
||||
# default case for runwayml/stable-diffusion-inpainting
|
||||
# default case for stable-diffusion-v1-5/stable-diffusion-inpainting
|
||||
num_channels_mask = mask.shape[1]
|
||||
num_channels_masked_image = masked_image_latents.shape[1]
|
||||
if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:
|
||||
|
||||
@@ -106,7 +106,7 @@ class Prompt2PromptPipeline(
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
safety_checker ([`StableDiffusionSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||||
Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
|
||||
Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for more details
|
||||
about a model's potential harms.
|
||||
feature_extractor ([`~transformers.CLIPImageProcessor`]):
|
||||
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
|
||||
@@ -187,8 +187,8 @@ class Prompt2PromptPipeline(
|
||||
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
||||
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
|
||||
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
|
||||
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
|
||||
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
|
||||
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5"
|
||||
" \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
|
||||
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
|
||||
" in the config might lead to incorrect results in future versions. If you have downloaded this"
|
||||
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
|
||||
|
||||
@@ -1730,7 +1730,7 @@ class StyleAlignedSDXLPipeline(
|
||||
|
||||
# Check that sizes of mask, masked image and latents match
|
||||
if num_channels_unet == 9:
|
||||
# default case for runwayml/stable-diffusion-inpainting
|
||||
# default case for stable-diffusion-v1-5/stable-diffusion-inpainting
|
||||
num_channels_mask = mask.shape[1]
|
||||
num_channels_masked_image = masked_image_latents.shape[1]
|
||||
if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet:
|
||||
|
||||
@@ -59,7 +59,7 @@ EXAMPLE_DOC_STRING = """
|
||||
>>> import torch
|
||||
>>> from diffusers import StableDiffusionPipeline
|
||||
|
||||
>>> pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
|
||||
>>> pipe = StableDiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16)
|
||||
>>> pipe = pipe.to("cuda")
|
||||
|
||||
>>> prompt = "a photo of an astronaut riding a horse on mars"
|
||||
@@ -392,7 +392,7 @@ class StableDiffusionBoxDiffPipeline(
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
safety_checker ([`StableDiffusionSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||||
Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
|
||||
Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for more details
|
||||
about a model's potential harms.
|
||||
feature_extractor ([`~transformers.CLIPImageProcessor`]):
|
||||
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
|
||||
@@ -473,8 +473,8 @@ class StableDiffusionBoxDiffPipeline(
|
||||
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
||||
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
|
||||
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
|
||||
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
|
||||
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
|
||||
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5"
|
||||
" \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
|
||||
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
|
||||
" in the config might lead to incorrect results in future versions. If you have downloaded this"
|
||||
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
|
||||
|
||||
@@ -42,7 +42,7 @@ EXAMPLE_DOC_STRING = """
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from diffusers import StableDiffusionPipeline
|
||||
>>> pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
|
||||
>>> pipe = StableDiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16)
|
||||
>>> pipe = pipe.to("cuda")
|
||||
>>> prompt = "a photo of an astronaut riding a horse on mars"
|
||||
>>> image = pipe(prompt).images[0]
|
||||
@@ -359,7 +359,7 @@ class StableDiffusionPAGPipeline(
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
safety_checker ([`StableDiffusionSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||||
Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
|
||||
Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for more details
|
||||
about a model's potential harms.
|
||||
feature_extractor ([`~transformers.CLIPImageProcessor`]):
|
||||
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
|
||||
@@ -440,8 +440,8 @@ class StableDiffusionPAGPipeline(
|
||||
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
||||
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
|
||||
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
|
||||
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
|
||||
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
|
||||
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5"
|
||||
" \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
|
||||
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
|
||||
" in the config might lead to incorrect results in future versions. If you have downloaded this"
|
||||
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
|
||||
|
||||
@@ -100,7 +100,7 @@ class StableDiffusionUpscaleLDM3DPipeline(
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
safety_checker ([`StableDiffusionSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||||
Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
|
||||
Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for more details
|
||||
about a model's potential harms.
|
||||
feature_extractor ([`~transformers.CLIPImageProcessor`]):
|
||||
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
|
||||
|
||||
@@ -2042,7 +2042,7 @@ class StableDiffusionXL_AE_Pipeline(
|
||||
|
||||
# 8. Check that sizes of mask, masked image and latents match
|
||||
if num_channels_unet == 9:
|
||||
# default case for runwayml/stable-diffusion-inpainting
|
||||
# default case for stable-diffusion-v1-5/stable-diffusion-inpainting
|
||||
num_channels_mask = mask.shape[1]
|
||||
num_channels_masked_image = masked_image_latents.shape[1]
|
||||
if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:
|
||||
|
||||
@@ -188,7 +188,7 @@ class StableDiffusionXLControlNetAdapterPipeline(
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
safety_checker ([`StableDiffusionSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||||
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
|
||||
Please, refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for details.
|
||||
feature_extractor ([`CLIPImageProcessor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
|
||||
@@ -330,7 +330,7 @@ class StableDiffusionXLControlNetAdapterInpaintPipeline(
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
safety_checker ([`StableDiffusionSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||||
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
|
||||
Please, refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for details.
|
||||
feature_extractor ([`CLIPImageProcessor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
requires_aesthetics_score (`bool`, *optional*, defaults to `"False"`):
|
||||
@@ -1569,7 +1569,7 @@ class StableDiffusionXLControlNetAdapterInpaintPipeline(
|
||||
|
||||
# 8. Check that sizes of mask, masked image and latents match
|
||||
if num_channels_unet == 9:
|
||||
# default case for runwayml/stable-diffusion-inpainting
|
||||
# default case for stable-diffusion-v1-5/stable-diffusion-inpainting
|
||||
num_channels_mask = mask.shape[1]
|
||||
num_channels_masked_image = masked_image_latents.shape[1]
|
||||
if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:
|
||||
|
||||
@@ -46,7 +46,7 @@ EXAMPLE_DOC_STRING = """
|
||||
>>> import torch
|
||||
>>> from diffusers import StableDiffusionPipeline
|
||||
|
||||
>>> pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
|
||||
>>> pipe = StableDiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16)
|
||||
>>> pipe = pipe.to("cuda")
|
||||
|
||||
>>> prompt = "a photo of an astronaut riding a horse on mars"
|
||||
@@ -86,7 +86,7 @@ class Zero1to3StableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
safety_checker ([`StableDiffusionSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||||
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
|
||||
Please, refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for details.
|
||||
feature_extractor ([`CLIPImageProcessor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
cc_projection ([`CCProjection`]):
|
||||
@@ -164,8 +164,8 @@ class Zero1to3StableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
|
||||
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
||||
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
|
||||
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
|
||||
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
|
||||
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
|
||||
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5"
|
||||
" \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
|
||||
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
|
||||
" in the config might lead to incorrect results in future versions. If you have downloaded this"
|
||||
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
|
||||
|
||||
@@ -288,7 +288,7 @@ class RerenderAVideoPipeline(StableDiffusionControlNetImg2ImgPipeline):
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
safety_checker ([`StableDiffusionSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||||
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
|
||||
Please, refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for details.
|
||||
feature_extractor ([`CLIPImageProcessor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
|
||||
@@ -54,7 +54,7 @@ EXAMPLE_DOC_STRING = """
|
||||
>>> # load control net and stable diffusion v1-5
|
||||
>>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
|
||||
>>> pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
|
||||
... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
|
||||
... "stable-diffusion-v1-5/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
|
||||
... )
|
||||
|
||||
>>> # speed up diffusion process with faster scheduler and memory optimization
|
||||
|
||||
@@ -158,7 +158,7 @@ EXAMPLE_DOC_STRING = """
|
||||
>>> # load control net and stable diffusion v1-5
|
||||
>>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
|
||||
>>> pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
|
||||
... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
|
||||
... "stable-diffusion-v1-5/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
|
||||
... )
|
||||
|
||||
>>> # speed up diffusion process with faster scheduler and memory optimization
|
||||
|
||||
@@ -64,7 +64,7 @@ class StableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
safety_checker ([`StableDiffusionSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||||
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
|
||||
Please, refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for details.
|
||||
feature_extractor ([`CLIPImageProcessor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
|
||||
@@ -114,7 +114,7 @@ class SdeDragPipeline(DiffusionPipeline):
|
||||
>>> from diffusers import DDIMScheduler, DiffusionPipeline
|
||||
|
||||
>>> # Load the pipeline
|
||||
>>> model_path = "runwayml/stable-diffusion-v1-5"
|
||||
>>> model_path = "stable-diffusion-v1-5/stable-diffusion-v1-5"
|
||||
>>> scheduler = DDIMScheduler.from_pretrained(model_path, subfolder="scheduler")
|
||||
>>> pipe = DiffusionPipeline.from_pretrained(model_path, scheduler=scheduler, custom_pipeline="sde_drag")
|
||||
>>> pipe.to('cuda')
|
||||
|
||||
@@ -46,7 +46,7 @@ class StableDiffusionComparisonPipeline(DiffusionPipeline, StableDiffusionMixin)
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
safety_checker ([`StableDiffusionMegaSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||||
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
|
||||
Please, refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for details.
|
||||
feature_extractor ([`CLIPImageProcessor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
|
||||
@@ -36,7 +36,7 @@ EXAMPLE_DOC_STRING = """
|
||||
>>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
|
||||
|
||||
>>> pipe_controlnet = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5",
|
||||
"stable-diffusion-v1-5/stable-diffusion-v1-5",
|
||||
controlnet=controlnet,
|
||||
safety_checker=None,
|
||||
torch_dtype=torch.float16
|
||||
|
||||
@@ -81,7 +81,7 @@ EXAMPLE_DOC_STRING = """
|
||||
>>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-seg", torch_dtype=torch.float16)
|
||||
|
||||
>>> pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-inpainting", controlnet=controlnet, safety_checker=None, torch_dtype=torch.float16
|
||||
"stable-diffusion-v1-5/stable-diffusion-inpainting", controlnet=controlnet, safety_checker=None, torch_dtype=torch.float16
|
||||
)
|
||||
|
||||
>>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
|
||||
|
||||
@@ -80,7 +80,7 @@ EXAMPLE_DOC_STRING = """
|
||||
>>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-seg", torch_dtype=torch.float16)
|
||||
|
||||
>>> pipe = StableDiffusionControlNetInpaintImg2ImgPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-inpainting", controlnet=controlnet, safety_checker=None, torch_dtype=torch.float16
|
||||
"stable-diffusion-v1-5/stable-diffusion-inpainting", controlnet=controlnet, safety_checker=None, torch_dtype=torch.float16
|
||||
)
|
||||
|
||||
>>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
|
||||
|
||||
@@ -37,7 +37,7 @@ EXAMPLE_DOC_STRING = """
|
||||
|
||||
>>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
|
||||
>>> pipe = StableDiffusionControlNetReferencePipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5",
|
||||
"stable-diffusion-v1-5/stable-diffusion-v1-5",
|
||||
controlnet=controlnet,
|
||||
safety_checker=None,
|
||||
torch_dtype=torch.float16
|
||||
|
||||
@@ -43,7 +43,7 @@ EXAMPLE_DOC_STRING = """
|
||||
>>> import torch
|
||||
>>> from diffusers import StableDiffusionPipeline
|
||||
|
||||
>>> pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", custom_pipeline="stable_diffusion_ipex")
|
||||
>>> pipe = DiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", custom_pipeline="stable_diffusion_ipex")
|
||||
|
||||
>>> # For Float32
|
||||
>>> pipe.prepare_for_ipex(prompt, dtype=torch.float32, height=512, width=512) #value of image height/width should be consistent with the pipeline inference
|
||||
@@ -85,7 +85,7 @@ class StableDiffusionIPEXPipeline(
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
safety_checker ([`StableDiffusionSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||||
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
|
||||
Please, refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for details.
|
||||
feature_extractor ([`CLIPImageProcessor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
@@ -161,8 +161,8 @@ class StableDiffusionIPEXPipeline(
|
||||
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
||||
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
|
||||
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
|
||||
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
|
||||
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
|
||||
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5"
|
||||
" \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
|
||||
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
|
||||
" in the config might lead to incorrect results in future versions. If you have downloaded this"
|
||||
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
|
||||
|
||||
@@ -47,7 +47,7 @@ class StableDiffusionMegaPipeline(DiffusionPipeline, StableDiffusionMixin):
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
safety_checker ([`StableDiffusionMegaSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||||
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
|
||||
Please, refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for details.
|
||||
feature_extractor ([`CLIPImageProcessor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
|
||||
@@ -46,7 +46,7 @@ EXAMPLE_DOC_STRING = """
|
||||
>>> input_image = load_image("https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png")
|
||||
|
||||
>>> pipe = StableDiffusionReferencePipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5",
|
||||
"stable-diffusion-v1-5/stable-diffusion-v1-5",
|
||||
safety_checker=None,
|
||||
torch_dtype=torch.float16
|
||||
).to('cuda:0')
|
||||
@@ -112,7 +112,7 @@ class StableDiffusionReferencePipeline(
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
safety_checker ([`StableDiffusionSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||||
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
|
||||
Please, refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for details.
|
||||
feature_extractor ([`CLIPImageProcessor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
@@ -194,8 +194,8 @@ class StableDiffusionReferencePipeline(
|
||||
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
||||
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
|
||||
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
|
||||
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
|
||||
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
|
||||
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5"
|
||||
" \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
|
||||
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
|
||||
" in the config might lead to incorrect results in future versions. If you have downloaded this"
|
||||
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
|
||||
|
||||
@@ -167,7 +167,7 @@ class StableDiffusionRepaintPipeline(
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
safety_checker ([`StableDiffusionSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||||
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
|
||||
Please, refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for details.
|
||||
feature_extractor ([`CLIPImageProcessor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
@@ -249,8 +249,8 @@ class StableDiffusionRepaintPipeline(
|
||||
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
||||
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
|
||||
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
|
||||
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
|
||||
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
|
||||
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5"
|
||||
" \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
|
||||
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
|
||||
" in the config might lead to incorrect results in future versions. If you have downloaded this"
|
||||
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
|
||||
|
||||
@@ -678,7 +678,7 @@ class TensorRTStableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
safety_checker ([`StableDiffusionSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||||
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
|
||||
Please, refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for details.
|
||||
feature_extractor ([`CLIPImageProcessor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
@@ -766,8 +766,8 @@ class TensorRTStableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
||||
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
|
||||
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
|
||||
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
|
||||
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
|
||||
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5"
|
||||
" \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
|
||||
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
|
||||
" in the config might lead to incorrect results in future versions. If you have downloaded this"
|
||||
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
|
||||
|
||||
@@ -682,7 +682,7 @@ class TensorRTStableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
safety_checker ([`StableDiffusionSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||||
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
|
||||
Please, refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for details.
|
||||
feature_extractor ([`CLIPImageProcessor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
@@ -770,8 +770,8 @@ class TensorRTStableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
||||
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
|
||||
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
|
||||
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
|
||||
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
|
||||
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5"
|
||||
" \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
|
||||
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
|
||||
" in the config might lead to incorrect results in future versions. If you have downloaded this"
|
||||
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
|
||||
|
||||
@@ -594,7 +594,7 @@ class TensorRTStableDiffusionPipeline(DiffusionPipeline):
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
safety_checker ([`StableDiffusionSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||||
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
|
||||
Please, refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for details.
|
||||
feature_extractor ([`CLIPImageProcessor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
@@ -682,8 +682,8 @@ class TensorRTStableDiffusionPipeline(DiffusionPipeline):
|
||||
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
||||
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
|
||||
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
|
||||
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
|
||||
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
|
||||
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5"
|
||||
" \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
|
||||
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
|
||||
" in the config might lead to incorrect results in future versions. If you have downloaded this"
|
||||
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
|
||||
|
||||
@@ -52,7 +52,7 @@ class TextInpainting(DiffusionPipeline, StableDiffusionMixin):
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
safety_checker ([`StableDiffusionSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||||
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
|
||||
Please, refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for details.
|
||||
feature_extractor ([`CLIPImageProcessor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
|
||||
@@ -1223,7 +1223,7 @@ class AnyTextPipeline(
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
safety_checker ([`StableDiffusionSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||||
Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
|
||||
Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for more details
|
||||
about a model's potential harms.
|
||||
feature_extractor ([`~transformers.CLIPImageProcessor`]):
|
||||
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
|
||||
|
||||
@@ -5,7 +5,7 @@ This script was added by @thedarkzeno .
|
||||
Please note that this script is not actively maintained, you can open an issue and tag @thedarkzeno or @patil-suraj though.
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="runwayml/stable-diffusion-inpainting"
|
||||
export MODEL_NAME="stable-diffusion-v1-5/stable-diffusion-inpainting"
|
||||
export INSTANCE_DIR="path-to-instance-images"
|
||||
export OUTPUT_DIR="path-to-save-model"
|
||||
|
||||
@@ -29,7 +29,7 @@ Prior-preservation is used to avoid overfitting and language-drift. Refer to the
|
||||
According to the paper, it's recommended to generate `num_epochs * num_samples` images for prior-preservation. 200-300 works well for most cases.
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="runwayml/stable-diffusion-inpainting"
|
||||
export MODEL_NAME="stable-diffusion-v1-5/stable-diffusion-inpainting"
|
||||
export INSTANCE_DIR="path-to-instance-images"
|
||||
export CLASS_DIR="path-to-class-images"
|
||||
export OUTPUT_DIR="path-to-save-model"
|
||||
@@ -60,7 +60,7 @@ With the help of gradient checkpointing and the 8-bit optimizer from bitsandbyte
|
||||
To install `bitandbytes` please refer to this [readme](https://github.com/TimDettmers/bitsandbytes#requirements--installation).
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="runwayml/stable-diffusion-inpainting"
|
||||
export MODEL_NAME="stable-diffusion-v1-5/stable-diffusion-inpainting"
|
||||
export INSTANCE_DIR="path-to-instance-images"
|
||||
export CLASS_DIR="path-to-class-images"
|
||||
export OUTPUT_DIR="path-to-save-model"
|
||||
@@ -92,7 +92,7 @@ Pass the `--train_text_encoder` argument to the script to enable training `text_
|
||||
___Note: Training text encoder requires more memory, with this option the training won't fit on 16GB GPU. It needs at least 24GB VRAM.___
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="runwayml/stable-diffusion-inpainting"
|
||||
export MODEL_NAME="stable-diffusion-v1-5/stable-diffusion-inpainting"
|
||||
export INSTANCE_DIR="path-to-instance-images"
|
||||
export CLASS_DIR="path-to-class-images"
|
||||
export OUTPUT_DIR="path-to-save-model"
|
||||
|
||||
@@ -55,7 +55,7 @@ The Accelerate launch command is used to train a model using multiple GPUs and m
|
||||
```
|
||||
accelerate launch --mixed_precision "fp16" \
|
||||
tutorial_train_ip-adapter.py \
|
||||
--pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5/" \
|
||||
--pretrained_model_name_or_path="stable-diffusion-v1-5/stable-diffusion-v1-5/" \
|
||||
--image_encoder_path="{image_encoder_path}" \
|
||||
--data_json_file="{data.json}" \
|
||||
--data_root_path="{image_path}" \
|
||||
@@ -73,7 +73,7 @@ tutorial_train_ip-adapter.py \
|
||||
```
|
||||
accelerate launch --num_processes 8 --multi_gpu --mixed_precision "fp16" \
|
||||
tutorial_train_ip-adapter.py \
|
||||
--pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5/" \
|
||||
--pretrained_model_name_or_path="stable-diffusion-v1-5/stable-diffusion-v1-5/" \
|
||||
--image_encoder_path="{image_encoder_path}" \
|
||||
--data_json_file="{data.json}" \
|
||||
--data_root_path="{image_path}" \
|
||||
|
||||
@@ -27,7 +27,7 @@ You can build multiple datasets for every subject and upload them to the 🤗 hu
|
||||
Before launching the training script, make sure to select the inpainting the target model, the output directory and the 🤗 datasets.
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="runwayml/stable-diffusion-inpainting"
|
||||
export MODEL_NAME="stable-diffusion-v1-5/stable-diffusion-inpainting"
|
||||
export OUTPUT_DIR="path-to-save-model"
|
||||
|
||||
export DATASET_1="gzguevara/mr_potato_head_masked"
|
||||
|
||||
@@ -177,7 +177,7 @@ class PromptDiffusionPipeline(
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
safety_checker ([`StableDiffusionSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||||
Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
|
||||
Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for more details
|
||||
about a model's potential harms.
|
||||
feature_extractor ([`~transformers.CLIPImageProcessor`]):
|
||||
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
|
||||
|
||||
@@ -238,7 +238,7 @@ def parse_args() -> argparse.Namespace:
|
||||
|
||||
# EXAMPLE USAGE:
|
||||
#
|
||||
# python vae_roundtrip.py --use_cuda --pretrained_model_name_or_path "runwayml/stable-diffusion-v1-5" --subfolder "vae" --input_image "foo.png"
|
||||
# python vae_roundtrip.py --use_cuda --pretrained_model_name_or_path "stable-diffusion-v1-5/stable-diffusion-v1-5" --subfolder "vae" --input_image "foo.png"
|
||||
#
|
||||
# python vae_roundtrip.py --use_cuda --pretrained_model_name_or_path "madebyollin/taesd" --use_tiny_nn --input_image "foo.png"
|
||||
#
|
||||
|
||||
@@ -24,7 +24,8 @@ args = args.parse_args()
|
||||
|
||||
|
||||
def _extract_into_tensor(arr, timesteps, broadcast_shape):
|
||||
# from: https://github.com/openai/guided-diffusion/blob/22e0df8183507e13a7813f8d38d51b072ca1e67c/guided_diffusion/gaussian_diffusion.py#L895 """
|
||||
# from: https://github.com/openai/guided-diffusion/blob/22e0df8183507e13a7813f8d38d51b072ca1e67c/guided_diffusion/gaussian_diffusion.py#L895
|
||||
# """
|
||||
res = arr[timesteps].float()
|
||||
dims_to_append = len(broadcast_shape) - len(res.shape)
|
||||
return res[(...,) + (None,) * dims_to_append]
|
||||
@@ -507,7 +508,9 @@ def rename_state_dict(sd, embedding):
|
||||
|
||||
|
||||
# encode with stable diffusion vae
|
||||
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
"stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16
|
||||
)
|
||||
pipe.vae.cuda()
|
||||
|
||||
# construct original decoder with jitted model
|
||||
@@ -1090,7 +1093,7 @@ def new_constructor(self, **kwargs):
|
||||
Encoder.__init__ = new_constructor
|
||||
|
||||
|
||||
vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae")
|
||||
vae = AutoencoderKL.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="vae")
|
||||
consistency_vae = ConsistencyDecoderVAE(
|
||||
encoder_args=vae.encoder.constructor_arguments,
|
||||
decoder_args=unet.config,
|
||||
@@ -1117,7 +1120,7 @@ print((sample_consistency_orig - sample_consistency_new_3).abs().sum())
|
||||
print("running with diffusers pipeline")
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5", vae=consistency_vae, torch_dtype=torch.float16
|
||||
"stable-diffusion-v1-5/stable-diffusion-v1-5", vae=consistency_vae, torch_dtype=torch.float16
|
||||
)
|
||||
pipe.to("cuda")
|
||||
|
||||
|
||||
@@ -0,0 +1,345 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Script to convert PRX checkpoint from original codebase to diffusers format.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import torch
|
||||
from safetensors.torch import save_file
|
||||
|
||||
from diffusers.models.transformers.transformer_prx import PRXTransformer2DModel
|
||||
from diffusers.pipelines.prx import PRXPipeline
|
||||
|
||||
|
||||
DEFAULT_RESOLUTION = 512
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PRXBase:
|
||||
context_in_dim: int = 2304
|
||||
hidden_size: int = 1792
|
||||
mlp_ratio: float = 3.5
|
||||
num_heads: int = 28
|
||||
depth: int = 16
|
||||
axes_dim: Tuple[int, int] = (32, 32)
|
||||
theta: int = 10_000
|
||||
time_factor: float = 1000.0
|
||||
time_max_period: int = 10_000
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PRXFlux(PRXBase):
|
||||
in_channels: int = 16
|
||||
patch_size: int = 2
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PRXDCAE(PRXBase):
|
||||
in_channels: int = 32
|
||||
patch_size: int = 1
|
||||
|
||||
|
||||
def build_config(vae_type: str) -> Tuple[dict, int]:
|
||||
if vae_type == "flux":
|
||||
cfg = PRXFlux()
|
||||
elif vae_type == "dc-ae":
|
||||
cfg = PRXDCAE()
|
||||
else:
|
||||
raise ValueError(f"Unsupported VAE type: {vae_type}. Use 'flux' or 'dc-ae'")
|
||||
|
||||
config_dict = asdict(cfg)
|
||||
config_dict["axes_dim"] = list(config_dict["axes_dim"]) # type: ignore[index]
|
||||
return config_dict
|
||||
|
||||
|
||||
def create_parameter_mapping(depth: int) -> dict:
|
||||
"""Create mapping from old parameter names to new diffusers names."""
|
||||
|
||||
# Key mappings for structural changes
|
||||
mapping = {}
|
||||
|
||||
# Map old structure (layers in PRXBlock) to new structure (layers in PRXAttention)
|
||||
for i in range(depth):
|
||||
# QKV projections moved to attention module
|
||||
mapping[f"blocks.{i}.img_qkv_proj.weight"] = f"blocks.{i}.attention.img_qkv_proj.weight"
|
||||
mapping[f"blocks.{i}.txt_kv_proj.weight"] = f"blocks.{i}.attention.txt_kv_proj.weight"
|
||||
|
||||
# QK norm moved to attention module and renamed to match Attention's qk_norm structure
|
||||
mapping[f"blocks.{i}.qk_norm.query_norm.scale"] = f"blocks.{i}.attention.norm_q.weight"
|
||||
mapping[f"blocks.{i}.qk_norm.key_norm.scale"] = f"blocks.{i}.attention.norm_k.weight"
|
||||
mapping[f"blocks.{i}.qk_norm.query_norm.weight"] = f"blocks.{i}.attention.norm_q.weight"
|
||||
mapping[f"blocks.{i}.qk_norm.key_norm.weight"] = f"blocks.{i}.attention.norm_k.weight"
|
||||
|
||||
# K norm for text tokens moved to attention module
|
||||
mapping[f"blocks.{i}.k_norm.scale"] = f"blocks.{i}.attention.norm_added_k.weight"
|
||||
mapping[f"blocks.{i}.k_norm.weight"] = f"blocks.{i}.attention.norm_added_k.weight"
|
||||
|
||||
# Attention output projection
|
||||
mapping[f"blocks.{i}.attn_out.weight"] = f"blocks.{i}.attention.to_out.0.weight"
|
||||
|
||||
return mapping
|
||||
|
||||
|
||||
def convert_checkpoint_parameters(old_state_dict: Dict[str, torch.Tensor], depth: int) -> Dict[str, torch.Tensor]:
|
||||
"""Convert old checkpoint parameters to new diffusers format."""
|
||||
|
||||
print("Converting checkpoint parameters...")
|
||||
|
||||
mapping = create_parameter_mapping(depth)
|
||||
converted_state_dict = {}
|
||||
|
||||
for key, value in old_state_dict.items():
|
||||
new_key = key
|
||||
|
||||
# Apply specific mappings if needed
|
||||
if key in mapping:
|
||||
new_key = mapping[key]
|
||||
print(f" Mapped: {key} -> {new_key}")
|
||||
|
||||
converted_state_dict[new_key] = value
|
||||
|
||||
print(f"✓ Converted {len(converted_state_dict)} parameters")
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def create_transformer_from_checkpoint(checkpoint_path: str, config: dict) -> PRXTransformer2DModel:
|
||||
"""Create and load PRXTransformer2DModel from old checkpoint."""
|
||||
|
||||
print(f"Loading checkpoint from: {checkpoint_path}")
|
||||
|
||||
# Load old checkpoint
|
||||
if not os.path.exists(checkpoint_path):
|
||||
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
|
||||
|
||||
old_checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
||||
|
||||
# Handle different checkpoint formats
|
||||
if isinstance(old_checkpoint, dict):
|
||||
if "model" in old_checkpoint:
|
||||
state_dict = old_checkpoint["model"]
|
||||
elif "state_dict" in old_checkpoint:
|
||||
state_dict = old_checkpoint["state_dict"]
|
||||
else:
|
||||
state_dict = old_checkpoint
|
||||
else:
|
||||
state_dict = old_checkpoint
|
||||
|
||||
print(f"✓ Loaded checkpoint with {len(state_dict)} parameters")
|
||||
|
||||
# Convert parameter names if needed
|
||||
model_depth = int(config.get("depth", 16))
|
||||
converted_state_dict = convert_checkpoint_parameters(state_dict, depth=model_depth)
|
||||
|
||||
# Create transformer with config
|
||||
print("Creating PRXTransformer2DModel...")
|
||||
transformer = PRXTransformer2DModel(**config)
|
||||
|
||||
# Load state dict
|
||||
print("Loading converted parameters...")
|
||||
missing_keys, unexpected_keys = transformer.load_state_dict(converted_state_dict, strict=False)
|
||||
|
||||
if missing_keys:
|
||||
print(f"⚠ Missing keys: {missing_keys}")
|
||||
if unexpected_keys:
|
||||
print(f"⚠ Unexpected keys: {unexpected_keys}")
|
||||
|
||||
if not missing_keys and not unexpected_keys:
|
||||
print("✓ All parameters loaded successfully!")
|
||||
|
||||
return transformer
|
||||
|
||||
|
||||
def create_scheduler_config(output_path: str, shift: float):
|
||||
"""Create FlowMatchEulerDiscreteScheduler config."""
|
||||
|
||||
scheduler_config = {"_class_name": "FlowMatchEulerDiscreteScheduler", "num_train_timesteps": 1000, "shift": shift}
|
||||
|
||||
scheduler_path = os.path.join(output_path, "scheduler")
|
||||
os.makedirs(scheduler_path, exist_ok=True)
|
||||
|
||||
with open(os.path.join(scheduler_path, "scheduler_config.json"), "w") as f:
|
||||
json.dump(scheduler_config, f, indent=2)
|
||||
|
||||
print("✓ Created scheduler config")
|
||||
|
||||
|
||||
def download_and_save_vae(vae_type: str, output_path: str):
|
||||
"""Download and save VAE to local directory."""
|
||||
from diffusers import AutoencoderDC, AutoencoderKL
|
||||
|
||||
vae_path = os.path.join(output_path, "vae")
|
||||
os.makedirs(vae_path, exist_ok=True)
|
||||
|
||||
if vae_type == "flux":
|
||||
print("Downloading FLUX VAE from black-forest-labs/FLUX.1-dev...")
|
||||
vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae")
|
||||
else: # dc-ae
|
||||
print("Downloading DC-AE VAE from mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers...")
|
||||
vae = AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers")
|
||||
|
||||
vae.save_pretrained(vae_path)
|
||||
print(f"✓ Saved VAE to {vae_path}")
|
||||
|
||||
|
||||
def download_and_save_text_encoder(output_path: str):
|
||||
"""Download and save T5Gemma text encoder and tokenizer."""
|
||||
from transformers import GemmaTokenizerFast
|
||||
from transformers.models.t5gemma.modeling_t5gemma import T5GemmaModel
|
||||
|
||||
text_encoder_path = os.path.join(output_path, "text_encoder")
|
||||
tokenizer_path = os.path.join(output_path, "tokenizer")
|
||||
os.makedirs(text_encoder_path, exist_ok=True)
|
||||
os.makedirs(tokenizer_path, exist_ok=True)
|
||||
|
||||
print("Downloading T5Gemma model from google/t5gemma-2b-2b-ul2...")
|
||||
t5gemma_model = T5GemmaModel.from_pretrained("google/t5gemma-2b-2b-ul2")
|
||||
|
||||
# Extract and save only the encoder
|
||||
t5gemma_encoder = t5gemma_model.encoder
|
||||
t5gemma_encoder.save_pretrained(text_encoder_path)
|
||||
print(f"✓ Saved T5GemmaEncoder to {text_encoder_path}")
|
||||
|
||||
print("Downloading tokenizer from google/t5gemma-2b-2b-ul2...")
|
||||
tokenizer = GemmaTokenizerFast.from_pretrained("google/t5gemma-2b-2b-ul2")
|
||||
tokenizer.model_max_length = 256
|
||||
tokenizer.save_pretrained(tokenizer_path)
|
||||
print(f"✓ Saved tokenizer to {tokenizer_path}")
|
||||
|
||||
|
||||
def create_model_index(vae_type: str, default_image_size: int, output_path: str):
|
||||
"""Create model_index.json for the pipeline."""
|
||||
|
||||
if vae_type == "flux":
|
||||
vae_class = "AutoencoderKL"
|
||||
else: # dc-ae
|
||||
vae_class = "AutoencoderDC"
|
||||
|
||||
model_index = {
|
||||
"_class_name": "PRXPipeline",
|
||||
"_diffusers_version": "0.31.0.dev0",
|
||||
"_name_or_path": os.path.basename(output_path),
|
||||
"default_sample_size": default_image_size,
|
||||
"scheduler": ["diffusers", "FlowMatchEulerDiscreteScheduler"],
|
||||
"text_encoder": ["prx", "T5GemmaEncoder"],
|
||||
"tokenizer": ["transformers", "GemmaTokenizerFast"],
|
||||
"transformer": ["diffusers", "PRXTransformer2DModel"],
|
||||
"vae": ["diffusers", vae_class],
|
||||
}
|
||||
|
||||
model_index_path = os.path.join(output_path, "model_index.json")
|
||||
with open(model_index_path, "w") as f:
|
||||
json.dump(model_index, f, indent=2)
|
||||
|
||||
|
||||
def main(args):
|
||||
# Validate inputs
|
||||
if not os.path.exists(args.checkpoint_path):
|
||||
raise FileNotFoundError(f"Checkpoint not found: {args.checkpoint_path}")
|
||||
|
||||
config = build_config(args.vae_type)
|
||||
|
||||
# Create output directory
|
||||
os.makedirs(args.output_path, exist_ok=True)
|
||||
print(f"✓ Output directory: {args.output_path}")
|
||||
|
||||
# Create transformer from checkpoint
|
||||
transformer = create_transformer_from_checkpoint(args.checkpoint_path, config)
|
||||
|
||||
# Save transformer
|
||||
transformer_path = os.path.join(args.output_path, "transformer")
|
||||
os.makedirs(transformer_path, exist_ok=True)
|
||||
|
||||
# Save config
|
||||
with open(os.path.join(transformer_path, "config.json"), "w") as f:
|
||||
json.dump(config, f, indent=2)
|
||||
|
||||
# Save model weights as safetensors
|
||||
state_dict = transformer.state_dict()
|
||||
save_file(state_dict, os.path.join(transformer_path, "diffusion_pytorch_model.safetensors"))
|
||||
print(f"✓ Saved transformer to {transformer_path}")
|
||||
|
||||
# Create scheduler config
|
||||
create_scheduler_config(args.output_path, args.shift)
|
||||
|
||||
download_and_save_vae(args.vae_type, args.output_path)
|
||||
download_and_save_text_encoder(args.output_path)
|
||||
|
||||
# Create model_index.json
|
||||
create_model_index(args.vae_type, args.resolution, args.output_path)
|
||||
|
||||
# Verify the pipeline can be loaded
|
||||
try:
|
||||
pipeline = PRXPipeline.from_pretrained(args.output_path)
|
||||
print("Pipeline loaded successfully!")
|
||||
print(f"Transformer: {type(pipeline.transformer).__name__}")
|
||||
print(f"VAE: {type(pipeline.vae).__name__}")
|
||||
print(f"Text Encoder: {type(pipeline.text_encoder).__name__}")
|
||||
print(f"Scheduler: {type(pipeline.scheduler).__name__}")
|
||||
|
||||
# Display model info
|
||||
num_params = sum(p.numel() for p in pipeline.transformer.parameters())
|
||||
print(f"✓ Transformer parameters: {num_params:,}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Pipeline verification failed: {e}")
|
||||
return False
|
||||
|
||||
print("Conversion completed successfully!")
|
||||
print(f"Converted pipeline saved to: {args.output_path}")
|
||||
print(f"VAE type: {args.vae_type}")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Convert PRX checkpoint to diffusers format")
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint_path", type=str, required=True, help="Path to the original PRX checkpoint (.pth file )"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output_path", type=str, required=True, help="Output directory for the converted diffusers pipeline"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--vae_type",
|
||||
type=str,
|
||||
choices=["flux", "dc-ae"],
|
||||
required=True,
|
||||
help="VAE type to use: 'flux' for AutoencoderKL (16 channels) or 'dc-ae' for AutoencoderDC (32 channels)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--resolution",
|
||||
type=int,
|
||||
choices=[256, 512, 1024],
|
||||
default=DEFAULT_RESOLUTION,
|
||||
help="Target resolution for the model (256, 512, or 1024). Affects the transformer's sample_size.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--shift",
|
||||
type=float,
|
||||
default=3.0,
|
||||
help="Shift for the scheduler",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
success = main(args)
|
||||
if not success:
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
print(f"Conversion failed: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
@@ -220,6 +220,7 @@ else:
|
||||
"HunyuanVideoTransformer3DModel",
|
||||
"I2VGenXLUNet",
|
||||
"Kandinsky3UNet",
|
||||
"Kandinsky5Transformer3DModel",
|
||||
"LatteTransformer3DModel",
|
||||
"LTXVideoTransformer3DModel",
|
||||
"Lumina2Transformer2DModel",
|
||||
@@ -233,6 +234,7 @@ else:
|
||||
"ParallelConfig",
|
||||
"PixArtTransformer2DModel",
|
||||
"PriorTransformer",
|
||||
"PRXTransformer2DModel",
|
||||
"QwenImageControlNetModel",
|
||||
"QwenImageMultiControlNetModel",
|
||||
"QwenImageTransformer2DModel",
|
||||
@@ -474,6 +476,7 @@ else:
|
||||
"ImageTextPipelineOutput",
|
||||
"Kandinsky3Img2ImgPipeline",
|
||||
"Kandinsky3Pipeline",
|
||||
"Kandinsky5T2VPipeline",
|
||||
"KandinskyCombinedPipeline",
|
||||
"KandinskyImg2ImgCombinedPipeline",
|
||||
"KandinskyImg2ImgPipeline",
|
||||
@@ -517,6 +520,7 @@ else:
|
||||
"PixArtAlphaPipeline",
|
||||
"PixArtSigmaPAGPipeline",
|
||||
"PixArtSigmaPipeline",
|
||||
"PRXPipeline",
|
||||
"QwenImageControlNetInpaintPipeline",
|
||||
"QwenImageControlNetPipeline",
|
||||
"QwenImageEditInpaintPipeline",
|
||||
@@ -912,6 +916,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
HunyuanVideoTransformer3DModel,
|
||||
I2VGenXLUNet,
|
||||
Kandinsky3UNet,
|
||||
Kandinsky5Transformer3DModel,
|
||||
LatteTransformer3DModel,
|
||||
LTXVideoTransformer3DModel,
|
||||
Lumina2Transformer2DModel,
|
||||
@@ -925,6 +930,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
ParallelConfig,
|
||||
PixArtTransformer2DModel,
|
||||
PriorTransformer,
|
||||
PRXTransformer2DModel,
|
||||
QwenImageControlNetModel,
|
||||
QwenImageMultiControlNetModel,
|
||||
QwenImageTransformer2DModel,
|
||||
@@ -1136,6 +1142,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
ImageTextPipelineOutput,
|
||||
Kandinsky3Img2ImgPipeline,
|
||||
Kandinsky3Pipeline,
|
||||
Kandinsky5T2VPipeline,
|
||||
KandinskyCombinedPipeline,
|
||||
KandinskyImg2ImgCombinedPipeline,
|
||||
KandinskyImg2ImgPipeline,
|
||||
@@ -1179,6 +1186,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
PixArtAlphaPipeline,
|
||||
PixArtSigmaPAGPipeline,
|
||||
PixArtSigmaPipeline,
|
||||
PRXPipeline,
|
||||
QwenImageControlNetInpaintPipeline,
|
||||
QwenImageControlNetPipeline,
|
||||
QwenImageEditInpaintPipeline,
|
||||
|
||||
@@ -77,6 +77,7 @@ if is_torch_available():
|
||||
"SanaLoraLoaderMixin",
|
||||
"Lumina2LoraLoaderMixin",
|
||||
"WanLoraLoaderMixin",
|
||||
"KandinskyLoraLoaderMixin",
|
||||
"HiDreamImageLoraLoaderMixin",
|
||||
"SkyReelsV2LoraLoaderMixin",
|
||||
"QwenImageLoraLoaderMixin",
|
||||
@@ -115,6 +116,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
FluxLoraLoaderMixin,
|
||||
HiDreamImageLoraLoaderMixin,
|
||||
HunyuanVideoLoraLoaderMixin,
|
||||
KandinskyLoraLoaderMixin,
|
||||
LoraLoaderMixin,
|
||||
LTXVideoLoraLoaderMixin,
|
||||
Lumina2LoraLoaderMixin,
|
||||
|
||||
@@ -3639,6 +3639,291 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
|
||||
super().unfuse_lora(components=components, **kwargs)
|
||||
|
||||
|
||||
class KandinskyLoraLoaderMixin(LoraBaseMixin):
|
||||
r"""
|
||||
Load LoRA layers into [`Kandinsky5Transformer3DModel`],
|
||||
"""
|
||||
|
||||
_lora_loadable_modules = ["transformer"]
|
||||
transformer_name = TRANSFORMER_NAME
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
def lora_state_dict(
|
||||
cls,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Return state dict for lora weights and the network alphas.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
||||
Can be either:
|
||||
- A string, the *model id* of a pretrained model hosted on the Hub.
|
||||
- A path to a *directory* containing the model weights.
|
||||
- A [torch state
|
||||
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
||||
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether to only load local model weights and configuration files.
|
||||
token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use.
|
||||
subfolder (`str`, *optional*, defaults to `""`):
|
||||
The subfolder location of a model file within a larger model repository.
|
||||
weight_name (`str`, *optional*, defaults to None):
|
||||
Name of the serialized state dict file.
|
||||
use_safetensors (`bool`, *optional*):
|
||||
Whether to use safetensors for loading.
|
||||
return_lora_metadata (`bool`, *optional*, defaults to False):
|
||||
When enabled, additionally return the LoRA adapter metadata.
|
||||
"""
|
||||
# Load the main state dict first which has the LoRA layers
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
token = kwargs.pop("token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
weight_name = kwargs.pop("weight_name", None)
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
return_lora_metadata = kwargs.pop("return_lora_metadata", False)
|
||||
|
||||
allow_pickle = False
|
||||
if use_safetensors is None:
|
||||
use_safetensors = True
|
||||
allow_pickle = True
|
||||
|
||||
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
|
||||
|
||||
state_dict, metadata = _fetch_state_dict(
|
||||
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
||||
weight_name=weight_name,
|
||||
use_safetensors=use_safetensors,
|
||||
local_files_only=local_files_only,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
allow_pickle=allow_pickle,
|
||||
)
|
||||
|
||||
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
|
||||
if is_dora_scale_present:
|
||||
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
|
||||
logger.warning(warn_msg)
|
||||
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
||||
|
||||
out = (state_dict, metadata) if return_lora_metadata else state_dict
|
||||
return out
|
||||
|
||||
def load_lora_weights(
|
||||
self,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
adapter_name: Optional[str] = None,
|
||||
hotswap: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer`
|
||||
|
||||
Parameters:
|
||||
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
||||
See [`~loaders.KandinskyLoraLoaderMixin.lora_state_dict`].
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model.
|
||||
hotswap (`bool`, *optional*):
|
||||
Whether to substitute an existing (LoRA) adapter with the newly loaded adapter in-place.
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
kwargs (`dict`, *optional*):
|
||||
See [`~loaders.KandinskyLoraLoaderMixin.lora_state_dict`].
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
|
||||
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
|
||||
# if a dict is passed, copy it instead of modifying it inplace
|
||||
if isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
||||
|
||||
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
||||
kwargs["return_lora_metadata"] = True
|
||||
state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
# Load LoRA into transformer
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
||||
adapter_name=adapter_name,
|
||||
metadata=metadata,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def load_lora_into_transformer(
|
||||
cls,
|
||||
state_dict,
|
||||
transformer,
|
||||
adapter_name=None,
|
||||
_pipeline=None,
|
||||
low_cpu_mem_usage=False,
|
||||
hotswap: bool = False,
|
||||
metadata=None,
|
||||
):
|
||||
"""
|
||||
Load the LoRA layers specified in `state_dict` into `transformer`.
|
||||
|
||||
Parameters:
|
||||
state_dict (`dict`):
|
||||
A standard state dict containing the lora layer parameters.
|
||||
transformer (`Kandinsky5Transformer3DModel`):
|
||||
The transformer model to load the LoRA layers into.
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model.
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights.
|
||||
hotswap (`bool`, *optional*):
|
||||
See [`~loaders.KandinskyLoraLoaderMixin.load_lora_weights`].
|
||||
metadata (`dict`):
|
||||
Optional LoRA adapter metadata.
|
||||
"""
|
||||
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
|
||||
# Load the layers corresponding to transformer.
|
||||
logger.info(f"Loading {cls.transformer_name}.")
|
||||
transformer.load_lora_adapter(
|
||||
state_dict,
|
||||
network_alphas=None,
|
||||
adapter_name=adapter_name,
|
||||
metadata=metadata,
|
||||
_pipeline=_pipeline,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def save_lora_weights(
|
||||
cls,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||
is_main_process: bool = True,
|
||||
weight_name: str = None,
|
||||
save_function: Callable = None,
|
||||
safe_serialization: bool = True,
|
||||
transformer_lora_adapter_metadata=None,
|
||||
):
|
||||
r"""
|
||||
Save the LoRA parameters corresponding to the transformer and text encoders.
|
||||
|
||||
Arguments:
|
||||
save_directory (`str` or `os.PathLike`):
|
||||
Directory to save LoRA parameters to.
|
||||
transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
||||
State dict of the LoRA layers corresponding to the `transformer`.
|
||||
is_main_process (`bool`, *optional*, defaults to `True`):
|
||||
Whether the process calling this is the main process.
|
||||
save_function (`Callable`):
|
||||
The function to use to save the state dictionary.
|
||||
safe_serialization (`bool`, *optional*, defaults to `True`):
|
||||
Whether to save the model using `safetensors` or the traditional PyTorch way.
|
||||
transformer_lora_adapter_metadata:
|
||||
LoRA adapter metadata associated with the transformer.
|
||||
"""
|
||||
lora_layers = {}
|
||||
lora_metadata = {}
|
||||
|
||||
if transformer_lora_layers:
|
||||
lora_layers[cls.transformer_name] = transformer_lora_layers
|
||||
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
|
||||
|
||||
if not lora_layers:
|
||||
raise ValueError("You must pass at least one of `transformer_lora_layers`")
|
||||
|
||||
cls._save_lora_weights(
|
||||
save_directory=save_directory,
|
||||
lora_layers=lora_layers,
|
||||
lora_metadata=lora_metadata,
|
||||
is_main_process=is_main_process,
|
||||
weight_name=weight_name,
|
||||
save_function=save_function,
|
||||
safe_serialization=safe_serialization,
|
||||
)
|
||||
|
||||
def fuse_lora(
|
||||
self,
|
||||
components: List[str] = ["transformer"],
|
||||
lora_scale: float = 1.0,
|
||||
safe_fusing: bool = False,
|
||||
adapter_names: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
|
||||
|
||||
Args:
|
||||
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
|
||||
lora_scale (`float`, defaults to 1.0):
|
||||
Controls how much to influence the outputs with the LoRA parameters.
|
||||
safe_fusing (`bool`, defaults to `False`):
|
||||
Whether to check fused weights for NaN values before fusing.
|
||||
adapter_names (`List[str]`, *optional*):
|
||||
Adapter names to be used for fusing.
|
||||
|
||||
Example:
|
||||
```py
|
||||
from diffusers import Kandinsky5T2VPipeline
|
||||
|
||||
pipeline = Kandinsky5T2VPipeline.from_pretrained("ai-forever/Kandinsky-5.0-T2V")
|
||||
pipeline.load_lora_weights("path/to/lora.safetensors")
|
||||
pipeline.fuse_lora(lora_scale=0.7)
|
||||
```
|
||||
"""
|
||||
super().fuse_lora(
|
||||
components=components,
|
||||
lora_scale=lora_scale,
|
||||
safe_fusing=safe_fusing,
|
||||
adapter_names=adapter_names,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
|
||||
r"""
|
||||
Reverses the effect of [`pipe.fuse_lora()`].
|
||||
|
||||
Args:
|
||||
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
||||
"""
|
||||
super().unfuse_lora(components=components, **kwargs)
|
||||
|
||||
|
||||
class WanLoraLoaderMixin(LoraBaseMixin):
|
||||
r"""
|
||||
Load LoRA layers into [`WanTransformer3DModel`]. Specific to [`WanPipeline`] and `[WanImageToVideoPipeline`].
|
||||
|
||||
@@ -293,7 +293,7 @@ class PeftAdapterMixin:
|
||||
# For hotswapping, we need the adapter name to be present in the state dict keys
|
||||
new_sd = {}
|
||||
for k, v in sd.items():
|
||||
if k.endswith("lora_A.weight") or key.endswith("lora_B.weight"):
|
||||
if k.endswith("lora_A.weight") or k.endswith("lora_B.weight"):
|
||||
k = k[: -len(".weight")] + f".{adapter_name}.weight"
|
||||
elif k.endswith("lora_B.bias"): # lora_bias=True option
|
||||
k = k[: -len(".bias")] + f".{adapter_name}.bias"
|
||||
|
||||
@@ -91,10 +91,12 @@ if is_torch_available():
|
||||
_import_structure["transformers.transformer_hidream_image"] = ["HiDreamImageTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_hunyuan_video_framepack"] = ["HunyuanVideoFramepackTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_kandinsky"] = ["Kandinsky5Transformer3DModel"]
|
||||
_import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"]
|
||||
_import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_prx"] = ["PRXTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_qwenimage"] = ["QwenImageTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
|
||||
_import_structure["transformers.transformer_skyreels_v2"] = ["SkyReelsV2Transformer3DModel"]
|
||||
@@ -182,6 +184,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
HunyuanDiT2DModel,
|
||||
HunyuanVideoFramepackTransformer3DModel,
|
||||
HunyuanVideoTransformer3DModel,
|
||||
Kandinsky5Transformer3DModel,
|
||||
LatteTransformer3DModel,
|
||||
LTXVideoTransformer3DModel,
|
||||
Lumina2Transformer2DModel,
|
||||
@@ -190,6 +193,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
OmniGenTransformer2DModel,
|
||||
PixArtTransformer2DModel,
|
||||
PriorTransformer,
|
||||
PRXTransformer2DModel,
|
||||
QwenImageTransformer2DModel,
|
||||
SanaTransformer2DModel,
|
||||
SD3Transformer2DModel,
|
||||
|
||||
@@ -128,13 +128,13 @@ class AutoModel(ConfigMixin):
|
||||
```py
|
||||
from diffusers import AutoModel
|
||||
|
||||
unet = AutoModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")
|
||||
unet = AutoModel.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="unet")
|
||||
```
|
||||
|
||||
If you get the error message below, you need to finetune the weights for your downstream task:
|
||||
|
||||
```bash
|
||||
Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
|
||||
Some weights of UNet2DConditionModel were not initialized from the model checkpoint at stable-diffusion-v1-5/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
|
||||
- conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated
|
||||
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
|
||||
```
|
||||
|
||||
@@ -20,10 +20,10 @@ from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils.accelerate_utils import apply_forward_hook
|
||||
from ..modeling_outputs import AutoencoderKLOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder, MaskConditionDecoder
|
||||
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution, Encoder, MaskConditionDecoder
|
||||
|
||||
|
||||
class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
|
||||
class AsymmetricAutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
r"""
|
||||
Designing a Better Asymmetric VQGAN for StableDiffusion https://huggingface.co/papers/2306.04632 . A VAE model with
|
||||
KL loss for encoding images into latents and decoding latent representations into images.
|
||||
@@ -107,9 +107,6 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
|
||||
self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
|
||||
self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1)
|
||||
|
||||
self.use_slicing = False
|
||||
self.use_tiling = False
|
||||
|
||||
self.register_to_config(block_out_channels=up_block_out_channels)
|
||||
self.register_to_config(force_upcast=False)
|
||||
|
||||
|
||||
@@ -27,7 +27,7 @@ from ..attention_processor import SanaMultiscaleLinearAttention
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import RMSNorm, get_normalization
|
||||
from ..transformers.sana_transformer import GLUMBConv
|
||||
from .vae import DecoderOutput, EncoderOutput
|
||||
from .vae import AutoencoderMixin, DecoderOutput, EncoderOutput
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
@@ -378,7 +378,7 @@ class Decoder(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class AutoencoderDC(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
class AutoencoderDC(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
r"""
|
||||
An Autoencoder model introduced in [DCAE](https://huggingface.co/papers/2410.10733) and used in
|
||||
[SANA](https://huggingface.co/papers/2410.10629).
|
||||
@@ -536,27 +536,6 @@ class AutoencoderDC(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
self.tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
|
||||
self.tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
|
||||
|
||||
def disable_tiling(self) -> None:
|
||||
r"""
|
||||
Disable tiled AE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
|
||||
decoding in one step.
|
||||
"""
|
||||
self.use_tiling = False
|
||||
|
||||
def enable_slicing(self) -> None:
|
||||
r"""
|
||||
Enable sliced AE decoding. When this option is enabled, the AE will split the input tensor in slices to compute
|
||||
decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
||||
"""
|
||||
self.use_slicing = True
|
||||
|
||||
def disable_slicing(self) -> None:
|
||||
r"""
|
||||
Disable sliced AE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
|
||||
decoding in one step.
|
||||
"""
|
||||
self.use_slicing = False
|
||||
|
||||
def _encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, num_channels, height, width = x.shape
|
||||
|
||||
|
||||
@@ -32,10 +32,10 @@ from ..attention_processor import (
|
||||
)
|
||||
from ..modeling_outputs import AutoencoderKLOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
|
||||
from .vae import AutoencoderMixin, Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
|
||||
|
||||
|
||||
class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin):
|
||||
class AutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin):
|
||||
r"""
|
||||
A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
|
||||
|
||||
@@ -138,35 +138,6 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapter
|
||||
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
|
||||
self.tile_overlap_factor = 0.25
|
||||
|
||||
def enable_tiling(self, use_tiling: bool = True):
|
||||
r"""
|
||||
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
||||
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
||||
processing larger images.
|
||||
"""
|
||||
self.use_tiling = use_tiling
|
||||
|
||||
def disable_tiling(self):
|
||||
r"""
|
||||
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
|
||||
decoding in one step.
|
||||
"""
|
||||
self.enable_tiling(False)
|
||||
|
||||
def enable_slicing(self):
|
||||
r"""
|
||||
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
||||
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
||||
"""
|
||||
self.use_slicing = True
|
||||
|
||||
def disable_slicing(self):
|
||||
r"""
|
||||
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
|
||||
decoding in one step.
|
||||
"""
|
||||
self.use_slicing = False
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
|
||||
@@ -28,6 +28,7 @@ from ..modeling_outputs import AutoencoderKLOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..resnet import ResnetBlock2D
|
||||
from ..upsampling import Upsample2D
|
||||
from .vae import AutoencoderMixin
|
||||
|
||||
|
||||
class AllegroTemporalConvLayer(nn.Module):
|
||||
@@ -673,7 +674,7 @@ class AllegroDecoder3D(nn.Module):
|
||||
return sample
|
||||
|
||||
|
||||
class AutoencoderKLAllegro(ModelMixin, ConfigMixin):
|
||||
class AutoencoderKLAllegro(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
r"""
|
||||
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. Used in
|
||||
[Allegro](https://github.com/rhymes-ai/Allegro).
|
||||
@@ -795,35 +796,6 @@ class AutoencoderKLAllegro(ModelMixin, ConfigMixin):
|
||||
sample_size - self.tile_overlap_w,
|
||||
)
|
||||
|
||||
def enable_tiling(self) -> None:
|
||||
r"""
|
||||
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
||||
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
||||
processing larger images.
|
||||
"""
|
||||
self.use_tiling = True
|
||||
|
||||
def disable_tiling(self) -> None:
|
||||
r"""
|
||||
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
|
||||
decoding in one step.
|
||||
"""
|
||||
self.use_tiling = False
|
||||
|
||||
def enable_slicing(self) -> None:
|
||||
r"""
|
||||
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
||||
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
||||
"""
|
||||
self.use_slicing = True
|
||||
|
||||
def disable_slicing(self) -> None:
|
||||
r"""
|
||||
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
|
||||
decoding in one step.
|
||||
"""
|
||||
self.use_slicing = False
|
||||
|
||||
def _encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# TODO(aryan)
|
||||
# if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
|
||||
|
||||
@@ -29,7 +29,7 @@ from ..downsampling import CogVideoXDownsample3D
|
||||
from ..modeling_outputs import AutoencoderKLOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..upsampling import CogVideoXUpsample3D
|
||||
from .vae import DecoderOutput, DiagonalGaussianDistribution
|
||||
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -955,7 +955,7 @@ class CogVideoXDecoder3D(nn.Module):
|
||||
return hidden_states, new_conv_cache
|
||||
|
||||
|
||||
class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
class AutoencoderKLCogVideoX(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
r"""
|
||||
A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in
|
||||
[CogVideoX](https://github.com/THUDM/CogVideo).
|
||||
@@ -1124,27 +1124,6 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
self.tile_overlap_factor_height = tile_overlap_factor_height or self.tile_overlap_factor_height
|
||||
self.tile_overlap_factor_width = tile_overlap_factor_width or self.tile_overlap_factor_width
|
||||
|
||||
def disable_tiling(self) -> None:
|
||||
r"""
|
||||
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
|
||||
decoding in one step.
|
||||
"""
|
||||
self.use_tiling = False
|
||||
|
||||
def enable_slicing(self) -> None:
|
||||
r"""
|
||||
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
||||
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
||||
"""
|
||||
self.use_slicing = True
|
||||
|
||||
def disable_slicing(self) -> None:
|
||||
r"""
|
||||
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
|
||||
decoding in one step.
|
||||
"""
|
||||
self.use_slicing = False
|
||||
|
||||
def _encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, num_channels, num_frames, height, width = x.shape
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ from ...utils import get_logger
|
||||
from ...utils.accelerate_utils import apply_forward_hook
|
||||
from ..modeling_outputs import AutoencoderKLOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .vae import DecoderOutput, IdentityDistribution
|
||||
from .vae import AutoencoderMixin, DecoderOutput, IdentityDistribution
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -875,7 +875,7 @@ class CosmosDecoder3d(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class AutoencoderKLCosmos(ModelMixin, ConfigMixin):
|
||||
class AutoencoderKLCosmos(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
r"""
|
||||
Autoencoder used in [Cosmos](https://huggingface.co/papers/2501.03575).
|
||||
|
||||
@@ -1031,27 +1031,6 @@ class AutoencoderKLCosmos(ModelMixin, ConfigMixin):
|
||||
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
|
||||
self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames
|
||||
|
||||
def disable_tiling(self) -> None:
|
||||
r"""
|
||||
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
|
||||
decoding in one step.
|
||||
"""
|
||||
self.use_tiling = False
|
||||
|
||||
def enable_slicing(self) -> None:
|
||||
r"""
|
||||
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
||||
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
||||
"""
|
||||
self.use_slicing = True
|
||||
|
||||
def disable_slicing(self) -> None:
|
||||
r"""
|
||||
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
|
||||
decoding in one step.
|
||||
"""
|
||||
self.use_slicing = False
|
||||
|
||||
def _encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.encoder(x)
|
||||
enc = self.quant_conv(x)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user