mirror of
https://github.com/jingyaogong/minimind.git
synced 2026-01-13 19:57:20 +08:00
[fix] update model
This commit is contained in:
parent
ce9394670b
commit
a044578d73
@ -193,8 +193,16 @@ class Attention(nn.Module):
|
||||
repeat_kv(xv, self.n_rep).transpose(1, 2)
|
||||
)
|
||||
|
||||
if self.flash and seq_len > 1 and (attention_mask is None or torch.all(attention_mask == 1)):
|
||||
output = F.scaled_dot_product_attention(xq, xk, xv, dropout_p=self.dropout if self.training else 0.0, is_causal=True)
|
||||
if self.flash and seq_len > 1:
|
||||
if attention_mask is None or torch.all(attention_mask == 1):
|
||||
attn_mask, is_causal = None, True
|
||||
else:
|
||||
causal_mask = torch.triu(torch.full((seq_len, seq_len), float("-inf"), device=xq.device), diagonal=1)
|
||||
extended_mask = (1.0 - attention_mask.unsqueeze(1).unsqueeze(2)) * float("-inf")
|
||||
attn_mask, is_causal = causal_mask.unsqueeze(0) + extended_mask, False
|
||||
|
||||
dropout_p = self.dropout if self.training else 0.0
|
||||
output = F.scaled_dot_product_attention(xq, xk, xv, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal)
|
||||
else:
|
||||
scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
||||
scores = scores + torch.triu(
|
||||
|
||||
Loading…
Reference in New Issue
Block a user