start pyramid attention broadcast
This commit is contained in:
@@ -477,7 +477,7 @@ class Attention(nn.Module):
|
||||
# For standard processors that are defined here, `**cross_attention_kwargs` is empty
|
||||
|
||||
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
|
||||
quiet_attn_parameters = {"ip_adapter_masks"}
|
||||
quiet_attn_parameters = {"ip_adapter_masks", "image_rotary_emb"}
|
||||
unused_kwargs = [
|
||||
k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters
|
||||
]
|
||||
|
||||
@@ -29,6 +29,7 @@ from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
|
||||
from ...utils import logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ...video_processor import VideoProcessor
|
||||
from ..pyramid_broadcast_utils import PyramidAttentionBroadcastMixin
|
||||
from .pipeline_output import CogVideoXPipelineOutput
|
||||
|
||||
|
||||
@@ -137,7 +138,7 @@ def retrieve_timesteps(
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
||||
class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin, PyramidAttentionBroadcastMixin):
|
||||
r"""
|
||||
Pipeline for text-to-video generation using CogVideoX.
|
||||
|
||||
@@ -605,6 +606,7 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
||||
)
|
||||
self._guidance_scale = guidance_scale
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._current_timestep = None
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Default call parameters
|
||||
@@ -674,6 +676,7 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
self._current_timestep = t
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
@@ -729,6 +732,7 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
self._current_timestep = None
|
||||
if not output_type == "latent":
|
||||
video = self.decode_latents(latents)
|
||||
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
|
||||
|
||||
@@ -0,0 +1,111 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from ..models.attention_processor import Attention, AttentionProcessor
|
||||
from ..utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class PyramidAttentionBroadcastAttentionProcessor:
|
||||
def __init__(self, pipeline, processor: AttentionProcessor) -> None:
|
||||
self.pipeline = pipeline
|
||||
self._original_processor = processor
|
||||
self._prev_hidden_states = None
|
||||
self._iteration = 0
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
if (
|
||||
hasattr(self.pipeline, "_current_timestep")
|
||||
and self.pipeline._current_timestep is not None
|
||||
and self._iteration % self.pipeline._pab_skip_range != 0
|
||||
and (
|
||||
self.pipeline._pab_timestep_range[0]
|
||||
< self.pipeline._current_timestep
|
||||
< self.pipeline._pab_timestep_range[1]
|
||||
)
|
||||
):
|
||||
# print("Using cached states:", self.pipeline._current_timestep)
|
||||
hidden_states = self._prev_hidden_states
|
||||
else:
|
||||
hidden_states = self._original_processor(*args, **kwargs)
|
||||
self._prev_hidden_states = hidden_states
|
||||
|
||||
self._iteration = (self._iteration + 1) % self.pipeline.num_timesteps
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class PyramidAttentionBroadcastMixin:
|
||||
r"""Mixin class for [Pyramid Attention Broadcast](https://www.arxiv.org/abs/2408.12588)."""
|
||||
|
||||
def _enable_pyramid_attention_broadcast(self) -> None:
|
||||
# def is_fake_integral_match(layer_id, name):
|
||||
# layer_id = layer_id.split(".")[-1]
|
||||
# name = name.split(".")[-1]
|
||||
# return layer_id.isnumeric() and name.isnumeric() and layer_id == name
|
||||
|
||||
denoiser: nn.Module = self.transformer if hasattr(self, "transformer") else self.unet
|
||||
|
||||
for name, module in denoiser.named_modules():
|
||||
if isinstance(module, Attention):
|
||||
module.processor = PyramidAttentionBroadcastAttentionProcessor(self, module.processor)
|
||||
|
||||
# target_modules = {}
|
||||
|
||||
# for layer_id in self._pab_skip_range:
|
||||
# for name, module in denoiser.named_modules():
|
||||
# if (
|
||||
# isinstance(module, Attention)
|
||||
# and re.search(layer_id, name) is not None
|
||||
# and not is_fake_integral_match(layer_id, name)
|
||||
# ):
|
||||
# target_modules[name] = module
|
||||
|
||||
# for name, module in target_modules.items():
|
||||
# # TODO: make this debug
|
||||
# logger.info(f"Enabling Pyramid Attention Broadcast in layer: {name}")
|
||||
# module.processor = PyramidAttentionBroadcastAttentionProcessor(self, module.processor)
|
||||
|
||||
def _disable_pyramid_attention_broadcast(self) -> None:
|
||||
denoiser: nn.Module = self.transformer if hasattr(self, "transformer") else self.unet
|
||||
for name, module in denoiser.named_modules():
|
||||
if isinstance(module, Attention) and isinstance(
|
||||
module.processor, PyramidAttentionBroadcastAttentionProcessor
|
||||
):
|
||||
# TODO: make this debug
|
||||
logger.info(f"Disabling Pyramid Attention Broadcast in layer: {name}")
|
||||
module.processor = module.processor._original_processor
|
||||
|
||||
def enable_pyramid_attention_broadcast(self, skip_range: int, timestep_range: Tuple[int, int]) -> None:
|
||||
if isinstance(skip_range, str):
|
||||
skip_range = [skip_range]
|
||||
|
||||
self._pab_skip_range = skip_range
|
||||
self._pab_timestep_range = timestep_range
|
||||
|
||||
self._enable_pyramid_attention_broadcast()
|
||||
|
||||
def disable_pyramid_attention_broadcast(self) -> None:
|
||||
self._pab_timestep_range = None
|
||||
self._pab_skip_range = None
|
||||
|
||||
@property
|
||||
def pyramid_attention_broadcast_enabled(self):
|
||||
return hasattr(self, "_pab_skip_range") and self._pab_skip_range is not None
|
||||
Reference in New Issue
Block a user