mirror of
https://github.com/jingyaogong/minimind.git
synced 2026-01-13 19:57:20 +08:00
[update] prompt prefill
This commit is contained in:
parent
05d0b216f6
commit
1279a61681
@ -179,7 +179,7 @@ class Attention(nn.Module):
|
||||
xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
|
||||
|
||||
cos, sin = position_embeddings
|
||||
xq, xk = apply_rotary_pos_emb(xq, xk, cos[:seq_len], sin[:seq_len])
|
||||
xq, xk = apply_rotary_pos_emb(xq, xk, cos, sin)
|
||||
|
||||
# kv_cache实现
|
||||
if past_key_value is not None:
|
||||
@ -193,14 +193,11 @@ class Attention(nn.Module):
|
||||
repeat_kv(xv, self.n_rep).transpose(1, 2)
|
||||
)
|
||||
|
||||
if self.flash and seq_len > 1 and (attention_mask is None or torch.all(attention_mask == 1)):
|
||||
if self.flash and (seq_len > 1) and (past_key_value is None) and (attention_mask is None or torch.all(attention_mask == 1)):
|
||||
output = F.scaled_dot_product_attention(xq, xk, xv, dropout_p=self.dropout if self.training else 0.0, is_causal=True)
|
||||
else:
|
||||
scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
||||
scores = scores + torch.triu(
|
||||
torch.full((seq_len, seq_len), float("-inf"), device=scores.device),
|
||||
diagonal=1
|
||||
).unsqueeze(0).unsqueeze(0) # scores+mask
|
||||
scores[:, :, :, -seq_len:] += torch.triu(torch.full((seq_len, seq_len), float("-inf"), device=scores.device), diagonal=1)
|
||||
|
||||
if attention_mask is not None:
|
||||
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user