mirror of
https://github.com/datawhalechina/llms-from-scratch-cn.git
synced 2026-04-25 08:58:17 +08:00
bug fix
This commit is contained in:
parent
65cc17a68c
commit
5e6ec490bd
@ -1288,7 +1288,7 @@
|
||||
" \n",
|
||||
" # 使用 softmax 函数和缩放因子归一化注意力分数\n",
|
||||
" # 注意这里的 dim=1,表示沿着键向量的维度进行归一化\n",
|
||||
" attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=1)\n",
|
||||
" attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)\n",
|
||||
"\n",
|
||||
" # 使用归一化的注意力权重和值向量计算上下文向量\n",
|
||||
" context_vec = attn_weights @ values\n",
|
||||
@ -1380,7 +1380,7 @@
|
||||
"keys = sa_v2.W_key(inputs) \n",
|
||||
"attn_scores = queries @ keys.T\n",
|
||||
"# 此处的注意力权重和上一节中的一致\n",
|
||||
"attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=1)\n",
|
||||
"attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)\n",
|
||||
"print(attn_weights)"
|
||||
]
|
||||
},
|
||||
@ -1759,7 +1759,7 @@
|
||||
" attn_scores.masked_fill_( # New, _ ops are in-place\n",
|
||||
" self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)\n",
|
||||
" # 经过 softmax \n",
|
||||
" attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=1)\n",
|
||||
" attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)\n",
|
||||
" # 进行 dropout\n",
|
||||
" attn_weights = self.dropout(attn_weights) # New\n",
|
||||
" # 得到最后结果\n",
|
||||
|
||||
@ -92,7 +92,7 @@
|
||||
" values = self.W_value(x)\n",
|
||||
" \n",
|
||||
" attn_scores = queries @ keys.T\n",
|
||||
" attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=1)\n",
|
||||
" attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)\n",
|
||||
"\n",
|
||||
" context_vec = attn_weights @ values\n",
|
||||
" return context_vec\n",
|
||||
|
||||
@ -230,7 +230,7 @@
|
||||
" # 使用掩码将未来位置的注意力分数置为负无穷,实现因果自注意力\n",
|
||||
" attn_scores.masked_fill_(self.mask.bool()[:n_tokens, :n_tokens], -torch.inf)\n",
|
||||
" # 归一化注意力分数\n",
|
||||
" attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=1)\n",
|
||||
" attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)\n",
|
||||
" # 应用dropout\n",
|
||||
" attn_weights = self.dropout(attn_weights)\n",
|
||||
"\n",
|
||||
|
||||
@ -149,7 +149,7 @@
|
||||
|
||||
| 姓名 | 职责 | 简介 | GitHub |
|
||||
| :-----:| :----------:| :-----------:|:------:|
|
||||
| 陈可为 | 项目负责人 | 华中科技大学 |[@Ethan-Chen-plus](https://github.com/Ethan-Chen-plus)|
|
||||
| 陈可为 | 项目负责人 | 中国科学院大学 |[@Ethan-Chen-plus](https://github.com/Ethan-Chen-plus)|
|
||||
| 王训志 | 第2章贡献者 | 南开大学 |[@aJupyter](https://github.com/aJupyter)|
|
||||
| 汪健麟 | 第2章贡献者 | ||
|
||||
| Aria | 第2章贡献者 | |[@ariafyy](https://github.com/ariafyy)|
|
||||
|
||||
@ -46,7 +46,7 @@
|
||||
"id": "490fa60b",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
""
|
||||
""
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -92,7 +92,7 @@
|
||||
"id": "92e1e8d6",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
""
|
||||
""
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@ -32,7 +32,7 @@
|
||||
"id": "f2df060a",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
""
|
||||
""
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -377,7 +377,7 @@
|
||||
"id": "f85ecff5",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
""
|
||||
""
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@ -43,7 +43,7 @@
|
||||
"id": "e843aae2",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
""
|
||||
""
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -220,7 +220,7 @@
|
||||
"id": "187ca144",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
""
|
||||
""
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -305,7 +305,7 @@
|
||||
"id": "cdae01fd",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
""
|
||||
""
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@ -45,7 +45,7 @@
|
||||
"id": "490fa60b",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
""
|
||||
""
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -70,7 +70,7 @@
|
||||
"id": "acc76cd1",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
""
|
||||
""
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@ -37,7 +37,7 @@
|
||||
"id": "e843aae2",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
""
|
||||
""
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -308,7 +308,7 @@
|
||||
"id": "70af4d55",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
""
|
||||
""
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -535,7 +535,7 @@
|
||||
"id": "3b942805",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
""
|
||||
""
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@ -23,7 +23,7 @@
|
||||
"id": "e85089aa-8671-4e5f-a2b3-ef252004ee4c",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"<img src=\"https://github.com/datawhalechina/llms-from-scratch-cn/blob/main/Translated_Book/img/fig-2-15.jpg?raw=true\" width=\"400px\">"
|
||||
"<img src=\"../img/fig-2-15.jpg\" width=\"400px\">"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -178,7 +178,7 @@
|
||||
"id": "f33c2741-bf1b-4c60-b7fd-61409d556646",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"<img src=\"https://github.com/datawhalechina/llms-from-scratch-cn/blob/main/Translated_Book/img/fig-2-16.jpg?raw=true\" width=\"500px\">"
|
||||
"<img src=\"../img/fig-2-16.jpg\" width=\"500px\">"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@ -61,7 +61,7 @@
|
||||
"id": "972f6e5a",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
""
|
||||
""
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
Loading…
Reference in New Issue
Block a user