start pyramid attention broadcast

This commit is contained in:
Aryan
2024-10-01 03:31:06 +02:00
parent c4a8979f30
commit 67c729d448
3 changed files with 117 additions and 2 deletions
+1 -1
View File
@@ -477,7 +477,7 @@ class Attention(nn.Module):
# For standard processors that are defined here, `**cross_attention_kwargs` is empty
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
quiet_attn_parameters = {"ip_adapter_masks"}
quiet_attn_parameters = {"ip_adapter_masks", "image_rotary_emb"}
unused_kwargs = [
k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters
]
@@ -29,6 +29,7 @@ from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
from ...utils import logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from ..pyramid_broadcast_utils import PyramidAttentionBroadcastMixin
from .pipeline_output import CogVideoXPipelineOutput
@@ -137,7 +138,7 @@ def retrieve_timesteps(
return timesteps, num_inference_steps
class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin, PyramidAttentionBroadcastMixin):
r"""
Pipeline for text-to-video generation using CogVideoX.
@@ -605,6 +606,7 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
)
self._guidance_scale = guidance_scale
self._attention_kwargs = attention_kwargs
self._current_timestep = None
self._interrupt = False
# 2. Default call parameters
@@ -674,6 +676,7 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
if self.interrupt:
continue
self._current_timestep = t
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
@@ -729,6 +732,7 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
self._current_timestep = None
if not output_type == "latent":
video = self.decode_latents(latents)
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
@@ -0,0 +1,111 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple
import torch.nn as nn
from ..models.attention_processor import Attention, AttentionProcessor
from ..utils import logging
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class PyramidAttentionBroadcastAttentionProcessor:
def __init__(self, pipeline, processor: AttentionProcessor) -> None:
self.pipeline = pipeline
self._original_processor = processor
self._prev_hidden_states = None
self._iteration = 0
def __call__(self, *args, **kwargs):
if (
hasattr(self.pipeline, "_current_timestep")
and self.pipeline._current_timestep is not None
and self._iteration % self.pipeline._pab_skip_range != 0
and (
self.pipeline._pab_timestep_range[0]
< self.pipeline._current_timestep
< self.pipeline._pab_timestep_range[1]
)
):
# print("Using cached states:", self.pipeline._current_timestep)
hidden_states = self._prev_hidden_states
else:
hidden_states = self._original_processor(*args, **kwargs)
self._prev_hidden_states = hidden_states
self._iteration = (self._iteration + 1) % self.pipeline.num_timesteps
return hidden_states
class PyramidAttentionBroadcastMixin:
r"""Mixin class for [Pyramid Attention Broadcast](https://www.arxiv.org/abs/2408.12588)."""
def _enable_pyramid_attention_broadcast(self) -> None:
# def is_fake_integral_match(layer_id, name):
# layer_id = layer_id.split(".")[-1]
# name = name.split(".")[-1]
# return layer_id.isnumeric() and name.isnumeric() and layer_id == name
denoiser: nn.Module = self.transformer if hasattr(self, "transformer") else self.unet
for name, module in denoiser.named_modules():
if isinstance(module, Attention):
module.processor = PyramidAttentionBroadcastAttentionProcessor(self, module.processor)
# target_modules = {}
# for layer_id in self._pab_skip_range:
# for name, module in denoiser.named_modules():
# if (
# isinstance(module, Attention)
# and re.search(layer_id, name) is not None
# and not is_fake_integral_match(layer_id, name)
# ):
# target_modules[name] = module
# for name, module in target_modules.items():
# # TODO: make this debug
# logger.info(f"Enabling Pyramid Attention Broadcast in layer: {name}")
# module.processor = PyramidAttentionBroadcastAttentionProcessor(self, module.processor)
def _disable_pyramid_attention_broadcast(self) -> None:
denoiser: nn.Module = self.transformer if hasattr(self, "transformer") else self.unet
for name, module in denoiser.named_modules():
if isinstance(module, Attention) and isinstance(
module.processor, PyramidAttentionBroadcastAttentionProcessor
):
# TODO: make this debug
logger.info(f"Disabling Pyramid Attention Broadcast in layer: {name}")
module.processor = module.processor._original_processor
def enable_pyramid_attention_broadcast(self, skip_range: int, timestep_range: Tuple[int, int]) -> None:
if isinstance(skip_range, str):
skip_range = [skip_range]
self._pab_skip_range = skip_range
self._pab_timestep_range = timestep_range
self._enable_pyramid_attention_broadcast()
def disable_pyramid_attention_broadcast(self) -> None:
self._pab_timestep_range = None
self._pab_skip_range = None
@property
def pyramid_attention_broadcast_enabled(self):
return hasattr(self, "_pab_skip_range") and self._pab_skip_range is not None