mirror of
https://github.com/jingyaogong/minimind.git
synced 2026-05-01 19:58:15 +08:00
Merge pull request #536 from yuyu5333/fix/attn_forward
fix: attn_forwad when is_causal=True assert attn_mask is None
This commit is contained in:
commit
ce9394670b
@ -194,13 +194,7 @@ class Attention(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.flash and seq_len > 1 and (attention_mask is None or torch.all(attention_mask == 1)):
|
if self.flash and seq_len > 1 and (attention_mask is None or torch.all(attention_mask == 1)):
|
||||||
attn_mask = (
|
output = F.scaled_dot_product_attention(xq, xk, xv, dropout_p=self.dropout if self.training else 0.0, is_causal=True)
|
||||||
None
|
|
||||||
if attention_mask is None
|
|
||||||
else attention_mask.view(bsz, 1, 1, -1).expand(bsz, self.n_local_heads, seq_len, -1).bool()
|
|
||||||
)
|
|
||||||
|
|
||||||
output = F.scaled_dot_product_attention(xq, xk, xv, attn_mask=attn_mask, dropout_p=self.dropout if self.training else 0.0, is_causal=True)
|
|
||||||
else:
|
else:
|
||||||
scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
||||||
scores = scores + torch.triu(
|
scores = scores + torch.triu(
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user