update readme

This commit is contained in:
gongjy
2024-10-08 23:40:29 +08:00
parent 000b0a496b
commit 772834148e
4 changed files with 20 additions and 30 deletions
+5 -1
View File
@@ -27,11 +27,15 @@ class RMSNorm(torch.nn.Module):
return output * self.weight
def precompute_pos_cis(dim: int, end: int, theta: float = 10000.0):
def precompute_pos_cis(dim: int, end: int, theta: float = 10000.0, train_len: int = 512):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device) # type: ignore
freqs = torch.outer(t, freqs).float() # type: ignore
pos_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
# # 计算缩放因子
# scale = train_len / end
# # 缩放旋转嵌入,实现线性的长度外推(注释掉不用是因为小模型依赖pos_cis拟合严重,直接做线性外推效果并不好)
# pos_cis = pos_cis * scale
return pos_cis