[Core] add QKV fusion to AuraFlow and PixArt Sigma (#8952)
* add fusion support to pixart * add to auraflow. * add tests * apply review feedback. * add back args and kwargs * style
This commit is contained in:
@@ -227,6 +227,7 @@ class Attention(nn.Module):
|
|||||||
self.to_k = None
|
self.to_k = None
|
||||||
self.to_v = None
|
self.to_v = None
|
||||||
|
|
||||||
|
self.added_proj_bias = added_proj_bias
|
||||||
if self.added_kv_proj_dim is not None:
|
if self.added_kv_proj_dim is not None:
|
||||||
self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias)
|
self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias)
|
||||||
self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias)
|
self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias)
|
||||||
@@ -698,12 +699,15 @@ class Attention(nn.Module):
|
|||||||
in_features = concatenated_weights.shape[1]
|
in_features = concatenated_weights.shape[1]
|
||||||
out_features = concatenated_weights.shape[0]
|
out_features = concatenated_weights.shape[0]
|
||||||
|
|
||||||
self.to_added_qkv = nn.Linear(in_features, out_features, bias=True, device=device, dtype=dtype)
|
self.to_added_qkv = nn.Linear(
|
||||||
self.to_added_qkv.weight.copy_(concatenated_weights)
|
in_features, out_features, bias=self.added_proj_bias, device=device, dtype=dtype
|
||||||
concatenated_bias = torch.cat(
|
|
||||||
[self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data]
|
|
||||||
)
|
)
|
||||||
self.to_added_qkv.bias.copy_(concatenated_bias)
|
self.to_added_qkv.weight.copy_(concatenated_weights)
|
||||||
|
if self.added_proj_bias:
|
||||||
|
concatenated_bias = torch.cat(
|
||||||
|
[self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data]
|
||||||
|
)
|
||||||
|
self.to_added_qkv.bias.copy_(concatenated_bias)
|
||||||
|
|
||||||
self.fused_projections = fuse
|
self.fused_projections = fuse
|
||||||
|
|
||||||
@@ -1274,6 +1278,103 @@ class AuraFlowAttnProcessor2_0:
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class FusedAuraFlowAttnProcessor2_0:
|
||||||
|
"""Attention processor used typically in processing Aura Flow with fused projections."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
if not hasattr(F, "scaled_dot_product_attention") and is_torch_version("<", "2.1"):
|
||||||
|
raise ImportError(
|
||||||
|
"FusedAuraFlowAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to at least 2.1 or above as we use `scale` in `F.scaled_dot_product_attention()`. "
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
attn: Attention,
|
||||||
|
hidden_states: torch.FloatTensor,
|
||||||
|
encoder_hidden_states: torch.FloatTensor = None,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
) -> torch.FloatTensor:
|
||||||
|
batch_size = hidden_states.shape[0]
|
||||||
|
|
||||||
|
# `sample` projections.
|
||||||
|
qkv = attn.to_qkv(hidden_states)
|
||||||
|
split_size = qkv.shape[-1] // 3
|
||||||
|
query, key, value = torch.split(qkv, split_size, dim=-1)
|
||||||
|
|
||||||
|
# `context` projections.
|
||||||
|
if encoder_hidden_states is not None:
|
||||||
|
encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
|
||||||
|
split_size = encoder_qkv.shape[-1] // 3
|
||||||
|
(
|
||||||
|
encoder_hidden_states_query_proj,
|
||||||
|
encoder_hidden_states_key_proj,
|
||||||
|
encoder_hidden_states_value_proj,
|
||||||
|
) = torch.split(encoder_qkv, split_size, dim=-1)
|
||||||
|
|
||||||
|
# Reshape.
|
||||||
|
inner_dim = key.shape[-1]
|
||||||
|
head_dim = inner_dim // attn.heads
|
||||||
|
query = query.view(batch_size, -1, attn.heads, head_dim)
|
||||||
|
key = key.view(batch_size, -1, attn.heads, head_dim)
|
||||||
|
value = value.view(batch_size, -1, attn.heads, head_dim)
|
||||||
|
|
||||||
|
# Apply QK norm.
|
||||||
|
if attn.norm_q is not None:
|
||||||
|
query = attn.norm_q(query)
|
||||||
|
if attn.norm_k is not None:
|
||||||
|
key = attn.norm_k(key)
|
||||||
|
|
||||||
|
# Concatenate the projections.
|
||||||
|
if encoder_hidden_states is not None:
|
||||||
|
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
|
||||||
|
batch_size, -1, attn.heads, head_dim
|
||||||
|
)
|
||||||
|
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(batch_size, -1, attn.heads, head_dim)
|
||||||
|
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
|
||||||
|
batch_size, -1, attn.heads, head_dim
|
||||||
|
)
|
||||||
|
|
||||||
|
if attn.norm_added_q is not None:
|
||||||
|
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
|
||||||
|
if attn.norm_added_k is not None:
|
||||||
|
encoder_hidden_states_key_proj = attn.norm_added_q(encoder_hidden_states_key_proj)
|
||||||
|
|
||||||
|
query = torch.cat([encoder_hidden_states_query_proj, query], dim=1)
|
||||||
|
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
|
||||||
|
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
|
||||||
|
|
||||||
|
query = query.transpose(1, 2)
|
||||||
|
key = key.transpose(1, 2)
|
||||||
|
value = value.transpose(1, 2)
|
||||||
|
|
||||||
|
# Attention.
|
||||||
|
hidden_states = F.scaled_dot_product_attention(
|
||||||
|
query, key, value, dropout_p=0.0, scale=attn.scale, is_causal=False
|
||||||
|
)
|
||||||
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||||
|
hidden_states = hidden_states.to(query.dtype)
|
||||||
|
|
||||||
|
# Split the attention outputs.
|
||||||
|
if encoder_hidden_states is not None:
|
||||||
|
hidden_states, encoder_hidden_states = (
|
||||||
|
hidden_states[:, encoder_hidden_states.shape[1] :],
|
||||||
|
hidden_states[:, : encoder_hidden_states.shape[1]],
|
||||||
|
)
|
||||||
|
|
||||||
|
# linear proj
|
||||||
|
hidden_states = attn.to_out[0](hidden_states)
|
||||||
|
# dropout
|
||||||
|
hidden_states = attn.to_out[1](hidden_states)
|
||||||
|
if encoder_hidden_states is not None:
|
||||||
|
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
||||||
|
|
||||||
|
if encoder_hidden_states is not None:
|
||||||
|
return hidden_states, encoder_hidden_states
|
||||||
|
else:
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
# YiYi to-do: refactor rope related functions/classes
|
# YiYi to-do: refactor rope related functions/classes
|
||||||
def apply_rope(xq, xk, freqs_cis):
|
def apply_rope(xq, xk, freqs_cis):
|
||||||
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
||||||
|
|||||||
@@ -22,7 +22,12 @@ import torch.nn.functional as F
|
|||||||
from ...configuration_utils import ConfigMixin, register_to_config
|
from ...configuration_utils import ConfigMixin, register_to_config
|
||||||
from ...utils import is_torch_version, logging
|
from ...utils import is_torch_version, logging
|
||||||
from ...utils.torch_utils import maybe_allow_in_graph
|
from ...utils.torch_utils import maybe_allow_in_graph
|
||||||
from ..attention_processor import Attention, AuraFlowAttnProcessor2_0
|
from ..attention_processor import (
|
||||||
|
Attention,
|
||||||
|
AttentionProcessor,
|
||||||
|
AuraFlowAttnProcessor2_0,
|
||||||
|
FusedAuraFlowAttnProcessor2_0,
|
||||||
|
)
|
||||||
from ..embeddings import TimestepEmbedding, Timesteps
|
from ..embeddings import TimestepEmbedding, Timesteps
|
||||||
from ..modeling_outputs import Transformer2DModelOutput
|
from ..modeling_outputs import Transformer2DModelOutput
|
||||||
from ..modeling_utils import ModelMixin
|
from ..modeling_utils import ModelMixin
|
||||||
@@ -320,6 +325,106 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin):
|
|||||||
|
|
||||||
self.gradient_checkpointing = False
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||||
|
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||||
|
r"""
|
||||||
|
Returns:
|
||||||
|
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||||
|
indexed by its weight name.
|
||||||
|
"""
|
||||||
|
# set recursively
|
||||||
|
processors = {}
|
||||||
|
|
||||||
|
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||||
|
if hasattr(module, "get_processor"):
|
||||||
|
processors[f"{name}.processor"] = module.get_processor()
|
||||||
|
|
||||||
|
for sub_name, child in module.named_children():
|
||||||
|
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||||
|
|
||||||
|
return processors
|
||||||
|
|
||||||
|
for name, module in self.named_children():
|
||||||
|
fn_recursive_add_processors(name, module, processors)
|
||||||
|
|
||||||
|
return processors
|
||||||
|
|
||||||
|
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||||
|
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||||
|
r"""
|
||||||
|
Sets the attention processor to use to compute attention.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||||
|
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||||
|
for **all** `Attention` layers.
|
||||||
|
|
||||||
|
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||||
|
processor. This is strongly recommended when setting trainable attention processors.
|
||||||
|
|
||||||
|
"""
|
||||||
|
count = len(self.attn_processors.keys())
|
||||||
|
|
||||||
|
if isinstance(processor, dict) and len(processor) != count:
|
||||||
|
raise ValueError(
|
||||||
|
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||||
|
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||||
|
)
|
||||||
|
|
||||||
|
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||||
|
if hasattr(module, "set_processor"):
|
||||||
|
if not isinstance(processor, dict):
|
||||||
|
module.set_processor(processor)
|
||||||
|
else:
|
||||||
|
module.set_processor(processor.pop(f"{name}.processor"))
|
||||||
|
|
||||||
|
for sub_name, child in module.named_children():
|
||||||
|
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||||
|
|
||||||
|
for name, module in self.named_children():
|
||||||
|
fn_recursive_attn_processor(name, module, processor)
|
||||||
|
|
||||||
|
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedAuraFlowAttnProcessor2_0
|
||||||
|
def fuse_qkv_projections(self):
|
||||||
|
"""
|
||||||
|
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
||||||
|
are fused. For cross-attention modules, key and value projection matrices are fused.
|
||||||
|
|
||||||
|
<Tip warning={true}>
|
||||||
|
|
||||||
|
This API is 🧪 experimental.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
"""
|
||||||
|
self.original_attn_processors = None
|
||||||
|
|
||||||
|
for _, attn_processor in self.attn_processors.items():
|
||||||
|
if "Added" in str(attn_processor.__class__.__name__):
|
||||||
|
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
||||||
|
|
||||||
|
self.original_attn_processors = self.attn_processors
|
||||||
|
|
||||||
|
for module in self.modules():
|
||||||
|
if isinstance(module, Attention):
|
||||||
|
module.fuse_projections(fuse=True)
|
||||||
|
|
||||||
|
self.set_attn_processor(FusedAuraFlowAttnProcessor2_0())
|
||||||
|
|
||||||
|
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
||||||
|
def unfuse_qkv_projections(self):
|
||||||
|
"""Disables the fused QKV projection if enabled.
|
||||||
|
|
||||||
|
<Tip warning={true}>
|
||||||
|
|
||||||
|
This API is 🧪 experimental.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
|
||||||
|
"""
|
||||||
|
if self.original_attn_processors is not None:
|
||||||
|
self.set_attn_processor(self.original_attn_processors)
|
||||||
|
|
||||||
def _set_gradient_checkpointing(self, module, value=False):
|
def _set_gradient_checkpointing(self, module, value=False):
|
||||||
if hasattr(module, "gradient_checkpointing"):
|
if hasattr(module, "gradient_checkpointing"):
|
||||||
module.gradient_checkpointing = value
|
module.gradient_checkpointing = value
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ from torch import nn
|
|||||||
from ...configuration_utils import ConfigMixin, register_to_config
|
from ...configuration_utils import ConfigMixin, register_to_config
|
||||||
from ...utils import is_torch_version, logging
|
from ...utils import is_torch_version, logging
|
||||||
from ..attention import BasicTransformerBlock
|
from ..attention import BasicTransformerBlock
|
||||||
from ..attention_processor import AttentionProcessor
|
from ..attention_processor import Attention, AttentionProcessor, FusedAttnProcessor2_0
|
||||||
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
|
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
|
||||||
from ..modeling_outputs import Transformer2DModelOutput
|
from ..modeling_outputs import Transformer2DModelOutput
|
||||||
from ..modeling_utils import ModelMixin
|
from ..modeling_utils import ModelMixin
|
||||||
@@ -247,6 +247,46 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
|
|||||||
for name, module in self.named_children():
|
for name, module in self.named_children():
|
||||||
fn_recursive_attn_processor(name, module, processor)
|
fn_recursive_attn_processor(name, module, processor)
|
||||||
|
|
||||||
|
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
|
||||||
|
def fuse_qkv_projections(self):
|
||||||
|
"""
|
||||||
|
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
||||||
|
are fused. For cross-attention modules, key and value projection matrices are fused.
|
||||||
|
|
||||||
|
<Tip warning={true}>
|
||||||
|
|
||||||
|
This API is 🧪 experimental.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
"""
|
||||||
|
self.original_attn_processors = None
|
||||||
|
|
||||||
|
for _, attn_processor in self.attn_processors.items():
|
||||||
|
if "Added" in str(attn_processor.__class__.__name__):
|
||||||
|
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
||||||
|
|
||||||
|
self.original_attn_processors = self.attn_processors
|
||||||
|
|
||||||
|
for module in self.modules():
|
||||||
|
if isinstance(module, Attention):
|
||||||
|
module.fuse_projections(fuse=True)
|
||||||
|
|
||||||
|
self.set_attn_processor(FusedAttnProcessor2_0())
|
||||||
|
|
||||||
|
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
||||||
|
def unfuse_qkv_projections(self):
|
||||||
|
"""Disables the fused QKV projection if enabled.
|
||||||
|
|
||||||
|
<Tip warning={true}>
|
||||||
|
|
||||||
|
This API is 🧪 experimental.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
|
||||||
|
"""
|
||||||
|
if self.original_attn_processors is not None:
|
||||||
|
self.set_attn_processor(self.original_attn_processors)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
|||||||
@@ -9,7 +9,11 @@ from diffusers.utils.testing_utils import (
|
|||||||
torch_device,
|
torch_device,
|
||||||
)
|
)
|
||||||
|
|
||||||
from ..test_pipelines_common import PipelineTesterMixin
|
from ..test_pipelines_common import (
|
||||||
|
PipelineTesterMixin,
|
||||||
|
check_qkv_fusion_matches_attn_procs_length,
|
||||||
|
check_qkv_fusion_processors_exist,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AuraFlowPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
|
class AuraFlowPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
|
||||||
@@ -119,3 +123,43 @@ class AuraFlowPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
|
|||||||
# Attention slicing needs to implemented differently for this because how single DiT and MMDiT
|
# Attention slicing needs to implemented differently for this because how single DiT and MMDiT
|
||||||
# blocks interfere with each other.
|
# blocks interfere with each other.
|
||||||
return
|
return
|
||||||
|
|
||||||
|
def test_fused_qkv_projections(self):
|
||||||
|
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||||
|
components = self.get_dummy_components()
|
||||||
|
pipe = self.pipeline_class(**components)
|
||||||
|
pipe = pipe.to(device)
|
||||||
|
pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
|
inputs = self.get_dummy_inputs(device)
|
||||||
|
image = pipe(**inputs).images
|
||||||
|
original_image_slice = image[0, -3:, -3:, -1]
|
||||||
|
|
||||||
|
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
|
||||||
|
# to the pipeline level.
|
||||||
|
pipe.transformer.fuse_qkv_projections()
|
||||||
|
assert check_qkv_fusion_processors_exist(
|
||||||
|
pipe.transformer
|
||||||
|
), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
|
||||||
|
assert check_qkv_fusion_matches_attn_procs_length(
|
||||||
|
pipe.transformer, pipe.transformer.original_attn_processors
|
||||||
|
), "Something wrong with the attention processors concerning the fused QKV projections."
|
||||||
|
|
||||||
|
inputs = self.get_dummy_inputs(device)
|
||||||
|
image = pipe(**inputs).images
|
||||||
|
image_slice_fused = image[0, -3:, -3:, -1]
|
||||||
|
|
||||||
|
pipe.transformer.unfuse_qkv_projections()
|
||||||
|
inputs = self.get_dummy_inputs(device)
|
||||||
|
image = pipe(**inputs).images
|
||||||
|
image_slice_disabled = image[0, -3:, -3:, -1]
|
||||||
|
|
||||||
|
assert np.allclose(
|
||||||
|
original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
|
||||||
|
), "Fusion of QKV projections shouldn't affect the outputs."
|
||||||
|
assert np.allclose(
|
||||||
|
image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
|
||||||
|
), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
|
||||||
|
assert np.allclose(
|
||||||
|
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
|
||||||
|
), "Original outputs should match when fused QKV projections are disabled."
|
||||||
|
|||||||
@@ -36,7 +36,12 @@ from diffusers.utils.testing_utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
|
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
|
||||||
from ..test_pipelines_common import PipelineTesterMixin, to_np
|
from ..test_pipelines_common import (
|
||||||
|
PipelineTesterMixin,
|
||||||
|
check_qkv_fusion_matches_attn_procs_length,
|
||||||
|
check_qkv_fusion_processors_exist,
|
||||||
|
to_np,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
enable_full_determinism()
|
enable_full_determinism()
|
||||||
@@ -308,6 +313,46 @@ class PixArtSigmaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
|||||||
def test_inference_batch_single_identical(self):
|
def test_inference_batch_single_identical(self):
|
||||||
self._test_inference_batch_single_identical(expected_max_diff=1e-3)
|
self._test_inference_batch_single_identical(expected_max_diff=1e-3)
|
||||||
|
|
||||||
|
def test_fused_qkv_projections(self):
|
||||||
|
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||||
|
components = self.get_dummy_components()
|
||||||
|
pipe = self.pipeline_class(**components)
|
||||||
|
pipe = pipe.to(device)
|
||||||
|
pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
|
inputs = self.get_dummy_inputs(device)
|
||||||
|
image = pipe(**inputs).images
|
||||||
|
original_image_slice = image[0, -3:, -3:, -1]
|
||||||
|
|
||||||
|
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
|
||||||
|
# to the pipeline level.
|
||||||
|
pipe.transformer.fuse_qkv_projections()
|
||||||
|
assert check_qkv_fusion_processors_exist(
|
||||||
|
pipe.transformer
|
||||||
|
), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
|
||||||
|
assert check_qkv_fusion_matches_attn_procs_length(
|
||||||
|
pipe.transformer, pipe.transformer.original_attn_processors
|
||||||
|
), "Something wrong with the attention processors concerning the fused QKV projections."
|
||||||
|
|
||||||
|
inputs = self.get_dummy_inputs(device)
|
||||||
|
image = pipe(**inputs).images
|
||||||
|
image_slice_fused = image[0, -3:, -3:, -1]
|
||||||
|
|
||||||
|
pipe.transformer.unfuse_qkv_projections()
|
||||||
|
inputs = self.get_dummy_inputs(device)
|
||||||
|
image = pipe(**inputs).images
|
||||||
|
image_slice_disabled = image[0, -3:, -3:, -1]
|
||||||
|
|
||||||
|
assert np.allclose(
|
||||||
|
original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
|
||||||
|
), "Fusion of QKV projections shouldn't affect the outputs."
|
||||||
|
assert np.allclose(
|
||||||
|
image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
|
||||||
|
), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
|
||||||
|
assert np.allclose(
|
||||||
|
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
|
||||||
|
), "Original outputs should match when fused QKV projections are disabled."
|
||||||
|
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
|
|||||||
Reference in New Issue
Block a user