add PAG support for Stable Diffusion 3 (#8861)
add pag sd3 --------- Co-authored-by: HyoungwonCho <jhw9811@korea.ac.kr> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: crepejung00 <jaewoojung00@naver.com> Co-authored-by: YiYi Xu <yixu310@gmail.com> Co-authored-by: Aryan <contact.aryanvs@gmail.com> Co-authored-by: Aryan <aryan@huggingface.co>
This commit is contained in:
committed by
GitHub
parent
325a5de3a9
commit
926daa30f9
@@ -74,6 +74,12 @@ Since RegEx is supported as a way for matching layer identifiers, it is crucial
|
||||
- __call__
|
||||
|
||||
|
||||
## StableDiffusion3PAGPipeline
|
||||
[[autodoc]] StableDiffusion3PAGPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
|
||||
## PixArtSigmaPAGPipeline
|
||||
[[autodoc]] PixArtSigmaPAGPipeline
|
||||
- all
|
||||
|
||||
@@ -308,6 +308,7 @@ else:
|
||||
"StableDiffusion3ControlNetPipeline",
|
||||
"StableDiffusion3Img2ImgPipeline",
|
||||
"StableDiffusion3InpaintPipeline",
|
||||
"StableDiffusion3PAGPipeline",
|
||||
"StableDiffusion3Pipeline",
|
||||
"StableDiffusionAdapterPipeline",
|
||||
"StableDiffusionAttendAndExcitePipeline",
|
||||
@@ -741,6 +742,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
StableDiffusion3ControlNetPipeline,
|
||||
StableDiffusion3Img2ImgPipeline,
|
||||
StableDiffusion3InpaintPipeline,
|
||||
StableDiffusion3PAGPipeline,
|
||||
StableDiffusion3Pipeline,
|
||||
StableDiffusionAdapterPipeline,
|
||||
StableDiffusionAttendAndExcitePipeline,
|
||||
|
||||
@@ -1106,6 +1106,326 @@ class JointAttnProcessor2_0:
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class PAGJointAttnProcessor2_0:
|
||||
"""Attention processor used typically in processing the SD3-like self-attention projections."""
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(
|
||||
"PAGJointAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor = None,
|
||||
) -> torch.FloatTensor:
|
||||
residual = hidden_states
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
context_input_ndim = encoder_hidden_states.ndim
|
||||
if context_input_ndim == 4:
|
||||
batch_size, channel, height, width = encoder_hidden_states.shape
|
||||
encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
# store the length of image patch sequences to create a mask that prevents interaction between patches
|
||||
# similar to making the self-attention map an identity matrix
|
||||
identity_block_size = hidden_states.shape[1]
|
||||
|
||||
# chunk
|
||||
hidden_states_org, hidden_states_ptb = hidden_states.chunk(2)
|
||||
encoder_hidden_states_org, encoder_hidden_states_ptb = encoder_hidden_states.chunk(2)
|
||||
|
||||
################## original path ##################
|
||||
batch_size = encoder_hidden_states_org.shape[0]
|
||||
|
||||
# `sample` projections.
|
||||
query_org = attn.to_q(hidden_states_org)
|
||||
key_org = attn.to_k(hidden_states_org)
|
||||
value_org = attn.to_v(hidden_states_org)
|
||||
|
||||
# `context` projections.
|
||||
encoder_hidden_states_org_query_proj = attn.add_q_proj(encoder_hidden_states_org)
|
||||
encoder_hidden_states_org_key_proj = attn.add_k_proj(encoder_hidden_states_org)
|
||||
encoder_hidden_states_org_value_proj = attn.add_v_proj(encoder_hidden_states_org)
|
||||
|
||||
# attention
|
||||
query_org = torch.cat([query_org, encoder_hidden_states_org_query_proj], dim=1)
|
||||
key_org = torch.cat([key_org, encoder_hidden_states_org_key_proj], dim=1)
|
||||
value_org = torch.cat([value_org, encoder_hidden_states_org_value_proj], dim=1)
|
||||
|
||||
inner_dim = key_org.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
query_org = query_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
key_org = key_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value_org = value_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
hidden_states_org = F.scaled_dot_product_attention(
|
||||
query_org, key_org, value_org, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states_org = hidden_states_org.to(query_org.dtype)
|
||||
|
||||
# Split the attention outputs.
|
||||
hidden_states_org, encoder_hidden_states_org = (
|
||||
hidden_states_org[:, : residual.shape[1]],
|
||||
hidden_states_org[:, residual.shape[1] :],
|
||||
)
|
||||
|
||||
# linear proj
|
||||
hidden_states_org = attn.to_out[0](hidden_states_org)
|
||||
# dropout
|
||||
hidden_states_org = attn.to_out[1](hidden_states_org)
|
||||
if not attn.context_pre_only:
|
||||
encoder_hidden_states_org = attn.to_add_out(encoder_hidden_states_org)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
if context_input_ndim == 4:
|
||||
encoder_hidden_states_org = encoder_hidden_states_org.transpose(-1, -2).reshape(
|
||||
batch_size, channel, height, width
|
||||
)
|
||||
|
||||
################## perturbed path ##################
|
||||
|
||||
batch_size = encoder_hidden_states_ptb.shape[0]
|
||||
|
||||
# `sample` projections.
|
||||
query_ptb = attn.to_q(hidden_states_ptb)
|
||||
key_ptb = attn.to_k(hidden_states_ptb)
|
||||
value_ptb = attn.to_v(hidden_states_ptb)
|
||||
|
||||
# `context` projections.
|
||||
encoder_hidden_states_ptb_query_proj = attn.add_q_proj(encoder_hidden_states_ptb)
|
||||
encoder_hidden_states_ptb_key_proj = attn.add_k_proj(encoder_hidden_states_ptb)
|
||||
encoder_hidden_states_ptb_value_proj = attn.add_v_proj(encoder_hidden_states_ptb)
|
||||
|
||||
# attention
|
||||
query_ptb = torch.cat([query_ptb, encoder_hidden_states_ptb_query_proj], dim=1)
|
||||
key_ptb = torch.cat([key_ptb, encoder_hidden_states_ptb_key_proj], dim=1)
|
||||
value_ptb = torch.cat([value_ptb, encoder_hidden_states_ptb_value_proj], dim=1)
|
||||
|
||||
inner_dim = key_ptb.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
query_ptb = query_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
key_ptb = key_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value_ptb = value_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
# create a full mask with all entries set to 0
|
||||
seq_len = query_ptb.size(2)
|
||||
full_mask = torch.zeros((seq_len, seq_len), device=query_ptb.device, dtype=query_ptb.dtype)
|
||||
|
||||
# set the attention value between image patches to -inf
|
||||
full_mask[:identity_block_size, :identity_block_size] = float("-inf")
|
||||
|
||||
# set the diagonal of the attention value between image patches to 0
|
||||
full_mask[:identity_block_size, :identity_block_size].fill_diagonal_(0)
|
||||
|
||||
# expand the mask to match the attention weights shape
|
||||
full_mask = full_mask.unsqueeze(0).unsqueeze(0) # Add batch and num_heads dimensions
|
||||
|
||||
hidden_states_ptb = F.scaled_dot_product_attention(
|
||||
query_ptb, key_ptb, value_ptb, attn_mask=full_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
hidden_states_ptb = hidden_states_ptb.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states_ptb = hidden_states_ptb.to(query_ptb.dtype)
|
||||
|
||||
# split the attention outputs.
|
||||
hidden_states_ptb, encoder_hidden_states_ptb = (
|
||||
hidden_states_ptb[:, : residual.shape[1]],
|
||||
hidden_states_ptb[:, residual.shape[1] :],
|
||||
)
|
||||
|
||||
# linear proj
|
||||
hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
|
||||
# dropout
|
||||
hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
|
||||
if not attn.context_pre_only:
|
||||
encoder_hidden_states_ptb = attn.to_add_out(encoder_hidden_states_ptb)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
if context_input_ndim == 4:
|
||||
encoder_hidden_states_ptb = encoder_hidden_states_ptb.transpose(-1, -2).reshape(
|
||||
batch_size, channel, height, width
|
||||
)
|
||||
|
||||
################ concat ###############
|
||||
hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])
|
||||
encoder_hidden_states = torch.cat([encoder_hidden_states_org, encoder_hidden_states_ptb])
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class PAGCFGJointAttnProcessor2_0:
|
||||
"""Attention processor used typically in processing the SD3-like self-attention projections."""
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(
|
||||
"PAGCFGJointAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
residual = hidden_states
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
context_input_ndim = encoder_hidden_states.ndim
|
||||
if context_input_ndim == 4:
|
||||
batch_size, channel, height, width = encoder_hidden_states.shape
|
||||
encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
identity_block_size = hidden_states.shape[
|
||||
1
|
||||
] # patch embeddings width * height (correspond to self-attention map width or height)
|
||||
|
||||
# chunk
|
||||
hidden_states_uncond, hidden_states_org, hidden_states_ptb = hidden_states.chunk(3)
|
||||
hidden_states_org = torch.cat([hidden_states_uncond, hidden_states_org])
|
||||
|
||||
(
|
||||
encoder_hidden_states_uncond,
|
||||
encoder_hidden_states_org,
|
||||
encoder_hidden_states_ptb,
|
||||
) = encoder_hidden_states.chunk(3)
|
||||
encoder_hidden_states_org = torch.cat([encoder_hidden_states_uncond, encoder_hidden_states_org])
|
||||
|
||||
################## original path ##################
|
||||
batch_size = encoder_hidden_states_org.shape[0]
|
||||
|
||||
# `sample` projections.
|
||||
query_org = attn.to_q(hidden_states_org)
|
||||
key_org = attn.to_k(hidden_states_org)
|
||||
value_org = attn.to_v(hidden_states_org)
|
||||
|
||||
# `context` projections.
|
||||
encoder_hidden_states_org_query_proj = attn.add_q_proj(encoder_hidden_states_org)
|
||||
encoder_hidden_states_org_key_proj = attn.add_k_proj(encoder_hidden_states_org)
|
||||
encoder_hidden_states_org_value_proj = attn.add_v_proj(encoder_hidden_states_org)
|
||||
|
||||
# attention
|
||||
query_org = torch.cat([query_org, encoder_hidden_states_org_query_proj], dim=1)
|
||||
key_org = torch.cat([key_org, encoder_hidden_states_org_key_proj], dim=1)
|
||||
value_org = torch.cat([value_org, encoder_hidden_states_org_value_proj], dim=1)
|
||||
|
||||
inner_dim = key_org.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
query_org = query_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
key_org = key_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value_org = value_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
hidden_states_org = F.scaled_dot_product_attention(
|
||||
query_org, key_org, value_org, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states_org = hidden_states_org.to(query_org.dtype)
|
||||
|
||||
# Split the attention outputs.
|
||||
hidden_states_org, encoder_hidden_states_org = (
|
||||
hidden_states_org[:, : residual.shape[1]],
|
||||
hidden_states_org[:, residual.shape[1] :],
|
||||
)
|
||||
|
||||
# linear proj
|
||||
hidden_states_org = attn.to_out[0](hidden_states_org)
|
||||
# dropout
|
||||
hidden_states_org = attn.to_out[1](hidden_states_org)
|
||||
if not attn.context_pre_only:
|
||||
encoder_hidden_states_org = attn.to_add_out(encoder_hidden_states_org)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
if context_input_ndim == 4:
|
||||
encoder_hidden_states_org = encoder_hidden_states_org.transpose(-1, -2).reshape(
|
||||
batch_size, channel, height, width
|
||||
)
|
||||
|
||||
################## perturbed path ##################
|
||||
|
||||
batch_size = encoder_hidden_states_ptb.shape[0]
|
||||
|
||||
# `sample` projections.
|
||||
query_ptb = attn.to_q(hidden_states_ptb)
|
||||
key_ptb = attn.to_k(hidden_states_ptb)
|
||||
value_ptb = attn.to_v(hidden_states_ptb)
|
||||
|
||||
# `context` projections.
|
||||
encoder_hidden_states_ptb_query_proj = attn.add_q_proj(encoder_hidden_states_ptb)
|
||||
encoder_hidden_states_ptb_key_proj = attn.add_k_proj(encoder_hidden_states_ptb)
|
||||
encoder_hidden_states_ptb_value_proj = attn.add_v_proj(encoder_hidden_states_ptb)
|
||||
|
||||
# attention
|
||||
query_ptb = torch.cat([query_ptb, encoder_hidden_states_ptb_query_proj], dim=1)
|
||||
key_ptb = torch.cat([key_ptb, encoder_hidden_states_ptb_key_proj], dim=1)
|
||||
value_ptb = torch.cat([value_ptb, encoder_hidden_states_ptb_value_proj], dim=1)
|
||||
|
||||
inner_dim = key_ptb.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
query_ptb = query_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
key_ptb = key_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value_ptb = value_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
# create a full mask with all entries set to 0
|
||||
seq_len = query_ptb.size(2)
|
||||
full_mask = torch.zeros((seq_len, seq_len), device=query_ptb.device, dtype=query_ptb.dtype)
|
||||
|
||||
# set the attention value between image patches to -inf
|
||||
full_mask[:identity_block_size, :identity_block_size] = float("-inf")
|
||||
|
||||
# set the diagonal of the attention value between image patches to 0
|
||||
full_mask[:identity_block_size, :identity_block_size].fill_diagonal_(0)
|
||||
|
||||
# expand the mask to match the attention weights shape
|
||||
full_mask = full_mask.unsqueeze(0).unsqueeze(0) # Add batch and num_heads dimensions
|
||||
|
||||
hidden_states_ptb = F.scaled_dot_product_attention(
|
||||
query_ptb, key_ptb, value_ptb, attn_mask=full_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
hidden_states_ptb = hidden_states_ptb.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states_ptb = hidden_states_ptb.to(query_ptb.dtype)
|
||||
|
||||
# split the attention outputs.
|
||||
hidden_states_ptb, encoder_hidden_states_ptb = (
|
||||
hidden_states_ptb[:, : residual.shape[1]],
|
||||
hidden_states_ptb[:, residual.shape[1] :],
|
||||
)
|
||||
|
||||
# linear proj
|
||||
hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
|
||||
# dropout
|
||||
hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
|
||||
if not attn.context_pre_only:
|
||||
encoder_hidden_states_ptb = attn.to_add_out(encoder_hidden_states_ptb)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
if context_input_ndim == 4:
|
||||
encoder_hidden_states_ptb = encoder_hidden_states_ptb.transpose(-1, -2).reshape(
|
||||
batch_size, channel, height, width
|
||||
)
|
||||
|
||||
################ concat ###############
|
||||
hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])
|
||||
encoder_hidden_states = torch.cat([encoder_hidden_states_org, encoder_hidden_states_ptb])
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class FusedJointAttnProcessor2_0:
|
||||
"""Attention processor used typically in processing the SD3-like self-attention projections."""
|
||||
|
||||
|
||||
@@ -147,6 +147,7 @@ else:
|
||||
[
|
||||
"AnimateDiffPAGPipeline",
|
||||
"HunyuanDiTPAGPipeline",
|
||||
"StableDiffusion3PAGPipeline",
|
||||
"StableDiffusionPAGPipeline",
|
||||
"StableDiffusionControlNetPAGPipeline",
|
||||
"StableDiffusionXLPAGPipeline",
|
||||
@@ -540,6 +541,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AnimateDiffPAGPipeline,
|
||||
HunyuanDiTPAGPipeline,
|
||||
PixArtSigmaPAGPipeline,
|
||||
StableDiffusion3PAGPipeline,
|
||||
StableDiffusionControlNetPAGPipeline,
|
||||
StableDiffusionPAGPipeline,
|
||||
StableDiffusionXLControlNetPAGPipeline,
|
||||
|
||||
@@ -52,6 +52,7 @@ from .latent_consistency_models import LatentConsistencyModelImg2ImgPipeline, La
|
||||
from .pag import (
|
||||
HunyuanDiTPAGPipeline,
|
||||
PixArtSigmaPAGPipeline,
|
||||
StableDiffusion3PAGPipeline,
|
||||
StableDiffusionControlNetPAGPipeline,
|
||||
StableDiffusionPAGPipeline,
|
||||
StableDiffusionXLControlNetPAGPipeline,
|
||||
@@ -84,6 +85,7 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
|
||||
("stable-diffusion", StableDiffusionPipeline),
|
||||
("stable-diffusion-xl", StableDiffusionXLPipeline),
|
||||
("stable-diffusion-3", StableDiffusion3Pipeline),
|
||||
("stable-diffusion-3-pag", StableDiffusion3PAGPipeline),
|
||||
("if", IFPipeline),
|
||||
("hunyuan", HunyuanDiTPipeline),
|
||||
("hunyuan-pag", HunyuanDiTPAGPipeline),
|
||||
|
||||
@@ -27,6 +27,7 @@ else:
|
||||
_import_structure["pipeline_pag_hunyuandit"] = ["HunyuanDiTPAGPipeline"]
|
||||
_import_structure["pipeline_pag_pixart_sigma"] = ["PixArtSigmaPAGPipeline"]
|
||||
_import_structure["pipeline_pag_sd"] = ["StableDiffusionPAGPipeline"]
|
||||
_import_structure["pipeline_pag_sd_3"] = ["StableDiffusion3PAGPipeline"]
|
||||
_import_structure["pipeline_pag_sd_animatediff"] = ["AnimateDiffPAGPipeline"]
|
||||
_import_structure["pipeline_pag_sd_xl"] = ["StableDiffusionXLPAGPipeline"]
|
||||
_import_structure["pipeline_pag_sd_xl_img2img"] = ["StableDiffusionXLPAGImg2ImgPipeline"]
|
||||
@@ -45,6 +46,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .pipeline_pag_hunyuandit import HunyuanDiTPAGPipeline
|
||||
from .pipeline_pag_pixart_sigma import PixArtSigmaPAGPipeline
|
||||
from .pipeline_pag_sd import StableDiffusionPAGPipeline
|
||||
from .pipeline_pag_sd_3 import StableDiffusion3PAGPipeline
|
||||
from .pipeline_pag_sd_animatediff import AnimateDiffPAGPipeline
|
||||
from .pipeline_pag_sd_xl import StableDiffusionXLPAGPipeline
|
||||
from .pipeline_pag_sd_xl_img2img import StableDiffusionXLPAGImg2ImgPipeline
|
||||
|
||||
@@ -0,0 +1,985 @@
|
||||
# Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from transformers import (
|
||||
CLIPTextModelWithProjection,
|
||||
CLIPTokenizer,
|
||||
T5EncoderModel,
|
||||
T5TokenizerFast,
|
||||
)
|
||||
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import FromSingleFileMixin, SD3LoraLoaderMixin
|
||||
from ...models.attention_processor import PAGCFGJointAttnProcessor2_0, PAGJointAttnProcessor2_0
|
||||
from ...models.autoencoders import AutoencoderKL
|
||||
from ...models.transformers import SD3Transformer2DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
is_torch_xla_available,
|
||||
logging,
|
||||
replace_example_docstring,
|
||||
scale_lora_layers,
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from ..stable_diffusion_3.pipeline_output import StableDiffusion3PipelineOutput
|
||||
from .pag_utils import PAGMixin
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from diffusers import AutoPipelineForText2Image
|
||||
|
||||
>>> pipe = AutoPipelineForText2Image.from_pretrained(
|
||||
... "stabilityai/stable-diffusion-3-medium-diffusers",
|
||||
... torch_dtype=torch.float16,
|
||||
... enable_pag=True,
|
||||
... pag_applied_layers=["blocks.13"],
|
||||
... )
|
||||
>>> pipe.to("cuda")
|
||||
>>> prompt = "A cat holding a sign that says hello world"
|
||||
>>> image = pipe(prompt, guidance_scale=5.0, pag_scale=0.7).images[0]
|
||||
>>> image.save("sd3_pag.png")
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accept_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class StableDiffusion3PAGPipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin, PAGMixin):
|
||||
r"""
|
||||
[PAG pipeline](https://huggingface.co/docs/diffusers/main/en/using-diffusers/pag) for text-to-image generation
|
||||
using Stable Diffusion 3.
|
||||
|
||||
Args:
|
||||
transformer ([`SD3Transformer2DModel`]):
|
||||
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
|
||||
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`CLIPTextModelWithProjection`]):
|
||||
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
|
||||
specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant,
|
||||
with an additional added projection layer that is initialized with a diagonal matrix with the `hidden_size`
|
||||
as its dimension.
|
||||
text_encoder_2 ([`CLIPTextModelWithProjection`]):
|
||||
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
|
||||
specifically the
|
||||
[laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
|
||||
variant.
|
||||
text_encoder_3 ([`T5EncoderModel`]):
|
||||
Frozen text-encoder. Stable Diffusion 3 uses
|
||||
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
|
||||
[t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
|
||||
tokenizer (`CLIPTokenizer`):
|
||||
Tokenizer of class
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
tokenizer_2 (`CLIPTokenizer`):
|
||||
Second Tokenizer of class
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
tokenizer_3 (`T5TokenizerFast`):
|
||||
Tokenizer of class
|
||||
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->transformer->vae"
|
||||
_optional_components = []
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transformer: SD3Transformer2DModel,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: CLIPTextModelWithProjection,
|
||||
tokenizer: CLIPTokenizer,
|
||||
text_encoder_2: CLIPTextModelWithProjection,
|
||||
tokenizer_2: CLIPTokenizer,
|
||||
text_encoder_3: T5EncoderModel,
|
||||
tokenizer_3: T5TokenizerFast,
|
||||
pag_applied_layers: Union[str, List[str]] = "blocks.1", # 1st transformer block
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
text_encoder_2=text_encoder_2,
|
||||
text_encoder_3=text_encoder_3,
|
||||
tokenizer=tokenizer,
|
||||
tokenizer_2=tokenizer_2,
|
||||
tokenizer_3=tokenizer_3,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
self.vae_scale_factor = (
|
||||
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
|
||||
)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
self.tokenizer_max_length = (
|
||||
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
|
||||
)
|
||||
self.default_sample_size = (
|
||||
self.transformer.config.sample_size
|
||||
if hasattr(self, "transformer") and self.transformer is not None
|
||||
else 128
|
||||
)
|
||||
|
||||
self.set_pag_applied_layers(
|
||||
pag_applied_layers, pag_attn_processors=(PAGCFGJointAttnProcessor2_0(), PAGJointAttnProcessor2_0())
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_t5_prompt_embeds
|
||||
def _get_t5_prompt_embeds(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
max_sequence_length: int = 256,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
device = device or self._execution_device
|
||||
dtype = dtype or self.text_encoder.dtype
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt)
|
||||
|
||||
if self.text_encoder_3 is None:
|
||||
return torch.zeros(
|
||||
(
|
||||
batch_size * num_images_per_prompt,
|
||||
self.tokenizer_max_length,
|
||||
self.transformer.config.joint_attention_dim,
|
||||
),
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
text_inputs = self.tokenizer_3(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = self.tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
||||
removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because `max_sequence_length` is set to "
|
||||
f" {max_sequence_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
prompt_embeds = self.text_encoder_3(text_input_ids.to(device))[0]
|
||||
|
||||
dtype = self.text_encoder_3.dtype
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
|
||||
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_clip_prompt_embeds
|
||||
def _get_clip_prompt_embeds(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
num_images_per_prompt: int = 1,
|
||||
device: Optional[torch.device] = None,
|
||||
clip_skip: Optional[int] = None,
|
||||
clip_model_index: int = 0,
|
||||
):
|
||||
device = device or self._execution_device
|
||||
|
||||
clip_tokenizers = [self.tokenizer, self.tokenizer_2]
|
||||
clip_text_encoders = [self.text_encoder, self.text_encoder_2]
|
||||
|
||||
tokenizer = clip_tokenizers[clip_model_index]
|
||||
text_encoder = clip_text_encoders[clip_model_index]
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt)
|
||||
|
||||
text_inputs = tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
||||
removed_text = tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {self.tokenizer_max_length} tokens: {removed_text}"
|
||||
)
|
||||
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
|
||||
pooled_prompt_embeds = prompt_embeds[0]
|
||||
|
||||
if clip_skip is None:
|
||||
prompt_embeds = prompt_embeds.hidden_states[-2]
|
||||
else:
|
||||
prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
|
||||
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
|
||||
|
||||
return prompt_embeds, pooled_prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.encode_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
prompt_2: Union[str, List[str]],
|
||||
prompt_3: Union[str, List[str]],
|
||||
device: Optional[torch.device] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
negative_prompt_3: Optional[Union[str, List[str]]] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
clip_skip: Optional[int] = None,
|
||||
max_sequence_length: int = 256,
|
||||
lora_scale: Optional[float] = None,
|
||||
):
|
||||
r"""
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
prompt_2 (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
||||
used in all text-encoders
|
||||
prompt_3 (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to be sent to the `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is
|
||||
used in all text-encoders
|
||||
device: (`torch.device`):
|
||||
torch device
|
||||
num_images_per_prompt (`int`):
|
||||
number of images that should be generated per prompt
|
||||
do_classifier_free_guidance (`bool`):
|
||||
whether to use classifier free guidance or not
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
negative_prompt_2 (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
|
||||
`text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
|
||||
negative_prompt_2 (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and
|
||||
`text_encoder_3`. If not defined, `negative_prompt` is used in both text-encoders
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
||||
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
||||
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
|
||||
input argument.
|
||||
clip_skip (`int`, *optional*):
|
||||
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
||||
the output of the pre-final layer will be used for computing the prompt embeddings.
|
||||
lora_scale (`float`, *optional*):
|
||||
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
|
||||
# set lora scale so that monkey patched LoRA
|
||||
# function of text encoder can correctly access it
|
||||
if lora_scale is not None and isinstance(self, SD3LoraLoaderMixin):
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
if self.text_encoder is not None and USE_PEFT_BACKEND:
|
||||
scale_lora_layers(self.text_encoder, lora_scale)
|
||||
if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
|
||||
scale_lora_layers(self.text_encoder_2, lora_scale)
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
if prompt is not None:
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt_2 = prompt_2 or prompt
|
||||
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
|
||||
|
||||
prompt_3 = prompt_3 or prompt
|
||||
prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3
|
||||
|
||||
prompt_embed, pooled_prompt_embed = self._get_clip_prompt_embeds(
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
clip_skip=clip_skip,
|
||||
clip_model_index=0,
|
||||
)
|
||||
prompt_2_embed, pooled_prompt_2_embed = self._get_clip_prompt_embeds(
|
||||
prompt=prompt_2,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
clip_skip=clip_skip,
|
||||
clip_model_index=1,
|
||||
)
|
||||
clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1)
|
||||
|
||||
t5_prompt_embed = self._get_t5_prompt_embeds(
|
||||
prompt=prompt_3,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
)
|
||||
|
||||
clip_prompt_embeds = torch.nn.functional.pad(
|
||||
clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1])
|
||||
)
|
||||
|
||||
prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2)
|
||||
pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1)
|
||||
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
negative_prompt = negative_prompt or ""
|
||||
negative_prompt_2 = negative_prompt_2 or negative_prompt
|
||||
negative_prompt_3 = negative_prompt_3 or negative_prompt
|
||||
|
||||
# normalize str to list
|
||||
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
negative_prompt_2 = (
|
||||
batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
|
||||
)
|
||||
negative_prompt_3 = (
|
||||
batch_size * [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3
|
||||
)
|
||||
|
||||
if prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
|
||||
negative_prompt_embed, negative_pooled_prompt_embed = self._get_clip_prompt_embeds(
|
||||
negative_prompt,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
clip_skip=None,
|
||||
clip_model_index=0,
|
||||
)
|
||||
negative_prompt_2_embed, negative_pooled_prompt_2_embed = self._get_clip_prompt_embeds(
|
||||
negative_prompt_2,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
clip_skip=None,
|
||||
clip_model_index=1,
|
||||
)
|
||||
negative_clip_prompt_embeds = torch.cat([negative_prompt_embed, negative_prompt_2_embed], dim=-1)
|
||||
|
||||
t5_negative_prompt_embed = self._get_t5_prompt_embeds(
|
||||
prompt=negative_prompt_3,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
)
|
||||
|
||||
negative_clip_prompt_embeds = torch.nn.functional.pad(
|
||||
negative_clip_prompt_embeds,
|
||||
(0, t5_negative_prompt_embed.shape[-1] - negative_clip_prompt_embeds.shape[-1]),
|
||||
)
|
||||
|
||||
negative_prompt_embeds = torch.cat([negative_clip_prompt_embeds, t5_negative_prompt_embed], dim=-2)
|
||||
negative_pooled_prompt_embeds = torch.cat(
|
||||
[negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1
|
||||
)
|
||||
|
||||
if self.text_encoder is not None:
|
||||
if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
if self.text_encoder_2 is not None:
|
||||
if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder_2, lora_scale)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.check_inputs
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
prompt_2,
|
||||
prompt_3,
|
||||
height,
|
||||
width,
|
||||
negative_prompt=None,
|
||||
negative_prompt_2=None,
|
||||
negative_prompt_3=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
pooled_prompt_embeds=None,
|
||||
negative_pooled_prompt_embeds=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
max_sequence_length=None,
|
||||
):
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt_2 is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt_3 is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt_3`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
|
||||
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
|
||||
elif prompt_3 is not None and (not isinstance(prompt_3, str) and not isinstance(prompt_3, list)):
|
||||
raise ValueError(f"`prompt_3` has to be of type `str` or `list` but is {type(prompt_3)}")
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
elif negative_prompt_3 is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt_3`: {negative_prompt_3} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
||||
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
||||
raise ValueError(
|
||||
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and pooled_prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
|
||||
)
|
||||
|
||||
if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
|
||||
)
|
||||
|
||||
if max_sequence_length is not None and max_sequence_length > 512:
|
||||
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.prepare_latents
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
device,
|
||||
generator,
|
||||
latents=None,
|
||||
):
|
||||
if latents is not None:
|
||||
return latents.to(device=device, dtype=dtype)
|
||||
|
||||
shape = (
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
int(height) // self.vae_scale_factor,
|
||||
int(width) // self.vae_scale_factor,
|
||||
)
|
||||
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
|
||||
return latents
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def clip_skip(self):
|
||||
return self._clip_skip
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1
|
||||
|
||||
@property
|
||||
def joint_attention_kwargs(self):
|
||||
return self._joint_attention_kwargs
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
prompt_3: Optional[Union[str, List[str]]] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 28,
|
||||
timesteps: List[int] = None,
|
||||
guidance_scale: float = 7.0,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
negative_prompt_3: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
clip_skip: Optional[int] = None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 256,
|
||||
pag_scale: float = 3.0,
|
||||
pag_adaptive_scale: float = 0.0,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
prompt_2 (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
||||
will be used instead
|
||||
prompt_3 (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is
|
||||
will be used instead
|
||||
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
||||
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
||||
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
||||
passed will be used. Must be in descending order.
|
||||
guidance_scale (`float`, *optional*, defaults to 7.0):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
negative_prompt_2 (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
|
||||
`text_encoder_2`. If not defined, `negative_prompt` is used instead
|
||||
negative_prompt_3 (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and
|
||||
`text_encoder_3`. If not defined, `negative_prompt` is used instead
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
||||
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
||||
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
|
||||
input argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
|
||||
of a plain tuple.
|
||||
joint_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
||||
`callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
|
||||
pag_scale (`float`, *optional*, defaults to 3.0):
|
||||
The scale factor for the perturbed attention guidance. If it is set to 0.0, the perturbed attention
|
||||
guidance will not be used.
|
||||
pag_adaptive_scale (`float`, *optional*, defaults to 0.0):
|
||||
The adaptive scale factor for the perturbed attention guidance. If it is set to 0.0, `pag_scale` is
|
||||
used.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] or `tuple`:
|
||||
[`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] if `return_dict` is True, otherwise a
|
||||
`tuple`. When returning a tuple, the first element is a list with the generated images.
|
||||
"""
|
||||
|
||||
height = height or self.default_sample_size * self.vae_scale_factor
|
||||
width = width or self.default_sample_size * self.vae_scale_factor
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
prompt_2,
|
||||
prompt_3,
|
||||
height,
|
||||
width,
|
||||
negative_prompt=negative_prompt,
|
||||
negative_prompt_2=negative_prompt_2,
|
||||
negative_prompt_3=negative_prompt_3,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
||||
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._clip_skip = clip_skip
|
||||
self._joint_attention_kwargs = joint_attention_kwargs
|
||||
self._interrupt = False
|
||||
self._pag_scale = pag_scale
|
||||
self._pag_adaptive_scale = pag_adaptive_scale #
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
lora_scale = (
|
||||
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
|
||||
)
|
||||
(
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
) = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
prompt_2=prompt_2,
|
||||
prompt_3=prompt_3,
|
||||
negative_prompt=negative_prompt,
|
||||
negative_prompt_2=negative_prompt_2,
|
||||
negative_prompt_3=negative_prompt_3,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
||||
device=device,
|
||||
clip_skip=self.clip_skip,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
lora_scale=lora_scale,
|
||||
)
|
||||
|
||||
if self.do_perturbed_attention_guidance:
|
||||
prompt_embeds = self._prepare_perturbed_attention_guidance(
|
||||
prompt_embeds, negative_prompt_embeds, self.do_classifier_free_guidance
|
||||
)
|
||||
pooled_prompt_embeds = self._prepare_perturbed_attention_guidance(
|
||||
pooled_prompt_embeds, negative_pooled_prompt_embeds, self.do_classifier_free_guidance
|
||||
)
|
||||
elif self.do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = self.transformer.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
if self.do_perturbed_attention_guidance:
|
||||
original_attn_proc = self.transformer.attn_processors
|
||||
self._set_pag_attn_processor(
|
||||
pag_applied_layers=self.pag_applied_layers,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
)
|
||||
|
||||
# 6. Denoising loop
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
# expand the latents if we are doing classifier free guidance, perturbed-attention guidance, or both
|
||||
latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0]))
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latent_model_input.shape[0])
|
||||
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
pooled_projections=pooled_prompt_embeds,
|
||||
joint_attention_kwargs=self.joint_attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
if self.do_perturbed_attention_guidance:
|
||||
noise_pred = self._apply_perturbed_attention_guidance(
|
||||
noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t
|
||||
)
|
||||
|
||||
elif self.do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents_dtype = latents.dtype
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
|
||||
if latents.dtype != latents_dtype:
|
||||
if torch.backends.mps.is_available():
|
||||
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
||||
latents = latents.to(latents_dtype)
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
negative_pooled_prompt_embeds = callback_outputs.pop(
|
||||
"negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
|
||||
)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
if output_type == "latent":
|
||||
image = latents
|
||||
|
||||
else:
|
||||
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
||||
|
||||
image = self.vae.decode(latents, return_dict=False)[0]
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if self.do_perturbed_attention_guidance:
|
||||
self.transformer.set_attn_processor(original_attn_proc)
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return StableDiffusion3PipelineOutput(images=image)
|
||||
@@ -1127,6 +1127,21 @@ class StableDiffusion3InpaintPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class StableDiffusion3PAGPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class StableDiffusion3Pipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
@@ -0,0 +1,295 @@
|
||||
import inspect
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
SD3Transformer2DModel,
|
||||
StableDiffusion3PAGPipeline,
|
||||
StableDiffusion3Pipeline,
|
||||
)
|
||||
from diffusers.utils.testing_utils import (
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ..test_pipelines_common import (
|
||||
PipelineTesterMixin,
|
||||
check_qkv_fusion_matches_attn_procs_length,
|
||||
check_qkv_fusion_processors_exist,
|
||||
)
|
||||
|
||||
|
||||
class StableDiffusion3PAGPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
|
||||
pipeline_class = StableDiffusion3PAGPipeline
|
||||
params = frozenset(
|
||||
[
|
||||
"prompt",
|
||||
"height",
|
||||
"width",
|
||||
"guidance_scale",
|
||||
"negative_prompt",
|
||||
"prompt_embeds",
|
||||
"negative_prompt_embeds",
|
||||
]
|
||||
)
|
||||
batch_params = frozenset(["prompt", "negative_prompt"])
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
transformer = SD3Transformer2DModel(
|
||||
sample_size=32,
|
||||
patch_size=1,
|
||||
in_channels=4,
|
||||
num_layers=2,
|
||||
attention_head_dim=8,
|
||||
num_attention_heads=4,
|
||||
caption_projection_dim=32,
|
||||
joint_attention_dim=32,
|
||||
pooled_projection_dim=64,
|
||||
out_channels=4,
|
||||
)
|
||||
clip_text_encoder_config = CLIPTextConfig(
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
hidden_size=32,
|
||||
intermediate_size=37,
|
||||
layer_norm_eps=1e-05,
|
||||
num_attention_heads=4,
|
||||
num_hidden_layers=5,
|
||||
pad_token_id=1,
|
||||
vocab_size=1000,
|
||||
hidden_act="gelu",
|
||||
projection_dim=32,
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
text_encoder = CLIPTextModelWithProjection(clip_text_encoder_config)
|
||||
|
||||
torch.manual_seed(0)
|
||||
text_encoder_2 = CLIPTextModelWithProjection(clip_text_encoder_config)
|
||||
|
||||
text_encoder_3 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
tokenizer_3 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKL(
|
||||
sample_size=32,
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
block_out_channels=(4,),
|
||||
layers_per_block=1,
|
||||
latent_channels=4,
|
||||
norm_num_groups=1,
|
||||
use_quant_conv=False,
|
||||
use_post_quant_conv=False,
|
||||
shift_factor=0.0609,
|
||||
scaling_factor=1.5035,
|
||||
)
|
||||
|
||||
scheduler = FlowMatchEulerDiscreteScheduler()
|
||||
|
||||
return {
|
||||
"scheduler": scheduler,
|
||||
"text_encoder": text_encoder,
|
||||
"text_encoder_2": text_encoder_2,
|
||||
"text_encoder_3": text_encoder_3,
|
||||
"tokenizer": tokenizer,
|
||||
"tokenizer_2": tokenizer_2,
|
||||
"tokenizer_3": tokenizer_3,
|
||||
"transformer": transformer,
|
||||
"vae": vae,
|
||||
}
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device="cpu").manual_seed(seed)
|
||||
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 5.0,
|
||||
"output_type": "np",
|
||||
"pag_scale": 0.0,
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_stable_diffusion_3_different_prompts(self):
|
||||
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
output_same_prompt = pipe(**inputs).images[0]
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
inputs["prompt_2"] = "a different prompt"
|
||||
inputs["prompt_3"] = "another different prompt"
|
||||
output_different_prompts = pipe(**inputs).images[0]
|
||||
|
||||
max_diff = np.abs(output_same_prompt - output_different_prompts).max()
|
||||
|
||||
# Outputs should be different here
|
||||
assert max_diff > 1e-2
|
||||
|
||||
def test_stable_diffusion_3_different_negative_prompts(self):
|
||||
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
output_same_prompt = pipe(**inputs).images[0]
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
inputs["negative_prompt_2"] = "deformed"
|
||||
inputs["negative_prompt_3"] = "blurry"
|
||||
output_different_prompts = pipe(**inputs).images[0]
|
||||
|
||||
max_diff = np.abs(output_same_prompt - output_different_prompts).max()
|
||||
|
||||
# Outputs should be different here
|
||||
assert max_diff > 1e-2
|
||||
|
||||
def test_stable_diffusion_3_prompt_embeds(self):
|
||||
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
output_with_prompt = pipe(**inputs).images[0]
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
prompt = inputs.pop("prompt")
|
||||
|
||||
do_classifier_free_guidance = inputs["guidance_scale"] > 1
|
||||
(
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
) = pipe.encode_prompt(
|
||||
prompt,
|
||||
prompt_2=None,
|
||||
prompt_3=None,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
device=torch_device,
|
||||
)
|
||||
output_with_embeds = pipe(
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
||||
**inputs,
|
||||
).images[0]
|
||||
|
||||
max_diff = np.abs(output_with_prompt - output_with_embeds).max()
|
||||
assert max_diff < 1e-4
|
||||
|
||||
def test_fused_qkv_projections(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = pipe(**inputs).images
|
||||
original_image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
|
||||
# to the pipeline level.
|
||||
pipe.transformer.fuse_qkv_projections()
|
||||
assert check_qkv_fusion_processors_exist(
|
||||
pipe.transformer
|
||||
), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
|
||||
assert check_qkv_fusion_matches_attn_procs_length(
|
||||
pipe.transformer, pipe.transformer.original_attn_processors
|
||||
), "Something wrong with the attention processors concerning the fused QKV projections."
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = pipe(**inputs).images
|
||||
image_slice_fused = image[0, -3:, -3:, -1]
|
||||
|
||||
pipe.transformer.unfuse_qkv_projections()
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = pipe(**inputs).images
|
||||
image_slice_disabled = image[0, -3:, -3:, -1]
|
||||
|
||||
assert np.allclose(
|
||||
original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
|
||||
), "Fusion of QKV projections shouldn't affect the outputs."
|
||||
assert np.allclose(
|
||||
image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
|
||||
), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
|
||||
assert np.allclose(
|
||||
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
|
||||
), "Original outputs should match when fused QKV projections are disabled."
|
||||
|
||||
def test_pag_disable_enable(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
|
||||
# base pipeline (expect same output when pag is disabled)
|
||||
pipe_sd = StableDiffusion3Pipeline(**components)
|
||||
pipe_sd = pipe_sd.to(device)
|
||||
pipe_sd.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
del inputs["pag_scale"]
|
||||
assert (
|
||||
"pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
|
||||
), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
|
||||
out = pipe_sd(**inputs).images[0, -3:, -3:, -1]
|
||||
|
||||
components = self.get_dummy_components()
|
||||
|
||||
# pag disabled with pag_scale=0.0
|
||||
pipe_pag = self.pipeline_class(**components)
|
||||
pipe_pag = pipe_pag.to(device)
|
||||
pipe_pag.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
inputs["pag_scale"] = 0.0
|
||||
out_pag_disabled = pipe_pag(**inputs).images[0, -3:, -3:, -1]
|
||||
|
||||
assert np.abs(out.flatten() - out_pag_disabled.flatten()).max() < 1e-3
|
||||
|
||||
def test_pag_applied_layers(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
|
||||
# base pipeline
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
all_self_attn_layers = [k for k in pipe.transformer.attn_processors.keys() if "attn" in k]
|
||||
original_attn_procs = pipe.transformer.attn_processors
|
||||
pag_layers = ["blocks.0", "blocks.1"]
|
||||
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
|
||||
assert set(pipe.pag_attn_processors) == set(all_self_attn_layers)
|
||||
|
||||
# blocks.0
|
||||
block_0_self_attn = ["transformer_blocks.0.attn.processor"]
|
||||
pipe.transformer.set_attn_processor(original_attn_procs.copy())
|
||||
pag_layers = ["blocks.0"]
|
||||
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
|
||||
assert set(pipe.pag_attn_processors) == set(block_0_self_attn)
|
||||
|
||||
pipe.transformer.set_attn_processor(original_attn_procs.copy())
|
||||
pag_layers = ["blocks.0.attn"]
|
||||
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
|
||||
assert set(pipe.pag_attn_processors) == set(block_0_self_attn)
|
||||
|
||||
pipe.transformer.set_attn_processor(original_attn_procs.copy())
|
||||
pag_layers = ["blocks.(0|1)"]
|
||||
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
|
||||
assert (len(pipe.pag_attn_processors)) == 2
|
||||
|
||||
pipe.transformer.set_attn_processor(original_attn_procs.copy())
|
||||
pag_layers = ["blocks.0", r"blocks\.1"]
|
||||
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
|
||||
assert len(pipe.pag_attn_processors) == 2
|
||||
Reference in New Issue
Block a user