Compare commits
21 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 496bf0be1b | |||
| 1db63655e4 | |||
| 4bf07e5fc5 | |||
| 571cea6dcf | |||
| 1f0570dba0 | |||
| a176cfde84 | |||
| 3aabef5de4 | |||
| 39be374591 | |||
| 54e17f3084 | |||
| 80702d222d | |||
| 625cc8ede8 | |||
| a2a9e4eadb | |||
| 0998bd75ad | |||
| 5f560d05a2 | |||
| 4b7a9e9fa9 | |||
| d8fa2de36f | |||
| 4df2739a5e | |||
| d92855ddf0 | |||
| 0a5c90ed47 | |||
| 0fa58127f8 | |||
| b165cf3742 |
@@ -0,0 +1,141 @@
|
||||
name: Fast PR tests for Modular
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches: [main]
|
||||
paths:
|
||||
- "src/diffusers/modular_pipelines/**.py"
|
||||
- "src/diffusers/models/modeling_utils.py"
|
||||
- "src/diffusers/models/model_loading_utils.py"
|
||||
- "src/diffusers/pipelines/pipeline_utils.py"
|
||||
- "src/diffusers/pipeline_loading_utils.py"
|
||||
- "src/diffusers/loaders/lora_base.py"
|
||||
- "src/diffusers/loaders/lora_pipeline.py"
|
||||
- "src/diffusers/loaders/peft.py"
|
||||
- "tests/modular_pipelines/**.py"
|
||||
- ".github/**.yml"
|
||||
- "utils/**.py"
|
||||
- "setup.py"
|
||||
push:
|
||||
branches:
|
||||
- ci-*
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
env:
|
||||
DIFFUSERS_IS_CI: yes
|
||||
HF_HUB_ENABLE_HF_TRANSFER: 1
|
||||
OMP_NUM_THREADS: 4
|
||||
MKL_NUM_THREADS: 4
|
||||
PYTEST_TIMEOUT: 60
|
||||
|
||||
jobs:
|
||||
check_code_quality:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.10"
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install .[quality]
|
||||
- name: Check quality
|
||||
run: make quality
|
||||
- name: Check if failure
|
||||
if: ${{ failure() }}
|
||||
run: |
|
||||
echo "Quality check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and run 'make style && make quality'" >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
check_repository_consistency:
|
||||
needs: check_code_quality
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.10"
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install .[quality]
|
||||
- name: Check repo consistency
|
||||
run: |
|
||||
python utils/check_copies.py
|
||||
python utils/check_dummies.py
|
||||
python utils/check_support_list.py
|
||||
make deps_table_check_updated
|
||||
- name: Check if failure
|
||||
if: ${{ failure() }}
|
||||
run: |
|
||||
echo "Repo consistency check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and run 'make fix-copies'" >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
run_fast_tests:
|
||||
needs: [check_code_quality, check_repository_consistency]
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
config:
|
||||
- name: Fast PyTorch Modular Pipeline CPU tests
|
||||
framework: pytorch_pipelines
|
||||
runner: aws-highmemory-32-plus
|
||||
image: diffusers/diffusers-pytorch-cpu
|
||||
report: torch_cpu_modular_pipelines
|
||||
|
||||
name: ${{ matrix.config.name }}
|
||||
|
||||
runs-on:
|
||||
group: ${{ matrix.config.runner }}
|
||||
|
||||
container:
|
||||
image: ${{ matrix.config.image }}
|
||||
options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/
|
||||
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
pip uninstall transformers -y && python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git --no-deps
|
||||
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python utils/print_env.py
|
||||
|
||||
- name: Run fast PyTorch Pipeline CPU tests
|
||||
if: ${{ matrix.config.framework == 'pytorch_pipelines' }}
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m pytest -n 8 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "not Flax and not Onnx" \
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
tests/modular_pipelines
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
run: cat reports/tests_${{ matrix.config.report }}_failures_short.txt
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: pr_${{ matrix.config.framework }}_${{ matrix.config.report }}_test_reports
|
||||
path: reports
|
||||
|
||||
|
||||
@@ -45,8 +45,6 @@ from .modular_pipeline_utils import (
|
||||
OutputParam,
|
||||
format_components,
|
||||
format_configs,
|
||||
format_inputs_short,
|
||||
format_intermediates_short,
|
||||
make_doc_string,
|
||||
)
|
||||
|
||||
@@ -76,139 +74,59 @@ class PipelineState:
|
||||
[`PipelineState`] stores the state of a pipeline. It is used to pass data between pipeline blocks.
|
||||
"""
|
||||
|
||||
inputs: Dict[str, Any] = field(default_factory=dict)
|
||||
intermediates: Dict[str, Any] = field(default_factory=dict)
|
||||
input_kwargs: Dict[str, List[str]] = field(default_factory=dict)
|
||||
intermediate_kwargs: Dict[str, List[str]] = field(default_factory=dict)
|
||||
values: Dict[str, Any] = field(default_factory=dict)
|
||||
kwargs_mapping: Dict[str, List[str]] = field(default_factory=dict)
|
||||
|
||||
def set_input(self, key: str, value: Any, kwargs_type: str = None):
|
||||
def set(self, key: str, value: Any, kwargs_type: str = None):
|
||||
"""
|
||||
Add an input to the immutable pipeline state, i.e, pipeline_state.inputs.
|
||||
|
||||
The kwargs_type parameter allows you to associate inputs with specific input types. For example, if you call
|
||||
set_input(prompt_embeds=..., kwargs_type="guider_kwargs"), this input will be automatically fetched when a
|
||||
pipeline block has "guider_kwargs" in its expected_inputs list.
|
||||
Add a value to the pipeline state.
|
||||
|
||||
Args:
|
||||
key (str): The key for the input
|
||||
value (Any): The input value
|
||||
kwargs_type (str): The kwargs_type with which the input is associated
|
||||
key (str): The key for the value
|
||||
value (Any): The value to store
|
||||
kwargs_type (str): The kwargs_type with which the value is associated
|
||||
"""
|
||||
self.inputs[key] = value
|
||||
self.values[key] = value
|
||||
|
||||
if kwargs_type is not None:
|
||||
if kwargs_type not in self.input_kwargs:
|
||||
self.input_kwargs[kwargs_type] = [key]
|
||||
if kwargs_type not in self.kwargs_mapping:
|
||||
self.kwargs_mapping[kwargs_type] = [key]
|
||||
else:
|
||||
self.input_kwargs[kwargs_type].append(key)
|
||||
self.kwargs_mapping[kwargs_type].append(key)
|
||||
|
||||
def set_intermediate(self, key: str, value: Any, kwargs_type: str = None):
|
||||
def get(self, keys: Union[str, List[str]], default: Any = None) -> Union[Any, Dict[str, Any]]:
|
||||
"""
|
||||
Add an intermediate value to the mutable pipeline state, i.e, pipeline_state.intermediates.
|
||||
|
||||
The kwargs_type parameter allows you to associate intermediate values with specific input types. For example,
|
||||
if you call set_intermediate(latents=..., kwargs_type="latents_kwargs"), this intermediate value will be
|
||||
automatically fetched when a pipeline block has "latents_kwargs" in its expected_intermediate_inputs list.
|
||||
Get one or multiple values from the pipeline state.
|
||||
|
||||
Args:
|
||||
key (str): The key for the intermediate value
|
||||
value (Any): The intermediate value
|
||||
kwargs_type (str): The kwargs_type with which the intermediate value is associated
|
||||
"""
|
||||
self.intermediates[key] = value
|
||||
if kwargs_type is not None:
|
||||
if kwargs_type not in self.intermediate_kwargs:
|
||||
self.intermediate_kwargs[kwargs_type] = [key]
|
||||
else:
|
||||
self.intermediate_kwargs[kwargs_type].append(key)
|
||||
|
||||
def get_input(self, key: str, default: Any = None) -> Any:
|
||||
"""
|
||||
Get an input from the pipeline state.
|
||||
|
||||
Args:
|
||||
key (str): The key for the input
|
||||
default (Any): The default value to return if the input is not found
|
||||
keys (Union[str, List[str]]): Key or list of keys for the values
|
||||
default (Any): The default value to return if not found
|
||||
|
||||
Returns:
|
||||
Any: The input value
|
||||
Union[Any, Dict[str, Any]]: Single value if keys is str, dictionary of values if keys is list
|
||||
"""
|
||||
value = self.inputs.get(key, default)
|
||||
if value is not None:
|
||||
return deepcopy(value)
|
||||
if isinstance(keys, str):
|
||||
return self.values.get(keys, default)
|
||||
return {key: self.values.get(key, default) for key in keys}
|
||||
|
||||
def get_inputs(self, keys: List[str], default: Any = None) -> Dict[str, Any]:
|
||||
def get_by_kwargs(self, kwargs_type: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Get multiple inputs from the pipeline state.
|
||||
|
||||
Args:
|
||||
keys (List[str]): The keys for the inputs
|
||||
default (Any): The default value to return if the input is not found
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Dictionary of inputs with matching keys
|
||||
"""
|
||||
return {key: self.inputs.get(key, default) for key in keys}
|
||||
|
||||
def get_inputs_kwargs(self, kwargs_type: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Get all inputs with matching kwargs_type.
|
||||
Get all values with matching kwargs_type.
|
||||
|
||||
Args:
|
||||
kwargs_type (str): The kwargs_type to filter by
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Dictionary of inputs with matching kwargs_type
|
||||
Dict[str, Any]: Dictionary of values with matching kwargs_type
|
||||
"""
|
||||
input_names = self.input_kwargs.get(kwargs_type, [])
|
||||
return self.get_inputs(input_names)
|
||||
|
||||
def get_intermediate_kwargs(self, kwargs_type: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Get all intermediates with matching kwargs_type.
|
||||
|
||||
Args:
|
||||
kwargs_type (str): The kwargs_type to filter by
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Dictionary of intermediates with matching kwargs_type
|
||||
"""
|
||||
intermediate_names = self.intermediate_kwargs.get(kwargs_type, [])
|
||||
return self.get_intermediates(intermediate_names)
|
||||
|
||||
def get_intermediate(self, key: str, default: Any = None) -> Any:
|
||||
"""
|
||||
Get an intermediate value from the pipeline state.
|
||||
|
||||
Args:
|
||||
key (str): The key for the intermediate value
|
||||
default (Any): The default value to return if the intermediate value is not found
|
||||
|
||||
Returns:
|
||||
Any: The intermediate value
|
||||
"""
|
||||
return self.intermediates.get(key, default)
|
||||
|
||||
def get_intermediates(self, keys: List[str], default: Any = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Get multiple intermediate values from the pipeline state.
|
||||
|
||||
Args:
|
||||
keys (List[str]): The keys for the intermediate values
|
||||
default (Any): The default value to return if the intermediate value is not found
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Dictionary of intermediate values with matching keys
|
||||
"""
|
||||
return {key: self.intermediates.get(key, default) for key in keys}
|
||||
value_names = self.kwargs_mapping.get(kwargs_type, [])
|
||||
return self.get(value_names)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert PipelineState to a dictionary.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Dictionary containing all attributes of the PipelineState
|
||||
"""
|
||||
return {**self.__dict__, "inputs": self.inputs, "intermediates": self.intermediates}
|
||||
return {**self.__dict__}
|
||||
|
||||
def __repr__(self):
|
||||
def format_value(v):
|
||||
@@ -219,21 +137,10 @@ class PipelineState:
|
||||
else:
|
||||
return repr(v)
|
||||
|
||||
inputs = "\n".join(f" {k}: {format_value(v)}" for k, v in self.inputs.items())
|
||||
intermediates = "\n".join(f" {k}: {format_value(v)}" for k, v in self.intermediates.items())
|
||||
values_str = "\n".join(f" {k}: {format_value(v)}" for k, v in self.values.items())
|
||||
kwargs_mapping_str = "\n".join(f" {k}: {v}" for k, v in self.kwargs_mapping.items())
|
||||
|
||||
# Format input_kwargs and intermediate_kwargs
|
||||
input_kwargs_str = "\n".join(f" {k}: {v}" for k, v in self.input_kwargs.items())
|
||||
intermediate_kwargs_str = "\n".join(f" {k}: {v}" for k, v in self.intermediate_kwargs.items())
|
||||
|
||||
return (
|
||||
f"PipelineState(\n"
|
||||
f" inputs={{\n{inputs}\n }},\n"
|
||||
f" intermediates={{\n{intermediates}\n }},\n"
|
||||
f" input_kwargs={{\n{input_kwargs_str}\n }},\n"
|
||||
f" intermediate_kwargs={{\n{intermediate_kwargs_str}\n }}\n"
|
||||
f")"
|
||||
)
|
||||
return f"PipelineState(\n values={{\n{values_str}\n }},\n kwargs_mapping={{\n{kwargs_mapping_str}\n }}\n)"
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -322,7 +229,7 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
||||
</Tip>
|
||||
"""
|
||||
|
||||
config_name = "config.json"
|
||||
config_name = "modular_config.json"
|
||||
model_name = None
|
||||
|
||||
@classmethod
|
||||
@@ -334,6 +241,14 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
||||
|
||||
return expected_modules, optional_parameters
|
||||
|
||||
def __init__(self):
|
||||
self.sub_blocks = InsertableDict()
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
"""Description of the block. Must be implemented by subclasses."""
|
||||
return ""
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return []
|
||||
@@ -343,8 +258,8 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
||||
return []
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[OutputParam]:
|
||||
"""List of intermediate output parameters. Must be implemented by subclasses."""
|
||||
def inputs(self) -> List[InputParam]:
|
||||
"""List of input parameters. Must be implemented by subclasses."""
|
||||
return []
|
||||
|
||||
@property
|
||||
@@ -352,6 +267,13 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
||||
"""List of intermediate output parameters. Must be implemented by subclasses."""
|
||||
return []
|
||||
|
||||
def _get_outputs(self):
|
||||
return self.intermediate_outputs
|
||||
|
||||
@property
|
||||
def outputs(self) -> List[OutputParam]:
|
||||
return self._get_outputs()
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
@@ -436,12 +358,12 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
||||
def get_block_state(self, state: PipelineState) -> dict:
|
||||
"""Get all inputs and intermediates in one dictionary"""
|
||||
data = {}
|
||||
state_inputs = self.inputs + self.intermediate_inputs
|
||||
state_inputs = self.inputs
|
||||
|
||||
# Check inputs
|
||||
for input_param in state_inputs:
|
||||
if input_param.name:
|
||||
value = state.get_input(input_param.name) or state.get_intermediate(input_param.name)
|
||||
value = state.get(input_param.name)
|
||||
if input_param.required and value is None:
|
||||
raise ValueError(f"Required input '{input_param.name}' is missing")
|
||||
elif value is not None or (value is None and input_param.name not in data):
|
||||
@@ -451,9 +373,7 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
||||
# if kwargs_type is provided, get all inputs with matching kwargs_type
|
||||
if input_param.kwargs_type not in data:
|
||||
data[input_param.kwargs_type] = {}
|
||||
inputs_kwargs = state.get_inputs_kwargs(input_param.kwargs_type) or state.get_intermediate_kwargs(
|
||||
input_param.kwargs_type
|
||||
)
|
||||
inputs_kwargs = state.get_by_kwargs(input_param.kwargs_type)
|
||||
if inputs_kwargs:
|
||||
for k, v in inputs_kwargs.items():
|
||||
if v is not None:
|
||||
@@ -467,25 +387,30 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
||||
if not hasattr(block_state, output_param.name):
|
||||
raise ValueError(f"Intermediate output '{output_param.name}' is missing in block state")
|
||||
param = getattr(block_state, output_param.name)
|
||||
state.set_intermediate(output_param.name, param, output_param.kwargs_type)
|
||||
state.set(output_param.name, param, output_param.kwargs_type)
|
||||
|
||||
for input_param in self.intermediate_inputs:
|
||||
for input_param in self.inputs:
|
||||
if input_param.name and hasattr(block_state, input_param.name):
|
||||
param = getattr(block_state, input_param.name)
|
||||
# Only add if the value is different from what's in the state
|
||||
current_value = state.get_intermediate(input_param.name)
|
||||
current_value = state.get(input_param.name)
|
||||
if current_value is not param: # Using identity comparison to check if object was modified
|
||||
state.set_intermediate(input_param.name, param, input_param.kwargs_type)
|
||||
state.set(input_param.name, param, input_param.kwargs_type)
|
||||
|
||||
elif input_param.kwargs_type:
|
||||
# if it is a kwargs type, e.g. "guider_input_fields", it is likely to be a list of parameters
|
||||
# we need to first find out which inputs are and loop through them.
|
||||
intermediate_kwargs = state.get_intermediate_kwargs(input_param.kwargs_type)
|
||||
intermediate_kwargs = state.get_by_kwargs(input_param.kwargs_type)
|
||||
for param_name, current_value in intermediate_kwargs.items():
|
||||
if param_name is None:
|
||||
continue
|
||||
|
||||
if not hasattr(block_state, param_name):
|
||||
continue
|
||||
|
||||
param = getattr(block_state, param_name)
|
||||
if current_value is not param: # Using identity comparison to check if object was modified
|
||||
state.set_intermediate(param_name, param, input_param.kwargs_type)
|
||||
state.set(param_name, param, input_param.kwargs_type)
|
||||
|
||||
@staticmethod
|
||||
def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> List[InputParam]:
|
||||
@@ -553,199 +478,17 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
||||
|
||||
return list(combined_dict.values())
|
||||
|
||||
|
||||
class PipelineBlock(ModularPipelineBlocks):
|
||||
"""
|
||||
A Pipeline Block is the basic building block of a Modular Pipeline.
|
||||
|
||||
This class inherits from [`ModularPipelineBlocks`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipeline blocks (such as loading or saving etc.)
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This is an experimental feature and is likely to change in the future.
|
||||
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
description (str, optional): A description of the block, defaults to None. Define as a property in subclasses.
|
||||
expected_components (List[ComponentSpec], optional):
|
||||
A list of components that are expected to be used in the block, defaults to []. To override, define as a
|
||||
property in subclasses.
|
||||
expected_configs (List[ConfigSpec], optional):
|
||||
A list of configs that are expected to be used in the block, defaults to []. To override, define as a
|
||||
property in subclasses.
|
||||
inputs (List[InputParam], optional):
|
||||
A list of inputs that are expected to be used in the block, defaults to []. To override, define as a
|
||||
property in subclasses.
|
||||
intermediate_inputs (List[InputParam], optional):
|
||||
A list of intermediate inputs that are expected to be used in the block, defaults to []. To override,
|
||||
define as a property in subclasses.
|
||||
intermediate_outputs (List[OutputParam], optional):
|
||||
A list of intermediate outputs that are expected to be used in the block, defaults to []. To override,
|
||||
define as a property in subclasses.
|
||||
outputs (List[OutputParam], optional):
|
||||
A list of outputs that are expected to be used in the block, defaults to []. To override, define as a
|
||||
property in subclasses.
|
||||
required_inputs (List[str], optional):
|
||||
A list of required inputs that are expected to be used in the block, defaults to []. To override, define as
|
||||
a property in subclasses.
|
||||
required_intermediate_inputs (List[str], optional):
|
||||
A list of required intermediate inputs that are expected to be used in the block, defaults to []. To
|
||||
override, define as a property in subclasses.
|
||||
required_intermediate_outputs (List[str], optional):
|
||||
A list of required intermediate outputs that are expected to be used in the block, defaults to []. To
|
||||
override, define as a property in subclasses.
|
||||
"""
|
||||
|
||||
model_name = None
|
||||
|
||||
def __init__(self):
|
||||
self.sub_blocks = InsertableDict()
|
||||
@property
|
||||
def input_names(self) -> List[str]:
|
||||
return [input_param.name for input_param in self.inputs]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
"""Description of the block. Must be implemented by subclasses."""
|
||||
# raise NotImplementedError("description method must be implemented in subclasses")
|
||||
return "TODO: add a description"
|
||||
def intermediate_output_names(self) -> List[str]:
|
||||
return [output_param.name for output_param in self.intermediate_outputs]
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return []
|
||||
|
||||
@property
|
||||
def expected_configs(self) -> List[ConfigSpec]:
|
||||
return []
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
"""List of input parameters. Must be implemented by subclasses."""
|
||||
return []
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[InputParam]:
|
||||
"""List of intermediate input parameters. Must be implemented by subclasses."""
|
||||
return []
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
"""List of intermediate output parameters. Must be implemented by subclasses."""
|
||||
return []
|
||||
|
||||
def _get_outputs(self):
|
||||
return self.intermediate_outputs
|
||||
|
||||
# YiYi TODO: is it too easy for user to unintentionally override these properties?
|
||||
# Adding outputs attributes here for consistency between PipelineBlock/AutoPipelineBlocks/SequentialPipelineBlocks
|
||||
@property
|
||||
def outputs(self) -> List[OutputParam]:
|
||||
return self._get_outputs()
|
||||
|
||||
def _get_required_inputs(self):
|
||||
input_names = []
|
||||
for input_param in self.inputs:
|
||||
if input_param.required:
|
||||
input_names.append(input_param.name)
|
||||
return input_names
|
||||
|
||||
@property
|
||||
def required_inputs(self) -> List[str]:
|
||||
return self._get_required_inputs()
|
||||
|
||||
def _get_required_intermediate_inputs(self):
|
||||
input_names = []
|
||||
for input_param in self.intermediate_inputs:
|
||||
if input_param.required:
|
||||
input_names.append(input_param.name)
|
||||
return input_names
|
||||
|
||||
# YiYi TODO: maybe we do not need this, it is only used in docstring,
|
||||
# intermediate_inputs is by default required, unless you manually handle it inside the block
|
||||
@property
|
||||
def required_intermediate_inputs(self) -> List[str]:
|
||||
return self._get_required_intermediate_inputs()
|
||||
|
||||
def __call__(self, pipeline, state: PipelineState) -> PipelineState:
|
||||
raise NotImplementedError("__call__ method must be implemented in subclasses")
|
||||
|
||||
def __repr__(self):
|
||||
class_name = self.__class__.__name__
|
||||
base_class = self.__class__.__bases__[0].__name__
|
||||
|
||||
# Format description with proper indentation
|
||||
desc_lines = self.description.split("\n")
|
||||
desc = []
|
||||
# First line with "Description:" label
|
||||
desc.append(f" Description: {desc_lines[0]}")
|
||||
# Subsequent lines with proper indentation
|
||||
if len(desc_lines) > 1:
|
||||
desc.extend(f" {line}" for line in desc_lines[1:])
|
||||
desc = "\n".join(desc) + "\n"
|
||||
|
||||
# Components section - use format_components with add_empty_lines=False
|
||||
expected_components = getattr(self, "expected_components", [])
|
||||
components_str = format_components(expected_components, indent_level=2, add_empty_lines=False)
|
||||
components = " " + components_str.replace("\n", "\n ")
|
||||
|
||||
# Configs section - use format_configs with add_empty_lines=False
|
||||
expected_configs = getattr(self, "expected_configs", [])
|
||||
configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False)
|
||||
configs = " " + configs_str.replace("\n", "\n ")
|
||||
|
||||
# Inputs section
|
||||
inputs_str = format_inputs_short(self.inputs)
|
||||
inputs = "Inputs:\n " + inputs_str
|
||||
|
||||
# Intermediates section
|
||||
intermediates_str = format_intermediates_short(
|
||||
self.intermediate_inputs, self.required_intermediate_inputs, self.intermediate_outputs
|
||||
)
|
||||
intermediates = f"Intermediates:\n{intermediates_str}"
|
||||
|
||||
return f"{class_name}(\n Class: {base_class}\n{desc}{components}\n{configs}\n {inputs}\n {intermediates}\n)"
|
||||
|
||||
@property
|
||||
def doc(self):
|
||||
return make_doc_string(
|
||||
self.inputs,
|
||||
self.intermediate_inputs,
|
||||
self.outputs,
|
||||
self.description,
|
||||
class_name=self.__class__.__name__,
|
||||
expected_components=self.expected_components,
|
||||
expected_configs=self.expected_configs,
|
||||
)
|
||||
|
||||
def set_block_state(self, state: PipelineState, block_state: BlockState):
|
||||
for output_param in self.intermediate_outputs:
|
||||
if not hasattr(block_state, output_param.name):
|
||||
raise ValueError(f"Intermediate output '{output_param.name}' is missing in block state")
|
||||
param = getattr(block_state, output_param.name)
|
||||
state.set_intermediate(output_param.name, param, output_param.kwargs_type)
|
||||
|
||||
for input_param in self.intermediate_inputs:
|
||||
if hasattr(block_state, input_param.name):
|
||||
param = getattr(block_state, input_param.name)
|
||||
# Only add if the value is different from what's in the state
|
||||
current_value = state.get_intermediate(input_param.name)
|
||||
if current_value is not param: # Using identity comparison to check if object was modified
|
||||
state.set_intermediate(input_param.name, param, input_param.kwargs_type)
|
||||
|
||||
for input_param in self.intermediate_inputs:
|
||||
if input_param.name and hasattr(block_state, input_param.name):
|
||||
param = getattr(block_state, input_param.name)
|
||||
# Only add if the value is different from what's in the state
|
||||
current_value = state.get_intermediate(input_param.name)
|
||||
if current_value is not param: # Using identity comparison to check if object was modified
|
||||
state.set_intermediate(input_param.name, param, input_param.kwargs_type)
|
||||
elif input_param.kwargs_type:
|
||||
# if it is a kwargs type, e.g. "guider_input_fields", it is likely to be a list of parameters
|
||||
# we need to first find out which inputs are and loop through them.
|
||||
intermediate_kwargs = state.get_intermediate_kwargs(input_param.kwargs_type)
|
||||
for param_name, current_value in intermediate_kwargs.items():
|
||||
param = getattr(block_state, param_name)
|
||||
if current_value is not param: # Using identity comparison to check if object was modified
|
||||
state.set_intermediate(param_name, param, input_param.kwargs_type)
|
||||
def output_names(self) -> List[str]:
|
||||
return [output_param.name for output_param in self.outputs]
|
||||
|
||||
|
||||
class AutoPipelineBlocks(ModularPipelineBlocks):
|
||||
@@ -836,22 +579,6 @@ class AutoPipelineBlocks(ModularPipelineBlocks):
|
||||
|
||||
return list(required_by_all)
|
||||
|
||||
# YiYi TODO: maybe we do not need this, it is only used in docstring,
|
||||
# intermediate_inputs is by default required, unless you manually handle it inside the block
|
||||
@property
|
||||
def required_intermediate_inputs(self) -> List[str]:
|
||||
if None not in self.block_trigger_inputs:
|
||||
return []
|
||||
first_block = next(iter(self.sub_blocks.values()))
|
||||
required_by_all = set(getattr(first_block, "required_intermediate_inputs", set()))
|
||||
|
||||
# Intersect with required inputs from all other blocks
|
||||
for block in list(self.sub_blocks.values())[1:]:
|
||||
block_required = set(getattr(block, "required_intermediate_inputs", set()))
|
||||
required_by_all.intersection_update(block_required)
|
||||
|
||||
return list(required_by_all)
|
||||
|
||||
# YiYi TODO: add test for this
|
||||
@property
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
@@ -865,18 +592,6 @@ class AutoPipelineBlocks(ModularPipelineBlocks):
|
||||
input_param.required = False
|
||||
return combined_inputs
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
named_inputs = [(name, block.intermediate_inputs) for name, block in self.sub_blocks.items()]
|
||||
combined_inputs = self.combine_inputs(*named_inputs)
|
||||
# mark Required inputs only if that input is required by all the blocks
|
||||
for input_param in combined_inputs:
|
||||
if input_param.name in self.required_intermediate_inputs:
|
||||
input_param.required = True
|
||||
else:
|
||||
input_param.required = False
|
||||
return combined_inputs
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[str]:
|
||||
named_outputs = [(name, block.intermediate_outputs) for name, block in self.sub_blocks.items()]
|
||||
@@ -895,10 +610,10 @@ class AutoPipelineBlocks(ModularPipelineBlocks):
|
||||
|
||||
block = self.trigger_to_block_map.get(None)
|
||||
for input_name in self.block_trigger_inputs:
|
||||
if input_name is not None and state.get_input(input_name) is not None:
|
||||
if input_name is not None and state.get(input_name) is not None:
|
||||
block = self.trigger_to_block_map[input_name]
|
||||
break
|
||||
elif input_name is not None and state.get_intermediate(input_name) is not None:
|
||||
elif input_name is not None and state.get(input_name) is not None:
|
||||
block = self.trigger_to_block_map[input_name]
|
||||
break
|
||||
|
||||
@@ -1117,6 +832,34 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
sub_blocks[block_name] = block_cls()
|
||||
self.sub_blocks = sub_blocks
|
||||
|
||||
def _get_inputs(self):
|
||||
inputs = []
|
||||
outputs = set()
|
||||
|
||||
# Go through all blocks in order
|
||||
for block in self.sub_blocks.values():
|
||||
# Add inputs that aren't in outputs yet
|
||||
for inp in block.inputs:
|
||||
if inp.name not in outputs and inp.name not in {input.name for input in inputs}:
|
||||
inputs.append(inp)
|
||||
|
||||
# Only add outputs if the block cannot be skipped
|
||||
should_add_outputs = True
|
||||
if hasattr(block, "block_trigger_inputs") and None not in block.block_trigger_inputs:
|
||||
should_add_outputs = False
|
||||
|
||||
if should_add_outputs:
|
||||
# Add this block's outputs
|
||||
block_intermediate_outputs = [out.name for out in block.intermediate_outputs]
|
||||
outputs.update(block_intermediate_outputs)
|
||||
|
||||
return inputs
|
||||
|
||||
# YiYi TODO: add test for this
|
||||
@property
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return self._get_inputs()
|
||||
|
||||
@property
|
||||
def required_inputs(self) -> List[str]:
|
||||
# Get the first block from the dictionary
|
||||
@@ -1130,65 +873,11 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
|
||||
return list(required_by_any)
|
||||
|
||||
# YiYi TODO: maybe we do not need this, it is only used in docstring,
|
||||
# intermediate_inputs is by default required, unless you manually handle it inside the block
|
||||
@property
|
||||
def required_intermediate_inputs(self) -> List[str]:
|
||||
required_intermediate_inputs = []
|
||||
for input_param in self.intermediate_inputs:
|
||||
if input_param.required:
|
||||
required_intermediate_inputs.append(input_param.name)
|
||||
return required_intermediate_inputs
|
||||
|
||||
# YiYi TODO: add test for this
|
||||
@property
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return self.get_inputs()
|
||||
|
||||
def get_inputs(self):
|
||||
named_inputs = [(name, block.inputs) for name, block in self.sub_blocks.items()]
|
||||
combined_inputs = self.combine_inputs(*named_inputs)
|
||||
# mark Required inputs only if that input is required any of the blocks
|
||||
for input_param in combined_inputs:
|
||||
if input_param.name in self.required_inputs:
|
||||
input_param.required = True
|
||||
else:
|
||||
input_param.required = False
|
||||
return combined_inputs
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return self.get_intermediate_inputs()
|
||||
|
||||
def get_intermediate_inputs(self):
|
||||
inputs = []
|
||||
outputs = set()
|
||||
added_inputs = set()
|
||||
|
||||
# Go through all blocks in order
|
||||
for block in self.sub_blocks.values():
|
||||
# Add inputs that aren't in outputs yet
|
||||
for inp in block.intermediate_inputs:
|
||||
if inp.name not in outputs and inp.name not in added_inputs:
|
||||
inputs.append(inp)
|
||||
added_inputs.add(inp.name)
|
||||
|
||||
# Only add outputs if the block cannot be skipped
|
||||
should_add_outputs = True
|
||||
if hasattr(block, "block_trigger_inputs") and None not in block.block_trigger_inputs:
|
||||
should_add_outputs = False
|
||||
|
||||
if should_add_outputs:
|
||||
# Add this block's outputs
|
||||
block_intermediate_outputs = [out.name for out in block.intermediate_outputs]
|
||||
outputs.update(block_intermediate_outputs)
|
||||
return inputs
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[str]:
|
||||
named_outputs = []
|
||||
for name, block in self.sub_blocks.items():
|
||||
inp_names = {inp.name for inp in block.intermediate_inputs}
|
||||
inp_names = {inp.name for inp in block.inputs}
|
||||
# so we only need to list new variables as intermediate_outputs, but if user wants to list these they modified it's still fine (a.k.a we don't enforce)
|
||||
# filter out them here so they do not end up as intermediate_outputs
|
||||
if name not in inp_names:
|
||||
@@ -1406,7 +1095,6 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
def doc(self):
|
||||
return make_doc_string(
|
||||
self.inputs,
|
||||
self.intermediate_inputs,
|
||||
self.outputs,
|
||||
self.description,
|
||||
class_name=self.__class__.__name__,
|
||||
@@ -1456,16 +1144,6 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
"""List of input parameters. Must be implemented by subclasses."""
|
||||
return []
|
||||
|
||||
@property
|
||||
def loop_intermediate_inputs(self) -> List[InputParam]:
|
||||
"""List of intermediate input parameters. Must be implemented by subclasses."""
|
||||
return []
|
||||
|
||||
@property
|
||||
def loop_intermediate_outputs(self) -> List[OutputParam]:
|
||||
"""List of intermediate output parameters. Must be implemented by subclasses."""
|
||||
return []
|
||||
|
||||
@property
|
||||
def loop_required_inputs(self) -> List[str]:
|
||||
input_names = []
|
||||
@@ -1475,12 +1153,9 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
return input_names
|
||||
|
||||
@property
|
||||
def loop_required_intermediate_inputs(self) -> List[str]:
|
||||
input_names = []
|
||||
for input_param in self.loop_intermediate_inputs:
|
||||
if input_param.required:
|
||||
input_names.append(input_param.name)
|
||||
return input_names
|
||||
def loop_intermediate_outputs(self) -> List[OutputParam]:
|
||||
"""List of intermediate output parameters. Must be implemented by subclasses."""
|
||||
return []
|
||||
|
||||
# modified from SequentialPipelineBlocks to include loop_expected_components
|
||||
@property
|
||||
@@ -1508,43 +1183,16 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
expected_configs.append(config)
|
||||
return expected_configs
|
||||
|
||||
# modified from SequentialPipelineBlocks to include loop_inputs
|
||||
def get_inputs(self):
|
||||
named_inputs = [(name, block.inputs) for name, block in self.sub_blocks.items()]
|
||||
named_inputs.append(("loop", self.loop_inputs))
|
||||
combined_inputs = self.combine_inputs(*named_inputs)
|
||||
# mark Required inputs only if that input is required any of the blocks
|
||||
for input_param in combined_inputs:
|
||||
if input_param.name in self.required_inputs:
|
||||
input_param.required = True
|
||||
else:
|
||||
input_param.required = False
|
||||
return combined_inputs
|
||||
|
||||
@property
|
||||
# Copied from diffusers.modular_pipelines.modular_pipeline.SequentialPipelineBlocks.inputs
|
||||
def inputs(self):
|
||||
return self.get_inputs()
|
||||
|
||||
# modified from SequentialPipelineBlocks to include loop_intermediate_inputs
|
||||
@property
|
||||
def intermediate_inputs(self):
|
||||
intermediates = self.get_intermediate_inputs()
|
||||
intermediate_names = [input.name for input in intermediates]
|
||||
for loop_intermediate_input in self.loop_intermediate_inputs:
|
||||
if loop_intermediate_input.name not in intermediate_names:
|
||||
intermediates.append(loop_intermediate_input)
|
||||
return intermediates
|
||||
|
||||
# modified from SequentialPipelineBlocks
|
||||
def get_intermediate_inputs(self):
|
||||
def _get_inputs(self):
|
||||
inputs = []
|
||||
inputs.extend(self.loop_inputs)
|
||||
outputs = set()
|
||||
|
||||
# Go through all blocks in order
|
||||
for block in self.sub_blocks.values():
|
||||
for name, block in self.sub_blocks.items():
|
||||
# Add inputs that aren't in outputs yet
|
||||
inputs.extend(input_name for input_name in block.intermediate_inputs if input_name.name not in outputs)
|
||||
for inp in block.inputs:
|
||||
if inp.name not in outputs and inp not in inputs:
|
||||
inputs.append(inp)
|
||||
|
||||
# Only add outputs if the block cannot be skipped
|
||||
should_add_outputs = True
|
||||
@@ -1555,8 +1203,20 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
# Add this block's outputs
|
||||
block_intermediate_outputs = [out.name for out in block.intermediate_outputs]
|
||||
outputs.update(block_intermediate_outputs)
|
||||
|
||||
for input_param in inputs:
|
||||
if input_param.name in self.required_inputs:
|
||||
input_param.required = True
|
||||
else:
|
||||
input_param.required = False
|
||||
|
||||
return inputs
|
||||
|
||||
@property
|
||||
# Copied from diffusers.modular_pipelines.modular_pipeline.SequentialPipelineBlocks.inputs
|
||||
def inputs(self):
|
||||
return self._get_inputs()
|
||||
|
||||
# modified from SequentialPipelineBlocks, if any additionan input required by the loop is required by the block
|
||||
@property
|
||||
def required_inputs(self) -> List[str]:
|
||||
@@ -1574,19 +1234,6 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
|
||||
return list(required_by_any)
|
||||
|
||||
# YiYi TODO: maybe we do not need this, it is only used in docstring,
|
||||
# intermediate_inputs is by default required, unless you manually handle it inside the block
|
||||
@property
|
||||
def required_intermediate_inputs(self) -> List[str]:
|
||||
required_intermediate_inputs = []
|
||||
for input_param in self.intermediate_inputs:
|
||||
if input_param.required:
|
||||
required_intermediate_inputs.append(input_param.name)
|
||||
for input_param in self.loop_intermediate_inputs:
|
||||
if input_param.required:
|
||||
required_intermediate_inputs.append(input_param.name)
|
||||
return required_intermediate_inputs
|
||||
|
||||
# YiYi TODO: this need to be thought about more
|
||||
# modified from SequentialPipelineBlocks to include loop_intermediate_outputs
|
||||
@property
|
||||
@@ -1876,96 +1523,6 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
params[input_param.name] = input_param.default
|
||||
return params
|
||||
|
||||
def __call__(self, state: PipelineState = None, output: Union[str, List[str]] = None, **kwargs):
|
||||
"""
|
||||
Execute the pipeline by running the pipeline blocks with the given inputs.
|
||||
|
||||
Args:
|
||||
state (`PipelineState`, optional):
|
||||
PipelineState instance contains inputs and intermediate values. If None, a new `PipelineState` will be
|
||||
created based on the user inputs and the pipeline blocks's requirement.
|
||||
output (`str` or `List[str]`, optional):
|
||||
Optional specification of what to return:
|
||||
- None: Returns the complete `PipelineState` with all inputs and intermediates (default)
|
||||
- str: Returns a specific intermediate value from the state (e.g. `output="image"`)
|
||||
- List[str]: Returns a dictionary of specific intermediate values (e.g. `output=["image",
|
||||
"latents"]`)
|
||||
|
||||
|
||||
Examples:
|
||||
```python
|
||||
# Get complete pipeline state
|
||||
state = pipeline(prompt="A beautiful sunset", num_inference_steps=20)
|
||||
print(state.intermediates) # All intermediate outputs
|
||||
|
||||
# Get specific output
|
||||
image = pipeline(prompt="A beautiful sunset", output="image")
|
||||
|
||||
# Get multiple specific outputs
|
||||
results = pipeline(prompt="A beautiful sunset", output=["image", "latents"])
|
||||
image, latents = results["image"], results["latents"]
|
||||
|
||||
# Continue from previous state
|
||||
state = pipeline(prompt="A beautiful sunset")
|
||||
new_state = pipeline(state=state, output="image") # Continue processing
|
||||
```
|
||||
|
||||
Returns:
|
||||
- If `output` is None: Complete `PipelineState` containing all inputs and intermediates
|
||||
- If `output` is str: The specific intermediate value from the state (e.g. `output="image"`)
|
||||
- If `output` is List[str]: Dictionary mapping output names to their values from the state (e.g.
|
||||
`output=["image", "latents"]`)
|
||||
"""
|
||||
if state is None:
|
||||
state = PipelineState()
|
||||
|
||||
# Make a copy of the input kwargs
|
||||
passed_kwargs = kwargs.copy()
|
||||
|
||||
# Add inputs to state, using defaults if not provided in the kwargs or the state
|
||||
# if same input already in the state, will override it if provided in the kwargs
|
||||
intermediate_inputs = [inp.name for inp in self.blocks.intermediate_inputs]
|
||||
for expected_input_param in self.blocks.inputs:
|
||||
name = expected_input_param.name
|
||||
default = expected_input_param.default
|
||||
kwargs_type = expected_input_param.kwargs_type
|
||||
if name in passed_kwargs:
|
||||
if name not in intermediate_inputs:
|
||||
state.set_input(name, passed_kwargs.pop(name), kwargs_type)
|
||||
else:
|
||||
state.set_input(name, passed_kwargs[name], kwargs_type)
|
||||
elif name not in state.inputs:
|
||||
state.set_input(name, default, kwargs_type)
|
||||
|
||||
for expected_intermediate_param in self.blocks.intermediate_inputs:
|
||||
name = expected_intermediate_param.name
|
||||
kwargs_type = expected_intermediate_param.kwargs_type
|
||||
if name in passed_kwargs:
|
||||
state.set_intermediate(name, passed_kwargs.pop(name), kwargs_type)
|
||||
|
||||
# Warn about unexpected inputs
|
||||
if len(passed_kwargs) > 0:
|
||||
warnings.warn(f"Unexpected input '{passed_kwargs.keys()}' provided. This input will be ignored.")
|
||||
# Run the pipeline
|
||||
with torch.no_grad():
|
||||
try:
|
||||
_, state = self.blocks(self, state)
|
||||
except Exception:
|
||||
error_msg = f"Error in block: ({self.blocks.__class__.__name__}):\n"
|
||||
logger.error(error_msg)
|
||||
raise
|
||||
|
||||
if output is None:
|
||||
return state
|
||||
|
||||
elif isinstance(output, str):
|
||||
return state.get_intermediate(output)
|
||||
|
||||
elif isinstance(output, (list, tuple)):
|
||||
return state.get_intermediates(output)
|
||||
else:
|
||||
raise ValueError(f"Output '{output}' is not a valid output type")
|
||||
|
||||
def load_default_components(self, **kwargs):
|
||||
"""
|
||||
Load from_pretrained components using the loading specs in the config dict.
|
||||
@@ -2784,3 +2341,92 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
type_hint=type_hint,
|
||||
**spec_dict,
|
||||
)
|
||||
|
||||
def set_progress_bar_config(self, **kwargs):
|
||||
for sub_block_name, sub_block in self.blocks.sub_blocks.items():
|
||||
if hasattr(sub_block, "set_progress_bar_config"):
|
||||
sub_block.set_progress_bar_config(**kwargs)
|
||||
|
||||
def __call__(self, state: PipelineState = None, output: Union[str, List[str]] = None, **kwargs):
|
||||
"""
|
||||
Execute the pipeline by running the pipeline blocks with the given inputs.
|
||||
|
||||
Args:
|
||||
state (`PipelineState`, optional):
|
||||
PipelineState instance contains inputs and intermediate values. If None, a new `PipelineState` will be
|
||||
created based on the user inputs and the pipeline blocks's requirement.
|
||||
output (`str` or `List[str]`, optional):
|
||||
Optional specification of what to return:
|
||||
- None: Returns the complete `PipelineState` with all inputs and intermediates (default)
|
||||
- str: Returns a specific intermediate value from the state (e.g. `output="image"`)
|
||||
- List[str]: Returns a dictionary of specific intermediate values (e.g. `output=["image",
|
||||
"latents"]`)
|
||||
|
||||
|
||||
Examples:
|
||||
```python
|
||||
# Get complete pipeline state
|
||||
state = pipeline(prompt="A beautiful sunset", num_inference_steps=20)
|
||||
print(state.intermediates) # All intermediate outputs
|
||||
|
||||
# Get specific output
|
||||
image = pipeline(prompt="A beautiful sunset", output="image")
|
||||
|
||||
# Get multiple specific outputs
|
||||
results = pipeline(prompt="A beautiful sunset", output=["image", "latents"])
|
||||
image, latents = results["image"], results["latents"]
|
||||
|
||||
# Continue from previous state
|
||||
state = pipeline(prompt="A beautiful sunset")
|
||||
new_state = pipeline(state=state, output="image") # Continue processing
|
||||
```
|
||||
|
||||
Returns:
|
||||
- If `output` is None: Complete `PipelineState` containing all inputs and intermediates
|
||||
- If `output` is str: The specific intermediate value from the state (e.g. `output="image"`)
|
||||
- If `output` is List[str]: Dictionary mapping output names to their values from the state (e.g.
|
||||
`output=["image", "latents"]`)
|
||||
"""
|
||||
if state is None:
|
||||
state = PipelineState()
|
||||
|
||||
# Make a copy of the input kwargs
|
||||
passed_kwargs = kwargs.copy()
|
||||
|
||||
# Add inputs to state, using defaults if not provided in the kwargs or the state
|
||||
# if same input already in the state, will override it if provided in the kwargs
|
||||
intermediate_inputs = [inp.name for inp in self.blocks.inputs]
|
||||
for expected_input_param in self.blocks.inputs:
|
||||
name = expected_input_param.name
|
||||
default = expected_input_param.default
|
||||
kwargs_type = expected_input_param.kwargs_type
|
||||
if name in passed_kwargs:
|
||||
if name not in intermediate_inputs:
|
||||
state.set(name, passed_kwargs.pop(name), kwargs_type)
|
||||
else:
|
||||
state.set(name, passed_kwargs[name], kwargs_type)
|
||||
elif name not in state.values:
|
||||
state.set(name, default, kwargs_type)
|
||||
|
||||
# Warn about unexpected inputs
|
||||
if len(passed_kwargs) > 0:
|
||||
warnings.warn(f"Unexpected input '{passed_kwargs.keys()}' provided. This input will be ignored.")
|
||||
# Run the pipeline
|
||||
with torch.no_grad():
|
||||
try:
|
||||
_, state = self.blocks(self, state)
|
||||
except Exception:
|
||||
error_msg = f"Error in block: ({self.blocks.__class__.__name__}):\n"
|
||||
logger.error(error_msg)
|
||||
raise
|
||||
|
||||
if output is None:
|
||||
return state
|
||||
|
||||
if isinstance(output, str):
|
||||
return state.get(output)
|
||||
|
||||
elif isinstance(output, (list, tuple)):
|
||||
return state.get(output)
|
||||
else:
|
||||
raise ValueError(f"Output '{output}' is not a valid output type")
|
||||
|
||||
@@ -27,7 +27,7 @@ from ...schedulers import EulerDiscreteScheduler
|
||||
from ...utils import logging
|
||||
from ...utils.torch_utils import randn_tensor, unwrap_module
|
||||
from ..modular_pipeline import (
|
||||
PipelineBlock,
|
||||
ModularPipelineBlocks,
|
||||
PipelineState,
|
||||
)
|
||||
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
|
||||
@@ -195,7 +195,7 @@ def prepare_latents_img2img(
|
||||
return latents
|
||||
|
||||
|
||||
class StableDiffusionXLInputStep(PipelineBlock):
|
||||
class StableDiffusionXLInputStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -213,11 +213,6 @@ class StableDiffusionXLInputStep(PipelineBlock):
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam(
|
||||
"prompt_embeds",
|
||||
required=True,
|
||||
@@ -394,7 +389,7 @@ class StableDiffusionXLInputStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock):
|
||||
class StableDiffusionXLImg2ImgSetTimestepsStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -421,11 +416,6 @@ class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock):
|
||||
InputParam("denoising_start"),
|
||||
# YiYi TODO: do we need num_images_per_prompt here?
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam(
|
||||
"batch_size",
|
||||
required=True,
|
||||
@@ -543,7 +533,7 @@ class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLSetTimestepsStep(PipelineBlock):
|
||||
class StableDiffusionXLSetTimestepsStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -611,7 +601,7 @@ class StableDiffusionXLSetTimestepsStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
|
||||
class StableDiffusionXLInpaintPrepareLatentsStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -640,11 +630,6 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
|
||||
"`num_inference_steps`. A value of 1, therefore, essentially ignores `image`. Note that in the case of "
|
||||
"`denoising_start` being declared as an integer, the value of `strength` will be ignored.",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam("generator"),
|
||||
InputParam(
|
||||
"batch_size",
|
||||
@@ -744,8 +729,6 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
|
||||
timestep=None,
|
||||
is_strength_max=True,
|
||||
add_noise=True,
|
||||
return_noise=False,
|
||||
return_image_latents=False,
|
||||
):
|
||||
shape = (
|
||||
batch_size,
|
||||
@@ -768,7 +751,7 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
|
||||
if image.shape[1] == 4:
|
||||
image_latents = image.to(device=device, dtype=dtype)
|
||||
image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
|
||||
elif return_image_latents or (latents is None and not is_strength_max):
|
||||
elif latents is None and not is_strength_max:
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
image_latents = self._encode_vae_image(components, image=image, generator=generator)
|
||||
image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
|
||||
@@ -786,13 +769,7 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
|
||||
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
latents = image_latents.to(device)
|
||||
|
||||
outputs = (latents,)
|
||||
|
||||
if return_noise:
|
||||
outputs += (noise,)
|
||||
|
||||
if return_image_latents:
|
||||
outputs += (image_latents,)
|
||||
outputs = (latents, noise, image_latents)
|
||||
|
||||
return outputs
|
||||
|
||||
@@ -864,7 +841,7 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
|
||||
block_state.height = block_state.image_latents.shape[-2] * components.vae_scale_factor
|
||||
block_state.width = block_state.image_latents.shape[-1] * components.vae_scale_factor
|
||||
|
||||
block_state.latents, block_state.noise = self.prepare_latents_inpaint(
|
||||
block_state.latents, block_state.noise, block_state.image_latents = self.prepare_latents_inpaint(
|
||||
components,
|
||||
block_state.batch_size * block_state.num_images_per_prompt,
|
||||
components.num_channels_latents,
|
||||
@@ -878,8 +855,6 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
|
||||
timestep=block_state.latent_timestep,
|
||||
is_strength_max=block_state.is_strength_max,
|
||||
add_noise=block_state.add_noise,
|
||||
return_noise=True,
|
||||
return_image_latents=False,
|
||||
)
|
||||
|
||||
# 7. Prepare mask latent variables
|
||||
@@ -900,7 +875,7 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock):
|
||||
class StableDiffusionXLImg2ImgPrepareLatentsStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -920,11 +895,6 @@ class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock):
|
||||
InputParam("latents"),
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
InputParam("denoising_start"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("generator"),
|
||||
InputParam(
|
||||
"latent_timestep",
|
||||
@@ -981,7 +951,7 @@ class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLPrepareLatentsStep(PipelineBlock):
|
||||
class StableDiffusionXLPrepareLatentsStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -1002,11 +972,6 @@ class StableDiffusionXLPrepareLatentsStep(PipelineBlock):
|
||||
InputParam("width"),
|
||||
InputParam("latents"),
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("generator"),
|
||||
InputParam(
|
||||
"batch_size",
|
||||
@@ -1092,7 +1057,7 @@ class StableDiffusionXLPrepareLatentsStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock):
|
||||
class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -1129,11 +1094,6 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock):
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
InputParam("aesthetic_score", default=6.0),
|
||||
InputParam("negative_aesthetic_score", default=2.0),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
@@ -1316,7 +1276,7 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock):
|
||||
class StableDiffusionXLPrepareAdditionalConditioningStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -1345,11 +1305,6 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock):
|
||||
InputParam("crops_coords_top_left", default=(0, 0)),
|
||||
InputParam("negative_crops_coords_top_left", default=(0, 0)),
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
@@ -1499,7 +1454,7 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLControlNetInputStep(PipelineBlock):
|
||||
class StableDiffusionXLControlNetInputStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -1527,11 +1482,6 @@ class StableDiffusionXLControlNetInputStep(PipelineBlock):
|
||||
InputParam("controlnet_conditioning_scale", default=1.0),
|
||||
InputParam("guess_mode", default=False),
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
@@ -1718,7 +1668,7 @@ class StableDiffusionXLControlNetInputStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLControlNetUnionInputStep(PipelineBlock):
|
||||
class StableDiffusionXLControlNetUnionInputStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
|
||||
@@ -24,7 +24,7 @@ from ...models import AutoencoderKL
|
||||
from ...models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import (
|
||||
PipelineBlock,
|
||||
ModularPipelineBlocks,
|
||||
PipelineState,
|
||||
)
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
@@ -33,7 +33,7 @@ from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class StableDiffusionXLDecodeStep(PipelineBlock):
|
||||
class StableDiffusionXLDecodeStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -56,17 +56,12 @@ class StableDiffusionXLDecodeStep(PipelineBlock):
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return [
|
||||
InputParam("output_type", default="pil"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The denoised latents from the denoising step",
|
||||
)
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -157,7 +152,7 @@ class StableDiffusionXLDecodeStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLInpaintOverlayMaskStep(PipelineBlock):
|
||||
class StableDiffusionXLInpaintOverlayMaskStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
|
||||
@@ -25,7 +25,7 @@ from ...utils import logging
|
||||
from ..modular_pipeline import (
|
||||
BlockState,
|
||||
LoopSequentialPipelineBlocks,
|
||||
PipelineBlock,
|
||||
ModularPipelineBlocks,
|
||||
PipelineState,
|
||||
)
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
@@ -37,7 +37,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
# YiYi experimenting composible denoise loop
|
||||
# loop step (1): prepare latent input for denoiser
|
||||
class StableDiffusionXLLoopBeforeDenoiser(PipelineBlock):
|
||||
class StableDiffusionXLLoopBeforeDenoiser(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -55,7 +55,7 @@ class StableDiffusionXLLoopBeforeDenoiser(PipelineBlock):
|
||||
)
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
def inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam(
|
||||
"latents",
|
||||
@@ -73,7 +73,7 @@ class StableDiffusionXLLoopBeforeDenoiser(PipelineBlock):
|
||||
|
||||
|
||||
# loop step (1): prepare latent input for denoiser (with inpainting)
|
||||
class StableDiffusionXLInpaintLoopBeforeDenoiser(PipelineBlock):
|
||||
class StableDiffusionXLInpaintLoopBeforeDenoiser(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -91,7 +91,7 @@ class StableDiffusionXLInpaintLoopBeforeDenoiser(PipelineBlock):
|
||||
)
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
def inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam(
|
||||
"latents",
|
||||
@@ -144,7 +144,7 @@ class StableDiffusionXLInpaintLoopBeforeDenoiser(PipelineBlock):
|
||||
|
||||
|
||||
# loop step (2): denoise the latents with guidance
|
||||
class StableDiffusionXLLoopDenoiser(PipelineBlock):
|
||||
class StableDiffusionXLLoopDenoiser(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -171,11 +171,6 @@ class StableDiffusionXLLoopDenoiser(PipelineBlock):
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return [
|
||||
InputParam("cross_attention_kwargs"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam(
|
||||
"num_inference_steps",
|
||||
required=True,
|
||||
@@ -249,7 +244,7 @@ class StableDiffusionXLLoopDenoiser(PipelineBlock):
|
||||
|
||||
|
||||
# loop step (2): denoise the latents with guidance (with controlnet)
|
||||
class StableDiffusionXLControlNetLoopDenoiser(PipelineBlock):
|
||||
class StableDiffusionXLControlNetLoopDenoiser(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -277,11 +272,6 @@ class StableDiffusionXLControlNetLoopDenoiser(PipelineBlock):
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return [
|
||||
InputParam("cross_attention_kwargs"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam(
|
||||
"controlnet_cond",
|
||||
required=True,
|
||||
@@ -449,7 +439,7 @@ class StableDiffusionXLControlNetLoopDenoiser(PipelineBlock):
|
||||
|
||||
|
||||
# loop step (3): scheduler step to update latents
|
||||
class StableDiffusionXLLoopAfterDenoiser(PipelineBlock):
|
||||
class StableDiffusionXLLoopAfterDenoiser(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -470,11 +460,6 @@ class StableDiffusionXLLoopAfterDenoiser(PipelineBlock):
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return [
|
||||
InputParam("eta", default=0.0),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam("generator"),
|
||||
]
|
||||
|
||||
@@ -520,7 +505,7 @@ class StableDiffusionXLLoopAfterDenoiser(PipelineBlock):
|
||||
|
||||
|
||||
# loop step (3): scheduler step to update latents (with inpainting)
|
||||
class StableDiffusionXLInpaintLoopAfterDenoiser(PipelineBlock):
|
||||
class StableDiffusionXLInpaintLoopAfterDenoiser(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -542,11 +527,6 @@ class StableDiffusionXLInpaintLoopAfterDenoiser(PipelineBlock):
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return [
|
||||
InputParam("eta", default=0.0),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam("generator"),
|
||||
InputParam(
|
||||
"timesteps",
|
||||
@@ -660,7 +640,7 @@ class StableDiffusionXLDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
|
||||
]
|
||||
|
||||
@property
|
||||
def loop_intermediate_inputs(self) -> List[InputParam]:
|
||||
def loop_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
"timesteps",
|
||||
|
||||
@@ -35,7 +35,7 @@ from ...utils import (
|
||||
scale_lora_layers,
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from ..modular_pipeline import PipelineBlock, PipelineState
|
||||
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
|
||||
from .modular_pipeline import StableDiffusionXLModularPipeline
|
||||
|
||||
@@ -57,7 +57,7 @@ def retrieve_latents(
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
class StableDiffusionXLIPAdapterStep(PipelineBlock):
|
||||
class StableDiffusionXLIPAdapterStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -215,7 +215,7 @@ class StableDiffusionXLIPAdapterStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLTextEncoderStep(PipelineBlock):
|
||||
class StableDiffusionXLTextEncoderStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -576,7 +576,7 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLVaeEncoderStep(PipelineBlock):
|
||||
class StableDiffusionXLVaeEncoderStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -601,11 +601,6 @@ class StableDiffusionXLVaeEncoderStep(PipelineBlock):
|
||||
InputParam("image", required=True),
|
||||
InputParam("height"),
|
||||
InputParam("width"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("generator"),
|
||||
InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"),
|
||||
InputParam(
|
||||
@@ -691,7 +686,7 @@ class StableDiffusionXLVaeEncoderStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
|
||||
class StableDiffusionXLInpaintVaeEncoderStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -726,11 +721,6 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
|
||||
InputParam("image", required=True),
|
||||
InputParam("mask_image", required=True),
|
||||
InputParam("padding_mask_crop"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"),
|
||||
InputParam("generator"),
|
||||
]
|
||||
|
||||
@@ -247,10 +247,6 @@ SDXL_INPUTS_SCHEMA = {
|
||||
"control_mode": InputParam(
|
||||
"control_mode", type_hint=List[int], required=True, description="Control mode for union controlnet"
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
SDXL_INTERMEDIATE_INPUTS_SCHEMA = {
|
||||
"prompt_embeds": InputParam(
|
||||
"prompt_embeds",
|
||||
type_hint=torch.Tensor,
|
||||
@@ -271,13 +267,6 @@ SDXL_INTERMEDIATE_INPUTS_SCHEMA = {
|
||||
"preprocess_kwargs": InputParam(
|
||||
"preprocess_kwargs", type_hint=Optional[dict], description="Kwargs for ImageProcessor"
|
||||
),
|
||||
"latents": InputParam(
|
||||
"latents", type_hint=torch.Tensor, required=True, description="Initial latents for denoising process"
|
||||
),
|
||||
"timesteps": InputParam("timesteps", type_hint=torch.Tensor, required=True, description="Timesteps for inference"),
|
||||
"num_inference_steps": InputParam(
|
||||
"num_inference_steps", type_hint=int, required=True, description="Number of denoising steps"
|
||||
),
|
||||
"latent_timestep": InputParam(
|
||||
"latent_timestep", type_hint=torch.Tensor, required=True, description="Initial noise level timestep"
|
||||
),
|
||||
|
||||
+488
@@ -0,0 +1,488 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import random
|
||||
import tempfile
|
||||
import unittest
|
||||
from typing import Any, Dict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from diffusers import (
|
||||
ClassifierFreeGuidance,
|
||||
ModularPipeline,
|
||||
StableDiffusionXLAutoBlocks,
|
||||
StableDiffusionXLModularPipeline,
|
||||
)
|
||||
from diffusers.loaders import ModularIPAdapterMixin
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
floats_tensor,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ...models.unets.test_models_unet_2d_condition import (
|
||||
create_ip_adapter_state_dict,
|
||||
)
|
||||
from ..test_modular_pipelines_common import (
|
||||
ModularPipelineTesterMixin,
|
||||
)
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class SDXLModularTests:
|
||||
"""
|
||||
This mixin defines method to create pipeline, base input and base test across all SDXL modular tests.
|
||||
"""
|
||||
|
||||
pipeline_class = StableDiffusionXLModularPipeline
|
||||
pipeline_blocks_class = StableDiffusionXLAutoBlocks
|
||||
repo = "hf-internal-testing/tiny-sdxl-modular"
|
||||
params = frozenset(
|
||||
[
|
||||
"prompt",
|
||||
"height",
|
||||
"width",
|
||||
"negative_prompt",
|
||||
"cross_attention_kwargs",
|
||||
"image",
|
||||
"mask_image",
|
||||
]
|
||||
)
|
||||
batch_params = frozenset(["prompt", "negative_prompt", "image", "mask_image"])
|
||||
|
||||
def get_pipeline(self, components_manager=None, torch_dtype=torch.float32):
|
||||
pipeline = self.pipeline_blocks_class().init_pipeline(self.repo, components_manager=components_manager)
|
||||
pipeline.load_default_components(torch_dtype=torch_dtype)
|
||||
return pipeline
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"output_type": "np",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def _test_stable_diffusion_xl_euler(self, expected_image_shape, expected_slice, expected_max_diff=1e-2):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
sd_pipe = self.get_pipeline()
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = sd_pipe(**inputs, output="images")
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == expected_image_shape
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < expected_max_diff, (
|
||||
"Image Slice does not match expected slice"
|
||||
)
|
||||
|
||||
|
||||
class SDXLModularIPAdapterTests:
|
||||
"""
|
||||
This mixin is designed to test IP Adapter.
|
||||
"""
|
||||
|
||||
def test_pipeline_inputs_and_blocks(self):
|
||||
blocks = self.pipeline_blocks_class()
|
||||
parameters = blocks.input_names
|
||||
|
||||
assert issubclass(self.pipeline_class, ModularIPAdapterMixin)
|
||||
assert "ip_adapter_image" in parameters, (
|
||||
"`ip_adapter_image` argument must be supported by the `__call__` method"
|
||||
)
|
||||
assert "ip_adapter" in blocks.sub_blocks, "pipeline must contain an IPAdapter block"
|
||||
|
||||
_ = blocks.sub_blocks.pop("ip_adapter")
|
||||
parameters = blocks.input_names
|
||||
assert "ip_adapter_image" not in parameters, (
|
||||
"`ip_adapter_image` argument must be removed from the `__call__` method"
|
||||
)
|
||||
assert "ip_adapter_image_embeds" not in parameters, (
|
||||
"`ip_adapter_image_embeds` argument must be supported by the `__call__` method"
|
||||
)
|
||||
|
||||
def _get_dummy_image_embeds(self, cross_attention_dim: int = 32):
|
||||
return torch.randn((1, 1, cross_attention_dim), device=torch_device)
|
||||
|
||||
def _get_dummy_faceid_image_embeds(self, cross_attention_dim: int = 32):
|
||||
return torch.randn((1, 1, 1, cross_attention_dim), device=torch_device)
|
||||
|
||||
def _get_dummy_masks(self, input_size: int = 64):
|
||||
_masks = torch.zeros((1, 1, input_size, input_size), device=torch_device)
|
||||
_masks[0, :, :, : int(input_size / 2)] = 1
|
||||
return _masks
|
||||
|
||||
def _modify_inputs_for_ip_adapter_test(self, inputs: Dict[str, Any]):
|
||||
blocks = self.pipeline_blocks_class()
|
||||
_ = blocks.sub_blocks.pop("ip_adapter")
|
||||
parameters = blocks.input_names
|
||||
if "image" in parameters and "strength" in parameters:
|
||||
inputs["num_inference_steps"] = 4
|
||||
|
||||
inputs["output_type"] = "np"
|
||||
return inputs
|
||||
|
||||
def test_ip_adapter(self, expected_max_diff: float = 1e-4, expected_pipe_slice=None):
|
||||
r"""Tests for IP-Adapter.
|
||||
|
||||
The following scenarios are tested:
|
||||
- Single IP-Adapter with scale=0 should produce same output as no IP-Adapter.
|
||||
- Multi IP-Adapter with scale=0 should produce same output as no IP-Adapter.
|
||||
- Single IP-Adapter with scale!=0 should produce different output compared to no IP-Adapter.
|
||||
- Multi IP-Adapter with scale!=0 should produce different output compared to no IP-Adapter.
|
||||
"""
|
||||
# Raising the tolerance for this test when it's run on a CPU because we
|
||||
# compare against static slices and that can be shaky (with a VVVV low probability).
|
||||
expected_max_diff = 9e-4 if torch_device == "cpu" else expected_max_diff
|
||||
|
||||
blocks = self.pipeline_blocks_class()
|
||||
_ = blocks.sub_blocks.pop("ip_adapter")
|
||||
pipe = blocks.init_pipeline(self.repo)
|
||||
pipe.load_default_components(torch_dtype=torch.float32)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
cross_attention_dim = pipe.unet.config.get("cross_attention_dim")
|
||||
|
||||
# forward pass without ip adapter
|
||||
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
|
||||
if expected_pipe_slice is None:
|
||||
output_without_adapter = pipe(**inputs, output="images")
|
||||
else:
|
||||
output_without_adapter = expected_pipe_slice
|
||||
|
||||
# 1. Single IP-Adapter test cases
|
||||
adapter_state_dict = create_ip_adapter_state_dict(pipe.unet)
|
||||
pipe.unet._load_ip_adapter_weights(adapter_state_dict)
|
||||
|
||||
# forward pass with single ip adapter, but scale=0 which should have no effect
|
||||
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
|
||||
inputs["ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)]
|
||||
inputs["negative_ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)]
|
||||
pipe.set_ip_adapter_scale(0.0)
|
||||
output_without_adapter_scale = pipe(**inputs, output="images")
|
||||
if expected_pipe_slice is not None:
|
||||
output_without_adapter_scale = output_without_adapter_scale[0, -3:, -3:, -1].flatten()
|
||||
|
||||
# forward pass with single ip adapter, but with scale of adapter weights
|
||||
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
|
||||
inputs["ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)]
|
||||
inputs["negative_ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)]
|
||||
pipe.set_ip_adapter_scale(42.0)
|
||||
output_with_adapter_scale = pipe(**inputs, output="images")
|
||||
if expected_pipe_slice is not None:
|
||||
output_with_adapter_scale = output_with_adapter_scale[0, -3:, -3:, -1].flatten()
|
||||
|
||||
max_diff_without_adapter_scale = np.abs(output_without_adapter_scale - output_without_adapter).max()
|
||||
max_diff_with_adapter_scale = np.abs(output_with_adapter_scale - output_without_adapter).max()
|
||||
|
||||
assert max_diff_without_adapter_scale < expected_max_diff, (
|
||||
"Output without ip-adapter must be same as normal inference"
|
||||
)
|
||||
assert max_diff_with_adapter_scale > 1e-2, "Output with ip-adapter must be different from normal inference"
|
||||
|
||||
# 2. Multi IP-Adapter test cases
|
||||
adapter_state_dict_1 = create_ip_adapter_state_dict(pipe.unet)
|
||||
adapter_state_dict_2 = create_ip_adapter_state_dict(pipe.unet)
|
||||
pipe.unet._load_ip_adapter_weights([adapter_state_dict_1, adapter_state_dict_2])
|
||||
|
||||
# forward pass with multi ip adapter, but scale=0 which should have no effect
|
||||
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
|
||||
inputs["ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2
|
||||
inputs["negative_ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2
|
||||
pipe.set_ip_adapter_scale([0.0, 0.0])
|
||||
output_without_multi_adapter_scale = pipe(**inputs, output="images")
|
||||
if expected_pipe_slice is not None:
|
||||
output_without_multi_adapter_scale = output_without_multi_adapter_scale[0, -3:, -3:, -1].flatten()
|
||||
|
||||
# forward pass with multi ip adapter, but with scale of adapter weights
|
||||
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
|
||||
inputs["ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2
|
||||
inputs["negative_ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2
|
||||
pipe.set_ip_adapter_scale([42.0, 42.0])
|
||||
output_with_multi_adapter_scale = pipe(**inputs, output="images")
|
||||
if expected_pipe_slice is not None:
|
||||
output_with_multi_adapter_scale = output_with_multi_adapter_scale[0, -3:, -3:, -1].flatten()
|
||||
|
||||
max_diff_without_multi_adapter_scale = np.abs(
|
||||
output_without_multi_adapter_scale - output_without_adapter
|
||||
).max()
|
||||
max_diff_with_multi_adapter_scale = np.abs(output_with_multi_adapter_scale - output_without_adapter).max()
|
||||
assert max_diff_without_multi_adapter_scale < expected_max_diff, (
|
||||
"Output without multi-ip-adapter must be same as normal inference"
|
||||
)
|
||||
assert max_diff_with_multi_adapter_scale > 1e-2, (
|
||||
"Output with multi-ip-adapter scale must be different from normal inference"
|
||||
)
|
||||
|
||||
|
||||
class SDXLModularControlNetTests:
|
||||
"""
|
||||
This mixin is designed to test ControlNet.
|
||||
"""
|
||||
|
||||
def test_pipeline_inputs(self):
|
||||
blocks = self.pipeline_blocks_class()
|
||||
parameters = blocks.input_names
|
||||
|
||||
assert "control_image" in parameters, "`control_image` argument must be supported by the `__call__` method"
|
||||
assert "controlnet_conditioning_scale" in parameters, (
|
||||
"`controlnet_conditioning_scale` argument must be supported by the `__call__` method"
|
||||
)
|
||||
|
||||
def _modify_inputs_for_controlnet_test(self, inputs: Dict[str, Any]):
|
||||
controlnet_embedder_scale_factor = 2
|
||||
image = torch.randn(
|
||||
(1, 3, 32 * controlnet_embedder_scale_factor, 32 * controlnet_embedder_scale_factor),
|
||||
device=torch_device,
|
||||
)
|
||||
inputs["control_image"] = image
|
||||
return inputs
|
||||
|
||||
def test_controlnet(self, expected_max_diff: float = 1e-4, expected_pipe_slice=None):
|
||||
r"""Tests for ControlNet.
|
||||
|
||||
The following scenarios are tested:
|
||||
- Single ControlNet with scale=0 should produce same output as no ControlNet.
|
||||
- Single ControlNet with scale!=0 should produce different output compared to no ControlNet.
|
||||
"""
|
||||
# Raising the tolerance for this test when it's run on a CPU because we
|
||||
# compare against static slices and that can be shaky (with a VVVV low probability).
|
||||
expected_max_diff = 9e-4 if torch_device == "cpu" else expected_max_diff
|
||||
|
||||
pipe = self.get_pipeline()
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
# forward pass without controlnet
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
output_without_controlnet = pipe(**inputs, output="images")
|
||||
output_without_controlnet = output_without_controlnet[0, -3:, -3:, -1].flatten()
|
||||
|
||||
# forward pass with single controlnet, but scale=0 which should have no effect
|
||||
inputs = self._modify_inputs_for_controlnet_test(self.get_dummy_inputs(torch_device))
|
||||
inputs["controlnet_conditioning_scale"] = 0.0
|
||||
output_without_controlnet_scale = pipe(**inputs, output="images")
|
||||
output_without_controlnet_scale = output_without_controlnet_scale[0, -3:, -3:, -1].flatten()
|
||||
|
||||
# forward pass with single controlnet, but with scale of adapter weights
|
||||
inputs = self._modify_inputs_for_controlnet_test(self.get_dummy_inputs(torch_device))
|
||||
inputs["controlnet_conditioning_scale"] = 42.0
|
||||
output_with_controlnet_scale = pipe(**inputs, output="images")
|
||||
output_with_controlnet_scale = output_with_controlnet_scale[0, -3:, -3:, -1].flatten()
|
||||
|
||||
max_diff_without_controlnet_scale = np.abs(output_without_controlnet_scale - output_without_controlnet).max()
|
||||
max_diff_with_controlnet_scale = np.abs(output_with_controlnet_scale - output_without_controlnet).max()
|
||||
|
||||
assert max_diff_without_controlnet_scale < expected_max_diff, (
|
||||
"Output without controlnet must be same as normal inference"
|
||||
)
|
||||
assert max_diff_with_controlnet_scale > 1e-2, "Output with controlnet must be different from normal inference"
|
||||
|
||||
def test_controlnet_cfg(self):
|
||||
pipe = self.get_pipeline()
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
# forward pass with CFG not applied
|
||||
guider = ClassifierFreeGuidance(guidance_scale=1.0)
|
||||
pipe.update_components(guider=guider)
|
||||
|
||||
inputs = self._modify_inputs_for_controlnet_test(self.get_dummy_inputs(torch_device))
|
||||
out_no_cfg = pipe(**inputs, output="images")
|
||||
|
||||
# forward pass with CFG applied
|
||||
guider = ClassifierFreeGuidance(guidance_scale=7.5)
|
||||
pipe.update_components(guider=guider)
|
||||
inputs = self._modify_inputs_for_controlnet_test(self.get_dummy_inputs(torch_device))
|
||||
out_cfg = pipe(**inputs, output="images")
|
||||
|
||||
assert out_cfg.shape == out_no_cfg.shape
|
||||
max_diff = np.abs(out_cfg - out_no_cfg).max()
|
||||
assert max_diff > 1e-2, "Output with CFG must be different from normal inference"
|
||||
|
||||
|
||||
class SDXLModularGuiderTests:
|
||||
def test_guider_cfg(self):
|
||||
pipe = self.get_pipeline()
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
# forward pass with CFG not applied
|
||||
guider = ClassifierFreeGuidance(guidance_scale=1.0)
|
||||
pipe.update_components(guider=guider)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
out_no_cfg = pipe(**inputs, output="images")
|
||||
|
||||
# forward pass with CFG applied
|
||||
guider = ClassifierFreeGuidance(guidance_scale=7.5)
|
||||
pipe.update_components(guider=guider)
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
out_cfg = pipe(**inputs, output="images")
|
||||
|
||||
assert out_cfg.shape == out_no_cfg.shape
|
||||
max_diff = np.abs(out_cfg - out_no_cfg).max()
|
||||
assert max_diff > 1e-2, "Output with CFG must be different from normal inference"
|
||||
|
||||
|
||||
class SDXLModularPipelineFastTests(
|
||||
SDXLModularTests,
|
||||
SDXLModularIPAdapterTests,
|
||||
SDXLModularControlNetTests,
|
||||
SDXLModularGuiderTests,
|
||||
ModularPipelineTesterMixin,
|
||||
unittest.TestCase,
|
||||
):
|
||||
"""Test cases for Stable Diffusion XL modular pipeline fast tests."""
|
||||
|
||||
def test_stable_diffusion_xl_euler(self):
|
||||
self._test_stable_diffusion_xl_euler(
|
||||
expected_image_shape=(1, 64, 64, 3),
|
||||
expected_slice=[
|
||||
0.5966781,
|
||||
0.62939394,
|
||||
0.48465094,
|
||||
0.51573336,
|
||||
0.57593524,
|
||||
0.47035995,
|
||||
0.53410417,
|
||||
0.51436996,
|
||||
0.47313565,
|
||||
],
|
||||
expected_max_diff=1e-2,
|
||||
)
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
super().test_inference_batch_single_identical(expected_max_diff=3e-3)
|
||||
|
||||
def test_stable_diffusion_xl_save_from_pretrained(self):
|
||||
pipes = []
|
||||
sd_pipe = self.get_pipeline().to(torch_device)
|
||||
pipes.append(sd_pipe)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
sd_pipe.save_pretrained(tmpdirname)
|
||||
sd_pipe = ModularPipeline.from_pretrained(tmpdirname).to(torch_device)
|
||||
sd_pipe.load_default_components(torch_dtype=torch.float32)
|
||||
sd_pipe.to(torch_device)
|
||||
pipes.append(sd_pipe)
|
||||
|
||||
image_slices = []
|
||||
for pipe in pipes:
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
image = pipe(**inputs, output="images")
|
||||
|
||||
image_slices.append(image[0, -3:, -3:, -1].flatten())
|
||||
|
||||
assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3
|
||||
|
||||
|
||||
class SDXLImg2ImgModularPipelineFastTests(
|
||||
SDXLModularTests,
|
||||
SDXLModularIPAdapterTests,
|
||||
SDXLModularControlNetTests,
|
||||
SDXLModularGuiderTests,
|
||||
ModularPipelineTesterMixin,
|
||||
unittest.TestCase,
|
||||
):
|
||||
"""Test cases for Stable Diffusion XL image-to-image modular pipeline fast tests."""
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
inputs = super().get_dummy_inputs(device, seed)
|
||||
image = floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(device)
|
||||
image = image / 2 + 0.5
|
||||
inputs["image"] = image
|
||||
inputs["strength"] = 0.8
|
||||
|
||||
return inputs
|
||||
|
||||
def test_stable_diffusion_xl_euler(self):
|
||||
self._test_stable_diffusion_xl_euler(
|
||||
expected_image_shape=(1, 64, 64, 3),
|
||||
expected_slice=[
|
||||
0.56943184,
|
||||
0.4702148,
|
||||
0.48048905,
|
||||
0.6235963,
|
||||
0.551138,
|
||||
0.49629188,
|
||||
0.60031277,
|
||||
0.5688907,
|
||||
0.43996853,
|
||||
],
|
||||
expected_max_diff=1e-2,
|
||||
)
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
super().test_inference_batch_single_identical(expected_max_diff=3e-3)
|
||||
|
||||
|
||||
class SDXLInpaintingModularPipelineFastTests(
|
||||
SDXLModularTests,
|
||||
SDXLModularIPAdapterTests,
|
||||
SDXLModularControlNetTests,
|
||||
SDXLModularGuiderTests,
|
||||
ModularPipelineTesterMixin,
|
||||
unittest.TestCase,
|
||||
):
|
||||
"""Test cases for Stable Diffusion XL inpainting modular pipeline fast tests."""
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
inputs = super().get_dummy_inputs(device, seed)
|
||||
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
|
||||
image = image.cpu().permute(0, 2, 3, 1)[0]
|
||||
init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64))
|
||||
# create mask
|
||||
image[8:, 8:, :] = 255
|
||||
mask_image = Image.fromarray(np.uint8(image)).convert("L").resize((64, 64))
|
||||
|
||||
inputs["image"] = init_image
|
||||
inputs["mask_image"] = mask_image
|
||||
inputs["strength"] = 1.0
|
||||
|
||||
return inputs
|
||||
|
||||
def test_stable_diffusion_xl_euler(self):
|
||||
self._test_stable_diffusion_xl_euler(
|
||||
expected_image_shape=(1, 64, 64, 3),
|
||||
expected_slice=[
|
||||
0.40872607,
|
||||
0.38842705,
|
||||
0.34893104,
|
||||
0.47837183,
|
||||
0.43792963,
|
||||
0.5332134,
|
||||
0.3716843,
|
||||
0.47274873,
|
||||
0.45000193,
|
||||
],
|
||||
expected_max_diff=1e-2,
|
||||
)
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
super().test_inference_batch_single_identical(expected_max_diff=3e-3)
|
||||
@@ -0,0 +1,358 @@
|
||||
import gc
|
||||
import tempfile
|
||||
import unittest
|
||||
from typing import Callable, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import diffusers
|
||||
from diffusers import ComponentsManager, ModularPipeline, ModularPipelineBlocks
|
||||
from diffusers.utils import logging
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_accelerator,
|
||||
require_torch,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
def to_np(tensor):
|
||||
if isinstance(tensor, torch.Tensor):
|
||||
tensor = tensor.detach().cpu().numpy()
|
||||
|
||||
return tensor
|
||||
|
||||
|
||||
@require_torch
|
||||
class ModularPipelineTesterMixin:
|
||||
"""
|
||||
This mixin is designed to be used with unittest.TestCase classes.
|
||||
It provides a set of common tests for each modular pipeline,
|
||||
including:
|
||||
- test_pipeline_call_signature: check if the pipeline's __call__ method has all required parameters
|
||||
- test_inference_batch_consistent: check if the pipeline's __call__ method can handle batch inputs
|
||||
- test_inference_batch_single_identical: check if the pipeline's __call__ method can handle single input
|
||||
- test_float16_inference: check if the pipeline's __call__ method can handle float16 inputs
|
||||
- test_to_device: check if the pipeline's __call__ method can handle different devices
|
||||
"""
|
||||
|
||||
# Canonical parameters that are passed to `__call__` regardless
|
||||
# of the type of pipeline. They are always optional and have common
|
||||
# sense default values.
|
||||
optional_params = frozenset(
|
||||
[
|
||||
"num_inference_steps",
|
||||
"num_images_per_prompt",
|
||||
"latents",
|
||||
"output_type",
|
||||
]
|
||||
)
|
||||
# this is modular specific: generator needs to be a intermediate input because it's mutable
|
||||
intermediate_params = frozenset(
|
||||
[
|
||||
"generator",
|
||||
]
|
||||
)
|
||||
|
||||
def get_generator(self, seed):
|
||||
device = torch_device if torch_device != "mps" else "cpu"
|
||||
generator = torch.Generator(device).manual_seed(seed)
|
||||
return generator
|
||||
|
||||
@property
|
||||
def pipeline_class(self) -> Union[Callable, ModularPipeline]:
|
||||
raise NotImplementedError(
|
||||
"You need to set the attribute `pipeline_class = ClassNameOfPipeline` in the child test class. "
|
||||
"See existing pipeline tests for reference."
|
||||
)
|
||||
|
||||
@property
|
||||
def repo(self) -> str:
|
||||
raise NotImplementedError(
|
||||
"You need to set the attribute `repo` in the child test class. See existing pipeline tests for reference."
|
||||
)
|
||||
|
||||
@property
|
||||
def pipeline_blocks_class(self) -> Union[Callable, ModularPipelineBlocks]:
|
||||
raise NotImplementedError(
|
||||
"You need to set the attribute `pipeline_blocks_class = ClassNameOfPipelineBlocks` in the child test class. "
|
||||
"See existing pipeline tests for reference."
|
||||
)
|
||||
|
||||
def get_pipeline(self):
|
||||
raise NotImplementedError(
|
||||
"You need to implement `get_pipeline(self)` in the child test class. "
|
||||
"See existing pipeline tests for reference."
|
||||
)
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
raise NotImplementedError(
|
||||
"You need to implement `get_dummy_inputs(self, device, seed)` in the child test class. "
|
||||
"See existing pipeline tests for reference."
|
||||
)
|
||||
|
||||
@property
|
||||
def params(self) -> frozenset:
|
||||
raise NotImplementedError(
|
||||
"You need to set the attribute `params` in the child test class. "
|
||||
"`params` are checked for if all values are present in `__call__`'s signature."
|
||||
" You can set `params` using one of the common set of parameters defined in `pipeline_params.py`"
|
||||
" e.g., `TEXT_TO_IMAGE_PARAMS` defines the common parameters used in text to "
|
||||
"image pipelines, including prompts and prompt embedding overrides."
|
||||
"If your pipeline's set of arguments has minor changes from one of the common sets of arguments, "
|
||||
"do not make modifications to the existing common sets of arguments. I.e. a text to image pipeline "
|
||||
"with non-configurable height and width arguments should set the attribute as "
|
||||
"`params = TEXT_TO_IMAGE_PARAMS - {'height', 'width'}`. "
|
||||
"See existing pipeline tests for reference."
|
||||
)
|
||||
|
||||
@property
|
||||
def batch_params(self) -> frozenset:
|
||||
raise NotImplementedError(
|
||||
"You need to set the attribute `batch_params` in the child test class. "
|
||||
"`batch_params` are the parameters required to be batched when passed to the pipeline's "
|
||||
"`__call__` method. `pipeline_params.py` provides some common sets of parameters such as "
|
||||
"`TEXT_TO_IMAGE_BATCH_PARAMS`, `IMAGE_VARIATION_BATCH_PARAMS`, etc... If your pipeline's "
|
||||
"set of batch arguments has minor changes from one of the common sets of batch arguments, "
|
||||
"do not make modifications to the existing common sets of batch arguments. I.e. a text to "
|
||||
"image pipeline `negative_prompt` is not batched should set the attribute as "
|
||||
"`batch_params = TEXT_TO_IMAGE_BATCH_PARAMS - {'negative_prompt'}`. "
|
||||
"See existing pipeline tests for reference."
|
||||
)
|
||||
|
||||
def setUp(self):
|
||||
# clean up the VRAM before each test
|
||||
super().setUp()
|
||||
torch.compiler.reset()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test in case of CUDA runtime errors
|
||||
super().tearDown()
|
||||
torch.compiler.reset()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_pipeline_call_signature(self):
|
||||
pipe = self.get_pipeline()
|
||||
input_parameters = pipe.blocks.input_names
|
||||
optional_parameters = pipe.default_call_parameters
|
||||
|
||||
def _check_for_parameters(parameters, expected_parameters, param_type):
|
||||
remaining_parameters = {param for param in parameters if param not in expected_parameters}
|
||||
assert len(remaining_parameters) == 0, (
|
||||
f"Required {param_type} parameters not present: {remaining_parameters}"
|
||||
)
|
||||
|
||||
_check_for_parameters(self.params, input_parameters, "input")
|
||||
_check_for_parameters(self.optional_params, optional_parameters, "optional")
|
||||
|
||||
def test_inference_batch_consistent(self, batch_sizes=[2], batch_generator=True):
|
||||
pipe = self.get_pipeline()
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
inputs["generator"] = self.get_generator(0)
|
||||
|
||||
logger = logging.get_logger(pipe.__module__)
|
||||
logger.setLevel(level=diffusers.logging.FATAL)
|
||||
|
||||
# prepare batched inputs
|
||||
batched_inputs = []
|
||||
for batch_size in batch_sizes:
|
||||
batched_input = {}
|
||||
batched_input.update(inputs)
|
||||
|
||||
for name in self.batch_params:
|
||||
if name not in inputs:
|
||||
continue
|
||||
|
||||
value = inputs[name]
|
||||
batched_input[name] = batch_size * [value]
|
||||
|
||||
if batch_generator and "generator" in inputs:
|
||||
batched_input["generator"] = [self.get_generator(i) for i in range(batch_size)]
|
||||
|
||||
if "batch_size" in inputs:
|
||||
batched_input["batch_size"] = batch_size
|
||||
|
||||
batched_inputs.append(batched_input)
|
||||
|
||||
logger.setLevel(level=diffusers.logging.WARNING)
|
||||
for batch_size, batched_input in zip(batch_sizes, batched_inputs):
|
||||
output = pipe(**batched_input, output="images")
|
||||
assert len(output) == batch_size, "Output is different from expected batch size"
|
||||
|
||||
def test_inference_batch_single_identical(
|
||||
self,
|
||||
batch_size=2,
|
||||
expected_max_diff=1e-4,
|
||||
):
|
||||
pipe = self.get_pipeline()
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
# Reset generator in case it is has been used in self.get_dummy_inputs
|
||||
inputs["generator"] = self.get_generator(0)
|
||||
|
||||
logger = logging.get_logger(pipe.__module__)
|
||||
logger.setLevel(level=diffusers.logging.FATAL)
|
||||
|
||||
# batchify inputs
|
||||
batched_inputs = {}
|
||||
batched_inputs.update(inputs)
|
||||
|
||||
for name in self.batch_params:
|
||||
if name not in inputs:
|
||||
continue
|
||||
|
||||
value = inputs[name]
|
||||
batched_inputs[name] = batch_size * [value]
|
||||
|
||||
if "generator" in inputs:
|
||||
batched_inputs["generator"] = [self.get_generator(i) for i in range(batch_size)]
|
||||
|
||||
if "batch_size" in inputs:
|
||||
batched_inputs["batch_size"] = batch_size
|
||||
|
||||
output = pipe(**inputs, output="images")
|
||||
output_batch = pipe(**batched_inputs, output="images")
|
||||
|
||||
assert output_batch.shape[0] == batch_size
|
||||
|
||||
max_diff = np.abs(to_np(output_batch[0]) - to_np(output[0])).max()
|
||||
assert max_diff < expected_max_diff, "Batch inference results different from single inference results"
|
||||
|
||||
@unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU")
|
||||
@require_accelerator
|
||||
def test_float16_inference(self, expected_max_diff=5e-2):
|
||||
pipe = self.get_pipeline()
|
||||
pipe.to(torch_device, torch.float32)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
pipe_fp16 = self.get_pipeline()
|
||||
pipe_fp16.to(torch_device, torch.float16)
|
||||
pipe_fp16.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
# Reset generator in case it is used inside dummy inputs
|
||||
if "generator" in inputs:
|
||||
inputs["generator"] = self.get_generator(0)
|
||||
output = pipe(**inputs, output="images")
|
||||
|
||||
fp16_inputs = self.get_dummy_inputs(torch_device)
|
||||
# Reset generator in case it is used inside dummy inputs
|
||||
if "generator" in fp16_inputs:
|
||||
fp16_inputs["generator"] = self.get_generator(0)
|
||||
output_fp16 = pipe_fp16(**fp16_inputs, output="images")
|
||||
|
||||
if isinstance(output, torch.Tensor):
|
||||
output = output.cpu()
|
||||
output_fp16 = output_fp16.cpu()
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(output.flatten(), output_fp16.flatten())
|
||||
assert max_diff < expected_max_diff, "FP16 inference is different from FP32 inference"
|
||||
|
||||
@require_accelerator
|
||||
def test_to_device(self):
|
||||
pipe = self.get_pipeline()
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
pipe.to("cpu")
|
||||
model_devices = [
|
||||
component.device.type for component in pipe.components.values() if hasattr(component, "device")
|
||||
]
|
||||
assert all(device == "cpu" for device in model_devices), "All pipeline components are not on CPU"
|
||||
|
||||
pipe.to(torch_device)
|
||||
model_devices = [
|
||||
component.device.type for component in pipe.components.values() if hasattr(component, "device")
|
||||
]
|
||||
assert all(device == torch_device for device in model_devices), (
|
||||
"All pipeline components are not on accelerator device"
|
||||
)
|
||||
|
||||
def test_inference_is_not_nan_cpu(self):
|
||||
pipe = self.get_pipeline()
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.to("cpu")
|
||||
|
||||
output = pipe(**self.get_dummy_inputs("cpu"), output="images")
|
||||
assert np.isnan(to_np(output)).sum() == 0, "CPU Inference returns NaN"
|
||||
|
||||
@require_accelerator
|
||||
def test_inference_is_not_nan(self):
|
||||
pipe = self.get_pipeline()
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.to(torch_device)
|
||||
|
||||
output = pipe(**self.get_dummy_inputs(torch_device), output="images")
|
||||
assert np.isnan(to_np(output)).sum() == 0, "Accelerator Inference returns NaN"
|
||||
|
||||
def test_num_images_per_prompt(self):
|
||||
pipe = self.get_pipeline()
|
||||
|
||||
if "num_images_per_prompt" not in pipe.blocks.input_names:
|
||||
return
|
||||
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
batch_sizes = [1, 2]
|
||||
num_images_per_prompts = [1, 2]
|
||||
|
||||
for batch_size in batch_sizes:
|
||||
for num_images_per_prompt in num_images_per_prompts:
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
for key in inputs.keys():
|
||||
if key in self.batch_params:
|
||||
inputs[key] = batch_size * [inputs[key]]
|
||||
|
||||
images = pipe(**inputs, num_images_per_prompt=num_images_per_prompt, output="images")
|
||||
|
||||
assert images.shape[0] == batch_size * num_images_per_prompt
|
||||
|
||||
@require_accelerator
|
||||
def test_components_auto_cpu_offload_inference_consistent(self):
|
||||
base_pipe = self.get_pipeline().to(torch_device)
|
||||
|
||||
cm = ComponentsManager()
|
||||
cm.enable_auto_cpu_offload(device=torch_device)
|
||||
offload_pipe = self.get_pipeline(components_manager=cm)
|
||||
|
||||
image_slices = []
|
||||
for pipe in [base_pipe, offload_pipe]:
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
image = pipe(**inputs, output="images")
|
||||
|
||||
image_slices.append(image[0, -3:, -3:, -1].flatten())
|
||||
|
||||
assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3
|
||||
|
||||
def test_save_from_pretrained(self):
|
||||
pipes = []
|
||||
base_pipe = self.get_pipeline().to(torch_device)
|
||||
pipes.append(base_pipe)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
base_pipe.save_pretrained(tmpdirname)
|
||||
pipe = ModularPipeline.from_pretrained(tmpdirname).to(torch_device)
|
||||
pipe.load_default_components(torch_dtype=torch.float16)
|
||||
pipe.to(torch_device)
|
||||
|
||||
pipes.append(pipe)
|
||||
|
||||
image_slices = []
|
||||
for pipe in pipes:
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
image = pipe(**inputs, output="images")
|
||||
|
||||
image_slices.append(image[0, -3:, -3:, -1].flatten())
|
||||
|
||||
assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3
|
||||
@@ -20,12 +20,6 @@ TEXT_TO_IMAGE_PARAMS = frozenset(
|
||||
]
|
||||
)
|
||||
|
||||
TEXT_TO_IMAGE_BATCH_PARAMS = frozenset(["prompt", "negative_prompt"])
|
||||
|
||||
TEXT_TO_IMAGE_IMAGE_PARAMS = frozenset([])
|
||||
|
||||
IMAGE_TO_IMAGE_IMAGE_PARAMS = frozenset(["image"])
|
||||
|
||||
IMAGE_VARIATION_PARAMS = frozenset(
|
||||
[
|
||||
"image",
|
||||
@@ -35,8 +29,6 @@ IMAGE_VARIATION_PARAMS = frozenset(
|
||||
]
|
||||
)
|
||||
|
||||
IMAGE_VARIATION_BATCH_PARAMS = frozenset(["image"])
|
||||
|
||||
TEXT_GUIDED_IMAGE_VARIATION_PARAMS = frozenset(
|
||||
[
|
||||
"prompt",
|
||||
@@ -50,8 +42,6 @@ TEXT_GUIDED_IMAGE_VARIATION_PARAMS = frozenset(
|
||||
]
|
||||
)
|
||||
|
||||
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS = frozenset(["prompt", "image", "negative_prompt"])
|
||||
|
||||
TEXT_GUIDED_IMAGE_INPAINTING_PARAMS = frozenset(
|
||||
[
|
||||
# Text guided image variation with an image mask
|
||||
@@ -67,8 +57,6 @@ TEXT_GUIDED_IMAGE_INPAINTING_PARAMS = frozenset(
|
||||
]
|
||||
)
|
||||
|
||||
TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["prompt", "image", "mask_image", "negative_prompt"])
|
||||
|
||||
IMAGE_INPAINTING_PARAMS = frozenset(
|
||||
[
|
||||
# image variation with an image mask
|
||||
@@ -80,8 +68,6 @@ IMAGE_INPAINTING_PARAMS = frozenset(
|
||||
]
|
||||
)
|
||||
|
||||
IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["image", "mask_image"])
|
||||
|
||||
IMAGE_GUIDED_IMAGE_INPAINTING_PARAMS = frozenset(
|
||||
[
|
||||
"example_image",
|
||||
@@ -93,20 +79,12 @@ IMAGE_GUIDED_IMAGE_INPAINTING_PARAMS = frozenset(
|
||||
]
|
||||
)
|
||||
|
||||
IMAGE_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["example_image", "image", "mask_image"])
|
||||
UNCONDITIONAL_IMAGE_GENERATION_PARAMS = frozenset(["batch_size"])
|
||||
|
||||
CLASS_CONDITIONED_IMAGE_GENERATION_PARAMS = frozenset(["class_labels"])
|
||||
|
||||
CLASS_CONDITIONED_IMAGE_GENERATION_BATCH_PARAMS = frozenset(["class_labels"])
|
||||
|
||||
UNCONDITIONAL_IMAGE_GENERATION_PARAMS = frozenset(["batch_size"])
|
||||
|
||||
UNCONDITIONAL_IMAGE_GENERATION_BATCH_PARAMS = frozenset([])
|
||||
|
||||
UNCONDITIONAL_AUDIO_GENERATION_PARAMS = frozenset(["batch_size"])
|
||||
|
||||
UNCONDITIONAL_AUDIO_GENERATION_BATCH_PARAMS = frozenset([])
|
||||
|
||||
TEXT_TO_AUDIO_PARAMS = frozenset(
|
||||
[
|
||||
"prompt",
|
||||
@@ -119,11 +97,38 @@ TEXT_TO_AUDIO_PARAMS = frozenset(
|
||||
]
|
||||
)
|
||||
|
||||
TEXT_TO_AUDIO_BATCH_PARAMS = frozenset(["prompt", "negative_prompt"])
|
||||
TOKENS_TO_AUDIO_GENERATION_PARAMS = frozenset(["input_tokens"])
|
||||
|
||||
UNCONDITIONAL_AUDIO_GENERATION_PARAMS = frozenset(["batch_size"])
|
||||
|
||||
# image params
|
||||
TEXT_TO_IMAGE_IMAGE_PARAMS = frozenset([])
|
||||
|
||||
IMAGE_TO_IMAGE_IMAGE_PARAMS = frozenset(["image"])
|
||||
|
||||
|
||||
# batch params
|
||||
TEXT_TO_IMAGE_BATCH_PARAMS = frozenset(["prompt", "negative_prompt"])
|
||||
|
||||
IMAGE_VARIATION_BATCH_PARAMS = frozenset(["image"])
|
||||
|
||||
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS = frozenset(["prompt", "image", "negative_prompt"])
|
||||
|
||||
TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["prompt", "image", "mask_image", "negative_prompt"])
|
||||
|
||||
IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["image", "mask_image"])
|
||||
|
||||
IMAGE_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["example_image", "image", "mask_image"])
|
||||
|
||||
UNCONDITIONAL_IMAGE_GENERATION_BATCH_PARAMS = frozenset([])
|
||||
|
||||
UNCONDITIONAL_AUDIO_GENERATION_BATCH_PARAMS = frozenset([])
|
||||
|
||||
TEXT_TO_AUDIO_BATCH_PARAMS = frozenset(["prompt", "negative_prompt"])
|
||||
|
||||
TOKENS_TO_AUDIO_GENERATION_BATCH_PARAMS = frozenset(["input_tokens"])
|
||||
|
||||
TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS = frozenset(["prompt_embeds"])
|
||||
|
||||
VIDEO_TO_VIDEO_BATCH_PARAMS = frozenset(["prompt", "negative_prompt", "video"])
|
||||
|
||||
# callback params
|
||||
TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS = frozenset(["prompt_embeds"])
|
||||
|
||||
Reference in New Issue
Block a user