up
This commit is contained in:
@@ -97,7 +97,8 @@ class ZSingleStreamAttnProcessor:
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
freqs_cis: Optional[torch.Tensor] = None,
|
||||
freqs_cos: Optional[torch.Tensor] = None,
|
||||
freqs_sin: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
@@ -113,17 +114,26 @@ class ZSingleStreamAttnProcessor:
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# Apply RoPE
|
||||
def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
|
||||
with torch.amp.autocast("cuda", enabled=False):
|
||||
x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))
|
||||
freqs_cis = freqs_cis.unsqueeze(2)
|
||||
x_out = torch.view_as_real(x * freqs_cis).flatten(3)
|
||||
return x_out.type_as(x_in) # todo
|
||||
# # Apply RoPE
|
||||
# def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
|
||||
# with torch.amp.autocast("cuda", enabled=False):
|
||||
# x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))
|
||||
# freqs_cis = freqs_cis.unsqueeze(2)
|
||||
# x_out = torch.view_as_real(x * freqs_cis).flatten(3)
|
||||
# return x_out.type_as(x_in) # todo
|
||||
|
||||
if freqs_cis is not None:
|
||||
query = apply_rotary_emb(query, freqs_cis)
|
||||
key = apply_rotary_emb(key, freqs_cis)
|
||||
def apply_rotary_emb(x_in: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor) -> torch.Tensor:
|
||||
freqs_cos = freqs_cos.unsqueeze(2) # [batch, seq, 1, head_dim//2]
|
||||
freqs_sin = freqs_sin.unsqueeze(2)
|
||||
x = x_in.reshape(*x_in.shape[:-1], -1, 2)
|
||||
x0, x1 = x[..., 0], x[..., 1]
|
||||
out0 = x0 * freqs_cos - x1 * freqs_sin
|
||||
out1 = x0 * freqs_sin + x1 * freqs_cos
|
||||
return torch.stack([out0, out1], dim=-1).flatten(-2).type_as(x_in)
|
||||
|
||||
if freqs_cos is not None and freqs_sin is not None:
|
||||
query = apply_rotary_emb(query, freqs_cos, freqs_sin)
|
||||
key = apply_rotary_emb(key, freqs_cos, freqs_sin)
|
||||
|
||||
# Cast to correct dtype
|
||||
dtype = query.dtype
|
||||
@@ -219,7 +229,8 @@ class ZImageTransformerBlock(nn.Module):
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
attn_mask: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
freqs_cos: torch.Tensor,
|
||||
freqs_sin: torch.Tensor,
|
||||
adaln_input: Optional[torch.Tensor] = None,
|
||||
):
|
||||
if self.modulation:
|
||||
@@ -232,7 +243,8 @@ class ZImageTransformerBlock(nn.Module):
|
||||
attn_out = self.attention(
|
||||
self.attention_norm1(x) * scale_msa,
|
||||
attention_mask=attn_mask,
|
||||
freqs_cis=freqs_cis,
|
||||
freqs_cos=freqs_cos,
|
||||
freqs_sin=freqs_sin,
|
||||
)
|
||||
x = x + gate_msa * self.attention_norm2(attn_out)
|
||||
|
||||
@@ -247,7 +259,8 @@ class ZImageTransformerBlock(nn.Module):
|
||||
attn_out = self.attention(
|
||||
self.attention_norm1(x),
|
||||
attention_mask=attn_mask,
|
||||
freqs_cis=freqs_cis,
|
||||
freqs_cos=freqs_cos,
|
||||
freqs_sin=freqs_sin,
|
||||
)
|
||||
x = x + self.attention_norm2(attn_out)
|
||||
|
||||
@@ -290,39 +303,48 @@ class RopeEmbedder:
|
||||
self.axes_dims = axes_dims
|
||||
self.axes_lens = axes_lens
|
||||
assert len(axes_dims) == len(axes_lens), "axes_dims and axes_lens must have the same length"
|
||||
self.freqs_cis = None
|
||||
self.freqs_cos = None
|
||||
self.freqs_sin = None
|
||||
|
||||
@staticmethod
|
||||
def precompute_freqs_cis(dim: List[int], end: List[int], theta: float = 256.0):
|
||||
with torch.device("cpu"):
|
||||
freqs_cis = []
|
||||
freqs_cos = []
|
||||
freqs_sin = []
|
||||
for i, (d, e) in enumerate(zip(dim, end)):
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d))
|
||||
timestep = torch.arange(e, device=freqs.device, dtype=torch.float64)
|
||||
freqs = torch.outer(timestep, freqs).float()
|
||||
freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to(torch.complex64) # complex64
|
||||
freqs_cis.append(freqs_cis_i)
|
||||
freqs_cos.append(freqs.cos())
|
||||
freqs_sin.append(freqs.sin())
|
||||
# freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to(torch.complex64) # complex64
|
||||
# freqs_cis.append(freqs_cis_i)
|
||||
|
||||
return freqs_cis
|
||||
return freqs_cos, freqs_sin
|
||||
|
||||
def __call__(self, ids: torch.Tensor):
|
||||
assert ids.ndim == 2
|
||||
assert ids.shape[-1] == len(self.axes_dims)
|
||||
device = ids.device
|
||||
|
||||
if self.freqs_cis is None:
|
||||
self.freqs_cis = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta)
|
||||
self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis]
|
||||
if self.freqs_cos is None or self.freqs_sin is None:
|
||||
self.freqs_cos, self.freqs_sin = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta)
|
||||
self.freqs_cos = [f.to(device) for f in self.freqs_cos]
|
||||
self.freqs_sin = [f.to(device) for f in self.freqs_sin]
|
||||
else:
|
||||
# Ensure freqs_cis are on the same device as ids
|
||||
if self.freqs_cis[0].device != device:
|
||||
self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis]
|
||||
if self.freqs_cos[0].device != device:
|
||||
self.freqs_cos = [f.to(device) for f in self.freqs_cos]
|
||||
if self.freqs_sin[0].device != device:
|
||||
self.freqs_sin = [f.to(device) for f in self.freqs_sin]
|
||||
|
||||
result = []
|
||||
cos_result = []
|
||||
sin_result = []
|
||||
for i in range(len(self.axes_dims)):
|
||||
index = ids[:, i]
|
||||
result.append(self.freqs_cis[i][index])
|
||||
return torch.cat(result, dim=-1)
|
||||
cos_result.append(self.freqs_cos[i][index])
|
||||
sin_result.append(self.freqs_sin[i][index])
|
||||
return torch.cat(cos_result, dim=-1), torch.cat(sin_result, dim=-1)
|
||||
|
||||
|
||||
class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
|
||||
@@ -587,20 +609,23 @@ class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOr
|
||||
adaln_input = t.type_as(x)
|
||||
x[torch.cat(x_inner_pad_mask)] = self.x_pad_token
|
||||
x = list(x.split(x_item_seqlens, dim=0))
|
||||
x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0))
|
||||
x_freqs_cos, x_freqs_sin =self.rope_embedder(torch.cat(x_pos_ids, dim=0))
|
||||
x_freqs_cos = list(x_freqs_cos.split(x_item_seqlens, dim=0))
|
||||
x_freqs_sin = list(x_freqs_sin.split(x_item_seqlens, dim=0))
|
||||
|
||||
x = pad_sequence(x, batch_first=True, padding_value=0.0)
|
||||
x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0)
|
||||
x_freqs_cos = pad_sequence(x_freqs_cos, batch_first=True, padding_value=0.0)
|
||||
x_freqs_sin = pad_sequence(x_freqs_sin, batch_first=True, padding_value=0.0)
|
||||
x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device)
|
||||
for i, seq_len in enumerate(x_item_seqlens):
|
||||
x_attn_mask[i, :seq_len] = 1
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
for layer in self.noise_refiner:
|
||||
x = self._gradient_checkpointing_func(layer, x, x_attn_mask, x_freqs_cis, adaln_input)
|
||||
x = self._gradient_checkpointing_func(layer, x, x_attn_mask, x_freqs_cos, x_freqs_sin, adaln_input)
|
||||
else:
|
||||
for layer in self.noise_refiner:
|
||||
x = layer(x, x_attn_mask, x_freqs_cis, adaln_input)
|
||||
x = layer(x, x_attn_mask, x_freqs_cos, x_freqs_sin, adaln_input)
|
||||
|
||||
# cap embed & refine
|
||||
cap_item_seqlens = [len(_) for _ in cap_feats]
|
||||
@@ -611,35 +636,41 @@ class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOr
|
||||
cap_feats = self.cap_embedder(cap_feats)
|
||||
cap_feats[torch.cat(cap_inner_pad_mask)] = self.cap_pad_token
|
||||
cap_feats = list(cap_feats.split(cap_item_seqlens, dim=0))
|
||||
cap_freqs_cis = list(self.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split(cap_item_seqlens, dim=0))
|
||||
cap_freqs_cos, cap_freqs_sin = self.rope_embedder(torch.cat(cap_pos_ids, dim=0))
|
||||
cap_freqs_cos = list(cap_freqs_cos.split(cap_item_seqlens, dim=0))
|
||||
cap_freqs_sin = list(cap_freqs_sin.split(cap_item_seqlens, dim=0))
|
||||
|
||||
cap_feats = pad_sequence(cap_feats, batch_first=True, padding_value=0.0)
|
||||
cap_freqs_cis = pad_sequence(cap_freqs_cis, batch_first=True, padding_value=0.0)
|
||||
cap_freqs_cos = pad_sequence(cap_freqs_cos, batch_first=True, padding_value=0.0)
|
||||
cap_freqs_sin = pad_sequence(cap_freqs_sin, batch_first=True, padding_value=0.0)
|
||||
cap_attn_mask = torch.zeros((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device)
|
||||
for i, seq_len in enumerate(cap_item_seqlens):
|
||||
cap_attn_mask[i, :seq_len] = 1
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
for layer in self.context_refiner:
|
||||
cap_feats = self._gradient_checkpointing_func(layer, cap_feats, cap_attn_mask, cap_freqs_cis)
|
||||
cap_feats = self._gradient_checkpointing_func(layer, cap_feats, cap_attn_mask, cap_freqs_cos, cap_freqs_sin)
|
||||
else:
|
||||
for layer in self.context_refiner:
|
||||
cap_feats = layer(cap_feats, cap_attn_mask, cap_freqs_cis)
|
||||
cap_feats = layer(cap_feats, cap_attn_mask, cap_freqs_cos, cap_freqs_sin)
|
||||
|
||||
# unified
|
||||
unified = []
|
||||
unified_freqs_cis = []
|
||||
unified_freqs_cos = []
|
||||
unified_freqs_sin = []
|
||||
for i in range(bsz):
|
||||
x_len = x_item_seqlens[i]
|
||||
cap_len = cap_item_seqlens[i]
|
||||
unified.append(torch.cat([x[i][:x_len], cap_feats[i][:cap_len]]))
|
||||
unified_freqs_cis.append(torch.cat([x_freqs_cis[i][:x_len], cap_freqs_cis[i][:cap_len]]))
|
||||
unified_freqs_cos.append(torch.cat([x_freqs_cos[i][:x_len], cap_freqs_cos[i][:cap_len]]))
|
||||
unified_freqs_sin.append(torch.cat([x_freqs_sin[i][:x_len], cap_freqs_sin[i][:cap_len]]))
|
||||
unified_item_seqlens = [a + b for a, b in zip(cap_item_seqlens, x_item_seqlens)]
|
||||
assert unified_item_seqlens == [len(_) for _ in unified]
|
||||
unified_max_item_seqlen = max(unified_item_seqlens)
|
||||
|
||||
unified = pad_sequence(unified, batch_first=True, padding_value=0.0)
|
||||
unified_freqs_cis = pad_sequence(unified_freqs_cis, batch_first=True, padding_value=0.0)
|
||||
unified_freqs_cos = pad_sequence(unified_freqs_cos, batch_first=True, padding_value=0.0)
|
||||
unified_freqs_sin = pad_sequence(unified_freqs_sin, batch_first=True, padding_value=0.0)
|
||||
unified_attn_mask = torch.zeros((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device)
|
||||
for i, seq_len in enumerate(unified_item_seqlens):
|
||||
unified_attn_mask[i, :seq_len] = 1
|
||||
@@ -647,11 +678,11 @@ class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOr
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
for layer in self.layers:
|
||||
unified = self._gradient_checkpointing_func(
|
||||
layer, unified, unified_attn_mask, unified_freqs_cis, adaln_input
|
||||
layer, unified, unified_attn_mask, unified_freqs_cos, unified_freqs_sin, adaln_input
|
||||
)
|
||||
else:
|
||||
for layer in self.layers:
|
||||
unified = layer(unified, unified_attn_mask, unified_freqs_cis, adaln_input)
|
||||
unified = layer(unified, unified_attn_mask, unified_freqs_cos, unified_freqs_sin, adaln_input)
|
||||
|
||||
unified = self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, adaln_input)
|
||||
unified = list(unified.unbind(dim=0))
|
||||
|
||||
Reference in New Issue
Block a user