From f5374dc87ff469c909331115aa11cd3d6df43f59 Mon Sep 17 00:00:00 2001 From: jingyaogong Date: Wed, 19 Nov 2025 22:26:53 +0800 Subject: [PATCH] [fix] model attn_mask --- model/model_minimind.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/model/model_minimind.py b/model/model_minimind.py index 2d640f6..e6b6096 100755 --- a/model/model_minimind.py +++ b/model/model_minimind.py @@ -193,16 +193,8 @@ class Attention(nn.Module): repeat_kv(xv, self.n_rep).transpose(1, 2) ) - if self.flash and seq_len > 1: - if attention_mask is None or torch.all(attention_mask == 1): - attn_mask, is_causal = None, True - else: - causal_mask = torch.triu(torch.full((seq_len, seq_len), float("-inf"), device=xq.device), diagonal=1) - extended_mask = (1.0 - attention_mask.unsqueeze(1).unsqueeze(2)) * float("-inf") - attn_mask, is_causal = causal_mask.unsqueeze(0) + extended_mask, False - - dropout_p = self.dropout if self.training else 0.0 - output = F.scaled_dot_product_attention(xq, xk, xv, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal) + if self.flash and seq_len > 1 and (attention_mask is None or torch.all(attention_mask == 1)): + output = F.scaled_dot_product_attention(xq, xk, xv, dropout_p=self.dropout if self.training else 0.0, is_causal=True) else: scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim) scores = scores + torch.triu(