Compare commits
6 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| e6b52a2c2b | |||
| f8ba5cd77a | |||
| c9c8217306 | |||
| 135df5be9d | |||
| 4a9dbd56f6 | |||
| 630d27fe5b |
@@ -25,6 +25,8 @@ Original model checkpoints for Flux can be found [here](https://huggingface.co/b
|
||||
|
||||
Flux can be quite expensive to run on consumer hardware devices. However, you can perform a suite of optimizations to run it faster and in a more memory-friendly manner. Check out [this section](https://huggingface.co/blog/sd3#memory-optimizations-for-sd3) for more details. Additionally, Flux can benefit from quantization for memory efficiency with a trade-off in inference latency. Refer to [this blog post](https://huggingface.co/blog/quanto-diffusers) to learn more. For an exhaustive list of resources, check out [this gist](https://gist.github.com/sayakpaul/b664605caf0aa3bf8585ab109dd5ac9c).
|
||||
|
||||
[Caching](../../optimization/cache) may also speed up inference by storing and reusing intermediate outputs.
|
||||
|
||||
</Tip>
|
||||
|
||||
Flux comes in the following variants:
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
|
||||
<Tip>
|
||||
|
||||
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
[Caching](../../optimization/cache) may also speed up inference by storing and reusing intermediate outputs.
|
||||
|
||||
</Tip>
|
||||
|
||||
|
||||
@@ -88,7 +88,7 @@ export_to_video(video, "output.mp4", fps=24)
|
||||
</hfoption>
|
||||
<hfoption id="inference speed">
|
||||
|
||||
[Compilation](../../optimization/fp16#torchcompile) is slow the first time but subsequent calls to the pipeline are faster.
|
||||
[Compilation](../../optimization/fp16#torchcompile) is slow the first time but subsequent calls to the pipeline are faster. [Caching](../../optimization/cache) may also speed up inference by storing and reusing intermediate outputs.
|
||||
|
||||
```py
|
||||
import torch
|
||||
|
||||
@@ -20,7 +20,7 @@ Check out the model card [here](https://huggingface.co/Qwen/Qwen-Image) to learn
|
||||
|
||||
<Tip>
|
||||
|
||||
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
[Caching](../../optimization/cache) may also speed up inference by storing and reusing intermediate outputs.
|
||||
|
||||
</Tip>
|
||||
|
||||
|
||||
@@ -119,7 +119,7 @@ export_to_video(output, "output.mp4", fps=16)
|
||||
</hfoption>
|
||||
<hfoption id="T2V inference speed">
|
||||
|
||||
[Compilation](../../optimization/fp16#torchcompile) is slow the first time but subsequent calls to the pipeline are faster.
|
||||
[Compilation](../../optimization/fp16#torchcompile) is slow the first time but subsequent calls to the pipeline are faster. [Caching](../../optimization/cache) may also speed up inference by storing and reusing intermediate outputs.
|
||||
|
||||
```py
|
||||
# pip install ftfy
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# /// script
|
||||
# dependencies = [
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# /// script
|
||||
# dependencies = [
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# /// script
|
||||
# dependencies = [
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import functools
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import functools
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import contextlib
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import contextlib
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import functools
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import itertools
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# /// script
|
||||
# dependencies = [
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# /// script
|
||||
# dependencies = [
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# /// script
|
||||
# dependencies = [
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import gc
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import contextlib
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import functools
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import contextlib
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import contextlib
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import contextlib
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import contextlib
|
||||
|
||||
+1
@@ -12,6 +12,7 @@
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import io
|
||||
|
||||
+1
@@ -12,6 +12,7 @@
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
|
||||
+1
@@ -12,6 +12,7 @@
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
|
||||
+1
@@ -12,6 +12,7 @@
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import contextlib
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import typing
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import functools
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import functools
|
||||
import math
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
@@ -162,7 +163,7 @@ class QwenEmbedRope(nn.Module):
|
||||
self.axes_dim = axes_dim
|
||||
pos_index = torch.arange(1024)
|
||||
neg_index = torch.arange(1024).flip(0) * -1 - 1
|
||||
self.pos_freqs = torch.cat(
|
||||
pos_freqs = torch.cat(
|
||||
[
|
||||
self.rope_params(pos_index, self.axes_dim[0], self.theta),
|
||||
self.rope_params(pos_index, self.axes_dim[1], self.theta),
|
||||
@@ -170,7 +171,7 @@ class QwenEmbedRope(nn.Module):
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
self.neg_freqs = torch.cat(
|
||||
neg_freqs = torch.cat(
|
||||
[
|
||||
self.rope_params(neg_index, self.axes_dim[0], self.theta),
|
||||
self.rope_params(neg_index, self.axes_dim[1], self.theta),
|
||||
@@ -179,6 +180,8 @@ class QwenEmbedRope(nn.Module):
|
||||
dim=1,
|
||||
)
|
||||
self.rope_cache = {}
|
||||
self.register_buffer("pos_freqs", pos_freqs, persistent=False)
|
||||
self.register_buffer("neg_freqs", neg_freqs, persistent=False)
|
||||
|
||||
# 是否使用 scale rope
|
||||
self.scale_rope = scale_rope
|
||||
@@ -198,33 +201,17 @@ class QwenEmbedRope(nn.Module):
|
||||
Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args:
|
||||
txt_length: [bs] a list of 1 integers representing the length of the text
|
||||
"""
|
||||
if self.pos_freqs.device != device:
|
||||
self.pos_freqs = self.pos_freqs.to(device)
|
||||
self.neg_freqs = self.neg_freqs.to(device)
|
||||
|
||||
if isinstance(video_fhw, list):
|
||||
video_fhw = video_fhw[0]
|
||||
frame, height, width = video_fhw
|
||||
rope_key = f"{frame}_{height}_{width}"
|
||||
|
||||
if rope_key not in self.rope_cache:
|
||||
seq_lens = frame * height * width
|
||||
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
freqs_frame = freqs_pos[0][:frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
|
||||
if self.scale_rope:
|
||||
freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0)
|
||||
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
|
||||
freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
|
||||
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
|
||||
|
||||
else:
|
||||
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
|
||||
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
|
||||
|
||||
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
|
||||
self.rope_cache[rope_key] = freqs.clone().contiguous()
|
||||
vid_freqs = self.rope_cache[rope_key]
|
||||
if not torch.compiler.is_compiling():
|
||||
if rope_key not in self.rope_cache:
|
||||
self.rope_cache[rope_key] = self._compute_video_freqs(frame, height, width)
|
||||
vid_freqs = self.rope_cache[rope_key]
|
||||
else:
|
||||
vid_freqs = self._compute_video_freqs(frame, height, width)
|
||||
|
||||
if self.scale_rope:
|
||||
max_vid_index = max(height // 2, width // 2)
|
||||
@@ -236,6 +223,25 @@ class QwenEmbedRope(nn.Module):
|
||||
|
||||
return vid_freqs, txt_freqs
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def _compute_video_freqs(self, frame, height, width):
|
||||
seq_lens = frame * height * width
|
||||
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
|
||||
freqs_frame = freqs_pos[0][:frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
|
||||
if self.scale_rope:
|
||||
freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0)
|
||||
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
|
||||
freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
|
||||
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
|
||||
else:
|
||||
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
|
||||
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
|
||||
|
||||
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
|
||||
return freqs.clone().contiguous()
|
||||
|
||||
|
||||
class QwenDoubleStreamAttnProcessor2_0:
|
||||
"""
|
||||
@@ -482,6 +488,7 @@ class QwenImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fro
|
||||
_supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["QwenImageTransformerBlock"]
|
||||
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
|
||||
_repeated_blocks = ["QwenImageTransformerBlock"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
|
||||
@@ -7,9 +7,15 @@ from ..utils import (
|
||||
get_objects_from_module,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
logging,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
logger.warning(
|
||||
"Modular Diffusers is currently an experimental feature under active development. The API is subject to breaking changes in future releases."
|
||||
)
|
||||
|
||||
# These modules contain pipelines from multiple libraries/frameworks
|
||||
_dummy_objects = {}
|
||||
_import_structure = {}
|
||||
@@ -25,7 +31,6 @@ else:
|
||||
_import_structure["modular_pipeline"] = [
|
||||
"ModularPipelineBlocks",
|
||||
"ModularPipeline",
|
||||
"PipelineBlock",
|
||||
"AutoPipelineBlocks",
|
||||
"SequentialPipelineBlocks",
|
||||
"LoopSequentialPipelineBlocks",
|
||||
@@ -59,21 +64,11 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
LoopSequentialPipelineBlocks,
|
||||
ModularPipeline,
|
||||
ModularPipelineBlocks,
|
||||
PipelineBlock,
|
||||
PipelineState,
|
||||
SequentialPipelineBlocks,
|
||||
)
|
||||
from .modular_pipeline_utils import (
|
||||
ComponentSpec,
|
||||
ConfigSpec,
|
||||
InputParam,
|
||||
InsertableDict,
|
||||
OutputParam,
|
||||
)
|
||||
from .stable_diffusion_xl import (
|
||||
StableDiffusionXLAutoBlocks,
|
||||
StableDiffusionXLModularPipeline,
|
||||
)
|
||||
from .modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, InsertableDict, OutputParam
|
||||
from .stable_diffusion_xl import StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline
|
||||
from .wan import WanAutoBlocks, WanModularPipeline
|
||||
else:
|
||||
import sys
|
||||
|
||||
@@ -22,7 +22,7 @@ from ...models import AutoencoderKL
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import logging
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..modular_pipeline import PipelineBlock, PipelineState
|
||||
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
from .modular_pipeline import FluxModularPipeline
|
||||
|
||||
@@ -231,7 +231,7 @@ def _get_initial_timesteps_and_optionals(
|
||||
return timesteps, num_inference_steps, sigmas, guidance
|
||||
|
||||
|
||||
class FluxInputStep(PipelineBlock):
|
||||
class FluxInputStep(ModularPipelineBlocks):
|
||||
model_name = "flux"
|
||||
|
||||
@property
|
||||
@@ -249,11 +249,6 @@ class FluxInputStep(PipelineBlock):
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam(
|
||||
"prompt_embeds",
|
||||
required=True,
|
||||
@@ -322,7 +317,7 @@ class FluxInputStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class FluxSetTimestepsStep(PipelineBlock):
|
||||
class FluxSetTimestepsStep(ModularPipelineBlocks):
|
||||
model_name = "flux"
|
||||
|
||||
@property
|
||||
@@ -340,14 +335,10 @@ class FluxSetTimestepsStep(PipelineBlock):
|
||||
InputParam("timesteps"),
|
||||
InputParam("sigmas"),
|
||||
InputParam("guidance_scale", default=3.5),
|
||||
InputParam("latents", type_hint=torch.Tensor),
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
InputParam("height", type_hint=int),
|
||||
InputParam("width", type_hint=int),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam(
|
||||
"batch_size",
|
||||
required=True,
|
||||
@@ -398,7 +389,7 @@ class FluxSetTimestepsStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class FluxImg2ImgSetTimestepsStep(PipelineBlock):
|
||||
class FluxImg2ImgSetTimestepsStep(ModularPipelineBlocks):
|
||||
model_name = "flux"
|
||||
|
||||
@property
|
||||
@@ -420,11 +411,6 @@ class FluxImg2ImgSetTimestepsStep(PipelineBlock):
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
InputParam("height", type_hint=int),
|
||||
InputParam("width", type_hint=int),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam(
|
||||
"batch_size",
|
||||
required=True,
|
||||
@@ -497,7 +483,7 @@ class FluxImg2ImgSetTimestepsStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class FluxPrepareLatentsStep(PipelineBlock):
|
||||
class FluxPrepareLatentsStep(ModularPipelineBlocks):
|
||||
model_name = "flux"
|
||||
|
||||
@property
|
||||
@@ -515,11 +501,6 @@ class FluxPrepareLatentsStep(PipelineBlock):
|
||||
InputParam("width", type_hint=int),
|
||||
InputParam("latents", type_hint=Optional[torch.Tensor]),
|
||||
InputParam("num_images_per_prompt", type_hint=int, default=1),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("generator"),
|
||||
InputParam(
|
||||
"batch_size",
|
||||
@@ -621,7 +602,7 @@ class FluxPrepareLatentsStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class FluxImg2ImgPrepareLatentsStep(PipelineBlock):
|
||||
class FluxImg2ImgPrepareLatentsStep(ModularPipelineBlocks):
|
||||
model_name = "flux"
|
||||
|
||||
@property
|
||||
@@ -639,11 +620,6 @@ class FluxImg2ImgPrepareLatentsStep(PipelineBlock):
|
||||
InputParam("width", type_hint=int),
|
||||
InputParam("latents", type_hint=Optional[torch.Tensor]),
|
||||
InputParam("num_images_per_prompt", type_hint=int, default=1),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("generator"),
|
||||
InputParam(
|
||||
"image_latents",
|
||||
|
||||
@@ -22,7 +22,7 @@ from ...configuration_utils import FrozenDict
|
||||
from ...models import AutoencoderKL
|
||||
from ...utils import logging
|
||||
from ...video_processor import VaeImageProcessor
|
||||
from ..modular_pipeline import PipelineBlock, PipelineState
|
||||
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
|
||||
|
||||
@@ -45,7 +45,7 @@ def _unpack_latents(latents, height, width, vae_scale_factor):
|
||||
return latents
|
||||
|
||||
|
||||
class FluxDecodeStep(PipelineBlock):
|
||||
class FluxDecodeStep(ModularPipelineBlocks):
|
||||
model_name = "flux"
|
||||
|
||||
@property
|
||||
@@ -70,17 +70,12 @@ class FluxDecodeStep(PipelineBlock):
|
||||
InputParam("output_type", default="pil"),
|
||||
InputParam("height", default=1024),
|
||||
InputParam("width", default=1024),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The denoised latents from the denoising step",
|
||||
)
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
|
||||
@@ -22,7 +22,7 @@ from ...utils import logging
|
||||
from ..modular_pipeline import (
|
||||
BlockState,
|
||||
LoopSequentialPipelineBlocks,
|
||||
PipelineBlock,
|
||||
ModularPipelineBlocks,
|
||||
PipelineState,
|
||||
)
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
@@ -32,7 +32,7 @@ from .modular_pipeline import FluxModularPipeline
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class FluxLoopDenoiser(PipelineBlock):
|
||||
class FluxLoopDenoiser(ModularPipelineBlocks):
|
||||
model_name = "flux"
|
||||
|
||||
@property
|
||||
@@ -49,11 +49,8 @@ class FluxLoopDenoiser(PipelineBlock):
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return [InputParam("joint_attention_kwargs")]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam("joint_attention_kwargs"),
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
@@ -113,7 +110,7 @@ class FluxLoopDenoiser(PipelineBlock):
|
||||
return components, block_state
|
||||
|
||||
|
||||
class FluxLoopAfterDenoiser(PipelineBlock):
|
||||
class FluxLoopAfterDenoiser(ModularPipelineBlocks):
|
||||
model_name = "flux"
|
||||
|
||||
@property
|
||||
@@ -175,7 +172,7 @@ class FluxDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
|
||||
]
|
||||
|
||||
@property
|
||||
def loop_intermediate_inputs(self) -> List[InputParam]:
|
||||
def loop_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
"timesteps",
|
||||
|
||||
@@ -24,7 +24,7 @@ from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL
|
||||
from ...utils import USE_PEFT_BACKEND, is_ftfy_available, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ..modular_pipeline import PipelineBlock, PipelineState
|
||||
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
|
||||
from .modular_pipeline import FluxModularPipeline
|
||||
|
||||
@@ -67,7 +67,7 @@ def retrieve_latents(
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
class FluxVaeEncoderStep(PipelineBlock):
|
||||
class FluxVaeEncoderStep(ModularPipelineBlocks):
|
||||
model_name = "flux"
|
||||
|
||||
@property
|
||||
@@ -88,11 +88,10 @@ class FluxVaeEncoderStep(PipelineBlock):
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [InputParam("image", required=True), InputParam("height"), InputParam("width")]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("image", required=True),
|
||||
InputParam("height"),
|
||||
InputParam("width"),
|
||||
InputParam("generator"),
|
||||
InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"),
|
||||
InputParam(
|
||||
@@ -157,7 +156,7 @@ class FluxVaeEncoderStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class FluxTextEncoderStep(PipelineBlock):
|
||||
class FluxTextEncoderStep(ModularPipelineBlocks):
|
||||
model_name = "flux"
|
||||
|
||||
@property
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -618,7 +618,6 @@ def format_configs(configs, indent_level=4, max_line_length=115, add_empty_lines
|
||||
|
||||
def make_doc_string(
|
||||
inputs,
|
||||
intermediate_inputs,
|
||||
outputs,
|
||||
description="",
|
||||
class_name=None,
|
||||
@@ -664,7 +663,7 @@ def make_doc_string(
|
||||
output += configs_str + "\n\n"
|
||||
|
||||
# Add inputs section
|
||||
output += format_input_params(inputs + intermediate_inputs, indent_level=2)
|
||||
output += format_input_params(inputs, indent_level=2)
|
||||
|
||||
# Add outputs section
|
||||
output += "\n\n"
|
||||
|
||||
@@ -27,7 +27,7 @@ from ...schedulers import EulerDiscreteScheduler
|
||||
from ...utils import logging
|
||||
from ...utils.torch_utils import randn_tensor, unwrap_module
|
||||
from ..modular_pipeline import (
|
||||
PipelineBlock,
|
||||
ModularPipelineBlocks,
|
||||
PipelineState,
|
||||
)
|
||||
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
|
||||
@@ -195,7 +195,7 @@ def prepare_latents_img2img(
|
||||
return latents
|
||||
|
||||
|
||||
class StableDiffusionXLInputStep(PipelineBlock):
|
||||
class StableDiffusionXLInputStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -213,11 +213,6 @@ class StableDiffusionXLInputStep(PipelineBlock):
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam(
|
||||
"prompt_embeds",
|
||||
required=True,
|
||||
@@ -394,7 +389,7 @@ class StableDiffusionXLInputStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock):
|
||||
class StableDiffusionXLImg2ImgSetTimestepsStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -421,11 +416,6 @@ class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock):
|
||||
InputParam("denoising_start"),
|
||||
# YiYi TODO: do we need num_images_per_prompt here?
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam(
|
||||
"batch_size",
|
||||
required=True,
|
||||
@@ -543,7 +533,7 @@ class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLSetTimestepsStep(PipelineBlock):
|
||||
class StableDiffusionXLSetTimestepsStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -611,7 +601,7 @@ class StableDiffusionXLSetTimestepsStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
|
||||
class StableDiffusionXLInpaintPrepareLatentsStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -640,11 +630,6 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
|
||||
"`num_inference_steps`. A value of 1, therefore, essentially ignores `image`. Note that in the case of "
|
||||
"`denoising_start` being declared as an integer, the value of `strength` will be ignored.",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam("generator"),
|
||||
InputParam(
|
||||
"batch_size",
|
||||
@@ -890,7 +875,7 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock):
|
||||
class StableDiffusionXLImg2ImgPrepareLatentsStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -910,11 +895,6 @@ class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock):
|
||||
InputParam("latents"),
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
InputParam("denoising_start"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("generator"),
|
||||
InputParam(
|
||||
"latent_timestep",
|
||||
@@ -971,7 +951,7 @@ class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLPrepareLatentsStep(PipelineBlock):
|
||||
class StableDiffusionXLPrepareLatentsStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -992,11 +972,6 @@ class StableDiffusionXLPrepareLatentsStep(PipelineBlock):
|
||||
InputParam("width"),
|
||||
InputParam("latents"),
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("generator"),
|
||||
InputParam(
|
||||
"batch_size",
|
||||
@@ -1082,7 +1057,7 @@ class StableDiffusionXLPrepareLatentsStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock):
|
||||
class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -1119,11 +1094,6 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock):
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
InputParam("aesthetic_score", default=6.0),
|
||||
InputParam("negative_aesthetic_score", default=2.0),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
@@ -1306,7 +1276,7 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock):
|
||||
class StableDiffusionXLPrepareAdditionalConditioningStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -1335,11 +1305,6 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock):
|
||||
InputParam("crops_coords_top_left", default=(0, 0)),
|
||||
InputParam("negative_crops_coords_top_left", default=(0, 0)),
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
@@ -1489,7 +1454,7 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLControlNetInputStep(PipelineBlock):
|
||||
class StableDiffusionXLControlNetInputStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -1517,11 +1482,6 @@ class StableDiffusionXLControlNetInputStep(PipelineBlock):
|
||||
InputParam("controlnet_conditioning_scale", default=1.0),
|
||||
InputParam("guess_mode", default=False),
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
@@ -1708,7 +1668,7 @@ class StableDiffusionXLControlNetInputStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLControlNetUnionInputStep(PipelineBlock):
|
||||
class StableDiffusionXLControlNetUnionInputStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -1737,11 +1697,6 @@ class StableDiffusionXLControlNetUnionInputStep(PipelineBlock):
|
||||
InputParam("controlnet_conditioning_scale", default=1.0),
|
||||
InputParam("guess_mode", default=False),
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
|
||||
@@ -24,7 +24,7 @@ from ...models import AutoencoderKL
|
||||
from ...models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import (
|
||||
PipelineBlock,
|
||||
ModularPipelineBlocks,
|
||||
PipelineState,
|
||||
)
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
@@ -33,7 +33,7 @@ from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class StableDiffusionXLDecodeStep(PipelineBlock):
|
||||
class StableDiffusionXLDecodeStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -56,17 +56,12 @@ class StableDiffusionXLDecodeStep(PipelineBlock):
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return [
|
||||
InputParam("output_type", default="pil"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The denoised latents from the denoising step",
|
||||
)
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -157,7 +152,7 @@ class StableDiffusionXLDecodeStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLInpaintOverlayMaskStep(PipelineBlock):
|
||||
class StableDiffusionXLInpaintOverlayMaskStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -184,11 +179,6 @@ class StableDiffusionXLInpaintOverlayMaskStep(PipelineBlock):
|
||||
InputParam("image"),
|
||||
InputParam("mask_image"),
|
||||
InputParam("padding_mask_crop"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam(
|
||||
"images",
|
||||
type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]],
|
||||
|
||||
@@ -25,7 +25,7 @@ from ...utils import logging
|
||||
from ..modular_pipeline import (
|
||||
BlockState,
|
||||
LoopSequentialPipelineBlocks,
|
||||
PipelineBlock,
|
||||
ModularPipelineBlocks,
|
||||
PipelineState,
|
||||
)
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
@@ -37,7 +37,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
# YiYi experimenting composible denoise loop
|
||||
# loop step (1): prepare latent input for denoiser
|
||||
class StableDiffusionXLLoopBeforeDenoiser(PipelineBlock):
|
||||
class StableDiffusionXLLoopBeforeDenoiser(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -55,7 +55,7 @@ class StableDiffusionXLLoopBeforeDenoiser(PipelineBlock):
|
||||
)
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
def inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam(
|
||||
"latents",
|
||||
@@ -73,7 +73,7 @@ class StableDiffusionXLLoopBeforeDenoiser(PipelineBlock):
|
||||
|
||||
|
||||
# loop step (1): prepare latent input for denoiser (with inpainting)
|
||||
class StableDiffusionXLInpaintLoopBeforeDenoiser(PipelineBlock):
|
||||
class StableDiffusionXLInpaintLoopBeforeDenoiser(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -91,7 +91,7 @@ class StableDiffusionXLInpaintLoopBeforeDenoiser(PipelineBlock):
|
||||
)
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
def inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam(
|
||||
"latents",
|
||||
@@ -144,7 +144,7 @@ class StableDiffusionXLInpaintLoopBeforeDenoiser(PipelineBlock):
|
||||
|
||||
|
||||
# loop step (2): denoise the latents with guidance
|
||||
class StableDiffusionXLLoopDenoiser(PipelineBlock):
|
||||
class StableDiffusionXLLoopDenoiser(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -171,11 +171,6 @@ class StableDiffusionXLLoopDenoiser(PipelineBlock):
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return [
|
||||
InputParam("cross_attention_kwargs"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam(
|
||||
"num_inference_steps",
|
||||
required=True,
|
||||
@@ -249,7 +244,7 @@ class StableDiffusionXLLoopDenoiser(PipelineBlock):
|
||||
|
||||
|
||||
# loop step (2): denoise the latents with guidance (with controlnet)
|
||||
class StableDiffusionXLControlNetLoopDenoiser(PipelineBlock):
|
||||
class StableDiffusionXLControlNetLoopDenoiser(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -277,11 +272,6 @@ class StableDiffusionXLControlNetLoopDenoiser(PipelineBlock):
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return [
|
||||
InputParam("cross_attention_kwargs"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam(
|
||||
"controlnet_cond",
|
||||
required=True,
|
||||
@@ -449,7 +439,7 @@ class StableDiffusionXLControlNetLoopDenoiser(PipelineBlock):
|
||||
|
||||
|
||||
# loop step (3): scheduler step to update latents
|
||||
class StableDiffusionXLLoopAfterDenoiser(PipelineBlock):
|
||||
class StableDiffusionXLLoopAfterDenoiser(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -470,11 +460,6 @@ class StableDiffusionXLLoopAfterDenoiser(PipelineBlock):
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return [
|
||||
InputParam("eta", default=0.0),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam("generator"),
|
||||
]
|
||||
|
||||
@@ -520,7 +505,7 @@ class StableDiffusionXLLoopAfterDenoiser(PipelineBlock):
|
||||
|
||||
|
||||
# loop step (3): scheduler step to update latents (with inpainting)
|
||||
class StableDiffusionXLInpaintLoopAfterDenoiser(PipelineBlock):
|
||||
class StableDiffusionXLInpaintLoopAfterDenoiser(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -542,11 +527,6 @@ class StableDiffusionXLInpaintLoopAfterDenoiser(PipelineBlock):
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return [
|
||||
InputParam("eta", default=0.0),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam("generator"),
|
||||
InputParam(
|
||||
"timesteps",
|
||||
@@ -660,7 +640,7 @@ class StableDiffusionXLDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
|
||||
]
|
||||
|
||||
@property
|
||||
def loop_intermediate_inputs(self) -> List[InputParam]:
|
||||
def loop_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
"timesteps",
|
||||
|
||||
@@ -35,7 +35,7 @@ from ...utils import (
|
||||
scale_lora_layers,
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from ..modular_pipeline import PipelineBlock, PipelineState
|
||||
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
|
||||
from .modular_pipeline import StableDiffusionXLModularPipeline
|
||||
|
||||
@@ -57,7 +57,7 @@ def retrieve_latents(
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
class StableDiffusionXLIPAdapterStep(PipelineBlock):
|
||||
class StableDiffusionXLIPAdapterStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -215,7 +215,7 @@ class StableDiffusionXLIPAdapterStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLTextEncoderStep(PipelineBlock):
|
||||
class StableDiffusionXLTextEncoderStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -576,7 +576,7 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLVaeEncoderStep(PipelineBlock):
|
||||
class StableDiffusionXLVaeEncoderStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -601,11 +601,6 @@ class StableDiffusionXLVaeEncoderStep(PipelineBlock):
|
||||
InputParam("image", required=True),
|
||||
InputParam("height"),
|
||||
InputParam("width"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("generator"),
|
||||
InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"),
|
||||
InputParam(
|
||||
@@ -668,12 +663,11 @@ class StableDiffusionXLVaeEncoderStep(PipelineBlock):
|
||||
block_state.device = components._execution_device
|
||||
block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype
|
||||
|
||||
block_state.image = components.image_processor.preprocess(
|
||||
image = components.image_processor.preprocess(
|
||||
block_state.image, height=block_state.height, width=block_state.width, **block_state.preprocess_kwargs
|
||||
)
|
||||
block_state.image = block_state.image.to(device=block_state.device, dtype=block_state.dtype)
|
||||
|
||||
block_state.batch_size = block_state.image.shape[0]
|
||||
image = image.to(device=block_state.device, dtype=block_state.dtype)
|
||||
block_state.batch_size = image.shape[0]
|
||||
|
||||
# if generator is a list, make sure the length of it matches the length of images (both should be batch_size)
|
||||
if isinstance(block_state.generator, list) and len(block_state.generator) != block_state.batch_size:
|
||||
@@ -682,16 +676,14 @@ class StableDiffusionXLVaeEncoderStep(PipelineBlock):
|
||||
f" size of {block_state.batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
block_state.image_latents = self._encode_vae_image(
|
||||
components, image=block_state.image, generator=block_state.generator
|
||||
)
|
||||
block_state.image_latents = self._encode_vae_image(components, image=image, generator=block_state.generator)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
|
||||
class StableDiffusionXLInpaintVaeEncoderStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -726,11 +718,6 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
|
||||
InputParam("image", required=True),
|
||||
InputParam("mask_image", required=True),
|
||||
InputParam("padding_mask_crop"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"),
|
||||
InputParam("generator"),
|
||||
]
|
||||
@@ -860,34 +847,32 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
|
||||
block_state.crops_coords = None
|
||||
block_state.resize_mode = "default"
|
||||
|
||||
block_state.image = components.image_processor.preprocess(
|
||||
image = components.image_processor.preprocess(
|
||||
block_state.image,
|
||||
height=block_state.height,
|
||||
width=block_state.width,
|
||||
crops_coords=block_state.crops_coords,
|
||||
resize_mode=block_state.resize_mode,
|
||||
)
|
||||
block_state.image = block_state.image.to(dtype=torch.float32)
|
||||
image = image.to(dtype=torch.float32)
|
||||
|
||||
block_state.mask = components.mask_processor.preprocess(
|
||||
mask = components.mask_processor.preprocess(
|
||||
block_state.mask_image,
|
||||
height=block_state.height,
|
||||
width=block_state.width,
|
||||
resize_mode=block_state.resize_mode,
|
||||
crops_coords=block_state.crops_coords,
|
||||
)
|
||||
block_state.masked_image = block_state.image * (block_state.mask < 0.5)
|
||||
block_state.masked_image = image * (mask < 0.5)
|
||||
|
||||
block_state.batch_size = block_state.image.shape[0]
|
||||
block_state.image = block_state.image.to(device=block_state.device, dtype=block_state.dtype)
|
||||
block_state.image_latents = self._encode_vae_image(
|
||||
components, image=block_state.image, generator=block_state.generator
|
||||
)
|
||||
block_state.batch_size = image.shape[0]
|
||||
image = image.to(device=block_state.device, dtype=block_state.dtype)
|
||||
block_state.image_latents = self._encode_vae_image(components, image=image, generator=block_state.generator)
|
||||
|
||||
# 7. Prepare mask latent variables
|
||||
block_state.mask, block_state.masked_image_latents = self.prepare_mask_latents(
|
||||
components,
|
||||
block_state.mask,
|
||||
mask,
|
||||
block_state.masked_image,
|
||||
block_state.batch_size,
|
||||
block_state.height,
|
||||
|
||||
@@ -247,10 +247,6 @@ SDXL_INPUTS_SCHEMA = {
|
||||
"control_mode": InputParam(
|
||||
"control_mode", type_hint=List[int], required=True, description="Control mode for union controlnet"
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
SDXL_INTERMEDIATE_INPUTS_SCHEMA = {
|
||||
"prompt_embeds": InputParam(
|
||||
"prompt_embeds",
|
||||
type_hint=torch.Tensor,
|
||||
@@ -271,13 +267,6 @@ SDXL_INTERMEDIATE_INPUTS_SCHEMA = {
|
||||
"preprocess_kwargs": InputParam(
|
||||
"preprocess_kwargs", type_hint=Optional[dict], description="Kwargs for ImageProcessor"
|
||||
),
|
||||
"latents": InputParam(
|
||||
"latents", type_hint=torch.Tensor, required=True, description="Initial latents for denoising process"
|
||||
),
|
||||
"timesteps": InputParam("timesteps", type_hint=torch.Tensor, required=True, description="Timesteps for inference"),
|
||||
"num_inference_steps": InputParam(
|
||||
"num_inference_steps", type_hint=int, required=True, description="Number of denoising steps"
|
||||
),
|
||||
"latent_timestep": InputParam(
|
||||
"latent_timestep", type_hint=torch.Tensor, required=True, description="Initial noise level timestep"
|
||||
),
|
||||
|
||||
@@ -20,7 +20,7 @@ import torch
|
||||
from ...schedulers import UniPCMultistepScheduler
|
||||
from ...utils import logging
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..modular_pipeline import PipelineBlock, PipelineState
|
||||
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
from .modular_pipeline import WanModularPipeline
|
||||
|
||||
@@ -94,7 +94,7 @@ def retrieve_timesteps(
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class WanInputStep(PipelineBlock):
|
||||
class WanInputStep(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@property
|
||||
@@ -194,7 +194,7 @@ class WanInputStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class WanSetTimestepsStep(PipelineBlock):
|
||||
class WanSetTimestepsStep(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@property
|
||||
@@ -243,7 +243,7 @@ class WanSetTimestepsStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class WanPrepareLatentsStep(PipelineBlock):
|
||||
class WanPrepareLatentsStep(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@property
|
||||
|
||||
@@ -22,14 +22,14 @@ from ...configuration_utils import FrozenDict
|
||||
from ...models import AutoencoderKLWan
|
||||
from ...utils import logging
|
||||
from ...video_processor import VideoProcessor
|
||||
from ..modular_pipeline import PipelineBlock, PipelineState
|
||||
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class WanDecodeStep(PipelineBlock):
|
||||
class WanDecodeStep(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@property
|
||||
|
||||
@@ -24,7 +24,7 @@ from ...utils import logging
|
||||
from ..modular_pipeline import (
|
||||
BlockState,
|
||||
LoopSequentialPipelineBlocks,
|
||||
PipelineBlock,
|
||||
ModularPipelineBlocks,
|
||||
PipelineState,
|
||||
)
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
@@ -34,7 +34,7 @@ from .modular_pipeline import WanModularPipeline
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class WanLoopDenoiser(PipelineBlock):
|
||||
class WanLoopDenoiser(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@property
|
||||
@@ -132,7 +132,7 @@ class WanLoopDenoiser(PipelineBlock):
|
||||
return components, block_state
|
||||
|
||||
|
||||
class WanLoopAfterDenoiser(PipelineBlock):
|
||||
class WanLoopAfterDenoiser(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@property
|
||||
|
||||
@@ -22,7 +22,7 @@ from transformers import AutoTokenizer, UMT5EncoderModel
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...guiders import ClassifierFreeGuidance
|
||||
from ...utils import is_ftfy_available, logging
|
||||
from ..modular_pipeline import PipelineBlock, PipelineState
|
||||
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
|
||||
from .modular_pipeline import WanModularPipeline
|
||||
|
||||
@@ -51,7 +51,7 @@ def prompt_clean(text):
|
||||
return text
|
||||
|
||||
|
||||
class WanTextEncoderStep(PipelineBlock):
|
||||
class WanTextEncoderStep(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@property
|
||||
|
||||
+1
@@ -10,6 +10,7 @@
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
+1
@@ -10,6 +10,7 @@
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
@@ -1711,6 +1711,11 @@ class ModelTesterMixin:
|
||||
if not self.model_class._supports_group_offloading:
|
||||
pytest.skip("Model does not support group offloading.")
|
||||
|
||||
if self.model_class.__name__ == "QwenImageTransformer2DModel":
|
||||
pytest.skip(
|
||||
"QwenImageTransformer2DModel doesn't support group offloading with disk. Needs to be investigated."
|
||||
)
|
||||
|
||||
def _has_generator_arg(model):
|
||||
sig = inspect.signature(model.forward)
|
||||
params = sig.parameters
|
||||
|
||||
@@ -0,0 +1,101 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import QwenImageTransformer2DModel
|
||||
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
|
||||
|
||||
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class QwenImageTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = QwenImageTransformer2DModel
|
||||
main_input_name = "hidden_states"
|
||||
# We override the items here because the transformer under consideration is small.
|
||||
model_split_percents = [0.7, 0.6, 0.6]
|
||||
|
||||
# Skip setting testing with default: AttnProcessor
|
||||
uses_custom_attn_processor = True
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
return self.prepare_dummy_input()
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (16, 16)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (16, 16)
|
||||
|
||||
def prepare_dummy_input(self, height=4, width=4):
|
||||
batch_size = 1
|
||||
num_latent_channels = embedding_dim = 16
|
||||
sequence_length = 7
|
||||
vae_scale_factor = 4
|
||||
|
||||
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device)
|
||||
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
|
||||
encoder_hidden_states_mask = torch.ones((batch_size, sequence_length)).to(torch_device, torch.long)
|
||||
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
|
||||
orig_height = height * 2 * vae_scale_factor
|
||||
orig_width = width * 2 * vae_scale_factor
|
||||
img_shapes = [(1, orig_height // vae_scale_factor // 2, orig_width // vae_scale_factor // 2)] * batch_size
|
||||
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"encoder_hidden_states_mask": encoder_hidden_states_mask,
|
||||
"timestep": timestep,
|
||||
"img_shapes": img_shapes,
|
||||
"txt_seq_lens": encoder_hidden_states_mask.sum(dim=1).tolist(),
|
||||
}
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"patch_size": 2,
|
||||
"in_channels": 16,
|
||||
"out_channels": 4,
|
||||
"num_layers": 2,
|
||||
"attention_head_dim": 16,
|
||||
"num_attention_heads": 3,
|
||||
"joint_attention_dim": 16,
|
||||
"guidance_embeds": False,
|
||||
"axes_dims_rope": (8, 4, 4),
|
||||
}
|
||||
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"QwenImageTransformer2DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
|
||||
class QwenImageTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
|
||||
model_class = QwenImageTransformer2DModel
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
return QwenImageTransformerTests().prepare_init_args_and_inputs_for_common()
|
||||
|
||||
def prepare_dummy_input(self, height, width):
|
||||
return QwenImageTransformerTests().prepare_dummy_input(height=height, width=width)
|
||||
-4
@@ -117,13 +117,9 @@ class SDXLModularIPAdapterTests:
|
||||
|
||||
_ = blocks.sub_blocks.pop("ip_adapter")
|
||||
parameters = blocks.input_names
|
||||
intermediate_parameters = blocks.intermediate_input_names
|
||||
assert "ip_adapter_image" not in parameters, (
|
||||
"`ip_adapter_image` argument must be removed from the `__call__` method"
|
||||
)
|
||||
assert "ip_adapter_image_embeds" not in intermediate_parameters, (
|
||||
"`ip_adapter_image_embeds` argument must be supported by the `__call__` method"
|
||||
)
|
||||
|
||||
def _get_dummy_image_embeds(self, cross_attention_dim: int = 32):
|
||||
return torch.randn((1, 1, cross_attention_dim), device=torch_device)
|
||||
|
||||
@@ -139,7 +139,6 @@ class ModularPipelineTesterMixin:
|
||||
def test_pipeline_call_signature(self):
|
||||
pipe = self.get_pipeline()
|
||||
input_parameters = pipe.blocks.input_names
|
||||
intermediate_parameters = pipe.blocks.intermediate_input_names
|
||||
optional_parameters = pipe.default_call_parameters
|
||||
|
||||
def _check_for_parameters(parameters, expected_parameters, param_type):
|
||||
@@ -149,7 +148,6 @@ class ModularPipelineTesterMixin:
|
||||
)
|
||||
|
||||
_check_for_parameters(self.params, input_parameters, "input")
|
||||
_check_for_parameters(self.intermediate_params, intermediate_parameters, "intermediate")
|
||||
_check_for_parameters(self.optional_params, optional_parameters, "optional")
|
||||
|
||||
def test_inference_batch_consistent(self, batch_sizes=[2], batch_generator=True):
|
||||
|
||||
@@ -124,37 +124,22 @@ class StableDiffusion3PipelineFastTests(unittest.TestCase, PipelineTesterMixin):
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_stable_diffusion_3_different_prompts(self):
|
||||
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
|
||||
def test_inference(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
output_same_prompt = pipe(**inputs).images[0]
|
||||
image = pipe(**inputs).images[0]
|
||||
generated_slice = image.flatten()
|
||||
generated_slice = np.concatenate([generated_slice[:8], generated_slice[-8:]])
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
inputs["prompt_2"] = "a different prompt"
|
||||
inputs["prompt_3"] = "another different prompt"
|
||||
output_different_prompts = pipe(**inputs).images[0]
|
||||
# fmt: off
|
||||
expected_slice = np.array([0.5112, 0.5228, 0.5235, 0.5524, 0.3188, 0.5017, 0.5574, 0.4899, 0.6812, 0.5991, 0.3908, 0.5213, 0.5582, 0.4457, 0.4204, 0.5616])
|
||||
# fmt: on
|
||||
|
||||
max_diff = np.abs(output_same_prompt - output_different_prompts).max()
|
||||
|
||||
# Outputs should be different here
|
||||
assert max_diff > 1e-2
|
||||
|
||||
def test_stable_diffusion_3_different_negative_prompts(self):
|
||||
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
output_same_prompt = pipe(**inputs).images[0]
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
inputs["negative_prompt_2"] = "deformed"
|
||||
inputs["negative_prompt_3"] = "blurry"
|
||||
output_different_prompts = pipe(**inputs).images[0]
|
||||
|
||||
max_diff = np.abs(output_same_prompt - output_different_prompts).max()
|
||||
|
||||
# Outputs should be different here
|
||||
assert max_diff > 1e-2
|
||||
self.assertTrue(
|
||||
np.allclose(generated_slice, expected_slice, atol=1e-3), "Output does not match expected slice."
|
||||
)
|
||||
|
||||
def test_fused_qkv_projections(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
@@ -268,40 +253,9 @@ class StableDiffusion3PipelineSlowTests(unittest.TestCase):
|
||||
|
||||
image = pipe(**inputs).images[0]
|
||||
image_slice = image[0, :10, :10]
|
||||
expected_slice = np.array(
|
||||
[
|
||||
0.4648,
|
||||
0.4404,
|
||||
0.4177,
|
||||
0.5063,
|
||||
0.4800,
|
||||
0.4287,
|
||||
0.5425,
|
||||
0.5190,
|
||||
0.4717,
|
||||
0.5430,
|
||||
0.5195,
|
||||
0.4766,
|
||||
0.5361,
|
||||
0.5122,
|
||||
0.4612,
|
||||
0.4871,
|
||||
0.4749,
|
||||
0.4058,
|
||||
0.4756,
|
||||
0.4678,
|
||||
0.3804,
|
||||
0.4832,
|
||||
0.4822,
|
||||
0.3799,
|
||||
0.5103,
|
||||
0.5034,
|
||||
0.3953,
|
||||
0.5073,
|
||||
0.4839,
|
||||
0.3884,
|
||||
]
|
||||
)
|
||||
# fmt: off
|
||||
expected_slice = np.array([0.4648, 0.4404, 0.4177, 0.5063, 0.4800, 0.4287, 0.5425, 0.5190, 0.4717, 0.5430, 0.5195, 0.4766, 0.5361, 0.5122, 0.4612, 0.4871, 0.4749, 0.4058, 0.4756, 0.4678, 0.3804, 0.4832, 0.4822, 0.3799, 0.5103, 0.5034, 0.3953, 0.5073, 0.4839, 0.3884])
|
||||
# fmt: on
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten())
|
||||
|
||||
|
||||
@@ -128,37 +128,22 @@ class StableDiffusion3Img2ImgPipelineFastTests(PipelineLatentTesterMixin, unitte
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_stable_diffusion_3_img2img_different_prompts(self):
|
||||
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
|
||||
def test_inference(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
output_same_prompt = pipe(**inputs).images[0]
|
||||
image = pipe(**inputs).images[0]
|
||||
generated_slice = image.flatten()
|
||||
generated_slice = np.concatenate([generated_slice[:8], generated_slice[-8:]])
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
inputs["prompt_2"] = "a different prompt"
|
||||
inputs["prompt_3"] = "another different prompt"
|
||||
output_different_prompts = pipe(**inputs).images[0]
|
||||
# fmt: off
|
||||
expected_slice = np.array([0.4564, 0.5486, 0.4868, 0.5923, 0.3775, 0.5543, 0.4807, 0.4177, 0.3778, 0.5957, 0.5726, 0.4333, 0.6312, 0.5062, 0.4838, 0.5984])
|
||||
# fmt: on
|
||||
|
||||
max_diff = np.abs(output_same_prompt - output_different_prompts).max()
|
||||
|
||||
# Outputs should be different here
|
||||
assert max_diff > 1e-2
|
||||
|
||||
def test_stable_diffusion_3_img2img_different_negative_prompts(self):
|
||||
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
output_same_prompt = pipe(**inputs).images[0]
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
inputs["negative_prompt_2"] = "deformed"
|
||||
inputs["negative_prompt_3"] = "blurry"
|
||||
output_different_prompts = pipe(**inputs).images[0]
|
||||
|
||||
max_diff = np.abs(output_same_prompt - output_different_prompts).max()
|
||||
|
||||
# Outputs should be different here
|
||||
assert max_diff > 1e-2
|
||||
self.assertTrue(
|
||||
np.allclose(generated_slice, expected_slice, atol=1e-3), "Output does not match expected slice."
|
||||
)
|
||||
|
||||
@unittest.skip("Skip for now.")
|
||||
def test_multi_vae(self):
|
||||
@@ -207,112 +192,16 @@ class StableDiffusion3Img2ImgPipelineSlowTests(unittest.TestCase):
|
||||
inputs = self.get_inputs(torch_device)
|
||||
image = pipe(**inputs).images[0]
|
||||
image_slice = image[0, :10, :10]
|
||||
|
||||
# fmt: off
|
||||
expected_slices = Expectations(
|
||||
{
|
||||
("xpu", 3): np.array(
|
||||
[
|
||||
0.5117,
|
||||
0.4421,
|
||||
0.3852,
|
||||
0.5044,
|
||||
0.4219,
|
||||
0.3262,
|
||||
0.5024,
|
||||
0.4329,
|
||||
0.3276,
|
||||
0.4978,
|
||||
0.4412,
|
||||
0.3355,
|
||||
0.4983,
|
||||
0.4338,
|
||||
0.3279,
|
||||
0.4893,
|
||||
0.4241,
|
||||
0.3129,
|
||||
0.4875,
|
||||
0.4253,
|
||||
0.3030,
|
||||
0.4961,
|
||||
0.4267,
|
||||
0.2988,
|
||||
0.5029,
|
||||
0.4255,
|
||||
0.3054,
|
||||
0.5132,
|
||||
0.4248,
|
||||
0.3222,
|
||||
]
|
||||
),
|
||||
("cuda", 7): np.array(
|
||||
[
|
||||
0.5435,
|
||||
0.4673,
|
||||
0.5732,
|
||||
0.4438,
|
||||
0.3557,
|
||||
0.4912,
|
||||
0.4331,
|
||||
0.3491,
|
||||
0.4915,
|
||||
0.4287,
|
||||
0.347,
|
||||
0.4849,
|
||||
0.4355,
|
||||
0.3469,
|
||||
0.4871,
|
||||
0.4431,
|
||||
0.3538,
|
||||
0.4912,
|
||||
0.4521,
|
||||
0.3643,
|
||||
0.5059,
|
||||
0.4587,
|
||||
0.373,
|
||||
0.5166,
|
||||
0.4685,
|
||||
0.3845,
|
||||
0.5264,
|
||||
0.4746,
|
||||
0.3914,
|
||||
0.5342,
|
||||
]
|
||||
),
|
||||
("cuda", 8): np.array(
|
||||
[
|
||||
0.5146,
|
||||
0.4385,
|
||||
0.3826,
|
||||
0.5098,
|
||||
0.4150,
|
||||
0.3218,
|
||||
0.5142,
|
||||
0.4312,
|
||||
0.3298,
|
||||
0.5127,
|
||||
0.4431,
|
||||
0.3411,
|
||||
0.5171,
|
||||
0.4424,
|
||||
0.3374,
|
||||
0.5088,
|
||||
0.4348,
|
||||
0.3242,
|
||||
0.5073,
|
||||
0.4380,
|
||||
0.3174,
|
||||
0.5132,
|
||||
0.4397,
|
||||
0.3115,
|
||||
0.5132,
|
||||
0.4343,
|
||||
0.3118,
|
||||
0.5219,
|
||||
0.4328,
|
||||
0.3256,
|
||||
]
|
||||
),
|
||||
("xpu", 3): np.array([0.5117, 0.4421, 0.3852, 0.5044, 0.4219, 0.3262, 0.5024, 0.4329, 0.3276, 0.4978, 0.4412, 0.3355, 0.4983, 0.4338, 0.3279, 0.4893, 0.4241, 0.3129, 0.4875, 0.4253, 0.3030, 0.4961, 0.4267, 0.2988, 0.5029, 0.4255, 0.3054, 0.5132, 0.4248, 0.3222]),
|
||||
("cuda", 7): np.array([0.5435, 0.4673, 0.5732, 0.4438, 0.3557, 0.4912, 0.4331, 0.3491, 0.4915, 0.4287, 0.347, 0.4849, 0.4355, 0.3469, 0.4871, 0.4431, 0.3538, 0.4912, 0.4521, 0.3643, 0.5059, 0.4587, 0.373, 0.5166, 0.4685, 0.3845, 0.5264, 0.4746, 0.3914, 0.5342]),
|
||||
("cuda", 8): np.array([0.5146, 0.4385, 0.3826, 0.5098, 0.4150, 0.3218, 0.5142, 0.4312, 0.3298, 0.5127, 0.4431, 0.3411, 0.5171, 0.4424, 0.3374, 0.5088, 0.4348, 0.3242, 0.5073, 0.4380, 0.3174, 0.5132, 0.4397, 0.3115, 0.5132, 0.4343, 0.3118, 0.5219, 0.4328, 0.3256]),
|
||||
}
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
expected_slice = expected_slices.get_expectation()
|
||||
|
||||
|
||||
@@ -132,37 +132,23 @@ class StableDiffusion3InpaintPipelineFastTests(PipelineLatentTesterMixin, unitte
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_stable_diffusion_3_inpaint_different_prompts(self):
|
||||
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
|
||||
def test_inference(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
output_same_prompt = pipe(**inputs).images[0]
|
||||
image = pipe(**inputs).images[0]
|
||||
generated_slice = image.flatten()
|
||||
generated_slice = np.concatenate([generated_slice[:8], generated_slice[-8:]])
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
inputs["prompt_2"] = "a different prompt"
|
||||
inputs["prompt_3"] = "another different prompt"
|
||||
output_different_prompts = pipe(**inputs).images[0]
|
||||
# fmt: off
|
||||
expected_slice = np.array([0.5035, 0.6661, 0.5859, 0.413, 0.4224, 0.4234, 0.7181, 0.5062, 0.5183, 0.6877, 0.5074, 0.585, 0.6111, 0.5422, 0.5306, 0.5891])
|
||||
# fmt: on
|
||||
|
||||
max_diff = np.abs(output_same_prompt - output_different_prompts).max()
|
||||
|
||||
# Outputs should be different here
|
||||
assert max_diff > 1e-2
|
||||
|
||||
def test_stable_diffusion_3_inpaint_different_negative_prompts(self):
|
||||
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
output_same_prompt = pipe(**inputs).images[0]
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
inputs["negative_prompt_2"] = "deformed"
|
||||
inputs["negative_prompt_3"] = "blurry"
|
||||
output_different_prompts = pipe(**inputs).images[0]
|
||||
|
||||
max_diff = np.abs(output_same_prompt - output_different_prompts).max()
|
||||
|
||||
# Outputs should be different here
|
||||
assert max_diff > 1e-2
|
||||
self.assertTrue(
|
||||
np.allclose(generated_slice, expected_slice, atol=1e-3), "Output does not match expected slice."
|
||||
)
|
||||
|
||||
@unittest.skip("Skip for now.")
|
||||
def test_multi_vae(self):
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user