diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 9f9bc5a46e..a207770f2f 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -477,7 +477,7 @@ class Attention(nn.Module): # For standard processors that are defined here, `**cross_attention_kwargs` is empty attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) - quiet_attn_parameters = {"ip_adapter_masks"} + quiet_attn_parameters = {"ip_adapter_masks", "image_rotary_emb"} unused_kwargs = [ k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters ] diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index 82839ffd2c..fe35c85dcc 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -29,6 +29,7 @@ from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler from ...utils import logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ...video_processor import VideoProcessor +from ..pyramid_broadcast_utils import PyramidAttentionBroadcastMixin from .pipeline_output import CogVideoXPipelineOutput @@ -137,7 +138,7 @@ def retrieve_timesteps( return timesteps, num_inference_steps -class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): +class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin, PyramidAttentionBroadcastMixin): r""" Pipeline for text-to-video generation using CogVideoX. @@ -605,6 +606,7 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): ) self._guidance_scale = guidance_scale self._attention_kwargs = attention_kwargs + self._current_timestep = None self._interrupt = False # 2. Default call parameters @@ -674,6 +676,7 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): if self.interrupt: continue + self._current_timestep = t latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) @@ -729,6 +732,7 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() + self._current_timestep = None if not output_type == "latent": video = self.decode_latents(latents) video = self.video_processor.postprocess_video(video=video, output_type=output_type) diff --git a/src/diffusers/pipelines/pyramid_broadcast_utils.py b/src/diffusers/pipelines/pyramid_broadcast_utils.py new file mode 100644 index 0000000000..23d3a6b25c --- /dev/null +++ b/src/diffusers/pipelines/pyramid_broadcast_utils.py @@ -0,0 +1,111 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple + +import torch.nn as nn + +from ..models.attention_processor import Attention, AttentionProcessor +from ..utils import logging + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class PyramidAttentionBroadcastAttentionProcessor: + def __init__(self, pipeline, processor: AttentionProcessor) -> None: + self.pipeline = pipeline + self._original_processor = processor + self._prev_hidden_states = None + self._iteration = 0 + + def __call__(self, *args, **kwargs): + if ( + hasattr(self.pipeline, "_current_timestep") + and self.pipeline._current_timestep is not None + and self._iteration % self.pipeline._pab_skip_range != 0 + and ( + self.pipeline._pab_timestep_range[0] + < self.pipeline._current_timestep + < self.pipeline._pab_timestep_range[1] + ) + ): + # print("Using cached states:", self.pipeline._current_timestep) + hidden_states = self._prev_hidden_states + else: + hidden_states = self._original_processor(*args, **kwargs) + self._prev_hidden_states = hidden_states + + self._iteration = (self._iteration + 1) % self.pipeline.num_timesteps + + return hidden_states + + +class PyramidAttentionBroadcastMixin: + r"""Mixin class for [Pyramid Attention Broadcast](https://www.arxiv.org/abs/2408.12588).""" + + def _enable_pyramid_attention_broadcast(self) -> None: + # def is_fake_integral_match(layer_id, name): + # layer_id = layer_id.split(".")[-1] + # name = name.split(".")[-1] + # return layer_id.isnumeric() and name.isnumeric() and layer_id == name + + denoiser: nn.Module = self.transformer if hasattr(self, "transformer") else self.unet + + for name, module in denoiser.named_modules(): + if isinstance(module, Attention): + module.processor = PyramidAttentionBroadcastAttentionProcessor(self, module.processor) + + # target_modules = {} + + # for layer_id in self._pab_skip_range: + # for name, module in denoiser.named_modules(): + # if ( + # isinstance(module, Attention) + # and re.search(layer_id, name) is not None + # and not is_fake_integral_match(layer_id, name) + # ): + # target_modules[name] = module + + # for name, module in target_modules.items(): + # # TODO: make this debug + # logger.info(f"Enabling Pyramid Attention Broadcast in layer: {name}") + # module.processor = PyramidAttentionBroadcastAttentionProcessor(self, module.processor) + + def _disable_pyramid_attention_broadcast(self) -> None: + denoiser: nn.Module = self.transformer if hasattr(self, "transformer") else self.unet + for name, module in denoiser.named_modules(): + if isinstance(module, Attention) and isinstance( + module.processor, PyramidAttentionBroadcastAttentionProcessor + ): + # TODO: make this debug + logger.info(f"Disabling Pyramid Attention Broadcast in layer: {name}") + module.processor = module.processor._original_processor + + def enable_pyramid_attention_broadcast(self, skip_range: int, timestep_range: Tuple[int, int]) -> None: + if isinstance(skip_range, str): + skip_range = [skip_range] + + self._pab_skip_range = skip_range + self._pab_timestep_range = timestep_range + + self._enable_pyramid_attention_broadcast() + + def disable_pyramid_attention_broadcast(self) -> None: + self._pab_timestep_range = None + self._pab_skip_range = None + + @property + def pyramid_attention_broadcast_enabled(self): + return hasattr(self, "_pab_skip_range") and self._pab_skip_range is not None