diff --git a/model/model_minimind.py b/model/model_minimind.py index ecd99b6..e6b6096 100755 --- a/model/model_minimind.py +++ b/model/model_minimind.py @@ -194,13 +194,7 @@ class Attention(nn.Module): ) if self.flash and seq_len > 1 and (attention_mask is None or torch.all(attention_mask == 1)): - attn_mask = ( - None - if attention_mask is None - else attention_mask.view(bsz, 1, 1, -1).expand(bsz, self.n_local_heads, seq_len, -1).bool() - ) - - output = F.scaled_dot_product_attention(xq, xk, xv, attn_mask=attn_mask, dropout_p=self.dropout if self.training else 0.0, is_causal=True) + output = F.scaled_dot_product_attention(xq, xk, xv, dropout_p=self.dropout if self.training else 0.0, is_causal=True) else: scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim) scores = scores + torch.triu(