This commit is contained in:
yiyixuxu
2025-11-26 22:20:25 +01:00
parent e6d4612309
commit e2c62d6798
@@ -97,7 +97,8 @@ class ZSingleStreamAttnProcessor:
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
freqs_cis: Optional[torch.Tensor] = None,
freqs_cos: Optional[torch.Tensor] = None,
freqs_sin: Optional[torch.Tensor] = None,
) -> torch.Tensor:
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
@@ -113,17 +114,26 @@ class ZSingleStreamAttnProcessor:
if attn.norm_k is not None:
key = attn.norm_k(key)
# Apply RoPE
def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
with torch.amp.autocast("cuda", enabled=False):
x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))
freqs_cis = freqs_cis.unsqueeze(2)
x_out = torch.view_as_real(x * freqs_cis).flatten(3)
return x_out.type_as(x_in) # todo
# # Apply RoPE
# def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
# with torch.amp.autocast("cuda", enabled=False):
# x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))
# freqs_cis = freqs_cis.unsqueeze(2)
# x_out = torch.view_as_real(x * freqs_cis).flatten(3)
# return x_out.type_as(x_in) # todo
if freqs_cis is not None:
query = apply_rotary_emb(query, freqs_cis)
key = apply_rotary_emb(key, freqs_cis)
def apply_rotary_emb(x_in: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor) -> torch.Tensor:
freqs_cos = freqs_cos.unsqueeze(2) # [batch, seq, 1, head_dim//2]
freqs_sin = freqs_sin.unsqueeze(2)
x = x_in.reshape(*x_in.shape[:-1], -1, 2)
x0, x1 = x[..., 0], x[..., 1]
out0 = x0 * freqs_cos - x1 * freqs_sin
out1 = x0 * freqs_sin + x1 * freqs_cos
return torch.stack([out0, out1], dim=-1).flatten(-2).type_as(x_in)
if freqs_cos is not None and freqs_sin is not None:
query = apply_rotary_emb(query, freqs_cos, freqs_sin)
key = apply_rotary_emb(key, freqs_cos, freqs_sin)
# Cast to correct dtype
dtype = query.dtype
@@ -219,7 +229,8 @@ class ZImageTransformerBlock(nn.Module):
self,
x: torch.Tensor,
attn_mask: torch.Tensor,
freqs_cis: torch.Tensor,
freqs_cos: torch.Tensor,
freqs_sin: torch.Tensor,
adaln_input: Optional[torch.Tensor] = None,
):
if self.modulation:
@@ -232,7 +243,8 @@ class ZImageTransformerBlock(nn.Module):
attn_out = self.attention(
self.attention_norm1(x) * scale_msa,
attention_mask=attn_mask,
freqs_cis=freqs_cis,
freqs_cos=freqs_cos,
freqs_sin=freqs_sin,
)
x = x + gate_msa * self.attention_norm2(attn_out)
@@ -247,7 +259,8 @@ class ZImageTransformerBlock(nn.Module):
attn_out = self.attention(
self.attention_norm1(x),
attention_mask=attn_mask,
freqs_cis=freqs_cis,
freqs_cos=freqs_cos,
freqs_sin=freqs_sin,
)
x = x + self.attention_norm2(attn_out)
@@ -290,39 +303,48 @@ class RopeEmbedder:
self.axes_dims = axes_dims
self.axes_lens = axes_lens
assert len(axes_dims) == len(axes_lens), "axes_dims and axes_lens must have the same length"
self.freqs_cis = None
self.freqs_cos = None
self.freqs_sin = None
@staticmethod
def precompute_freqs_cis(dim: List[int], end: List[int], theta: float = 256.0):
with torch.device("cpu"):
freqs_cis = []
freqs_cos = []
freqs_sin = []
for i, (d, e) in enumerate(zip(dim, end)):
freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d))
timestep = torch.arange(e, device=freqs.device, dtype=torch.float64)
freqs = torch.outer(timestep, freqs).float()
freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to(torch.complex64) # complex64
freqs_cis.append(freqs_cis_i)
freqs_cos.append(freqs.cos())
freqs_sin.append(freqs.sin())
# freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to(torch.complex64) # complex64
# freqs_cis.append(freqs_cis_i)
return freqs_cis
return freqs_cos, freqs_sin
def __call__(self, ids: torch.Tensor):
assert ids.ndim == 2
assert ids.shape[-1] == len(self.axes_dims)
device = ids.device
if self.freqs_cis is None:
self.freqs_cis = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta)
self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis]
if self.freqs_cos is None or self.freqs_sin is None:
self.freqs_cos, self.freqs_sin = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta)
self.freqs_cos = [f.to(device) for f in self.freqs_cos]
self.freqs_sin = [f.to(device) for f in self.freqs_sin]
else:
# Ensure freqs_cis are on the same device as ids
if self.freqs_cis[0].device != device:
self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis]
if self.freqs_cos[0].device != device:
self.freqs_cos = [f.to(device) for f in self.freqs_cos]
if self.freqs_sin[0].device != device:
self.freqs_sin = [f.to(device) for f in self.freqs_sin]
result = []
cos_result = []
sin_result = []
for i in range(len(self.axes_dims)):
index = ids[:, i]
result.append(self.freqs_cis[i][index])
return torch.cat(result, dim=-1)
cos_result.append(self.freqs_cos[i][index])
sin_result.append(self.freqs_sin[i][index])
return torch.cat(cos_result, dim=-1), torch.cat(sin_result, dim=-1)
class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
@@ -587,20 +609,23 @@ class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOr
adaln_input = t.type_as(x)
x[torch.cat(x_inner_pad_mask)] = self.x_pad_token
x = list(x.split(x_item_seqlens, dim=0))
x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0))
x_freqs_cos, x_freqs_sin =self.rope_embedder(torch.cat(x_pos_ids, dim=0))
x_freqs_cos = list(x_freqs_cos.split(x_item_seqlens, dim=0))
x_freqs_sin = list(x_freqs_sin.split(x_item_seqlens, dim=0))
x = pad_sequence(x, batch_first=True, padding_value=0.0)
x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0)
x_freqs_cos = pad_sequence(x_freqs_cos, batch_first=True, padding_value=0.0)
x_freqs_sin = pad_sequence(x_freqs_sin, batch_first=True, padding_value=0.0)
x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device)
for i, seq_len in enumerate(x_item_seqlens):
x_attn_mask[i, :seq_len] = 1
if torch.is_grad_enabled() and self.gradient_checkpointing:
for layer in self.noise_refiner:
x = self._gradient_checkpointing_func(layer, x, x_attn_mask, x_freqs_cis, adaln_input)
x = self._gradient_checkpointing_func(layer, x, x_attn_mask, x_freqs_cos, x_freqs_sin, adaln_input)
else:
for layer in self.noise_refiner:
x = layer(x, x_attn_mask, x_freqs_cis, adaln_input)
x = layer(x, x_attn_mask, x_freqs_cos, x_freqs_sin, adaln_input)
# cap embed & refine
cap_item_seqlens = [len(_) for _ in cap_feats]
@@ -611,35 +636,41 @@ class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOr
cap_feats = self.cap_embedder(cap_feats)
cap_feats[torch.cat(cap_inner_pad_mask)] = self.cap_pad_token
cap_feats = list(cap_feats.split(cap_item_seqlens, dim=0))
cap_freqs_cis = list(self.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split(cap_item_seqlens, dim=0))
cap_freqs_cos, cap_freqs_sin = self.rope_embedder(torch.cat(cap_pos_ids, dim=0))
cap_freqs_cos = list(cap_freqs_cos.split(cap_item_seqlens, dim=0))
cap_freqs_sin = list(cap_freqs_sin.split(cap_item_seqlens, dim=0))
cap_feats = pad_sequence(cap_feats, batch_first=True, padding_value=0.0)
cap_freqs_cis = pad_sequence(cap_freqs_cis, batch_first=True, padding_value=0.0)
cap_freqs_cos = pad_sequence(cap_freqs_cos, batch_first=True, padding_value=0.0)
cap_freqs_sin = pad_sequence(cap_freqs_sin, batch_first=True, padding_value=0.0)
cap_attn_mask = torch.zeros((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device)
for i, seq_len in enumerate(cap_item_seqlens):
cap_attn_mask[i, :seq_len] = 1
if torch.is_grad_enabled() and self.gradient_checkpointing:
for layer in self.context_refiner:
cap_feats = self._gradient_checkpointing_func(layer, cap_feats, cap_attn_mask, cap_freqs_cis)
cap_feats = self._gradient_checkpointing_func(layer, cap_feats, cap_attn_mask, cap_freqs_cos, cap_freqs_sin)
else:
for layer in self.context_refiner:
cap_feats = layer(cap_feats, cap_attn_mask, cap_freqs_cis)
cap_feats = layer(cap_feats, cap_attn_mask, cap_freqs_cos, cap_freqs_sin)
# unified
unified = []
unified_freqs_cis = []
unified_freqs_cos = []
unified_freqs_sin = []
for i in range(bsz):
x_len = x_item_seqlens[i]
cap_len = cap_item_seqlens[i]
unified.append(torch.cat([x[i][:x_len], cap_feats[i][:cap_len]]))
unified_freqs_cis.append(torch.cat([x_freqs_cis[i][:x_len], cap_freqs_cis[i][:cap_len]]))
unified_freqs_cos.append(torch.cat([x_freqs_cos[i][:x_len], cap_freqs_cos[i][:cap_len]]))
unified_freqs_sin.append(torch.cat([x_freqs_sin[i][:x_len], cap_freqs_sin[i][:cap_len]]))
unified_item_seqlens = [a + b for a, b in zip(cap_item_seqlens, x_item_seqlens)]
assert unified_item_seqlens == [len(_) for _ in unified]
unified_max_item_seqlen = max(unified_item_seqlens)
unified = pad_sequence(unified, batch_first=True, padding_value=0.0)
unified_freqs_cis = pad_sequence(unified_freqs_cis, batch_first=True, padding_value=0.0)
unified_freqs_cos = pad_sequence(unified_freqs_cos, batch_first=True, padding_value=0.0)
unified_freqs_sin = pad_sequence(unified_freqs_sin, batch_first=True, padding_value=0.0)
unified_attn_mask = torch.zeros((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device)
for i, seq_len in enumerate(unified_item_seqlens):
unified_attn_mask[i, :seq_len] = 1
@@ -647,11 +678,11 @@ class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOr
if torch.is_grad_enabled() and self.gradient_checkpointing:
for layer in self.layers:
unified = self._gradient_checkpointing_func(
layer, unified, unified_attn_mask, unified_freqs_cis, adaln_input
layer, unified, unified_attn_mask, unified_freqs_cos, unified_freqs_sin, adaln_input
)
else:
for layer in self.layers:
unified = layer(unified, unified_attn_mask, unified_freqs_cis, adaln_input)
unified = layer(unified, unified_attn_mask, unified_freqs_cos, unified_freqs_sin, adaln_input)
unified = self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, adaln_input)
unified = list(unified.unbind(dim=0))