This commit is contained in:
zhangtao0408
2025-11-27 12:24:02 +03:00
committed by GitHub
@@ -11,7 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import math
import inspect import inspect
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
@@ -38,11 +39,18 @@ from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
from mindiesd import attention_forward as mindie_sd_attn_forward
STREAM_VECTOR = torch.npu.Stream()
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
if torch.distributed.is_available():
import torch.distributed._functional_collectives as funcol
def _get_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None): from ..attention_dispatch import npu_fusion_attention
def _get_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None, cal_q=True):
if cal_q:
query = attn.to_q(hidden_states) query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states) key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states) value = attn.to_v(hidden_states)
@@ -52,9 +60,10 @@ def _get_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states
encoder_query = attn.add_q_proj(encoder_hidden_states) encoder_query = attn.add_q_proj(encoder_hidden_states)
encoder_key = attn.add_k_proj(encoder_hidden_states) encoder_key = attn.add_k_proj(encoder_hidden_states)
encoder_value = attn.add_v_proj(encoder_hidden_states) encoder_value = attn.add_v_proj(encoder_hidden_states)
if cal_q:
return query, key, value, encoder_query, encoder_key, encoder_value return query, key, value, encoder_query, encoder_key, encoder_value
else:
return value, encoder_query, encoder_key, encoder_value
def _get_fused_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None): def _get_fused_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None):
query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1) query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
@@ -66,11 +75,38 @@ def _get_fused_projections(attn: "FluxAttention", hidden_states, encoder_hidden_
return query, key, value, encoder_query, encoder_key, encoder_value return query, key, value, encoder_query, encoder_key, encoder_value
def _get_qkv_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None): def _get_qkv_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None, cal_q=True):
if attn.fused_projections: if attn.fused_projections and cal_q:
return _get_fused_projections(attn, hidden_states, encoder_hidden_states) return _get_fused_projections(attn, hidden_states, encoder_hidden_states)
return _get_projections(attn, hidden_states, encoder_hidden_states) return _get_projections(attn, hidden_states, encoder_hidden_states, cal_q)
def _wait_tensor(tensor):
if isinstance(tensor, funcol.AsyncCollectiveTensor):
tensor = tensor.wait()
return tensor
def _all_to_all_single(x: torch.Tensor, group) -> torch.Tensor:
shape = x.shape
x = x.flatten()
x = funcol.all_to_all_single(x, None, None, group)
x = x.reshape(shape)
x = _wait_tensor(x)
return x
def ulysses_preforward(
x: torch.Tensor,
group,
world_size,
B,
S_LOCAL,
H,
D,
H_LOCAL
):
x = x.reshape(B, S_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous()
x = x.flatten()
x = funcol.all_to_all_single(x, None, None, group)
return x
class FluxAttnProcessor: class FluxAttnProcessor:
_attention_backend = None _attention_backend = None
@@ -87,11 +123,27 @@ class FluxAttnProcessor:
encoder_hidden_states: torch.Tensor = None, encoder_hidden_states: torch.Tensor = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None,
pre_query: Optional[torch.Tensor] = None,
pre_key: Optional[torch.Tensor] = None,
cal_q=True
) -> torch.Tensor: ) -> torch.Tensor:
query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections( if hasattr(self._parallel_config, "context_parallel_config") and \
attn, hidden_states, encoder_hidden_states self._parallel_config.context_parallel_config is not None:
return self._context_parallel_forward_qkv(
attn, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, pre_query, pre_key, cal_q
) )
qkv_proj_out = _get_qkv_projections(
attn, hidden_states, encoder_hidden_states, cal_q
)
if cal_q:
query, key, value, encoder_query, encoder_key, encoder_value = qkv_proj_out
else:
value, encoder_query, encoder_key, encoder_value = qkv_proj_out
query = pre_query
key = pre_key
query = query.unflatten(-1, (attn.heads, -1)) query = query.unflatten(-1, (attn.heads, -1))
key = key.unflatten(-1, (attn.heads, -1)) key = key.unflatten(-1, (attn.heads, -1))
value = value.unflatten(-1, (attn.heads, -1)) value = value.unflatten(-1, (attn.heads, -1))
@@ -138,6 +190,106 @@ class FluxAttnProcessor:
else: else:
return hidden_states return hidden_states
def _context_parallel_forward_qkv(
self,
attn: "FluxAttention",
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
pre_query: Optional[torch.Tensor] = None,
pre_key: Optional[torch.Tensor] = None,
cal_q=True
) -> torch.Tensor:
ulysses_mesh = self._parallel_config.context_parallel_config._ulysses_mesh
world_size = self._parallel_config.context_parallel_config.ulysses_degree
group = ulysses_mesh.get_group()
ev_q = torch.npu.Event()
ev_k = torch.npu.Event()
query = attn.to_q(hidden_states)
query = query.unflatten(-1, (attn.heads, -1))
ev_q.record()
key = attn.to_k(hidden_states)
key = key.unflatten(-1, (attn.heads, -1))
ev_k.record()
value = attn.to_v(hidden_states)
value = value.unflatten(-1, (attn.heads, -1))
if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None:
encoder_value = attn.add_v_proj(encoder_hidden_states)
encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
value = torch.cat([encoder_value, value], dim=1)
with torch.npu.stream(STREAM_VECTOR):
ev_q.wait()
query = attn.norm_q(query)
if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None:
encoder_query = attn.add_q_proj(encoder_hidden_states)
encoder_query = encoder_query.unflatten(-1, (attn.heads, -1))
encoder_query = attn.norm_added_q(encoder_query)
query = torch.cat([encoder_query, query], dim=1)
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
B, S_Q_LOCAL, H, D = query.shape
H_LOCAL = H // world_size
query_all = ulysses_preforward(query, group, world_size, B, S_Q_LOCAL, H, D, H_LOCAL)
ev_k.wait()
key = attn.norm_k(key)
if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None:
encoder_key = attn.add_k_proj(encoder_hidden_states)
encoder_key = encoder_key.unflatten(-1, (attn.heads, -1))
encoder_key = attn.norm_added_k(encoder_key)
key = torch.cat([encoder_key, key], dim=1)
if image_rotary_emb is not None:
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
_, S_KV_LOCAL, _, _ = key.shape
key_all = ulysses_preforward(key, group, world_size, B, S_KV_LOCAL, H, D, H_LOCAL)
value_all = ulysses_preforward(value, group, world_size, B, S_KV_LOCAL, H, D, H_LOCAL)
query_all = _wait_tensor(query_all)
query_all = query_all.reshape(world_size, S_Q_LOCAL, B, H_LOCAL, D).flatten(0, 1).permute(1, 0, 2, 3).contiguous()
key_all = _wait_tensor(key_all)
key_all = key_all.reshape(world_size, S_KV_LOCAL, B, H_LOCAL, D).flatten(0, 1).permute(1, 0, 2, 3).contiguous()
value_all = _wait_tensor(value_all)
value_all = value_all.reshape(world_size, S_KV_LOCAL, B, H_LOCAL, D).flatten(0, 1).permute(1, 0, 2, 3).contiguous()
out = mindie_sd_attn_forward(
query_all,
key_all,
value_all,
opt_mode="manual",
op_type="ascend_laser_attention",
layout="BNSD"
)
out = out.reshape(B, world_size, S_Q_LOCAL, H_LOCAL, D).permute(1, 3, 0, 2, 4).contiguous()
out = _all_to_all_single(out, group)
hidden_states = out.flatten(0, 1).permute(1, 2, 0, 3).contiguous()
hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.to(query.dtype)
if encoder_hidden_states is not None:
encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
[encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
)
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
return hidden_states, encoder_hidden_states
else:
return hidden_states
class FluxIPAdapterAttnProcessor(torch.nn.Module): class FluxIPAdapterAttnProcessor(torch.nn.Module):
"""Flux Attention processor for IP-Adapter.""" """Flux Attention processor for IP-Adapter."""
@@ -633,6 +785,7 @@ class FluxTransformer2DModel(
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
self.gradient_checkpointing = False self.gradient_checkpointing = False
self.image_rotary_emb = None
def forward( def forward(
self, self,
@@ -717,11 +870,15 @@ class FluxTransformer2DModel(
img_ids = img_ids[0] img_ids = img_ids[0]
ids = torch.cat((txt_ids, img_ids), dim=0) ids = torch.cat((txt_ids, img_ids), dim=0)
if self.image_rotary_emb is None:
if is_torch_npu_available(): if is_torch_npu_available():
freqs_cos, freqs_sin = self.pos_embed(ids.cpu()) freqs_cos, freqs_sin = self.pos_embed(ids.cpu())
image_rotary_emb = (freqs_cos.npu(), freqs_sin.npu()) self.image_rotary_emb = (
freqs_cos.npu().to(hidden_states.dtype),
freqs_sin.npu().to(hidden_states.dtype)
)
else: else:
image_rotary_emb = self.pos_embed(ids) self.image_rotary_emb = self.pos_embed(ids)
if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs: if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds") ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
@@ -735,7 +892,7 @@ class FluxTransformer2DModel(
hidden_states, hidden_states,
encoder_hidden_states, encoder_hidden_states,
temb, temb,
image_rotary_emb, self.image_rotary_emb,
joint_attention_kwargs, joint_attention_kwargs,
) )
@@ -744,7 +901,7 @@ class FluxTransformer2DModel(
hidden_states=hidden_states, hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
temb=temb, temb=temb,
image_rotary_emb=image_rotary_emb, image_rotary_emb=self.image_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs, joint_attention_kwargs=joint_attention_kwargs,
) )
@@ -767,7 +924,7 @@ class FluxTransformer2DModel(
hidden_states, hidden_states,
encoder_hidden_states, encoder_hidden_states,
temb, temb,
image_rotary_emb, self.image_rotary_emb,
joint_attention_kwargs, joint_attention_kwargs,
) )
@@ -776,7 +933,7 @@ class FluxTransformer2DModel(
hidden_states=hidden_states, hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
temb=temb, temb=temb,
image_rotary_emb=image_rotary_emb, image_rotary_emb=self.image_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs, joint_attention_kwargs=joint_attention_kwargs,
) )