mirror of
https://github.com/jingyaogong/minimind.git
synced 2026-01-13 19:57:20 +08:00
fix: attn_forwad when is_causal=True assert attn_mask is None
This commit is contained in:
parent
9c98cabc9a
commit
7d02ce673c
@ -194,13 +194,7 @@ class Attention(nn.Module):
|
||||
)
|
||||
|
||||
if self.flash and seq_len > 1 and (attention_mask is None or torch.all(attention_mask == 1)):
|
||||
attn_mask = (
|
||||
None
|
||||
if attention_mask is None
|
||||
else attention_mask.view(bsz, 1, 1, -1).expand(bsz, self.n_local_heads, seq_len, -1).bool()
|
||||
)
|
||||
|
||||
output = F.scaled_dot_product_attention(xq, xk, xv, attn_mask=attn_mask, dropout_p=self.dropout if self.training else 0.0, is_causal=True)
|
||||
output = F.scaled_dot_product_attention(xq, xk, xv, dropout_p=self.dropout if self.training else 0.0, is_causal=True)
|
||||
else:
|
||||
scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
||||
scores = scores + torch.triu(
|
||||
|
||||
Loading…
Reference in New Issue
Block a user