Replace padding with pad_sequence; Add gradient checkpointing.

This commit is contained in:
Jerry Qilong Wu 2025-11-24 18:37:24 +00:00
parent 2bb39f46cd
commit 71e8049a84

View File

@ -19,6 +19,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torch.nn.utils.rnn import pad_sequence
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
@ -355,6 +356,7 @@ class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOr
self.rope_theta = rope_theta
self.t_scale = t_scale
self.gradient_checkpointing = False
assert len(all_patch_size) == len(all_f_patch_size)
@ -579,29 +581,18 @@ class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOr
x = list(x.split(x_item_seqlens, dim=0))
x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0))
pad_tensor = torch.zeros((1, self.dim), dtype=x[0].dtype, device=device)
freqs_pad_tensor = torch.zeros(
(1, self.dim // self.n_heads // 2),
dtype=x_freqs_cis[0].dtype,
device=device,
)
x_attn_mask = torch.ones((bsz, x_max_item_seqlen), dtype=torch.bool, device=device)
for i, (item, freqs_item) in enumerate(zip(x, x_freqs_cis)):
seq_len = x_item_seqlens[i]
pad_len = x_max_item_seqlen - seq_len
x[i] = torch.cat([item, pad_tensor.repeat(pad_len, 1)])
x_freqs_cis[i] = torch.cat([freqs_item, freqs_pad_tensor.repeat(pad_len, 1)])
x_attn_mask[i, seq_len:] = 0
x = torch.stack(x)
x_freqs_cis = torch.stack(x_freqs_cis)
x = pad_sequence(x, batch_first=True, padding_value=0.0)
x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0)
x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device)
for i, seq_len in enumerate(x_item_seqlens):
x_attn_mask[i, :seq_len] = 1
for layer in self.noise_refiner:
x = layer(
x,
x_attn_mask,
x_freqs_cis,
adaln_input,
)
if torch.is_grad_enabled() and self.gradient_checkpointing:
for layer in self.noise_refiner:
x = self._gradient_checkpointing_func(layer, x, x_attn_mask, x_freqs_cis, adaln_input)
else:
for layer in self.noise_refiner:
x = layer(x, x_attn_mask, x_freqs_cis, adaln_input)
# cap embed & refine
cap_item_seqlens = [len(_) for _ in cap_feats]
@ -614,29 +605,18 @@ class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOr
cap_feats = list(cap_feats.split(cap_item_seqlens, dim=0))
cap_freqs_cis = list(self.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split(cap_item_seqlens, dim=0))
# Reuse padding tensors (convert dtype if needed)
cap_pad_tensor = pad_tensor.to(cap_feats[0].dtype) if pad_tensor.dtype != cap_feats[0].dtype else pad_tensor
cap_freqs_pad_tensor = (
freqs_pad_tensor.to(cap_freqs_cis[0].dtype)
if freqs_pad_tensor.dtype != cap_freqs_cis[0].dtype
else freqs_pad_tensor
)
cap_attn_mask = torch.ones((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device)
for i, (item, freqs_item) in enumerate(zip(cap_feats, cap_freqs_cis)):
seq_len = cap_item_seqlens[i]
pad_len = cap_max_item_seqlen - seq_len
cap_feats[i] = torch.cat([item, cap_pad_tensor.repeat(pad_len, 1)])
cap_freqs_cis[i] = torch.cat([freqs_item, cap_freqs_pad_tensor.repeat(pad_len, 1)])
cap_attn_mask[i, seq_len:] = 0
cap_feats = torch.stack(cap_feats)
cap_freqs_cis = torch.stack(cap_freqs_cis)
cap_feats = pad_sequence(cap_feats, batch_first=True, padding_value=0.0)
cap_freqs_cis = pad_sequence(cap_freqs_cis, batch_first=True, padding_value=0.0)
cap_attn_mask = torch.zeros((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device)
for i, seq_len in enumerate(cap_item_seqlens):
cap_attn_mask[i, :seq_len] = 1
for layer in self.context_refiner:
cap_feats = layer(
cap_feats,
cap_attn_mask,
cap_freqs_cis,
)
if torch.is_grad_enabled() and self.gradient_checkpointing:
for layer in self.context_refiner:
cap_feats = self._gradient_checkpointing_func(layer, cap_feats, cap_attn_mask, cap_freqs_cis)
else:
for layer in self.context_refiner:
cap_feats = layer(cap_feats, cap_attn_mask, cap_freqs_cis)
# unified
unified = []
@ -650,29 +630,18 @@ class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOr
assert unified_item_seqlens == [len(_) for _ in unified]
unified_max_item_seqlen = max(unified_item_seqlens)
unified_pad_tensor = pad_tensor.to(unified[0].dtype) if pad_tensor.dtype != unified[0].dtype else pad_tensor
unified_freqs_pad_tensor = (
freqs_pad_tensor.to(unified_freqs_cis[0].dtype)
if freqs_pad_tensor.dtype != unified_freqs_cis[0].dtype
else freqs_pad_tensor
)
unified_attn_mask = torch.ones((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device)
for i, (item, freqs_item) in enumerate(zip(unified, unified_freqs_cis)):
seq_len = unified_item_seqlens[i]
pad_len = unified_max_item_seqlen - seq_len
unified[i] = torch.cat([item, unified_pad_tensor.repeat(pad_len, 1)])
unified_freqs_cis[i] = torch.cat([freqs_item, unified_freqs_pad_tensor.repeat(pad_len, 1)])
unified_attn_mask[i, seq_len:] = 0
unified = torch.stack(unified)
unified_freqs_cis = torch.stack(unified_freqs_cis)
unified = pad_sequence(unified, batch_first=True, padding_value=0.0)
unified_freqs_cis = pad_sequence(unified_freqs_cis, batch_first=True, padding_value=0.0)
unified_attn_mask = torch.zeros((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device)
for i, seq_len in enumerate(unified_item_seqlens):
unified_attn_mask[i, :seq_len] = 1
for layer in self.layers:
unified = layer(
unified,
unified_attn_mask,
unified_freqs_cis,
adaln_input,
)
if torch.is_grad_enabled() and self.gradient_checkpointing:
for layer in self.layers:
unified = self._gradient_checkpointing_func(layer, unified, unified_attn_mask, unified_freqs_cis, adaln_input)
else:
for layer in self.layers:
unified = layer(unified, unified_attn_mask, unified_freqs_cis, adaln_input)
unified = self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, adaln_input)
unified = list(unified.unbind(dim=0))