Fix CrossAttention._sliced_attention (#563)
* Fix CrossAttention._sliced_attention Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
8d36d5adb1
commit
84616b5de5
@ -249,13 +249,15 @@ class CrossAttention(nn.Module):
|
|||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
def forward(self, hidden_states, context=None, mask=None):
|
def forward(self, hidden_states, context=None, mask=None):
|
||||||
batch_size, sequence_length, dim = hidden_states.shape
|
batch_size, sequence_length, _ = hidden_states.shape
|
||||||
|
|
||||||
query = self.to_q(hidden_states)
|
query = self.to_q(hidden_states)
|
||||||
context = context if context is not None else hidden_states
|
context = context if context is not None else hidden_states
|
||||||
key = self.to_k(context)
|
key = self.to_k(context)
|
||||||
value = self.to_v(context)
|
value = self.to_v(context)
|
||||||
|
|
||||||
|
dim = query.shape[-1]
|
||||||
|
|
||||||
query = self.reshape_heads_to_batch_dim(query)
|
query = self.reshape_heads_to_batch_dim(query)
|
||||||
key = self.reshape_heads_to_batch_dim(key)
|
key = self.reshape_heads_to_batch_dim(key)
|
||||||
value = self.reshape_heads_to_batch_dim(value)
|
value = self.reshape_heads_to_batch_dim(value)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user