Merge 2001236861 into 6bf668c4d2
This commit is contained in:
@@ -11,7 +11,8 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import os
|
||||||
|
import math
|
||||||
import inspect
|
import inspect
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
@@ -38,11 +39,18 @@ from ..modeling_outputs import Transformer2DModelOutput
|
|||||||
from ..modeling_utils import ModelMixin
|
from ..modeling_utils import ModelMixin
|
||||||
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
|
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
|
||||||
|
|
||||||
|
from mindiesd import attention_forward as mindie_sd_attn_forward
|
||||||
|
|
||||||
|
STREAM_VECTOR = torch.npu.Stream()
|
||||||
|
|
||||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
if torch.distributed.is_available():
|
||||||
|
import torch.distributed._functional_collectives as funcol
|
||||||
|
|
||||||
def _get_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None):
|
from ..attention_dispatch import npu_fusion_attention
|
||||||
|
def _get_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None, cal_q=True):
|
||||||
|
if cal_q:
|
||||||
query = attn.to_q(hidden_states)
|
query = attn.to_q(hidden_states)
|
||||||
key = attn.to_k(hidden_states)
|
key = attn.to_k(hidden_states)
|
||||||
value = attn.to_v(hidden_states)
|
value = attn.to_v(hidden_states)
|
||||||
@@ -52,9 +60,10 @@ def _get_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states
|
|||||||
encoder_query = attn.add_q_proj(encoder_hidden_states)
|
encoder_query = attn.add_q_proj(encoder_hidden_states)
|
||||||
encoder_key = attn.add_k_proj(encoder_hidden_states)
|
encoder_key = attn.add_k_proj(encoder_hidden_states)
|
||||||
encoder_value = attn.add_v_proj(encoder_hidden_states)
|
encoder_value = attn.add_v_proj(encoder_hidden_states)
|
||||||
|
if cal_q:
|
||||||
return query, key, value, encoder_query, encoder_key, encoder_value
|
return query, key, value, encoder_query, encoder_key, encoder_value
|
||||||
|
else:
|
||||||
|
return value, encoder_query, encoder_key, encoder_value
|
||||||
|
|
||||||
def _get_fused_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None):
|
def _get_fused_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None):
|
||||||
query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
|
query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
|
||||||
@@ -66,11 +75,38 @@ def _get_fused_projections(attn: "FluxAttention", hidden_states, encoder_hidden_
|
|||||||
return query, key, value, encoder_query, encoder_key, encoder_value
|
return query, key, value, encoder_query, encoder_key, encoder_value
|
||||||
|
|
||||||
|
|
||||||
def _get_qkv_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None):
|
def _get_qkv_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None, cal_q=True):
|
||||||
if attn.fused_projections:
|
if attn.fused_projections and cal_q:
|
||||||
return _get_fused_projections(attn, hidden_states, encoder_hidden_states)
|
return _get_fused_projections(attn, hidden_states, encoder_hidden_states)
|
||||||
return _get_projections(attn, hidden_states, encoder_hidden_states)
|
return _get_projections(attn, hidden_states, encoder_hidden_states, cal_q)
|
||||||
|
|
||||||
|
def _wait_tensor(tensor):
|
||||||
|
if isinstance(tensor, funcol.AsyncCollectiveTensor):
|
||||||
|
tensor = tensor.wait()
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
def _all_to_all_single(x: torch.Tensor, group) -> torch.Tensor:
|
||||||
|
shape = x.shape
|
||||||
|
x = x.flatten()
|
||||||
|
x = funcol.all_to_all_single(x, None, None, group)
|
||||||
|
x = x.reshape(shape)
|
||||||
|
x = _wait_tensor(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def ulysses_preforward(
|
||||||
|
x: torch.Tensor,
|
||||||
|
group,
|
||||||
|
world_size,
|
||||||
|
B,
|
||||||
|
S_LOCAL,
|
||||||
|
H,
|
||||||
|
D,
|
||||||
|
H_LOCAL
|
||||||
|
):
|
||||||
|
x = x.reshape(B, S_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous()
|
||||||
|
x = x.flatten()
|
||||||
|
x = funcol.all_to_all_single(x, None, None, group)
|
||||||
|
return x
|
||||||
|
|
||||||
class FluxAttnProcessor:
|
class FluxAttnProcessor:
|
||||||
_attention_backend = None
|
_attention_backend = None
|
||||||
@@ -87,11 +123,27 @@ class FluxAttnProcessor:
|
|||||||
encoder_hidden_states: torch.Tensor = None,
|
encoder_hidden_states: torch.Tensor = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||||
|
pre_query: Optional[torch.Tensor] = None,
|
||||||
|
pre_key: Optional[torch.Tensor] = None,
|
||||||
|
cal_q=True
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
|
if hasattr(self._parallel_config, "context_parallel_config") and \
|
||||||
attn, hidden_states, encoder_hidden_states
|
self._parallel_config.context_parallel_config is not None:
|
||||||
|
|
||||||
|
return self._context_parallel_forward_qkv(
|
||||||
|
attn, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, pre_query, pre_key, cal_q
|
||||||
)
|
)
|
||||||
|
|
||||||
|
qkv_proj_out = _get_qkv_projections(
|
||||||
|
attn, hidden_states, encoder_hidden_states, cal_q
|
||||||
|
)
|
||||||
|
if cal_q:
|
||||||
|
query, key, value, encoder_query, encoder_key, encoder_value = qkv_proj_out
|
||||||
|
else:
|
||||||
|
value, encoder_query, encoder_key, encoder_value = qkv_proj_out
|
||||||
|
query = pre_query
|
||||||
|
key = pre_key
|
||||||
|
|
||||||
query = query.unflatten(-1, (attn.heads, -1))
|
query = query.unflatten(-1, (attn.heads, -1))
|
||||||
key = key.unflatten(-1, (attn.heads, -1))
|
key = key.unflatten(-1, (attn.heads, -1))
|
||||||
value = value.unflatten(-1, (attn.heads, -1))
|
value = value.unflatten(-1, (attn.heads, -1))
|
||||||
@@ -138,6 +190,106 @@ class FluxAttnProcessor:
|
|||||||
else:
|
else:
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
def _context_parallel_forward_qkv(
|
||||||
|
self,
|
||||||
|
attn: "FluxAttention",
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
encoder_hidden_states: torch.Tensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||||
|
pre_query: Optional[torch.Tensor] = None,
|
||||||
|
pre_key: Optional[torch.Tensor] = None,
|
||||||
|
cal_q=True
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
ulysses_mesh = self._parallel_config.context_parallel_config._ulysses_mesh
|
||||||
|
world_size = self._parallel_config.context_parallel_config.ulysses_degree
|
||||||
|
group = ulysses_mesh.get_group()
|
||||||
|
|
||||||
|
ev_q = torch.npu.Event()
|
||||||
|
ev_k = torch.npu.Event()
|
||||||
|
|
||||||
|
query = attn.to_q(hidden_states)
|
||||||
|
query = query.unflatten(-1, (attn.heads, -1))
|
||||||
|
ev_q.record()
|
||||||
|
key = attn.to_k(hidden_states)
|
||||||
|
key = key.unflatten(-1, (attn.heads, -1))
|
||||||
|
ev_k.record()
|
||||||
|
|
||||||
|
value = attn.to_v(hidden_states)
|
||||||
|
value = value.unflatten(-1, (attn.heads, -1))
|
||||||
|
if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None:
|
||||||
|
encoder_value = attn.add_v_proj(encoder_hidden_states)
|
||||||
|
encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
|
||||||
|
value = torch.cat([encoder_value, value], dim=1)
|
||||||
|
|
||||||
|
with torch.npu.stream(STREAM_VECTOR):
|
||||||
|
ev_q.wait()
|
||||||
|
query = attn.norm_q(query)
|
||||||
|
if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None:
|
||||||
|
encoder_query = attn.add_q_proj(encoder_hidden_states)
|
||||||
|
encoder_query = encoder_query.unflatten(-1, (attn.heads, -1))
|
||||||
|
encoder_query = attn.norm_added_q(encoder_query)
|
||||||
|
query = torch.cat([encoder_query, query], dim=1)
|
||||||
|
if image_rotary_emb is not None:
|
||||||
|
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
|
||||||
|
|
||||||
|
B, S_Q_LOCAL, H, D = query.shape
|
||||||
|
H_LOCAL = H // world_size
|
||||||
|
query_all = ulysses_preforward(query, group, world_size, B, S_Q_LOCAL, H, D, H_LOCAL)
|
||||||
|
|
||||||
|
ev_k.wait()
|
||||||
|
key = attn.norm_k(key)
|
||||||
|
if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None:
|
||||||
|
encoder_key = attn.add_k_proj(encoder_hidden_states)
|
||||||
|
encoder_key = encoder_key.unflatten(-1, (attn.heads, -1))
|
||||||
|
encoder_key = attn.norm_added_k(encoder_key)
|
||||||
|
key = torch.cat([encoder_key, key], dim=1)
|
||||||
|
if image_rotary_emb is not None:
|
||||||
|
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
|
||||||
|
|
||||||
|
_, S_KV_LOCAL, _, _ = key.shape
|
||||||
|
key_all = ulysses_preforward(key, group, world_size, B, S_KV_LOCAL, H, D, H_LOCAL)
|
||||||
|
|
||||||
|
value_all = ulysses_preforward(value, group, world_size, B, S_KV_LOCAL, H, D, H_LOCAL)
|
||||||
|
|
||||||
|
query_all = _wait_tensor(query_all)
|
||||||
|
query_all = query_all.reshape(world_size, S_Q_LOCAL, B, H_LOCAL, D).flatten(0, 1).permute(1, 0, 2, 3).contiguous()
|
||||||
|
|
||||||
|
key_all = _wait_tensor(key_all)
|
||||||
|
key_all = key_all.reshape(world_size, S_KV_LOCAL, B, H_LOCAL, D).flatten(0, 1).permute(1, 0, 2, 3).contiguous()
|
||||||
|
|
||||||
|
value_all = _wait_tensor(value_all)
|
||||||
|
value_all = value_all.reshape(world_size, S_KV_LOCAL, B, H_LOCAL, D).flatten(0, 1).permute(1, 0, 2, 3).contiguous()
|
||||||
|
|
||||||
|
out = mindie_sd_attn_forward(
|
||||||
|
query_all,
|
||||||
|
key_all,
|
||||||
|
value_all,
|
||||||
|
opt_mode="manual",
|
||||||
|
op_type="ascend_laser_attention",
|
||||||
|
layout="BNSD"
|
||||||
|
)
|
||||||
|
|
||||||
|
out = out.reshape(B, world_size, S_Q_LOCAL, H_LOCAL, D).permute(1, 3, 0, 2, 4).contiguous()
|
||||||
|
out = _all_to_all_single(out, group)
|
||||||
|
hidden_states = out.flatten(0, 1).permute(1, 2, 0, 3).contiguous()
|
||||||
|
|
||||||
|
hidden_states = hidden_states.flatten(2, 3)
|
||||||
|
hidden_states = hidden_states.to(query.dtype)
|
||||||
|
|
||||||
|
if encoder_hidden_states is not None:
|
||||||
|
encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
|
||||||
|
[encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
|
||||||
|
)
|
||||||
|
hidden_states = attn.to_out[0](hidden_states)
|
||||||
|
hidden_states = attn.to_out[1](hidden_states)
|
||||||
|
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
||||||
|
|
||||||
|
return hidden_states, encoder_hidden_states
|
||||||
|
else:
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class FluxIPAdapterAttnProcessor(torch.nn.Module):
|
class FluxIPAdapterAttnProcessor(torch.nn.Module):
|
||||||
"""Flux Attention processor for IP-Adapter."""
|
"""Flux Attention processor for IP-Adapter."""
|
||||||
@@ -633,6 +785,7 @@ class FluxTransformer2DModel(
|
|||||||
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
|
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
|
||||||
|
|
||||||
self.gradient_checkpointing = False
|
self.gradient_checkpointing = False
|
||||||
|
self.image_rotary_emb = None
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -717,11 +870,15 @@ class FluxTransformer2DModel(
|
|||||||
img_ids = img_ids[0]
|
img_ids = img_ids[0]
|
||||||
|
|
||||||
ids = torch.cat((txt_ids, img_ids), dim=0)
|
ids = torch.cat((txt_ids, img_ids), dim=0)
|
||||||
|
if self.image_rotary_emb is None:
|
||||||
if is_torch_npu_available():
|
if is_torch_npu_available():
|
||||||
freqs_cos, freqs_sin = self.pos_embed(ids.cpu())
|
freqs_cos, freqs_sin = self.pos_embed(ids.cpu())
|
||||||
image_rotary_emb = (freqs_cos.npu(), freqs_sin.npu())
|
self.image_rotary_emb = (
|
||||||
|
freqs_cos.npu().to(hidden_states.dtype),
|
||||||
|
freqs_sin.npu().to(hidden_states.dtype)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
image_rotary_emb = self.pos_embed(ids)
|
self.image_rotary_emb = self.pos_embed(ids)
|
||||||
|
|
||||||
if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
|
if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
|
||||||
ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
|
ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
|
||||||
@@ -735,7 +892,7 @@ class FluxTransformer2DModel(
|
|||||||
hidden_states,
|
hidden_states,
|
||||||
encoder_hidden_states,
|
encoder_hidden_states,
|
||||||
temb,
|
temb,
|
||||||
image_rotary_emb,
|
self.image_rotary_emb,
|
||||||
joint_attention_kwargs,
|
joint_attention_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -744,7 +901,7 @@ class FluxTransformer2DModel(
|
|||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
temb=temb,
|
temb=temb,
|
||||||
image_rotary_emb=image_rotary_emb,
|
image_rotary_emb=self.image_rotary_emb,
|
||||||
joint_attention_kwargs=joint_attention_kwargs,
|
joint_attention_kwargs=joint_attention_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -767,7 +924,7 @@ class FluxTransformer2DModel(
|
|||||||
hidden_states,
|
hidden_states,
|
||||||
encoder_hidden_states,
|
encoder_hidden_states,
|
||||||
temb,
|
temb,
|
||||||
image_rotary_emb,
|
self.image_rotary_emb,
|
||||||
joint_attention_kwargs,
|
joint_attention_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -776,7 +933,7 @@ class FluxTransformer2DModel(
|
|||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
temb=temb,
|
temb=temb,
|
||||||
image_rotary_emb=image_rotary_emb,
|
image_rotary_emb=self.image_rotary_emb,
|
||||||
joint_attention_kwargs=joint_attention_kwargs,
|
joint_attention_kwargs=joint_attention_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user