mirror of
https://github.com/jingyaogong/minimind.git
synced 2026-04-23 15:58:15 +08:00
[update] open causal
This commit is contained in:
parent
9348fde743
commit
2ab6455d9d
@ -95,6 +95,7 @@ class Attention(nn.Module):
|
||||
self.n_local_kv_heads = self.num_key_value_heads
|
||||
self.n_rep = self.n_local_heads // self.n_local_kv_heads
|
||||
self.head_dim = config.head_dim
|
||||
self.is_causal = True
|
||||
self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)
|
||||
self.k_proj = nn.Linear(config.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
||||
self.v_proj = nn.Linear(config.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
||||
@ -120,11 +121,11 @@ class Attention(nn.Module):
|
||||
xv = torch.cat([past_key_value[1], xv], dim=1)
|
||||
past_kv = (xk, xv) if use_cache else None
|
||||
xq, xk, xv = (xq.transpose(1, 2), repeat_kv(xk, self.n_rep).transpose(1, 2), repeat_kv(xv, self.n_rep).transpose(1, 2))
|
||||
if self.flash and (seq_len > 1) and (past_key_value is None) and (attention_mask is None or torch.all(attention_mask == 1)):
|
||||
output = F.scaled_dot_product_attention(xq, xk, xv, dropout_p=self.dropout if self.training else 0.0, is_causal=True)
|
||||
if self.flash and (seq_len > 1) and (not self.is_causal or past_key_value is None) and (attention_mask is None or torch.all(attention_mask == 1)):
|
||||
output = F.scaled_dot_product_attention(xq, xk, xv, dropout_p=self.dropout if self.training else 0.0, is_causal=self.is_causal)
|
||||
else:
|
||||
scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
||||
scores[:, :, :, -seq_len:] += torch.full((seq_len, seq_len), float("-inf"), device=scores.device).triu(1)
|
||||
if self.is_causal: scores[:, :, :, -seq_len:] += torch.full((seq_len, seq_len), float("-inf"), device=scores.device).triu(1)
|
||||
if attention_mask is not None: scores += (1.0 - attention_mask.unsqueeze(1).unsqueeze(2)) * -1e9
|
||||
output = self.attn_dropout(F.softmax(scores.float(), dim=-1).type_as(xq)) @ xv
|
||||
output = output.transpose(1, 2).reshape(bsz, seq_len, -1)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user