mirror of
https://github.com/datawhalechina/llms-from-scratch-cn.git
synced 2026-06-06 00:04:42 +00:00
0527
This commit is contained in:
@@ -1,23 +1,40 @@
|
||||
# llama3 implemented from scratch
|
||||
in this file, i implemented llama3 from scratch, one tensor and matrix multiplication at a time.
|
||||
# 从头开始实现llama3
|
||||
在这个文件中,我逐个张量和矩阵地从头实现了llama3。
|
||||
本地可以运行:llama3-from-scratch.ipynb
|
||||
<br>
|
||||
also, im going to load tensors directly from the model file that meta provided for llama3, you need to download the weights before running this file.
|
||||
here is the offical link to download the weights: https://llama.meta.com/llama-downloads/
|
||||
此外,我将直接从meta提供给llama3的模型文件中加载张量,你需要在运行此文件之前下载权重。
|
||||
这是下载权重的官方链接: [点击这里下载权重](https://llama.meta.com/llama-downloads/)
|
||||
|
||||
<div>
|
||||
<img src="images/archi.png"/>
|
||||
</div>
|
||||
|
||||
## tokenizer
|
||||
im not going to implement a bpe tokenizer (but andrej karpathy has a really clean implementation)
|
||||
https://hf-mirror.com/NousResearch/Meta-Llama-3-8B
|
||||
https://gitee.com/hf-models/Meta-Llama-3-8B-Instruct/
|
||||
## 分词器
|
||||
我不打算实现一个BPE分词器(但是Andrej Karpathy有一个非常干净的实现)。
|
||||
<br>
|
||||
link to his implementation: https://github.com/karpathy/minbpe
|
||||
他的实现链接: [点击这里查看他的实现](https://github.com/karpathy/minbpe)
|
||||
|
||||
<div>
|
||||
<img src="images/karpathyminbpe.png" width="600"/>
|
||||
</div>
|
||||
|
||||
|
||||
```python
|
||||
%env HF_ENDPOINT = "https://hf-mirror.com"
|
||||
```
|
||||
|
||||
env: HF_ENDPOINT="https://hf-mirror.com"
|
||||
|
||||
|
||||
|
||||
```python
|
||||
%pip install blobfile -q
|
||||
```
|
||||
|
||||
Note: you may need to restart the kernel to use updated packages.
|
||||
|
||||
|
||||
|
||||
```python
|
||||
from pathlib import Path
|
||||
@@ -27,7 +44,7 @@ import torch
|
||||
import json
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
tokenizer_path = "Meta-Llama-3-8B/tokenizer.model"
|
||||
tokenizer_path = "./tokenizer.model"
|
||||
special_tokens = [
|
||||
"<|begin_of_text|>",
|
||||
"<|end_of_text|>",
|
||||
@@ -58,17 +75,37 @@ tokenizer.decode(tokenizer.encode("hello world!"))
|
||||
|
||||
|
||||
|
||||
## reading the model file
|
||||
normally, reading this depends on how the model classes are written and the variable names inside them.
|
||||
## 读取模型文件
|
||||
通常,读取模型文件取决于模型类的编写方式以及其中的变量名。
|
||||
<br>
|
||||
but since we are implementing llama3 from scratch we will read the file one tensor at a time.
|
||||
但由于我们是从头开始实现llama3,我们将逐个张量地读取文件。
|
||||
|
||||
<div>
|
||||
<img src="images/model.png" width="600"/>
|
||||
</div>
|
||||
|
||||
可以在这里下载模型:https://gitee.com/hf-models/Meta-Llama-3-8B-Instruct/blob/main/original/consolidated.00.pth
|
||||
|
||||
|
||||
```python
|
||||
model = torch.load("Meta-Llama-3-8B/consolidated.00.pth")
|
||||
!wget 'https://lfs.gitee.com/api/lfs/storage/projects/34266234/be52262c9289304f3e8240e0749bf257bc04264405a86cd4de38efb9068724ee?Expires=1716626632&Signature=xgDOu9JHNM6ECazR3nA4NQHwXs%2BiG%2BCtnzza6ekSuqs%3D&FileName=consolidated.00.pth'
|
||||
```
|
||||
|
||||
--2024-05-25 16:24:15-- https://lfs.gitee.com/api/lfs/storage/projects/34266234/be52262c9289304f3e8240e0749bf257bc04264405a86cd4de38efb9068724ee?Expires=1716626632&Signature=xgDOu9JHNM6ECazR3nA4NQHwXs%2BiG%2BCtnzza6ekSuqs%3D&FileName=consolidated.00.pth
|
||||
Resolving lfs.gitee.com (lfs.gitee.com)... 180.76.198.180
|
||||
Connecting to lfs.gitee.com (lfs.gitee.com)|180.76.198.180|:443... connected.
|
||||
HTTP request sent, awaiting response... 200 OK
|
||||
Length: 16060617592 (15G) [application/octet-stream]
|
||||
Saving to: ‘be52262c9289304f3e8240e0749bf257bc04264405a86cd4de38efb9068724ee?Expires=1716626632&Signature=xgDOu9JHNM6ECazR3nA4NQHwXs+iG+Ctnzza6ekSuqs=&FileName=consolidated.00.pth’
|
||||
|
||||
0% [ ] 105,193,134 453KB/s eta 11h 21m^C
|
||||
|
||||
|
||||
我的机器12s可以载入,接下来仅用cpu进行推理,我这边内存30G足够了,然后cpu推理一个词大约30s,稍微慢了一些,不过我们主要理解原理
|
||||
|
||||
|
||||
```python
|
||||
model = torch.load("/data1/ckw/consolidated.00.pth")
|
||||
print(json.dumps(list(model.keys())[:20], indent=4))
|
||||
```
|
||||
|
||||
@@ -98,7 +135,7 @@ print(json.dumps(list(model.keys())[:20], indent=4))
|
||||
|
||||
|
||||
```python
|
||||
with open("Meta-Llama-3-8B/params.json", "r") as f:
|
||||
with open("./params.json", "r") as f:
|
||||
config = json.load(f)
|
||||
config
|
||||
```
|
||||
@@ -118,10 +155,10 @@ config
|
||||
|
||||
|
||||
|
||||
## we use this config to infer details about the model like
|
||||
1. the model has 32 transformer layers
|
||||
2. each multi-head attention block has 32 heads
|
||||
3. the vocab size and so on
|
||||
## 我们使用这个配置来推断模型的细节,比如:
|
||||
1. 模型有32个Transformer层
|
||||
2. 每个多头注意力块有32个头
|
||||
3. 词汇表大小,等等
|
||||
|
||||
|
||||
```python
|
||||
@@ -136,8 +173,9 @@ norm_eps = config["norm_eps"]
|
||||
rope_theta = torch.tensor(config["rope_theta"])
|
||||
```
|
||||
|
||||
## converting text to tokens
|
||||
here we use tiktoken (i think an openai library) as the tokenizer
|
||||
## 将文本转换为标记
|
||||
这里我们使用tiktoken(我认为是OpenAI的一个库)作为分词器
|
||||
|
||||
<div>
|
||||
<img src="images/tokens.png" width="600"/>
|
||||
</div>
|
||||
@@ -156,13 +194,13 @@ print(prompt_split_as_tokens)
|
||||
['<|begin_of_text|>', 'the', ' answer', ' to', ' the', ' ultimate', ' question', ' of', ' life', ',', ' the', ' universe', ',', ' and', ' everything', ' is', ' ']
|
||||
|
||||
|
||||
## converting tokens to their embedding
|
||||
IM SORRY but this is the only part of the codebase where i use an inbuilt neural network module
|
||||
## 将标记转换为它们的嵌入向量
|
||||
这是代码库中我唯一使用内置神经网络模块的部分。
|
||||
<br>
|
||||
anyway, so our [17x1] tokens are now [17x4096], i.e. 17 embeddings (one for each token) of length 4096
|
||||
无论如何,我们的[17x1]标记现在是[17x4096],即长度为4096的17个嵌入向量(每个标记一个)。
|
||||
<br>
|
||||
<br>
|
||||
note: keep track of the shapes, it makes it much easier to understand everything
|
||||
注意: 跟踪形状,这样可以更容易理解所有内容
|
||||
|
||||
<div>
|
||||
<img src="images/embeddings.png" width="600"/>
|
||||
@@ -183,12 +221,12 @@ token_embeddings_unnormalized.shape
|
||||
|
||||
|
||||
|
||||
## we then normalize the embedding using rms normalization
|
||||
please, note after this step the shapes dont change, the values are just normalized
|
||||
## 然后我们使用RMS归一化来标准化嵌入向量
|
||||
请注意,在此步骤之后,形状不会改变,只是值被标准化了。
|
||||
<br>
|
||||
things to keep in mind, we need a norm_eps (from config) because we dont want to accidently set rms to 0 and divide by 0
|
||||
需要记住的一些事情,我们需要一个norm_eps(来自配置),因为我们不希望意外地将RMS设置为0并除以0。
|
||||
<br>
|
||||
here is the formula:
|
||||
以下是公式:
|
||||
<div>
|
||||
<img src="images/rms.png" width="600"/>
|
||||
</div>
|
||||
@@ -202,12 +240,12 @@ def rms_norm(tensor, norm_weights):
|
||||
return (tensor * torch.rsqrt(tensor.pow(2).mean(-1, keepdim=True) + norm_eps)) * norm_weights
|
||||
```
|
||||
|
||||
# building the first first layer of the transformer
|
||||
# 构建Transformer的第一层
|
||||
|
||||
### normalization
|
||||
you will see me accessing layer.0 from the model dict (this is the first layer)
|
||||
### 标准化
|
||||
你会看到我从模型字典中访问layer.0(这是第一层)。
|
||||
<br>
|
||||
anyway, so after normalizing our shapes are still [17x4096] same as embedding but normalized
|
||||
无论如何,所以在我们标准化后,形状仍然是[17x4096],与嵌入向量相同,但是标准化了
|
||||
|
||||
<div>
|
||||
<img src="images/norm.png" width="600"/>
|
||||
@@ -226,21 +264,22 @@ token_embeddings.shape
|
||||
|
||||
|
||||
|
||||
### attention implemented from scratch
|
||||
let's load the attention heads of the first layer of the transformer
|
||||
### 从头实现的注意力机制
|
||||
让我们加载Transformer第一层的注意力头
|
||||
|
||||
<div>
|
||||
<img src="images/qkv.png" width="600"/>
|
||||
</div>
|
||||
|
||||
<br>
|
||||
|
||||
> when we load the query, key, value and output vectors from the model we notice the shapes to be [4096x4096], [1024x4096], [1024x4096], [4096x4096]
|
||||
> 当我们从模型中加载查询(query)、键(key)、值(value)和输出(output)向量时,我们注意到它们的形状为[4096x4096]、[1024x4096]、[1024x4096]、[4096x4096]
|
||||
<br>
|
||||
> at first glance this is weird because ideally we want each q,k,v and o for each head individually
|
||||
> 乍一看这有点奇怪,因为理想情况下我们希望每个注意力头的q、k、v和o都是分开的
|
||||
<br>
|
||||
> the authors of the code bundled them togeather because its easy it helps parallize attention head multiplication.
|
||||
> 代码的作者将它们捆绑在一起,因为这样做容易并行化注意力头的乘法。
|
||||
<br>
|
||||
> im going to unwrap everything...
|
||||
> 我要将所有东西解开...
|
||||
|
||||
|
||||
```python
|
||||
@@ -255,10 +294,10 @@ print(
|
||||
torch.Size([4096, 4096]) torch.Size([1024, 4096]) torch.Size([1024, 4096]) torch.Size([4096, 4096])
|
||||
|
||||
|
||||
### unwrapping query
|
||||
in the next section we will unwrap the queries from multiple attention heads, the resulting shape is [32x128x4096]
|
||||
### 解开查询
|
||||
在下一节中,我们将从多个注意力头中解开查询,结果形状为[32x128x4096]
|
||||
<br><br>
|
||||
here, 32 is the number of attention heads in llama3, 128 is the size of the query vector and 4096 is the size of the token embedding
|
||||
这里,32是llama3中的注意力头数量,128是查询向量的大小,4096是标记嵌入的大小
|
||||
|
||||
|
||||
```python
|
||||
@@ -275,8 +314,8 @@ q_layer0.shape
|
||||
|
||||
|
||||
|
||||
### im going to implement the first head of the first layer
|
||||
here i access the query weight matrix first head of the first layer, the size of this query weight matrix is [128x4096]
|
||||
### 我要实现第一层的第一个注意力头
|
||||
在这里,我首先访问第一层的第一个注意力头的查询权重矩阵,该查询权重矩阵的大小为[128x4096]
|
||||
|
||||
|
||||
```python
|
||||
@@ -291,8 +330,9 @@ q_layer0_head0.shape
|
||||
|
||||
|
||||
|
||||
### we now multiply the query weights with the token embedding, to recive a query for the token
|
||||
here you can see the resulting shape is [17x128], this is because we have 17 tokens and for each token there is a 128 length query.
|
||||
### 现在我们将查询权重与标记嵌入相乘,以获得每个标记的查询
|
||||
在这里,你可以看到结果的形状为[17x128],这是因为我们有17个标记,对于每个标记,都有一个长度为128的查询。
|
||||
|
||||
<div>
|
||||
<img src="images/q_per_token.png" width="600"/>
|
||||
</div>
|
||||
@@ -310,17 +350,16 @@ q_per_token.shape
|
||||
|
||||
|
||||
|
||||
## positioning encoding
|
||||
we are now at a stage where we have a query vector for each token in our prompt, but if you think about it -- the indivitually query vector has no idea about the position in the prompt.
|
||||
## 位置编码
|
||||
现在我们处于这样一个阶段:我们在我们的提示中为每个标记都有一个查询向量,但是如果你想一想--每个单独的查询向量并不知道在提示中的位置。
|
||||
<br><br>
|
||||
query: "the answer to the ultimate question of life, the universe, and everything is "
|
||||
查询:"生命、宇宙和一切的终极问题的答案是"
|
||||
<br><br>
|
||||
in our prompt we have used "the" three times, we need the query vectors of all 3 "the" tokens to have different query vectors (each of size [1x128]) based on their positions in the query. we perform these rotations using RoPE (rotory positional embedding).
|
||||
在我们的提示中,我们使用了"the"三次,我们需要所有3个"the"标记的查询向量都根据它们在查询中的位置有不同的查询向量(每个大小为[1x128])。我们使用RoPE(旋转位置编码)来执行这些旋转。
|
||||
<br><br>
|
||||
### RoPE
|
||||
watch this video (this is what i watched) to understand the math.
|
||||
https://www.youtube.com/watch?v=o29P0Kpobz0&t=530s
|
||||
|
||||
观看这个视频(这是我看的)以理解数学原理。
|
||||
[点击这里观看视频](https://www.youtube.com/watch?v=o29P0Kpobz0&t=530s)
|
||||
|
||||
<div>
|
||||
<img src="images/rope.png" width="600"/>
|
||||
@@ -339,16 +378,15 @@ q_per_token_split_into_pairs.shape
|
||||
|
||||
|
||||
|
||||
in the above step, we split the query vectors into pairs, we apply a rotational angle shift to each pair!
|
||||
在上述步骤中,我们将查询向量分成一对对,对每对应用旋转角度偏移!
|
||||
<br><br>
|
||||
we now have a vector of size [17x64x2], this is the 128 length queries split into 64 pairs for each token in the prompt! each of those 64 pairs will be rotated by m*(theta) where m is the position of the token for which we are rotating the query!
|
||||
|
||||
现在我们有一个大小为[17x64x2]的向量,这是128长度的查询分成64对,对于提示中的每个标记!每个这样的64对将通过m*(theta)进行旋转,其中m是我们正在旋转查询的标记的位置!
|
||||
|
||||
<div>
|
||||
<img src="images/qsplit.png" width="600"/>
|
||||
</div>
|
||||
|
||||
## using dot product of complex numbers to rotate a vector
|
||||
## 使用复数的点积来旋转向量
|
||||
<div>
|
||||
<img src="images/freq_cis.png" width="600"/>
|
||||
</div>
|
||||
@@ -398,32 +436,39 @@ freqs
|
||||
|
||||
|
||||
```python
|
||||
plt.rcParams['axes.unicode_minus'] = False # 显示负号
|
||||
```
|
||||
|
||||
|
||||
```python
|
||||
plt.rcParams["font.sans-serif"]=['simhei']
|
||||
freqs_for_each_token = torch.outer(torch.arange(17), freqs)
|
||||
freqs_cis = torch.polar(torch.ones_like(freqs_for_each_token), freqs_for_each_token)
|
||||
freqs_cis.shape
|
||||
|
||||
# viewing tjhe third row of freqs_cis
|
||||
# 查看freqs_cis的第三行
|
||||
value = freqs_cis[3]
|
||||
plt.figure()
|
||||
for i, element in enumerate(value[:17]):
|
||||
plt.plot([0, element.real], [0, element.imag], color='blue', linewidth=1, label=f"Index: {i}")
|
||||
plt.annotate(f"{i}", xy=(element.real, element.imag), color='red')
|
||||
plt.xlabel('Real')
|
||||
plt.ylabel('Imaginary')
|
||||
plt.title('Plot of one row of freqs_cis')
|
||||
plt.xlabel('实部')
|
||||
plt.ylabel('虚部')
|
||||
plt.title('freqs_cis的一行的图示')
|
||||
plt.show()
|
||||
|
||||
```
|
||||
|
||||
|
||||
|
||||

|
||||

|
||||
|
||||
|
||||
|
||||
### now that we have a complex number (the angle change vector) for every token's query element
|
||||
we can convert our queries (the one we split into pairs) as complex numbers and then dot product to rotate the query based on the position
|
||||
### 现在我们为每个标记的查询元素有了一个复数(角度变化向量)
|
||||
我们可以将我们的查询(我们分成对的那些)转换为复数,然后进行点积来根据位置旋转查询
|
||||
<br>
|
||||
honeslty this is beautiful to think about :)
|
||||
说实话,这样想真的很美 :)
|
||||
|
||||
|
||||
```python
|
||||
@@ -451,8 +496,8 @@ q_per_token_as_complex_numbers_rotated.shape
|
||||
|
||||
|
||||
|
||||
### after rotated vector is obtained
|
||||
we can get back our the queries as pairs by viewing the complex numbers as real numbers again
|
||||
### 在获得旋转向量后
|
||||
我们可以通过将复数视为实数来重新获取我们的查询对
|
||||
|
||||
|
||||
```python
|
||||
@@ -467,7 +512,7 @@ q_per_token_split_into_pairs_rotated.shape
|
||||
|
||||
|
||||
|
||||
the rotated pairs are now merged, we now have a new query vector (rotated query vector) that is of the shape [17x128] where 17 is the number of tokens and the 128 is the dim of the query vector
|
||||
旋转后的查询对现已合并,我们现在有一个新的查询向量(旋转后的查询向量),其形状为\[17x128\],其中17表示标记数量,128表示查询向量的维度。
|
||||
|
||||
|
||||
```python
|
||||
@@ -482,17 +527,17 @@ q_per_token_rotated.shape
|
||||
|
||||
|
||||
|
||||
# keys (almost the same as queries)
|
||||
# 键(几乎与查询相同)
|
||||
<div>
|
||||
<img src="images/keys.png" width="600px"/>
|
||||
</div>
|
||||
im lazy as fuck, so im not going to go through the math for keys, the only things you need to keep in mind are:
|
||||
我太懒了,所以我不打算为键做数学推导,你需要记住的几点是:
|
||||
<br>
|
||||
> keys generate key vectors also of dimention 128
|
||||
> 键生成的键向量也是128维的
|
||||
<br>
|
||||
> keys have only 1/4th the number of the weights as queries, this is because the weights for keys are shared across 4 heads at a time, to reduce the number of computations need
|
||||
> 键的权重数量只有查询的四分之一,这是因为键的权重在4个头中共享,以减少计算量
|
||||
<br>
|
||||
> keys are also rotated to add positional info, just like queries because of the same reasons
|
||||
> 键也会旋转以添加位置信息,与查询一样,因为同样的原因
|
||||
|
||||
|
||||
```python
|
||||
@@ -586,19 +631,19 @@ k_per_token_rotated.shape
|
||||
|
||||
|
||||
|
||||
## at this stage now have both the rotated values of queries and keys, for each token.
|
||||
## 在这个阶段,我们现在对于每个标记都有了旋转后的查询和键的值。
|
||||
<div>
|
||||
<img src="images/keys0.png" width="600px"/>
|
||||
</div>
|
||||
each of the queries and keys are now of shape [17x128].
|
||||
每个查询和键现在的形状都是[17x128]。
|
||||
|
||||
## in the next step we will multiply the queries and key matrices
|
||||
doing this will give us a score mapping each token with one another
|
||||
## 下一步我们将对查询和键矩阵进行相乘
|
||||
这样做将为我们提供一个将每个标记相互映射的分数
|
||||
<br>
|
||||
this score describes how well each token's query relates to the each tokens's key.
|
||||
THIS IS SELF ATTENTION :)
|
||||
这个分数描述了每个标记的查询与每个标记的键之间的关系。
|
||||
这就是自注意力机制 :)
|
||||
<br>
|
||||
the shape of the attention score matrix (qk_per_token) is [17x17] where 17 is the number of tokens in the prompt
|
||||
注意力分数矩阵的形状(qk_per_token)是[17x17],其中17是提示中的标记数量
|
||||
|
||||
<div>
|
||||
<img src="images/qkmatmul.png" width="600px"/>
|
||||
@@ -617,12 +662,13 @@ qk_per_token.shape
|
||||
|
||||
|
||||
|
||||
# we now have to mask query key scores
|
||||
during the training process of llama3, the future token qk scores are masked.
|
||||
# 现在我们需要对查询键分数进行掩码处理
|
||||
在llama3的训练过程中,未来标记的查询键分数是被掩码的。
|
||||
<br>
|
||||
why? because during training we only learn to predict tokens using past tokens.
|
||||
为什么?因为在训练过程中,我们只学习使用过去的标记来预测标记。
|
||||
<br>
|
||||
as a result, during inference we set the future tokens to zero.
|
||||
因此,在推理过程中,我们将未来的标记分数设置为零。
|
||||
|
||||
<div>
|
||||
<img src="images/mask.png" width="600px"/>
|
||||
</div>
|
||||
@@ -630,20 +676,21 @@ as a result, during inference we set the future tokens to zero.
|
||||
|
||||
```python
|
||||
def display_qk_heatmap(qk_per_token):
|
||||
_, ax = plt.subplots()
|
||||
fig, ax = plt.subplots(figsize=(30, 8)) # 设置图像大小为12x8英寸
|
||||
im = ax.imshow(qk_per_token.to(float).detach(), cmap='viridis')
|
||||
ax.set_xticks(range(len(prompt_split_as_tokens)))
|
||||
ax.set_yticks(range(len(prompt_split_as_tokens)))
|
||||
ax.set_xticklabels(prompt_split_as_tokens)
|
||||
ax.set_yticklabels(prompt_split_as_tokens)
|
||||
ax.figure.colorbar(im, ax=ax)
|
||||
|
||||
|
||||
display_qk_heatmap(qk_per_token)
|
||||
|
||||
```
|
||||
|
||||
|
||||
|
||||

|
||||

|
||||
|
||||
|
||||
|
||||
@@ -685,7 +732,7 @@ display_qk_heatmap(qk_per_token_after_masking)
|
||||
|
||||
|
||||
|
||||

|
||||

|
||||
|
||||
|
||||
|
||||
@@ -701,21 +748,20 @@ display_qk_heatmap(qk_per_token_after_masking_after_softmax)
|
||||
|
||||
|
||||
|
||||

|
||||

|
||||
|
||||
|
||||
|
||||
## values (almost the end of attention)
|
||||
## 值(注意力机制的最后一步)
|
||||
|
||||
<div>
|
||||
<img src="images/value.png" width="600px"/>
|
||||
</div>
|
||||
these scores (0-1) are used to determine how much of value matrix is used per token
|
||||
这些分数(0-1)用于确定每个标记使用多少值矩阵
|
||||
<br>
|
||||
> just like keys, value weights are also shared acorss every 4 attention heads (to save computation)
|
||||
> 就像键一样,值的权重也在每4个注意力头中共享(以节省计算)
|
||||
<br>
|
||||
> as a result, the shape of the value weight matrix below is [8x128x4096]
|
||||
|
||||
> 因此,下面值权重矩阵的形状是[8x128x4096]
|
||||
|
||||
|
||||
```python
|
||||
@@ -731,7 +777,7 @@ v_layer0.shape
|
||||
|
||||
|
||||
|
||||
the first layer, first head value weight matrix is given below
|
||||
第一层,第一个注意力头的值权重矩阵如下所示:
|
||||
|
||||
|
||||
```python
|
||||
@@ -746,11 +792,11 @@ v_layer0_head0.shape
|
||||
|
||||
|
||||
|
||||
## value vectors
|
||||
## 值向量
|
||||
<div>
|
||||
<img src="images/v0.png" width="600px"/>
|
||||
</div>
|
||||
we now use the value weghts to get the attention values per token, this is of size [17x128] where 17 is the number of tokens in the prompt and 128 is the dim of the value vector per token
|
||||
我们现在使用值权重来获取每个标记的注意力值,其大小为[17x128],其中17是提示中的标记数量,128是每个标记的值向量维度。
|
||||
|
||||
|
||||
```python
|
||||
@@ -765,11 +811,11 @@ v_per_token.shape
|
||||
|
||||
|
||||
|
||||
## attention
|
||||
## 注意力机制
|
||||
<div>
|
||||
<img src="images/attention.png" width="600px"/>
|
||||
</div>
|
||||
the resultant attention vector after multipying with the values per token is of shape [17*128]
|
||||
与每个标记的值相乘后得到的注意力向量的形状为[17x128]。
|
||||
|
||||
|
||||
```python
|
||||
@@ -784,13 +830,13 @@ qkv_attention.shape
|
||||
|
||||
|
||||
|
||||
# multi head attention
|
||||
# 多头注意力机制
|
||||
<div>
|
||||
<img src="images/heads.png" width="600px"/>
|
||||
</div>
|
||||
WE NOW HAVE THE ATTENTION VALUE OF THE FIRST LAYER AND FIRST HEAD
|
||||
我们现在得到了第一层和第一个头的注意力值
|
||||
<br>
|
||||
now im going to run a loop and perform the exact same math as the cells above but for every head in the first layer
|
||||
接下来,我将运行一个循环,为第一层的每个头执行与上面相同的数学计算。
|
||||
|
||||
|
||||
```python
|
||||
@@ -836,9 +882,9 @@ len(qkv_attention_store)
|
||||
<div>
|
||||
<img src="images/stacked.png" width="600px"/>
|
||||
</div>
|
||||
we now have a the qkv_attention matrix for all 32 heads on the first layer, next im going to merge all attention scores into one large matrix of size [17x4096]
|
||||
我们现在得到了第一层上所有32个头的qkv_attention矩阵,接下来我将把所有注意力得分合并成一个大小为[17x4096]的大矩阵。
|
||||
<br>
|
||||
we are almost at the end :)
|
||||
我们快要完成了 :)
|
||||
|
||||
|
||||
```python
|
||||
@@ -853,11 +899,11 @@ stacked_qkv_attention.shape
|
||||
|
||||
|
||||
|
||||
# weight matrix, one of the final steps
|
||||
# 权重矩阵,最后的步骤之一
|
||||
<div>
|
||||
<img src="images/weightmatrix.png" width="600px"/>
|
||||
</div>
|
||||
one of the last things to do for a layer 0 attention is, is to multiply the weight matrix of the
|
||||
对于第0层注意力机制,最后要做的一件事是将注意力值与权重矩阵相乘。
|
||||
|
||||
|
||||
```python
|
||||
@@ -872,7 +918,7 @@ w_layer0.shape
|
||||
|
||||
|
||||
|
||||
### this is a simple linear layer, so we just matmul
|
||||
### 这是一个简单的线性层,所以我们只需要进行矩阵乘法
|
||||
|
||||
|
||||
```python
|
||||
@@ -890,7 +936,7 @@ embedding_delta.shape
|
||||
<div>
|
||||
<img src="images/afterattention.png" width="600px"/>
|
||||
</div>
|
||||
we now have the change in the embedding value after attention, that should be adding to the original token embeddings
|
||||
我们现在得到了注意力机制后的嵌入值变化,这个变化应当加到原始的标记嵌入上。
|
||||
|
||||
|
||||
```python
|
||||
@@ -905,7 +951,7 @@ embedding_after_edit.shape
|
||||
|
||||
|
||||
|
||||
## we normalize and then run a feed forward neural network through the embedding delta
|
||||
## 我们对嵌入增量进行归一化,然后通过一个前馈神经网络进行处理
|
||||
<div>
|
||||
<img src="images/norm_after.png" width="600px"/>
|
||||
</div>
|
||||
@@ -923,13 +969,13 @@ embedding_after_edit_normalized.shape
|
||||
|
||||
|
||||
|
||||
## loading the ff weights and implementing the feed forward network
|
||||
## 加载前馈网络权重并实现前馈网络
|
||||
<div>
|
||||
<img src="images/swiglu.png" width="600px"/>
|
||||
</div>
|
||||
in llama3, they used a SwiGLU feedforward network, this network architecture is really good at adding non linearity when needed by the model.
|
||||
在llama3中,他们使用了SwiGLU前馈网络,这种网络架构在模型需要时非常擅长添加非线性。
|
||||
<br>
|
||||
its pretty standard to use this feed forward network architecture in llms these days
|
||||
如今在大型语言模型中使用这种前馈网络架构是相当标准的做法。
|
||||
|
||||
|
||||
```python
|
||||
@@ -947,12 +993,12 @@ output_after_feedforward.shape
|
||||
|
||||
|
||||
|
||||
# WE FINALLY HAVE NEW EDITED EMBEDDINGS FOR EACH TOKEN AFTER THE FIRST LAYER
|
||||
just 31 more layers to go before we are done (one for loop away)
|
||||
# 我们终于在第一层之后得到了每个标记的新编辑嵌入
|
||||
只剩下31层就完成了(只需一个循环)
|
||||
<br>
|
||||
you can imagine this edited embedding as having information about all queries asked on the first layer
|
||||
你可以想象这个编辑后的嵌入包含了第一层所有查询的信息
|
||||
<br>
|
||||
now each layer will encode more and more complex queries on the quesions asked, until we have an embedding that knows everything about the next token that we need.
|
||||
现在,每一层将编码越来越复杂的查询,直到我们得到一个了解下一个需要标记的所有信息的嵌入。
|
||||
|
||||
|
||||
```python
|
||||
@@ -967,14 +1013,14 @@ layer_0_embedding.shape
|
||||
|
||||
|
||||
|
||||
# god, everything all at once
|
||||
# 天啊,一切都在一起
|
||||
<div>
|
||||
<img src="images/god.png" width="600px"/>
|
||||
</div>
|
||||
yep, this is it. everything we did before, all at once, for every single layer.
|
||||
没错,就是这样。我们之前做的一切,现在一次性完成,对每一层都一样。
|
||||
<br>
|
||||
|
||||
# have fun reading :)
|
||||
# 祝你阅读愉快 :)
|
||||
|
||||
|
||||
```python
|
||||
@@ -1024,8 +1070,8 @@ for layer in range(n_layers):
|
||||
final_embedding = embedding_after_edit+output_after_feedforward
|
||||
```
|
||||
|
||||
# we now have the final embedding, the best guess the model could make about the next token
|
||||
the shape of the embedding is the same as regular token embeddings [17x4096] where 17 is the number of tokens and 4096 is the embedding dim
|
||||
# 我们现在有了最终的嵌入,这是模型对下一个标记的最佳猜测
|
||||
嵌入的形状与常规标记嵌入相同,为[17x4096],其中17是标记数量,4096是嵌入维度
|
||||
<div>
|
||||
<img src="images/last_norm.png" width="600px"/>
|
||||
</div>
|
||||
@@ -1043,11 +1089,11 @@ final_embedding.shape
|
||||
|
||||
|
||||
|
||||
# finally, lets decode the embedding into the token value
|
||||
# 最后,让我们将嵌入解码为标记值
|
||||
<div>
|
||||
<img src="images/finallayer.png" width="600px"/>
|
||||
</div>
|
||||
we will use the output decoder to convert the final embedding into a token
|
||||
我们将使用输出解码器将最终嵌入转换为标记。
|
||||
|
||||
|
||||
```python
|
||||
@@ -1061,9 +1107,9 @@ model["output.weight"].shape
|
||||
|
||||
|
||||
|
||||
# we use the embedding of the last token to predict the next value
|
||||
hopefully in our case, 42 :)
|
||||
note: 42 is the answer to "the answer to the ultimate question of life, the universe, and everything is ", according to the book "hitchhiker's guide to the galaxy", most mordern llms would answer with 42 here, which should validate our entire code! wish me luck :)
|
||||
# 我们使用最后一个标记的嵌入来预测下一个值
|
||||
希望在我们的例子中是42 :)
|
||||
注意:42是《银河系漫游指南》一书中“生命、宇宙及一切的终极问题的答案”的答案,大多数现代大型语言模型在这里都会回答42,这应该验证我们的整个代码!祝我好运 :)
|
||||
|
||||
|
||||
```python
|
||||
@@ -1078,8 +1124,8 @@ logits.shape
|
||||
|
||||
|
||||
|
||||
### the model predicted token number 2983 as the next token, is this the token number for 42?
|
||||
IM HYPING YOU UP, this is the last cell of code, hopefully you had fun :)
|
||||
### 模型预测下一个标记为2983号标记,这是42的标记号吗?
|
||||
希望这里让你兴奋起来了,这是最后一个代码单元,希望你玩得开心 :)
|
||||
|
||||
|
||||
```python
|
||||
@@ -1111,25 +1157,295 @@ tokenizer.decode([next_token.item()])
|
||||
|
||||
|
||||
|
||||
# thank you, i love you :)
|
||||
# 谢谢你,我爱你们,亲爱的读者 :)
|
||||
|
||||
This is the end. Hopefully you enjoyed reading it!
|
||||
这就是结尾了。希望你喜欢阅读!
|
||||
感谢datawhale小伙伴的相关支持和赞赏。
|
||||
我们是A10 Research,很荣幸这个工作帮到大家。
|
||||
如果你想支持我的工作
|
||||
|
||||
If you want to support my work
|
||||
1. 在推特上关注我 [https://twitter.com/naklecha](https://twitter.com/naklecha)
|
||||
2. 或者,请我喝杯咖啡 [https://www.buymeacoffee.com/naklecha](https://www.buymeacoffee.com/naklecha)
|
||||
|
||||
1. follow me on twitter https://twitter.com/naklecha
|
||||
2. or, buy me a coffee [https://www.buymeacoffee.com/naklecha](https://www.buymeacoffee.com/naklecha)
|
||||
老实说,如果你能看到这里,你已经让我非常开心了 :)
|
||||
|
||||
Honestly, if you made it this far you already made my day :)
|
||||
## 是什么激励我?
|
||||
|
||||
## what motivates me?
|
||||
我的朋友和我正在执行一个使命——让研究更易于访问!
|
||||
我们创建了一个研究实验室,叫做A10 - [AAAAAAAAAA.org](http://aaaaaaaaaa.org/)
|
||||
|
||||
My friends and I are on a mission - to make research more accessible!
|
||||
We created a research lab called A10 - [AAAAAAAAAA.org](http://aaaaaaaaaa.org/)
|
||||
A10的推特 - [https://twitter.com/aaaaaaaaaaorg](https://twitter.com/aaaaaaaaaaorg)
|
||||
|
||||
A10 twitter - https://twitter.com/aaaaaaaaaaorg
|
||||
|
||||
our thesis:
|
||||
我们的论点:
|
||||
<div>
|
||||
<img src="images/a10.png" width="600px"/>
|
||||
</div>
|
||||
|
||||
我们目前的主要目标是让研究变得更易获得。这个领域非常混乱,大家似乎都在分享低熵的高层次见解(哈哈,最近的流行语信息熵为0)。我们希望深入探讨话题,并与大家分享。除此之外,我们还会推出一些很棒的开源项目,并训练/微调模型(在过程中分享我们的进展)。
|
||||
|
||||
# 备注:预测"datawhalechina is a group for "的下一个词
|
||||
|
||||
|
||||
```python
|
||||
prompt = "datawhalechina is a group for "
|
||||
tokens = [128000] + tokenizer.encode(prompt)
|
||||
print(tokens)
|
||||
tokens = torch.tensor(tokens)
|
||||
prompt_split_as_tokens = [tokenizer.decode([token.item()]) for token in tokens]
|
||||
print(prompt_split_as_tokens)
|
||||
```
|
||||
|
||||
[128000, 695, 1336, 1604, 81236, 374, 264, 1912, 369, 220]
|
||||
['<|begin_of_text|>', 'data', 'wh', 'ale', 'china', ' is', ' a', ' group', ' for', ' ']
|
||||
|
||||
|
||||
|
||||
```python
|
||||
embedding_layer = torch.nn.Embedding(vocab_size, dim)
|
||||
embedding_layer.weight.data.copy_(model["tok_embeddings.weight"])
|
||||
token_embeddings_unnormalized = embedding_layer(tokens).to(torch.bfloat16)
|
||||
token_embeddings_unnormalized.shape
|
||||
```
|
||||
|
||||
|
||||
|
||||
|
||||
torch.Size([10, 4096])
|
||||
|
||||
|
||||
|
||||
|
||||
```python
|
||||
from tqdm import tqdm
|
||||
```
|
||||
|
||||
这里需要由17改10
|
||||
|
||||
|
||||
```python
|
||||
plt.rcParams["font.sans-serif"]=['simhei']
|
||||
freqs_for_each_token = torch.outer(torch.arange(10), freqs)
|
||||
freqs_cis = torch.polar(torch.ones_like(freqs_for_each_token), freqs_for_each_token)
|
||||
freqs_cis.shape
|
||||
|
||||
# 查看freqs_cis的第三行
|
||||
value = freqs_cis[3]
|
||||
plt.figure()
|
||||
for i, element in enumerate(value[:10]):
|
||||
plt.plot([0, element.real], [0, element.imag], color='blue', linewidth=1, label=f"Index: {i}")
|
||||
plt.annotate(f"{i}", xy=(element.real, element.imag), color='red')
|
||||
plt.xlabel('实部')
|
||||
plt.ylabel('虚部')
|
||||
plt.title('freqs_cis的一行的图示')
|
||||
plt.show()
|
||||
|
||||
```
|
||||
|
||||
|
||||
|
||||

|
||||
|
||||
|
||||
|
||||
|
||||
```python
|
||||
final_embedding = token_embeddings_unnormalized
|
||||
for layer in tqdm(range(n_layers)):
|
||||
qkv_attention_store = []
|
||||
layer_embedding_norm = rms_norm(final_embedding, model[f"layers.{layer}.attention_norm.weight"])
|
||||
q_layer = model[f"layers.{layer}.attention.wq.weight"]
|
||||
q_layer = q_layer.view(n_heads, q_layer.shape[0] // n_heads, dim)
|
||||
k_layer = model[f"layers.{layer}.attention.wk.weight"]
|
||||
k_layer = k_layer.view(n_kv_heads, k_layer.shape[0] // n_kv_heads, dim)
|
||||
v_layer = model[f"layers.{layer}.attention.wv.weight"]
|
||||
v_layer = v_layer.view(n_kv_heads, v_layer.shape[0] // n_kv_heads, dim)
|
||||
w_layer = model[f"layers.{layer}.attention.wo.weight"]
|
||||
for head in range(n_heads):
|
||||
q_layer_head = q_layer[head]
|
||||
k_layer_head = k_layer[head//4]
|
||||
v_layer_head = v_layer[head//4]
|
||||
q_per_token = torch.matmul(layer_embedding_norm, q_layer_head.T)
|
||||
k_per_token = torch.matmul(layer_embedding_norm, k_layer_head.T)
|
||||
v_per_token = torch.matmul(layer_embedding_norm, v_layer_head.T)
|
||||
q_per_token_split_into_pairs = q_per_token.float().view(q_per_token.shape[0], -1, 2)
|
||||
q_per_token_as_complex_numbers = torch.view_as_complex(q_per_token_split_into_pairs)
|
||||
q_per_token_split_into_pairs_rotated = torch.view_as_real(q_per_token_as_complex_numbers * freqs_cis)
|
||||
q_per_token_rotated = q_per_token_split_into_pairs_rotated.view(q_per_token.shape)
|
||||
k_per_token_split_into_pairs = k_per_token.float().view(k_per_token.shape[0], -1, 2)
|
||||
k_per_token_as_complex_numbers = torch.view_as_complex(k_per_token_split_into_pairs)
|
||||
k_per_token_split_into_pairs_rotated = torch.view_as_real(k_per_token_as_complex_numbers * freqs_cis)
|
||||
k_per_token_rotated = k_per_token_split_into_pairs_rotated.view(k_per_token.shape)
|
||||
qk_per_token = torch.matmul(q_per_token_rotated, k_per_token_rotated.T)/(128)**0.5
|
||||
mask = torch.full((len(token_embeddings_unnormalized), len(token_embeddings_unnormalized)), float("-inf"))
|
||||
mask = torch.triu(mask, diagonal=1)
|
||||
qk_per_token_after_masking = qk_per_token + mask
|
||||
qk_per_token_after_masking_after_softmax = torch.nn.functional.softmax(qk_per_token_after_masking, dim=1).to(torch.bfloat16)
|
||||
qkv_attention = torch.matmul(qk_per_token_after_masking_after_softmax, v_per_token)
|
||||
qkv_attention_store.append(qkv_attention)
|
||||
|
||||
stacked_qkv_attention = torch.cat(qkv_attention_store, dim=-1)
|
||||
w_layer = model[f"layers.{layer}.attention.wo.weight"]
|
||||
embedding_delta = torch.matmul(stacked_qkv_attention, w_layer.T)
|
||||
embedding_after_edit = final_embedding + embedding_delta
|
||||
embedding_after_edit_normalized = rms_norm(embedding_after_edit, model[f"layers.{layer}.ffn_norm.weight"])
|
||||
w1 = model[f"layers.{layer}.feed_forward.w1.weight"]
|
||||
w2 = model[f"layers.{layer}.feed_forward.w2.weight"]
|
||||
w3 = model[f"layers.{layer}.feed_forward.w3.weight"]
|
||||
output_after_feedforward = torch.matmul(torch.functional.F.silu(torch.matmul(embedding_after_edit_normalized, w1.T)) * torch.matmul(embedding_after_edit_normalized, w3.T), w2.T)
|
||||
final_embedding = embedding_after_edit+output_after_feedforward
|
||||
```
|
||||
|
||||
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:59<00:00, 1.87s/it]
|
||||
|
||||
|
||||
|
||||
```python
|
||||
final_embedding = rms_norm(final_embedding, model["norm.weight"])
|
||||
logits = torch.matmul(final_embedding[-1], model["output.weight"].T)
|
||||
next_token = torch.argmax(logits, dim=-1)
|
||||
tokenizer.decode([next_token.item()])
|
||||
```
|
||||
|
||||
|
||||
|
||||
|
||||
' data'
|
||||
|
||||
|
||||
|
||||
# 备注:部分代码草稿
|
||||
|
||||
|
||||
```python
|
||||
k_per_token_rotated = k_per_token_split_into_pairs_rotated.view(k_per_token.shape)
|
||||
k_per_token_split_into_pairs_rotated = torch.view_as_real(k_per_token_as_complex_numbers * freqs_cis)
|
||||
k_per_token_as_complex_numbers = torch.view_as_complex(k_per_token_split_into_pairs)
|
||||
k_per_token_split_into_pairs = k_per_token.float().view(k_per_token.shape[0], -1, 2)
|
||||
k_per_token = torch.matmul(token_embeddings, k_layer0_head0.T)
|
||||
k_layer0_head0 = k_layer0[0]
|
||||
k_layer0 = model["layers.0.attention.wk.weight"]
|
||||
k_layer0 = k_layer0.view(n_kv_heads, k_layer0.shape[0] // n_kv_heads, dim)
|
||||
```
|
||||
|
||||
|
||||
```python
|
||||
qk_per_token = torch.matmul(q_per_token_rotated, k_per_token_rotated.T)/(head_dim)**0.5
|
||||
```
|
||||
|
||||
|
||||
```python
|
||||
mask = torch.full((len(tokens), len(tokens)), float("-inf"), device=tokens.device)
|
||||
mask = torch.triu(mask, diagonal=1)
|
||||
qk_per_token_after_masking = qk_per_token + mask
|
||||
qk_per_token_after_masking_after_softmax = torch.nn.functional.softmax(qk_per_token_after_masking, dim=1).to(torch.bfloat16)
|
||||
```
|
||||
|
||||
|
||||
```python
|
||||
v_layer0_head0 = v_layer0[0]
|
||||
v_layer0 = model["layers.0.attention.wv.weight"]
|
||||
v_layer0 = v_layer0.view(n_kv_heads, v_layer0.shape[0] // n_kv_heads, dim)
|
||||
v_per_token = torch.matmul(token_embeddings, v_layer0_head0.T)
|
||||
qkv_attention = torch.matmul(qk_per_token_after_masking_after_softmax, v_per_token)
|
||||
```
|
||||
|
||||
|
||||
```python
|
||||
qkv_attention_store = []
|
||||
|
||||
for head in range(n_heads):
|
||||
q_layer0_head = q_layer0[head]
|
||||
k_layer0_head = k_layer0[head//4] # key weights are shared across 4 heads
|
||||
v_layer0_head = v_layer0[head//4] # value weights are shared across 4 heads
|
||||
q_per_token = torch.matmul(token_embeddings, q_layer0_head.T)
|
||||
k_per_token = torch.matmul(token_embeddings, k_layer0_head.T)
|
||||
v_per_token = torch.matmul(token_embeddings, v_layer0_head.T)
|
||||
|
||||
q_per_token_split_into_pairs = q_per_token.float().view(q_per_token.shape[0], -1, 2)
|
||||
q_per_token_as_complex_numbers = torch.view_as_complex(q_per_token_split_into_pairs)
|
||||
q_per_token_split_into_pairs_rotated = torch.view_as_real(q_per_token_as_complex_numbers * freqs_cis[:len(tokens)])
|
||||
q_per_token_rotated = q_per_token_split_into_pairs_rotated.view(q_per_token.shape)
|
||||
|
||||
k_per_token_split_into_pairs = k_per_token.float().view(k_per_token.shape[0], -1, 2)
|
||||
k_per_token_as_complex_numbers = torch.view_as_complex(k_per_token_split_into_pairs)
|
||||
k_per_token_split_into_pairs_rotated = torch.view_as_real(k_per_token_as_complex_numbers * freqs_cis[:len(tokens)])
|
||||
k_per_token_rotated = k_per_token_split_into_pairs_rotated.view(k_per_token.shape)
|
||||
|
||||
qk_per_token = torch.matmul(q_per_token_rotated, k_per_token_rotated.T)/(128)**0.5
|
||||
mask = torch.full((len(tokens), len(tokens)), float("-inf"), device=tokens.device)
|
||||
mask = torch.triu(mask, diagonal=1)
|
||||
qk_per_token_after_masking = qk_per_token + mask
|
||||
qk_per_token_after_masking_after_softmax = torch.nn.functional.softmax(qk_per_token_after_masking, dim=1).to(torch.bfloat16)
|
||||
qkv_attention = torch.matmul(qk_per_token_after_masking_after_softmax, v_per_token)
|
||||
qkv_attention = torch.matmul(qk_per_token_after_masking_after_softmax, v_per_token)
|
||||
qkv_attention_store.append(qkv_attention)
|
||||
|
||||
# len(qkv_attention_store)
|
||||
stacked_qkv_attention = torch.cat(qkv_attention_store, dim=-1)
|
||||
|
||||
w_layer0 = model["layers.0.attention.wo.weight"]
|
||||
embedding_delta = torch.matmul(stacked_qkv_attention, w_layer0.T)
|
||||
embedding_after_edit = token_embeddings_unnormalized + embedding_delta
|
||||
embedding_after_edit_normalized = rms_norm(embedding_after_edit, model["layers.0.ffn_norm.weight"])
|
||||
w1 = model["layers.0.feed_forward.w1.weight"]
|
||||
w2 = model["layers.0.feed_forward.w2.weight"]
|
||||
w3 = model["layers.0.feed_forward.w3.weight"]
|
||||
output_after_feedforward = torch.matmul(torch.functional.F.silu(torch.matmul(embedding_after_edit_normalized, w1.T)) * torch.matmul(embedding_after_edit_normalized, w3.T), w2.T)
|
||||
layer_0_embedding = embedding_after_edit+output_after_feedforward
|
||||
```
|
||||
|
||||
|
||||
```python
|
||||
final_embedding = token_embeddings_unnormalized
|
||||
for layer in range(n_layers):
|
||||
qkv_attention_store = []
|
||||
layer_embedding_norm = rms_norm(final_embedding, model[f"layers.{layer}.attention_norm.weight"])
|
||||
q_layer = model[f"layers.{layer}.attention.wq.weight"]
|
||||
q_layer = q_layer.view(n_heads, q_layer.shape[0] // n_heads, dim)
|
||||
k_layer = model[f"layers.{layer}.attention.wk.weight"]
|
||||
k_layer = k_layer.view(n_kv_heads, k_layer.shape[0] // n_kv_heads, dim)
|
||||
v_layer = model[f"layers.{layer}.attention.wv.weight"]
|
||||
v_layer = v_layer.view(n_kv_heads, v_layer.shape[0] // n_kv_heads, dim)
|
||||
w_layer = model[f"layers.{layer}.attention.wo.weight"]
|
||||
for head in range(n_heads):
|
||||
q_layer_head = q_layer[head]
|
||||
k_layer_head = k_layer[head//4]
|
||||
v_layer_head = v_layer[head//4]
|
||||
q_per_token = torch.matmul(layer_embedding_norm, q_layer_head.T)
|
||||
k_per_token = torch.matmul(layer_embedding_norm, k_layer_head.T)
|
||||
v_per_token = torch.matmul(layer_embedding_norm, v_layer_head.T)
|
||||
q_per_token_split_into_pairs = q_per_token.float().view(q_per_token.shape[0], -1, 2)
|
||||
q_per_token_as_complex_numbers = torch.view_as_complex(q_per_token_split_into_pairs)
|
||||
q_per_token_split_into_pairs_rotated = torch.view_as_real(q_per_token_as_complex_numbers * freqs_cis)
|
||||
q_per_token_rotated = q_per_token_split_into_pairs_rotated.view(q_per_token.shape)
|
||||
k_per_token_split_into_pairs = k_per_token.float().view(k_per_token.shape[0], -1, 2)
|
||||
k_per_token_as_complex_numbers = torch.view_as_complex(k_per_token_split_into_pairs)
|
||||
k_per_token_split_into_pairs_rotated = torch.view_as_real(k_per_token_as_complex_numbers * freqs_cis)
|
||||
k_per_token_rotated = k_per_token_split_into_pairs_rotated.view(k_per_token.shape)
|
||||
qk_per_token = torch.matmul(q_per_token_rotated, k_per_token_rotated.T)/(128)**0.5
|
||||
mask = torch.full((len(token_embeddings_unnormalized), len(token_embeddings_unnormalized)), float("-inf"))
|
||||
mask = torch.triu(mask, diagonal=1)
|
||||
qk_per_token_after_masking = qk_per_token + mask
|
||||
qk_per_token_after_masking_after_softmax = torch.nn.functional.softmax(qk_per_token_after_masking, dim=1).to(torch.bfloat16)
|
||||
qkv_attention = torch.matmul(qk_per_token_after_masking_after_softmax, v_per_token)
|
||||
qkv_attention_store.append(qkv_attention)
|
||||
|
||||
stacked_qkv_attention = torch.cat(qkv_attention_store, dim=-1)
|
||||
w_layer = model[f"layers.{layer}.attention.wo.weight"]
|
||||
embedding_delta = torch.matmul(stacked_qkv_attention, w_layer.T)
|
||||
embedding_after_edit = final_embedding + embedding_delta
|
||||
embedding_after_edit_normalized = rms_norm(embedding_after_edit, model[f"layers.{layer}.ffn_norm.weight"])
|
||||
w1 = model[f"layers.{layer}.feed_forward.w1.weight"]
|
||||
w2 = model[f"layers.{layer}.feed_forward.w2.weight"]
|
||||
w3 = model[f"layers.{layer}.feed_forward.w3.weight"]
|
||||
output_after_feedforward = torch.matmul(torch.functional.F.silu(torch.matmul(embedding_after_edit_normalized, w1.T)) * torch.matmul(embedding_after_edit_normalized, w3.T), w2.T)
|
||||
final_embedding = embedding_after_edit+output_after_feedforward
|
||||
```
|
||||
|
||||
|
||||
```python
|
||||
final_embedding = rms_norm(final_embedding, model["norm.weight"])
|
||||
logits = torch.matmul(final_embedding[-1], model["output.weight"].T)
|
||||
next_token = torch.argmax(logits, dim=-1)
|
||||
tokenizer.decode([next_token.item()])
|
||||
```
|
||||
|
||||
Reference in New Issue
Block a user