Replace padding with pad_sequence; Add gradient checkpointing.
This commit is contained in:
parent
2bb39f46cd
commit
71e8049a84
@ -19,6 +19,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
@ -355,6 +356,7 @@ class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOr
|
||||
|
||||
self.rope_theta = rope_theta
|
||||
self.t_scale = t_scale
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
assert len(all_patch_size) == len(all_f_patch_size)
|
||||
|
||||
@ -579,29 +581,18 @@ class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOr
|
||||
x = list(x.split(x_item_seqlens, dim=0))
|
||||
x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0))
|
||||
|
||||
pad_tensor = torch.zeros((1, self.dim), dtype=x[0].dtype, device=device)
|
||||
freqs_pad_tensor = torch.zeros(
|
||||
(1, self.dim // self.n_heads // 2),
|
||||
dtype=x_freqs_cis[0].dtype,
|
||||
device=device,
|
||||
)
|
||||
x_attn_mask = torch.ones((bsz, x_max_item_seqlen), dtype=torch.bool, device=device)
|
||||
for i, (item, freqs_item) in enumerate(zip(x, x_freqs_cis)):
|
||||
seq_len = x_item_seqlens[i]
|
||||
pad_len = x_max_item_seqlen - seq_len
|
||||
x[i] = torch.cat([item, pad_tensor.repeat(pad_len, 1)])
|
||||
x_freqs_cis[i] = torch.cat([freqs_item, freqs_pad_tensor.repeat(pad_len, 1)])
|
||||
x_attn_mask[i, seq_len:] = 0
|
||||
x = torch.stack(x)
|
||||
x_freqs_cis = torch.stack(x_freqs_cis)
|
||||
x = pad_sequence(x, batch_first=True, padding_value=0.0)
|
||||
x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0)
|
||||
x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device)
|
||||
for i, seq_len in enumerate(x_item_seqlens):
|
||||
x_attn_mask[i, :seq_len] = 1
|
||||
|
||||
for layer in self.noise_refiner:
|
||||
x = layer(
|
||||
x,
|
||||
x_attn_mask,
|
||||
x_freqs_cis,
|
||||
adaln_input,
|
||||
)
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
for layer in self.noise_refiner:
|
||||
x = self._gradient_checkpointing_func(layer, x, x_attn_mask, x_freqs_cis, adaln_input)
|
||||
else:
|
||||
for layer in self.noise_refiner:
|
||||
x = layer(x, x_attn_mask, x_freqs_cis, adaln_input)
|
||||
|
||||
# cap embed & refine
|
||||
cap_item_seqlens = [len(_) for _ in cap_feats]
|
||||
@ -614,29 +605,18 @@ class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOr
|
||||
cap_feats = list(cap_feats.split(cap_item_seqlens, dim=0))
|
||||
cap_freqs_cis = list(self.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split(cap_item_seqlens, dim=0))
|
||||
|
||||
# Reuse padding tensors (convert dtype if needed)
|
||||
cap_pad_tensor = pad_tensor.to(cap_feats[0].dtype) if pad_tensor.dtype != cap_feats[0].dtype else pad_tensor
|
||||
cap_freqs_pad_tensor = (
|
||||
freqs_pad_tensor.to(cap_freqs_cis[0].dtype)
|
||||
if freqs_pad_tensor.dtype != cap_freqs_cis[0].dtype
|
||||
else freqs_pad_tensor
|
||||
)
|
||||
cap_attn_mask = torch.ones((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device)
|
||||
for i, (item, freqs_item) in enumerate(zip(cap_feats, cap_freqs_cis)):
|
||||
seq_len = cap_item_seqlens[i]
|
||||
pad_len = cap_max_item_seqlen - seq_len
|
||||
cap_feats[i] = torch.cat([item, cap_pad_tensor.repeat(pad_len, 1)])
|
||||
cap_freqs_cis[i] = torch.cat([freqs_item, cap_freqs_pad_tensor.repeat(pad_len, 1)])
|
||||
cap_attn_mask[i, seq_len:] = 0
|
||||
cap_feats = torch.stack(cap_feats)
|
||||
cap_freqs_cis = torch.stack(cap_freqs_cis)
|
||||
cap_feats = pad_sequence(cap_feats, batch_first=True, padding_value=0.0)
|
||||
cap_freqs_cis = pad_sequence(cap_freqs_cis, batch_first=True, padding_value=0.0)
|
||||
cap_attn_mask = torch.zeros((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device)
|
||||
for i, seq_len in enumerate(cap_item_seqlens):
|
||||
cap_attn_mask[i, :seq_len] = 1
|
||||
|
||||
for layer in self.context_refiner:
|
||||
cap_feats = layer(
|
||||
cap_feats,
|
||||
cap_attn_mask,
|
||||
cap_freqs_cis,
|
||||
)
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
for layer in self.context_refiner:
|
||||
cap_feats = self._gradient_checkpointing_func(layer, cap_feats, cap_attn_mask, cap_freqs_cis)
|
||||
else:
|
||||
for layer in self.context_refiner:
|
||||
cap_feats = layer(cap_feats, cap_attn_mask, cap_freqs_cis)
|
||||
|
||||
# unified
|
||||
unified = []
|
||||
@ -650,29 +630,18 @@ class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOr
|
||||
assert unified_item_seqlens == [len(_) for _ in unified]
|
||||
unified_max_item_seqlen = max(unified_item_seqlens)
|
||||
|
||||
unified_pad_tensor = pad_tensor.to(unified[0].dtype) if pad_tensor.dtype != unified[0].dtype else pad_tensor
|
||||
unified_freqs_pad_tensor = (
|
||||
freqs_pad_tensor.to(unified_freqs_cis[0].dtype)
|
||||
if freqs_pad_tensor.dtype != unified_freqs_cis[0].dtype
|
||||
else freqs_pad_tensor
|
||||
)
|
||||
unified_attn_mask = torch.ones((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device)
|
||||
for i, (item, freqs_item) in enumerate(zip(unified, unified_freqs_cis)):
|
||||
seq_len = unified_item_seqlens[i]
|
||||
pad_len = unified_max_item_seqlen - seq_len
|
||||
unified[i] = torch.cat([item, unified_pad_tensor.repeat(pad_len, 1)])
|
||||
unified_freqs_cis[i] = torch.cat([freqs_item, unified_freqs_pad_tensor.repeat(pad_len, 1)])
|
||||
unified_attn_mask[i, seq_len:] = 0
|
||||
unified = torch.stack(unified)
|
||||
unified_freqs_cis = torch.stack(unified_freqs_cis)
|
||||
unified = pad_sequence(unified, batch_first=True, padding_value=0.0)
|
||||
unified_freqs_cis = pad_sequence(unified_freqs_cis, batch_first=True, padding_value=0.0)
|
||||
unified_attn_mask = torch.zeros((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device)
|
||||
for i, seq_len in enumerate(unified_item_seqlens):
|
||||
unified_attn_mask[i, :seq_len] = 1
|
||||
|
||||
for layer in self.layers:
|
||||
unified = layer(
|
||||
unified,
|
||||
unified_attn_mask,
|
||||
unified_freqs_cis,
|
||||
adaln_input,
|
||||
)
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
for layer in self.layers:
|
||||
unified = self._gradient_checkpointing_func(layer, unified, unified_attn_mask, unified_freqs_cis, adaln_input)
|
||||
else:
|
||||
for layer in self.layers:
|
||||
unified = layer(unified, unified_attn_mask, unified_freqs_cis, adaln_input)
|
||||
|
||||
unified = self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, adaln_input)
|
||||
unified = list(unified.unbind(dim=0))
|
||||
|
||||
Loading…
Reference in New Issue
Block a user