fix: attn_forwad when is_causal=True assert attn_mask is None

This commit is contained in:
yuyu5333 2025-11-18 03:17:17 +00:00
parent 9c98cabc9a
commit 7d02ce673c

View File

@ -194,13 +194,7 @@ class Attention(nn.Module):
)
if self.flash and seq_len > 1 and (attention_mask is None or torch.all(attention_mask == 1)):
attn_mask = (
None
if attention_mask is None
else attention_mask.view(bsz, 1, 1, -1).expand(bsz, self.n_local_heads, seq_len, -1).bool()
)
output = F.scaled_dot_product_attention(xq, xk, xv, attn_mask=attn_mask, dropout_p=self.dropout if self.training else 0.0, is_causal=True)
output = F.scaled_dot_product_attention(xq, xk, xv, dropout_p=self.dropout if self.training else 0.0, is_causal=True)
else:
scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim)
scores = scores + torch.triu(