From 7d02ce673cdbfe54a47c1a241a74c4b9620ca3e0 Mon Sep 17 00:00:00 2001 From: yuyu5333 <1812107659@qq.com> Date: Tue, 18 Nov 2025 03:17:17 +0000 Subject: [PATCH] fix: attn_forwad when is_causal=True assert attn_mask is None --- model/model_minimind.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/model/model_minimind.py b/model/model_minimind.py index ecd99b6..e6b6096 100755 --- a/model/model_minimind.py +++ b/model/model_minimind.py @@ -194,13 +194,7 @@ class Attention(nn.Module): ) if self.flash and seq_len > 1 and (attention_mask is None or torch.all(attention_mask == 1)): - attn_mask = ( - None - if attention_mask is None - else attention_mask.view(bsz, 1, 1, -1).expand(bsz, self.n_local_heads, seq_len, -1).bool() - ) - - output = F.scaled_dot_product_attention(xq, xk, xv, attn_mask=attn_mask, dropout_p=self.dropout if self.training else 0.0, is_causal=True) + output = F.scaled_dot_product_attention(xq, xk, xv, dropout_p=self.dropout if self.training else 0.0, is_causal=True) else: scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim) scores = scores + torch.triu(