[Kolors] Add PAG (#8934)
* txt2img pag added * autopipe added, fixed case * style * apply suggestions * added fast tests, added todo tests * revert dummy objects for kolors * fix pag dummies * fix test imports * update pag tests * add kolor pag to docs --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
@@ -43,6 +43,11 @@ Since RegEx is supported as a way for matching layer identifiers, it is crucial
|
|||||||
- all
|
- all
|
||||||
- __call__
|
- __call__
|
||||||
|
|
||||||
|
## KolorsPAGPipeline
|
||||||
|
[[autodoc]] KolorsPAGPipeline
|
||||||
|
- all
|
||||||
|
- __call__
|
||||||
|
|
||||||
## StableDiffusionPAGPipeline
|
## StableDiffusionPAGPipeline
|
||||||
[[autodoc]] StableDiffusionPAGPipeline
|
[[autodoc]] StableDiffusionPAGPipeline
|
||||||
- all
|
- all
|
||||||
|
|||||||
@@ -280,8 +280,6 @@ else:
|
|||||||
"KandinskyV22Pipeline",
|
"KandinskyV22Pipeline",
|
||||||
"KandinskyV22PriorEmb2EmbPipeline",
|
"KandinskyV22PriorEmb2EmbPipeline",
|
||||||
"KandinskyV22PriorPipeline",
|
"KandinskyV22PriorPipeline",
|
||||||
"KolorsImg2ImgPipeline",
|
|
||||||
"KolorsPipeline",
|
|
||||||
"LatentConsistencyModelImg2ImgPipeline",
|
"LatentConsistencyModelImg2ImgPipeline",
|
||||||
"LatentConsistencyModelPipeline",
|
"LatentConsistencyModelPipeline",
|
||||||
"LattePipeline",
|
"LattePipeline",
|
||||||
@@ -397,7 +395,7 @@ except OptionalDependencyNotAvailable:
|
|||||||
]
|
]
|
||||||
|
|
||||||
else:
|
else:
|
||||||
_import_structure["pipelines"].extend(["KolorsImg2ImgPipeline", "KolorsPipeline"])
|
_import_structure["pipelines"].extend(["KolorsImg2ImgPipeline", "KolorsPAGPipeline", "KolorsPipeline"])
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
|
if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
|
||||||
@@ -820,7 +818,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|||||||
except OptionalDependencyNotAvailable:
|
except OptionalDependencyNotAvailable:
|
||||||
from .utils.dummy_torch_and_transformers_and_sentencepiece_objects import * # noqa F403
|
from .utils.dummy_torch_and_transformers_and_sentencepiece_objects import * # noqa F403
|
||||||
else:
|
else:
|
||||||
from .pipelines import KolorsImg2ImgPipeline, KolorsPipeline
|
from .pipelines import KolorsImg2ImgPipeline, KolorsPAGPipeline, KolorsPipeline
|
||||||
try:
|
try:
|
||||||
if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
|
if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
|
||||||
raise OptionalDependencyNotAvailable()
|
raise OptionalDependencyNotAvailable()
|
||||||
|
|||||||
@@ -146,6 +146,7 @@ else:
|
|||||||
_import_structure["pag"].extend(
|
_import_structure["pag"].extend(
|
||||||
[
|
[
|
||||||
"AnimateDiffPAGPipeline",
|
"AnimateDiffPAGPipeline",
|
||||||
|
"KolorsPAGPipeline",
|
||||||
"HunyuanDiTPAGPipeline",
|
"HunyuanDiTPAGPipeline",
|
||||||
"StableDiffusion3PAGPipeline",
|
"StableDiffusion3PAGPipeline",
|
||||||
"StableDiffusionPAGPipeline",
|
"StableDiffusionPAGPipeline",
|
||||||
@@ -540,6 +541,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|||||||
from .pag import (
|
from .pag import (
|
||||||
AnimateDiffPAGPipeline,
|
AnimateDiffPAGPipeline,
|
||||||
HunyuanDiTPAGPipeline,
|
HunyuanDiTPAGPipeline,
|
||||||
|
KolorsPAGPipeline,
|
||||||
PixArtSigmaPAGPipeline,
|
PixArtSigmaPAGPipeline,
|
||||||
StableDiffusion3PAGPipeline,
|
StableDiffusion3PAGPipeline,
|
||||||
StableDiffusionControlNetPAGPipeline,
|
StableDiffusionControlNetPAGPipeline,
|
||||||
|
|||||||
@@ -162,8 +162,10 @@ _AUTO_INPAINT_DECODER_PIPELINES_MAPPING = OrderedDict(
|
|||||||
|
|
||||||
if is_sentencepiece_available():
|
if is_sentencepiece_available():
|
||||||
from .kolors import KolorsPipeline
|
from .kolors import KolorsPipeline
|
||||||
|
from .pag import KolorsPAGPipeline
|
||||||
|
|
||||||
AUTO_TEXT2IMAGE_PIPELINES_MAPPING["kolors"] = KolorsPipeline
|
AUTO_TEXT2IMAGE_PIPELINES_MAPPING["kolors"] = KolorsPipeline
|
||||||
|
AUTO_TEXT2IMAGE_PIPELINES_MAPPING["kolors-pag"] = KolorsPAGPipeline
|
||||||
AUTO_IMAGE2IMAGE_PIPELINES_MAPPING["kolors"] = KolorsPipeline
|
AUTO_IMAGE2IMAGE_PIPELINES_MAPPING["kolors"] = KolorsPipeline
|
||||||
|
|
||||||
SUPPORTED_TASKS_MAPPINGS = [
|
SUPPORTED_TASKS_MAPPINGS = [
|
||||||
|
|||||||
@@ -143,10 +143,18 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
|||||||
def unk_token(self) -> str:
|
def unk_token(self) -> str:
|
||||||
return "<unk>"
|
return "<unk>"
|
||||||
|
|
||||||
|
@unk_token.setter
|
||||||
|
def unk_token(self, value: str):
|
||||||
|
self._unk_token = value
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def pad_token(self) -> str:
|
def pad_token(self) -> str:
|
||||||
return "<unk>"
|
return "<unk>"
|
||||||
|
|
||||||
|
@pad_token.setter
|
||||||
|
def pad_token(self, value: str):
|
||||||
|
self._pad_token = value
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def pad_token_id(self):
|
def pad_token_id(self):
|
||||||
return self.get_command("<pad>")
|
return self.get_command("<pad>")
|
||||||
@@ -155,6 +163,10 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
|||||||
def eos_token(self) -> str:
|
def eos_token(self) -> str:
|
||||||
return "</s>"
|
return "</s>"
|
||||||
|
|
||||||
|
@eos_token.setter
|
||||||
|
def eos_token(self, value: str):
|
||||||
|
self._eos_token = value
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def eos_token_id(self):
|
def eos_token_id(self):
|
||||||
return self.get_command("<eos>")
|
return self.get_command("<eos>")
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ else:
|
|||||||
_import_structure["pipeline_pag_controlnet_sd"] = ["StableDiffusionControlNetPAGPipeline"]
|
_import_structure["pipeline_pag_controlnet_sd"] = ["StableDiffusionControlNetPAGPipeline"]
|
||||||
_import_structure["pipeline_pag_controlnet_sd_xl"] = ["StableDiffusionXLControlNetPAGPipeline"]
|
_import_structure["pipeline_pag_controlnet_sd_xl"] = ["StableDiffusionXLControlNetPAGPipeline"]
|
||||||
_import_structure["pipeline_pag_hunyuandit"] = ["HunyuanDiTPAGPipeline"]
|
_import_structure["pipeline_pag_hunyuandit"] = ["HunyuanDiTPAGPipeline"]
|
||||||
|
_import_structure["pipeline_pag_kolors"] = ["KolorsPAGPipeline"]
|
||||||
_import_structure["pipeline_pag_pixart_sigma"] = ["PixArtSigmaPAGPipeline"]
|
_import_structure["pipeline_pag_pixart_sigma"] = ["PixArtSigmaPAGPipeline"]
|
||||||
_import_structure["pipeline_pag_sd"] = ["StableDiffusionPAGPipeline"]
|
_import_structure["pipeline_pag_sd"] = ["StableDiffusionPAGPipeline"]
|
||||||
_import_structure["pipeline_pag_sd_3"] = ["StableDiffusion3PAGPipeline"]
|
_import_structure["pipeline_pag_sd_3"] = ["StableDiffusion3PAGPipeline"]
|
||||||
@@ -44,6 +45,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|||||||
from .pipeline_pag_controlnet_sd import StableDiffusionControlNetPAGPipeline
|
from .pipeline_pag_controlnet_sd import StableDiffusionControlNetPAGPipeline
|
||||||
from .pipeline_pag_controlnet_sd_xl import StableDiffusionXLControlNetPAGPipeline
|
from .pipeline_pag_controlnet_sd_xl import StableDiffusionXLControlNetPAGPipeline
|
||||||
from .pipeline_pag_hunyuandit import HunyuanDiTPAGPipeline
|
from .pipeline_pag_hunyuandit import HunyuanDiTPAGPipeline
|
||||||
|
from .pipeline_pag_kolors import KolorsPAGPipeline
|
||||||
from .pipeline_pag_pixart_sigma import PixArtSigmaPAGPipeline
|
from .pipeline_pag_pixart_sigma import PixArtSigmaPAGPipeline
|
||||||
from .pipeline_pag_sd import StableDiffusionPAGPipeline
|
from .pipeline_pag_sd import StableDiffusionPAGPipeline
|
||||||
from .pipeline_pag_sd_3 import StableDiffusion3PAGPipeline
|
from .pipeline_pag_sd_3 import StableDiffusion3PAGPipeline
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -17,6 +17,21 @@ class KolorsImg2ImgPipeline(metaclass=DummyObject):
|
|||||||
requires_backends(cls, ["torch", "transformers", "sentencepiece"])
|
requires_backends(cls, ["torch", "transformers", "sentencepiece"])
|
||||||
|
|
||||||
|
|
||||||
|
class KolorsPAGPipeline(metaclass=DummyObject):
|
||||||
|
_backends = ["torch", "transformers", "sentencepiece"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch", "transformers", "sentencepiece"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch", "transformers", "sentencepiece"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch", "transformers", "sentencepiece"])
|
||||||
|
|
||||||
|
|
||||||
class KolorsPipeline(metaclass=DummyObject):
|
class KolorsPipeline(metaclass=DummyObject):
|
||||||
_backends = ["torch", "transformers", "sentencepiece"]
|
_backends = ["torch", "transformers", "sentencepiece"]
|
||||||
|
|
||||||
|
|||||||
@@ -133,23 +133,11 @@ class KolorsPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
|||||||
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
|
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
|
||||||
self.assertLessEqual(max_diff, 1e-3)
|
self.assertLessEqual(max_diff, 1e-3)
|
||||||
|
|
||||||
# throws AttributeError: property 'eos_token' of 'ChatGLMTokenizer' object has no setter
|
|
||||||
# not sure if it is worth to fix it before integrating it to transformers
|
|
||||||
def test_save_load_optional_components(self):
|
def test_save_load_optional_components(self):
|
||||||
# TODO (Alvaro) need to fix later
|
super().test_save_load_optional_components(expected_max_difference=2e-4)
|
||||||
pass
|
|
||||||
|
|
||||||
# throws AttributeError: property 'eos_token' of 'ChatGLMTokenizer' object has no setter
|
|
||||||
# not sure if it is worth to fix it before integrating it to transformers
|
|
||||||
def test_save_load_float16(self):
|
def test_save_load_float16(self):
|
||||||
# TODO (Alvaro) need to fix later
|
super().test_save_load_float16(expected_max_diff=2e-1)
|
||||||
pass
|
|
||||||
|
|
||||||
# throws AttributeError: property 'eos_token' of 'ChatGLMTokenizer' object has no setter
|
|
||||||
# not sure if it is worth to fix it before integrating it to transformers
|
|
||||||
def test_save_load_local(self):
|
|
||||||
# TODO (Alvaro) need to fix later
|
|
||||||
pass
|
|
||||||
|
|
||||||
def test_inference_batch_single_identical(self):
|
def test_inference_batch_single_identical(self):
|
||||||
self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=5e-4)
|
self._test_inference_batch_single_identical(expected_max_diff=5e-4)
|
||||||
|
|||||||
@@ -0,0 +1,152 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2024 HuggingFace Inc.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import random
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from diffusers import (
|
||||||
|
AutoencoderKL,
|
||||||
|
EulerDiscreteScheduler,
|
||||||
|
KolorsImg2ImgPipeline,
|
||||||
|
UNet2DConditionModel,
|
||||||
|
)
|
||||||
|
from diffusers.pipelines.kolors import ChatGLMModel, ChatGLMTokenizer
|
||||||
|
from diffusers.utils.testing_utils import (
|
||||||
|
enable_full_determinism,
|
||||||
|
floats_tensor,
|
||||||
|
)
|
||||||
|
|
||||||
|
from ..pipeline_params import (
|
||||||
|
TEXT_TO_IMAGE_BATCH_PARAMS,
|
||||||
|
TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS,
|
||||||
|
TEXT_TO_IMAGE_IMAGE_PARAMS,
|
||||||
|
TEXT_TO_IMAGE_PARAMS,
|
||||||
|
)
|
||||||
|
from ..test_pipelines_common import PipelineTesterMixin
|
||||||
|
|
||||||
|
|
||||||
|
enable_full_determinism()
|
||||||
|
|
||||||
|
|
||||||
|
class KolorsPipelineImg2ImgFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||||
|
pipeline_class = KolorsImg2ImgPipeline
|
||||||
|
params = TEXT_TO_IMAGE_PARAMS
|
||||||
|
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||||
|
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||||
|
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||||
|
callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS.union({"add_text_embeds", "add_time_ids"})
|
||||||
|
|
||||||
|
# Copied from tests.pipelines.kolors.test_kolors.KolorsPipelineFastTests.get_dummy_components
|
||||||
|
def get_dummy_components(self, time_cond_proj_dim=None):
|
||||||
|
torch.manual_seed(0)
|
||||||
|
unet = UNet2DConditionModel(
|
||||||
|
block_out_channels=(2, 4),
|
||||||
|
layers_per_block=2,
|
||||||
|
time_cond_proj_dim=time_cond_proj_dim,
|
||||||
|
sample_size=32,
|
||||||
|
in_channels=4,
|
||||||
|
out_channels=4,
|
||||||
|
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
|
||||||
|
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
|
||||||
|
# specific config below
|
||||||
|
attention_head_dim=(2, 4),
|
||||||
|
use_linear_projection=True,
|
||||||
|
addition_embed_type="text_time",
|
||||||
|
addition_time_embed_dim=8,
|
||||||
|
transformer_layers_per_block=(1, 2),
|
||||||
|
projection_class_embeddings_input_dim=56,
|
||||||
|
cross_attention_dim=8,
|
||||||
|
norm_num_groups=1,
|
||||||
|
)
|
||||||
|
scheduler = EulerDiscreteScheduler(
|
||||||
|
beta_start=0.00085,
|
||||||
|
beta_end=0.012,
|
||||||
|
steps_offset=1,
|
||||||
|
beta_schedule="scaled_linear",
|
||||||
|
timestep_spacing="leading",
|
||||||
|
)
|
||||||
|
torch.manual_seed(0)
|
||||||
|
vae = AutoencoderKL(
|
||||||
|
block_out_channels=[32, 64],
|
||||||
|
in_channels=3,
|
||||||
|
out_channels=3,
|
||||||
|
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||||
|
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||||
|
latent_channels=4,
|
||||||
|
sample_size=128,
|
||||||
|
)
|
||||||
|
torch.manual_seed(0)
|
||||||
|
text_encoder = ChatGLMModel.from_pretrained("hf-internal-testing/tiny-random-chatglm3-6b")
|
||||||
|
tokenizer = ChatGLMTokenizer.from_pretrained("hf-internal-testing/tiny-random-chatglm3-6b")
|
||||||
|
|
||||||
|
components = {
|
||||||
|
"unet": unet,
|
||||||
|
"scheduler": scheduler,
|
||||||
|
"vae": vae,
|
||||||
|
"text_encoder": text_encoder,
|
||||||
|
"tokenizer": tokenizer,
|
||||||
|
"image_encoder": None,
|
||||||
|
"feature_extractor": None,
|
||||||
|
}
|
||||||
|
return components
|
||||||
|
|
||||||
|
def get_dummy_inputs(self, device, seed=0):
|
||||||
|
image = floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(device)
|
||||||
|
image = image / 2 + 0.5
|
||||||
|
|
||||||
|
if str(device).startswith("mps"):
|
||||||
|
generator = torch.manual_seed(seed)
|
||||||
|
else:
|
||||||
|
generator = torch.Generator(device=device).manual_seed(seed)
|
||||||
|
|
||||||
|
inputs = {
|
||||||
|
"prompt": "A painting of a squirrel eating a burger",
|
||||||
|
"image": image,
|
||||||
|
"generator": generator,
|
||||||
|
"num_inference_steps": 2,
|
||||||
|
"guidance_scale": 5.0,
|
||||||
|
"output_type": "np",
|
||||||
|
"strength": 0.8,
|
||||||
|
}
|
||||||
|
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
def test_inference(self):
|
||||||
|
device = "cpu"
|
||||||
|
|
||||||
|
components = self.get_dummy_components()
|
||||||
|
pipe = self.pipeline_class(**components)
|
||||||
|
pipe.to(device)
|
||||||
|
pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
|
inputs = self.get_dummy_inputs(device)
|
||||||
|
image = pipe(**inputs).images
|
||||||
|
image_slice = image[0, -3:, -3:, -1]
|
||||||
|
|
||||||
|
self.assertEqual(image.shape, (1, 64, 64, 3))
|
||||||
|
expected_slice = np.array(
|
||||||
|
[0.54823864, 0.43654007, 0.4886489, 0.63072854, 0.53641886, 0.4896852, 0.62123513, 0.5621531, 0.42809626]
|
||||||
|
)
|
||||||
|
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
|
||||||
|
self.assertLessEqual(max_diff, 1e-3)
|
||||||
|
|
||||||
|
def test_inference_batch_single_identical(self):
|
||||||
|
self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=3e-3)
|
||||||
|
|
||||||
|
def test_float16_inference(self):
|
||||||
|
super().test_float16_inference(expected_max_diff=7e-2)
|
||||||
@@ -0,0 +1,252 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2024 HuggingFace Inc.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from diffusers import (
|
||||||
|
AutoencoderKL,
|
||||||
|
EulerDiscreteScheduler,
|
||||||
|
KolorsPAGPipeline,
|
||||||
|
KolorsPipeline,
|
||||||
|
UNet2DConditionModel,
|
||||||
|
)
|
||||||
|
from diffusers.pipelines.kolors import ChatGLMModel, ChatGLMTokenizer
|
||||||
|
from diffusers.utils.testing_utils import enable_full_determinism
|
||||||
|
|
||||||
|
from ..pipeline_params import (
|
||||||
|
TEXT_TO_IMAGE_BATCH_PARAMS,
|
||||||
|
TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS,
|
||||||
|
TEXT_TO_IMAGE_IMAGE_PARAMS,
|
||||||
|
TEXT_TO_IMAGE_PARAMS,
|
||||||
|
)
|
||||||
|
from ..test_pipelines_common import (
|
||||||
|
PipelineFromPipeTesterMixin,
|
||||||
|
PipelineTesterMixin,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
enable_full_determinism()
|
||||||
|
|
||||||
|
|
||||||
|
class KolorsPAGPipelineFastTests(
|
||||||
|
PipelineTesterMixin,
|
||||||
|
PipelineFromPipeTesterMixin,
|
||||||
|
unittest.TestCase,
|
||||||
|
):
|
||||||
|
pipeline_class = KolorsPAGPipeline
|
||||||
|
params = TEXT_TO_IMAGE_PARAMS.union({"pag_scale", "pag_adaptive_scale"})
|
||||||
|
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||||
|
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||||
|
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||||
|
callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS.union({"add_text_embeds", "add_time_ids"})
|
||||||
|
|
||||||
|
# Copied from tests.pipelines.kolors.test_kolors.KolorsPipelineFastTests.get_dummy_components
|
||||||
|
def get_dummy_components(self, time_cond_proj_dim=None):
|
||||||
|
torch.manual_seed(0)
|
||||||
|
unet = UNet2DConditionModel(
|
||||||
|
block_out_channels=(2, 4),
|
||||||
|
layers_per_block=2,
|
||||||
|
time_cond_proj_dim=time_cond_proj_dim,
|
||||||
|
sample_size=32,
|
||||||
|
in_channels=4,
|
||||||
|
out_channels=4,
|
||||||
|
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
|
||||||
|
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
|
||||||
|
# specific config below
|
||||||
|
attention_head_dim=(2, 4),
|
||||||
|
use_linear_projection=True,
|
||||||
|
addition_embed_type="text_time",
|
||||||
|
addition_time_embed_dim=8,
|
||||||
|
transformer_layers_per_block=(1, 2),
|
||||||
|
projection_class_embeddings_input_dim=56,
|
||||||
|
cross_attention_dim=8,
|
||||||
|
norm_num_groups=1,
|
||||||
|
)
|
||||||
|
scheduler = EulerDiscreteScheduler(
|
||||||
|
beta_start=0.00085,
|
||||||
|
beta_end=0.012,
|
||||||
|
steps_offset=1,
|
||||||
|
beta_schedule="scaled_linear",
|
||||||
|
timestep_spacing="leading",
|
||||||
|
)
|
||||||
|
torch.manual_seed(0)
|
||||||
|
vae = AutoencoderKL(
|
||||||
|
block_out_channels=[32, 64],
|
||||||
|
in_channels=3,
|
||||||
|
out_channels=3,
|
||||||
|
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||||
|
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||||
|
latent_channels=4,
|
||||||
|
sample_size=128,
|
||||||
|
)
|
||||||
|
torch.manual_seed(0)
|
||||||
|
text_encoder = ChatGLMModel.from_pretrained("hf-internal-testing/tiny-random-chatglm3-6b")
|
||||||
|
tokenizer = ChatGLMTokenizer.from_pretrained("hf-internal-testing/tiny-random-chatglm3-6b")
|
||||||
|
|
||||||
|
components = {
|
||||||
|
"unet": unet,
|
||||||
|
"scheduler": scheduler,
|
||||||
|
"vae": vae,
|
||||||
|
"text_encoder": text_encoder,
|
||||||
|
"tokenizer": tokenizer,
|
||||||
|
"image_encoder": None,
|
||||||
|
"feature_extractor": None,
|
||||||
|
}
|
||||||
|
return components
|
||||||
|
|
||||||
|
def get_dummy_inputs(self, device, seed=0):
|
||||||
|
if str(device).startswith("mps"):
|
||||||
|
generator = torch.manual_seed(seed)
|
||||||
|
else:
|
||||||
|
generator = torch.Generator(device=device).manual_seed(seed)
|
||||||
|
inputs = {
|
||||||
|
"prompt": "A painting of a squirrel eating a burger",
|
||||||
|
"generator": generator,
|
||||||
|
"num_inference_steps": 2,
|
||||||
|
"guidance_scale": 5.0,
|
||||||
|
"pag_scale": 0.9,
|
||||||
|
"output_type": "np",
|
||||||
|
}
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
def test_pag_disable_enable(self):
|
||||||
|
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||||
|
components = self.get_dummy_components()
|
||||||
|
|
||||||
|
# base pipeline (expect same output when pag is disabled)
|
||||||
|
pipe_sd = KolorsPipeline(**components)
|
||||||
|
pipe_sd = pipe_sd.to(device)
|
||||||
|
pipe_sd.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
|
inputs = self.get_dummy_inputs(device)
|
||||||
|
del inputs["pag_scale"]
|
||||||
|
assert (
|
||||||
|
"pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
|
||||||
|
), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
|
||||||
|
out = pipe_sd(**inputs).images[0, -3:, -3:, -1]
|
||||||
|
|
||||||
|
# pag disabled with pag_scale=0.0
|
||||||
|
pipe_pag = self.pipeline_class(**components)
|
||||||
|
pipe_pag = pipe_pag.to(device)
|
||||||
|
pipe_pag.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
|
inputs = self.get_dummy_inputs(device)
|
||||||
|
inputs["pag_scale"] = 0.0
|
||||||
|
out_pag_disabled = pipe_pag(**inputs).images[0, -3:, -3:, -1]
|
||||||
|
|
||||||
|
# pag enabled
|
||||||
|
pipe_pag = self.pipeline_class(**components, pag_applied_layers=["mid", "up", "down"])
|
||||||
|
pipe_pag = pipe_pag.to(device)
|
||||||
|
pipe_pag.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
|
inputs = self.get_dummy_inputs(device)
|
||||||
|
out_pag_enabled = pipe_pag(**inputs).images[0, -3:, -3:, -1]
|
||||||
|
|
||||||
|
assert np.abs(out.flatten() - out_pag_disabled.flatten()).max() < 1e-3
|
||||||
|
assert np.abs(out.flatten() - out_pag_enabled.flatten()).max() > 1e-3
|
||||||
|
|
||||||
|
def test_pag_applied_layers(self):
|
||||||
|
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||||
|
components = self.get_dummy_components()
|
||||||
|
|
||||||
|
# base pipeline
|
||||||
|
pipe = self.pipeline_class(**components)
|
||||||
|
pipe = pipe.to(device)
|
||||||
|
pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
|
# pag_applied_layers = ["mid","up","down"] should apply to all self-attention layers
|
||||||
|
all_self_attn_layers = [k for k in pipe.unet.attn_processors.keys() if "attn1" in k]
|
||||||
|
original_attn_procs = pipe.unet.attn_processors
|
||||||
|
pag_layers = ["mid", "down", "up"]
|
||||||
|
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
|
||||||
|
assert set(pipe.pag_attn_processors) == set(all_self_attn_layers)
|
||||||
|
|
||||||
|
all_self_attn_mid_layers = [
|
||||||
|
"mid_block.attentions.0.transformer_blocks.0.attn1.processor",
|
||||||
|
"mid_block.attentions.0.transformer_blocks.1.attn1.processor",
|
||||||
|
]
|
||||||
|
pipe.unet.set_attn_processor(original_attn_procs.copy())
|
||||||
|
pag_layers = ["mid"]
|
||||||
|
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
|
||||||
|
assert set(pipe.pag_attn_processors) == set(all_self_attn_mid_layers)
|
||||||
|
|
||||||
|
pipe.unet.set_attn_processor(original_attn_procs.copy())
|
||||||
|
pag_layers = ["mid_block"]
|
||||||
|
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
|
||||||
|
assert set(pipe.pag_attn_processors) == set(all_self_attn_mid_layers)
|
||||||
|
|
||||||
|
pipe.unet.set_attn_processor(original_attn_procs.copy())
|
||||||
|
pag_layers = ["mid_block.attentions.0"]
|
||||||
|
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
|
||||||
|
assert set(pipe.pag_attn_processors) == set(all_self_attn_mid_layers)
|
||||||
|
|
||||||
|
# pag_applied_layers = ["mid.block_0.attentions_1"] does not exist in the model
|
||||||
|
pipe.unet.set_attn_processor(original_attn_procs.copy())
|
||||||
|
pag_layers = ["mid_block.attentions.1"]
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
|
||||||
|
|
||||||
|
# pag_applied_layers = "down" should apply to all self-attention layers in down_blocks
|
||||||
|
pipe.unet.set_attn_processor(original_attn_procs.copy())
|
||||||
|
pag_layers = ["down"]
|
||||||
|
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
|
||||||
|
assert len(pipe.pag_attn_processors) == 4
|
||||||
|
|
||||||
|
pipe.unet.set_attn_processor(original_attn_procs.copy())
|
||||||
|
pag_layers = ["down_blocks.0"]
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
|
||||||
|
|
||||||
|
pipe.unet.set_attn_processor(original_attn_procs.copy())
|
||||||
|
pag_layers = ["down_blocks.1"]
|
||||||
|
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
|
||||||
|
assert len(pipe.pag_attn_processors) == 4
|
||||||
|
|
||||||
|
pipe.unet.set_attn_processor(original_attn_procs.copy())
|
||||||
|
pag_layers = ["down_blocks.1.attentions.1"]
|
||||||
|
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
|
||||||
|
assert len(pipe.pag_attn_processors) == 2
|
||||||
|
|
||||||
|
def test_pag_inference(self):
|
||||||
|
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||||
|
components = self.get_dummy_components()
|
||||||
|
|
||||||
|
pipe_pag = self.pipeline_class(**components, pag_applied_layers=["mid", "up", "down"])
|
||||||
|
pipe_pag = pipe_pag.to(device)
|
||||||
|
pipe_pag.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
|
inputs = self.get_dummy_inputs(device)
|
||||||
|
image = pipe_pag(**inputs).images
|
||||||
|
image_slice = image[0, -3:, -3:, -1]
|
||||||
|
|
||||||
|
assert image.shape == (
|
||||||
|
1,
|
||||||
|
64,
|
||||||
|
64,
|
||||||
|
3,
|
||||||
|
), f"the shape of the output image should be (1, 64, 64, 3) but got {image.shape}"
|
||||||
|
expected_slice = np.array(
|
||||||
|
[0.26030684, 0.43192005, 0.4042826, 0.4189067, 0.5181305, 0.3832534, 0.472135, 0.4145031, 0.43726248]
|
||||||
|
)
|
||||||
|
|
||||||
|
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
|
||||||
|
self.assertLessEqual(max_diff, 1e-3)
|
||||||
|
|
||||||
|
def test_inference_batch_single_identical(self):
|
||||||
|
self._test_inference_batch_single_identical(expected_max_diff=3e-3)
|
||||||
@@ -26,6 +26,7 @@ from diffusers import (
|
|||||||
ConsistencyDecoderVAE,
|
ConsistencyDecoderVAE,
|
||||||
DDIMScheduler,
|
DDIMScheduler,
|
||||||
DiffusionPipeline,
|
DiffusionPipeline,
|
||||||
|
KolorsPipeline,
|
||||||
StableDiffusionPipeline,
|
StableDiffusionPipeline,
|
||||||
StableDiffusionXLPipeline,
|
StableDiffusionXLPipeline,
|
||||||
UNet2DConditionModel,
|
UNet2DConditionModel,
|
||||||
@@ -656,6 +657,8 @@ class PipelineFromPipeTesterMixin:
|
|||||||
def original_pipeline_class(self):
|
def original_pipeline_class(self):
|
||||||
if "xl" in self.pipeline_class.__name__.lower():
|
if "xl" in self.pipeline_class.__name__.lower():
|
||||||
original_pipeline_class = StableDiffusionXLPipeline
|
original_pipeline_class = StableDiffusionXLPipeline
|
||||||
|
elif "kolors" in self.pipeline_class.__name__.lower():
|
||||||
|
original_pipeline_class = KolorsPipeline
|
||||||
else:
|
else:
|
||||||
original_pipeline_class = StableDiffusionPipeline
|
original_pipeline_class = StableDiffusionPipeline
|
||||||
|
|
||||||
@@ -681,6 +684,9 @@ class PipelineFromPipeTesterMixin:
|
|||||||
elif self.original_pipeline_class == StableDiffusionXLPipeline:
|
elif self.original_pipeline_class == StableDiffusionXLPipeline:
|
||||||
original_repo = "hf-internal-testing/tiny-stable-diffusion-xl-pipe"
|
original_repo = "hf-internal-testing/tiny-stable-diffusion-xl-pipe"
|
||||||
original_kwargs = {"requires_aesthetics_score": True, "force_zeros_for_empty_prompt": False}
|
original_kwargs = {"requires_aesthetics_score": True, "force_zeros_for_empty_prompt": False}
|
||||||
|
elif self.original_pipeline_class == KolorsPipeline:
|
||||||
|
original_repo = "hf-internal-testing/tiny-kolors-pipe"
|
||||||
|
original_kwargs = {"force_zeros_for_empty_prompt": False}
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"original_pipeline_class must be either StableDiffusionPipeline or StableDiffusionXLPipeline"
|
"original_pipeline_class must be either StableDiffusionPipeline or StableDiffusionXLPipeline"
|
||||||
|
|||||||
Reference in New Issue
Block a user