[refactor] Wan single file implementation (#11918)
* update * update * update * add coauthor Co-Authored-By: Dhruv Nair <dhruv.nair@gmail.com> * improve test * handle ip adapter params correctly * fix chroma qkv fusion test * fix fastercache implementation * remove set_attention_backend related code * fix more tests * fight more tests * add back set_attention_backend * update * update * make style * make fix-copies * make ip adapter processor compatible with attention dispatcher * refactor chroma as well * attnetion dispatcher support * remove transpose; fix rope shape * remove rmsnorm assert * minify and deprecate npu/xla processors * remove rmsnorm assert * minify and deprecate npu/xla processors * update * Update src/diffusers/models/transformers/transformer_wan.py --------- Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
This commit is contained in:
@@ -21,10 +21,10 @@ import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import FeedForward
|
||||
from ..attention_processor import Attention
|
||||
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
@@ -35,18 +35,51 @@ from ..normalization import FP32LayerNorm
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class WanAttnProcessor2_0:
|
||||
def _get_qkv_projections(attn: "WanAttention", hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor):
|
||||
# encoder_hidden_states is only passed for cross-attention
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
|
||||
if attn.fused_projections:
|
||||
if attn.cross_attention_dim_head is None:
|
||||
# In self-attention layers, we can fuse the entire QKV projection into a single linear
|
||||
query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
|
||||
else:
|
||||
# In cross-attention layers, we can only fuse the KV projections into a single linear
|
||||
query = attn.to_q(hidden_states)
|
||||
key, value = attn.to_kv(encoder_hidden_states).chunk(2, dim=-1)
|
||||
else:
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
return query, key, value
|
||||
|
||||
|
||||
def _get_added_kv_projections(attn: "WanAttention", encoder_hidden_states_img: torch.Tensor):
|
||||
if attn.fused_projections:
|
||||
key_img, value_img = attn.to_added_kv(encoder_hidden_states_img).chunk(2, dim=-1)
|
||||
else:
|
||||
key_img = attn.add_k_proj(encoder_hidden_states_img)
|
||||
value_img = attn.add_v_proj(encoder_hidden_states_img)
|
||||
return key_img, value_img
|
||||
|
||||
|
||||
class WanAttnProcessor:
|
||||
_attention_backend = None
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError("WanAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
|
||||
raise ImportError(
|
||||
"WanAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher."
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
attn: "WanAttention",
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
rotary_emb: Optional[torch.Tensor] = None,
|
||||
rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
) -> torch.Tensor:
|
||||
encoder_hidden_states_img = None
|
||||
if attn.add_k_proj is not None:
|
||||
@@ -54,21 +87,15 @@ class WanAttnProcessor2_0:
|
||||
image_context_length = encoder_hidden_states.shape[1] - 512
|
||||
encoder_hidden_states_img = encoder_hidden_states[:, :image_context_length]
|
||||
encoder_hidden_states = encoder_hidden_states[:, image_context_length:]
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
query, key, value = _get_qkv_projections(attn, hidden_states, encoder_hidden_states)
|
||||
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
query = attn.norm_q(query)
|
||||
key = attn.norm_k(key)
|
||||
|
||||
query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
||||
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
||||
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
||||
query = query.unflatten(2, (attn.heads, -1))
|
||||
key = key.unflatten(2, (attn.heads, -1))
|
||||
value = value.unflatten(2, (attn.heads, -1))
|
||||
|
||||
if rotary_emb is not None:
|
||||
|
||||
@@ -77,8 +104,7 @@ class WanAttnProcessor2_0:
|
||||
freqs_cos: torch.Tensor,
|
||||
freqs_sin: torch.Tensor,
|
||||
):
|
||||
x = hidden_states.view(*hidden_states.shape[:-1], -1, 2)
|
||||
x1, x2 = x[..., 0], x[..., 1]
|
||||
x1, x2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1)
|
||||
cos = freqs_cos[..., 0::2]
|
||||
sin = freqs_sin[..., 1::2]
|
||||
out = torch.empty_like(hidden_states)
|
||||
@@ -92,23 +118,34 @@ class WanAttnProcessor2_0:
|
||||
# I2V task
|
||||
hidden_states_img = None
|
||||
if encoder_hidden_states_img is not None:
|
||||
key_img = attn.add_k_proj(encoder_hidden_states_img)
|
||||
key_img, value_img = _get_added_kv_projections(attn, encoder_hidden_states_img)
|
||||
key_img = attn.norm_added_k(key_img)
|
||||
value_img = attn.add_v_proj(encoder_hidden_states_img)
|
||||
|
||||
key_img = key_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
||||
value_img = value_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
||||
key_img = key_img.unflatten(2, (attn.heads, -1))
|
||||
value_img = value_img.unflatten(2, (attn.heads, -1))
|
||||
|
||||
hidden_states_img = F.scaled_dot_product_attention(
|
||||
query, key_img, value_img, attn_mask=None, dropout_p=0.0, is_causal=False
|
||||
hidden_states_img = dispatch_attention_fn(
|
||||
query,
|
||||
key_img,
|
||||
value_img,
|
||||
attn_mask=None,
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
backend=self._attention_backend,
|
||||
)
|
||||
hidden_states_img = hidden_states_img.transpose(1, 2).flatten(2, 3)
|
||||
hidden_states_img = hidden_states_img.flatten(2, 3)
|
||||
hidden_states_img = hidden_states_img.type_as(query)
|
||||
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
hidden_states = dispatch_attention_fn(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=attention_mask,
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
backend=self._attention_backend,
|
||||
)
|
||||
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
hidden_states = hidden_states.type_as(query)
|
||||
|
||||
if hidden_states_img is not None:
|
||||
@@ -119,6 +156,119 @@ class WanAttnProcessor2_0:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class WanAttnProcessor2_0:
|
||||
def __new__(cls, *args, **kwargs):
|
||||
deprecation_message = (
|
||||
"The WanAttnProcessor2_0 class is deprecated and will be removed in a future version. "
|
||||
"Please use WanAttnProcessor instead. "
|
||||
)
|
||||
deprecate("WanAttnProcessor2_0", "1.0.0", deprecation_message, standard_warn=False)
|
||||
return WanAttnProcessor(*args, **kwargs)
|
||||
|
||||
|
||||
class WanAttention(torch.nn.Module, AttentionModuleMixin):
|
||||
_default_processor_cls = WanAttnProcessor
|
||||
_available_processors = [WanAttnProcessor]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
heads: int = 8,
|
||||
dim_head: int = 64,
|
||||
eps: float = 1e-5,
|
||||
dropout: float = 0.0,
|
||||
added_kv_proj_dim: Optional[int] = None,
|
||||
cross_attention_dim_head: Optional[int] = None,
|
||||
processor=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.inner_dim = dim_head * heads
|
||||
self.heads = heads
|
||||
self.added_kv_proj_dim = added_kv_proj_dim
|
||||
self.cross_attention_dim_head = cross_attention_dim_head
|
||||
self.kv_inner_dim = self.inner_dim if cross_attention_dim_head is None else cross_attention_dim_head * heads
|
||||
|
||||
self.to_q = torch.nn.Linear(dim, self.inner_dim, bias=True)
|
||||
self.to_k = torch.nn.Linear(dim, self.kv_inner_dim, bias=True)
|
||||
self.to_v = torch.nn.Linear(dim, self.kv_inner_dim, bias=True)
|
||||
self.to_out = torch.nn.ModuleList(
|
||||
[
|
||||
torch.nn.Linear(self.inner_dim, dim, bias=True),
|
||||
torch.nn.Dropout(dropout),
|
||||
]
|
||||
)
|
||||
self.norm_q = torch.nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True)
|
||||
self.norm_k = torch.nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True)
|
||||
|
||||
self.add_k_proj = self.add_v_proj = None
|
||||
if added_kv_proj_dim is not None:
|
||||
self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
|
||||
self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
|
||||
self.norm_added_k = torch.nn.RMSNorm(dim_head * heads, eps=eps)
|
||||
|
||||
self.set_processor(processor)
|
||||
|
||||
def fuse_projections(self):
|
||||
if getattr(self, "fused_projections", False):
|
||||
return
|
||||
|
||||
if self.cross_attention_dim_head is None:
|
||||
concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
|
||||
concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
|
||||
out_features, in_features = concatenated_weights.shape
|
||||
with torch.device("meta"):
|
||||
self.to_qkv = nn.Linear(in_features, out_features, bias=True)
|
||||
self.to_qkv.load_state_dict(
|
||||
{"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
|
||||
)
|
||||
else:
|
||||
concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
|
||||
concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])
|
||||
out_features, in_features = concatenated_weights.shape
|
||||
with torch.device("meta"):
|
||||
self.to_kv = nn.Linear(in_features, out_features, bias=True)
|
||||
self.to_kv.load_state_dict(
|
||||
{"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
|
||||
)
|
||||
|
||||
if self.added_kv_proj_dim is not None:
|
||||
concatenated_weights = torch.cat([self.add_k_proj.weight.data, self.add_v_proj.weight.data])
|
||||
concatenated_bias = torch.cat([self.add_k_proj.bias.data, self.add_v_proj.bias.data])
|
||||
out_features, in_features = concatenated_weights.shape
|
||||
with torch.device("meta"):
|
||||
self.to_added_kv = nn.Linear(in_features, out_features, bias=True)
|
||||
self.to_added_kv.load_state_dict(
|
||||
{"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
|
||||
)
|
||||
|
||||
self.fused_projections = True
|
||||
|
||||
@torch.no_grad()
|
||||
def unfuse_projections(self):
|
||||
if not getattr(self, "fused_projections", False):
|
||||
return
|
||||
|
||||
if hasattr(self, "to_qkv"):
|
||||
delattr(self, "to_qkv")
|
||||
if hasattr(self, "to_kv"):
|
||||
delattr(self, "to_kv")
|
||||
if hasattr(self, "to_added_kv"):
|
||||
delattr(self, "to_added_kv")
|
||||
|
||||
self.fused_projections = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, rotary_emb, **kwargs)
|
||||
|
||||
|
||||
class WanImageEmbedding(torch.nn.Module):
|
||||
def __init__(self, in_features: int, out_features: int, pos_embed_seq_len=None):
|
||||
super().__init__()
|
||||
@@ -247,8 +397,8 @@ class WanRotaryPosEmbed(nn.Module):
|
||||
freqs_sin_h = freqs_sin[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
|
||||
freqs_sin_w = freqs_sin[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
|
||||
|
||||
freqs_cos = torch.cat([freqs_cos_f, freqs_cos_h, freqs_cos_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1)
|
||||
freqs_sin = torch.cat([freqs_sin_f, freqs_sin_h, freqs_sin_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1)
|
||||
freqs_cos = torch.cat([freqs_cos_f, freqs_cos_h, freqs_cos_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1)
|
||||
freqs_sin = torch.cat([freqs_sin_f, freqs_sin_h, freqs_sin_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1)
|
||||
|
||||
return freqs_cos, freqs_sin
|
||||
|
||||
@@ -269,33 +419,24 @@ class WanTransformerBlock(nn.Module):
|
||||
|
||||
# 1. Self-attention
|
||||
self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False)
|
||||
self.attn1 = Attention(
|
||||
query_dim=dim,
|
||||
self.attn1 = WanAttention(
|
||||
dim=dim,
|
||||
heads=num_heads,
|
||||
kv_heads=num_heads,
|
||||
dim_head=dim // num_heads,
|
||||
qk_norm=qk_norm,
|
||||
eps=eps,
|
||||
bias=True,
|
||||
cross_attention_dim=None,
|
||||
out_bias=True,
|
||||
processor=WanAttnProcessor2_0(),
|
||||
cross_attention_dim_head=None,
|
||||
processor=WanAttnProcessor(),
|
||||
)
|
||||
|
||||
# 2. Cross-attention
|
||||
self.attn2 = Attention(
|
||||
query_dim=dim,
|
||||
self.attn2 = WanAttention(
|
||||
dim=dim,
|
||||
heads=num_heads,
|
||||
kv_heads=num_heads,
|
||||
dim_head=dim // num_heads,
|
||||
qk_norm=qk_norm,
|
||||
eps=eps,
|
||||
bias=True,
|
||||
cross_attention_dim=None,
|
||||
out_bias=True,
|
||||
added_kv_proj_dim=added_kv_proj_dim,
|
||||
added_proj_bias=True,
|
||||
processor=WanAttnProcessor2_0(),
|
||||
cross_attention_dim_head=dim // num_heads,
|
||||
processor=WanAttnProcessor(),
|
||||
)
|
||||
self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
|
||||
|
||||
@@ -332,12 +473,12 @@ class WanTransformerBlock(nn.Module):
|
||||
|
||||
# 1. Self-attention
|
||||
norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
|
||||
attn_output = self.attn1(hidden_states=norm_hidden_states, rotary_emb=rotary_emb)
|
||||
attn_output = self.attn1(norm_hidden_states, None, None, rotary_emb)
|
||||
hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states)
|
||||
|
||||
# 2. Cross-attention
|
||||
norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states)
|
||||
attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
|
||||
attn_output = self.attn2(norm_hidden_states, encoder_hidden_states, None, None)
|
||||
hidden_states = hidden_states + attn_output
|
||||
|
||||
# 3. Feed-forward
|
||||
@@ -350,7 +491,9 @@ class WanTransformerBlock(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
|
||||
class WanTransformer3DModel(
|
||||
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin
|
||||
):
|
||||
r"""
|
||||
A Transformer model for video-like data used in the Wan model.
|
||||
|
||||
|
||||
@@ -22,12 +22,17 @@ from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ..attention import FeedForward
|
||||
from ..attention_processor import Attention
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import FP32LayerNorm
|
||||
from .transformer_wan import WanAttnProcessor2_0, WanRotaryPosEmbed, WanTimeTextImageEmbedding, WanTransformerBlock
|
||||
from .transformer_wan import (
|
||||
WanAttention,
|
||||
WanAttnProcessor,
|
||||
WanRotaryPosEmbed,
|
||||
WanTimeTextImageEmbedding,
|
||||
WanTransformerBlock,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -55,33 +60,22 @@ class WanVACETransformerBlock(nn.Module):
|
||||
|
||||
# 2. Self-attention
|
||||
self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False)
|
||||
self.attn1 = Attention(
|
||||
query_dim=dim,
|
||||
self.attn1 = WanAttention(
|
||||
dim=dim,
|
||||
heads=num_heads,
|
||||
kv_heads=num_heads,
|
||||
dim_head=dim // num_heads,
|
||||
qk_norm=qk_norm,
|
||||
eps=eps,
|
||||
bias=True,
|
||||
cross_attention_dim=None,
|
||||
out_bias=True,
|
||||
processor=WanAttnProcessor2_0(),
|
||||
processor=WanAttnProcessor(),
|
||||
)
|
||||
|
||||
# 3. Cross-attention
|
||||
self.attn2 = Attention(
|
||||
query_dim=dim,
|
||||
self.attn2 = WanAttention(
|
||||
dim=dim,
|
||||
heads=num_heads,
|
||||
kv_heads=num_heads,
|
||||
dim_head=dim // num_heads,
|
||||
qk_norm=qk_norm,
|
||||
eps=eps,
|
||||
bias=True,
|
||||
cross_attention_dim=None,
|
||||
out_bias=True,
|
||||
added_kv_proj_dim=added_kv_proj_dim,
|
||||
added_proj_bias=True,
|
||||
processor=WanAttnProcessor2_0(),
|
||||
processor=WanAttnProcessor(),
|
||||
)
|
||||
self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
|
||||
|
||||
@@ -116,12 +110,12 @@ class WanVACETransformerBlock(nn.Module):
|
||||
norm_hidden_states = (self.norm1(control_hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(
|
||||
control_hidden_states
|
||||
)
|
||||
attn_output = self.attn1(hidden_states=norm_hidden_states, rotary_emb=rotary_emb)
|
||||
attn_output = self.attn1(norm_hidden_states, None, None, rotary_emb)
|
||||
control_hidden_states = (control_hidden_states.float() + attn_output * gate_msa).type_as(control_hidden_states)
|
||||
|
||||
# 2. Cross-attention
|
||||
norm_hidden_states = self.norm2(control_hidden_states.float()).type_as(control_hidden_states)
|
||||
attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
|
||||
attn_output = self.attn2(norm_hidden_states, encoder_hidden_states, None, None)
|
||||
control_hidden_states = control_hidden_states + attn_output
|
||||
|
||||
# 3. Feed-forward
|
||||
|
||||
Reference in New Issue
Block a user