mirror of
https://github.com/jingyaogong/minimind.git
synced 2026-01-13 19:57:20 +08:00
[feat] update yarn
This commit is contained in:
parent
6b86ea399a
commit
151fdf7e76
19
README.md
19
README.md
@ -1580,13 +1580,28 @@ DPO和在线PPO的区别在于reject和chosen都是离线准备的,和minimind
|
||||
## Ⅳ RoPE长度外推
|
||||
|
||||
MiniMind支持通过YaRN算法进行RoPE位置编码的长度外推,使模型能够处理超出训练长度的文本序列。
|
||||
在使用`eval_llm.py`进行推理时,只需添加`--inference_rope_scaling`参数即可启用RoPE外推:
|
||||
|
||||
原生torch模型在使用`eval_llm.py`进行推理时,只需添加`--inference_rope_scaling`参数即可启用RoPE外推:
|
||||
|
||||
```bash
|
||||
python eval_llm.py --weight full_sft --inference_rope_scaling
|
||||
```
|
||||
|
||||
下图展示了在不同文本「西游记」白话文小说长度下,使用RoPE scaling前后的困惑度(PPL)对比。可以看出,启用RoPE scaling后,模型在长文本上的表现显著提升:
|
||||
对于Transformers格式的模型,可以在config.json中添加以下配置实现长度外推:
|
||||
|
||||
```json
|
||||
"rope_scaling": {
|
||||
"type": "yarn",
|
||||
"factor": 16.0,
|
||||
"original_max_position_embeddings": 2048,
|
||||
"beta_fast": 32.0,
|
||||
"beta_slow": 1.0,
|
||||
"attention_factor": 1.0
|
||||
}
|
||||
```
|
||||
|
||||
在MiniMind-Small模型上,测试输入不同长度的「西游记」白话文小说,评估RoPE scaling前后的困惑度(PPL)对比。
|
||||
可以看出,启用YaRN外推后,模型在长文本上的PPL表现显著下降:
|
||||
|
||||
<div align="center">
|
||||
<img src="./images/rope_ppl.png">
|
||||
|
||||
23
README_en.md
23
README_en.md
@ -1547,13 +1547,32 @@ Personal subjective evaluation basically aligns with DeepSeek-R1, where:
|
||||
## Ⅳ RoPE Long-text Extrapolation
|
||||
|
||||
MiniMind supports RoPE position encoding length extrapolation through YaRN algorithm, enabling models to handle text sequences exceeding training length.
|
||||
When using `eval_llm.py` for inference, just add `--inference_rope_scaling` parameter to enable RoPE extrapolation:
|
||||
|
||||
For native torch models, when using `eval_llm.py` for inference, just add `--inference_rope_scaling` parameter to enable RoPE extrapolation:
|
||||
|
||||
```bash
|
||||
python eval_llm.py --weight full_sft --inference_rope_scaling
|
||||
```
|
||||
|
||||
The chart below shows perplexity (PPL) comparison before and after RoPE scaling on different lengths of "Journey to the West" vernacular fiction text. You can see that after enabling RoPE scaling, model performance on long texts is significantly improved.
|
||||
For Transformers format models, add the following configuration to config.json to enable length extrapolation:
|
||||
|
||||
```json
|
||||
"rope_scaling": {
|
||||
"type": "yarn",
|
||||
"factor": 16.0,
|
||||
"original_max_position_embeddings": 2048,
|
||||
"beta_fast": 32.0,
|
||||
"beta_slow": 1.0,
|
||||
"attention_factor": 1.0
|
||||
}
|
||||
```
|
||||
|
||||
Testing on MiniMind-Small model with different lengths of "Journey to the West" vernacular fiction text to evaluate perplexity (PPL) comparison before and after RoPE scaling.
|
||||
You can see that after enabling YaRN extrapolation, the model's PPL performance on long texts significantly decreases:
|
||||
|
||||
<div align="center">
|
||||
<img src="./images/rope_ppl.png">
|
||||
</div>
|
||||
|
||||
## Ⅴ Objective Benchmarks
|
||||
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 144 KiB After Width: | Height: | Size: 79 KiB |
@ -54,12 +54,13 @@ class MiniMindConfig(PretrainedConfig):
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.rope_theta = rope_theta
|
||||
self.inference_rope_scaling = inference_rope_scaling
|
||||
# 外推长度 = factor * original_max_position_embeddings
|
||||
# 外推长度 = factor * original_max_position_embeddings = 32768
|
||||
self.rope_scaling = {
|
||||
"beta_fast": 4,
|
||||
"beta_fast": 32,
|
||||
"beta_slow": 1,
|
||||
"factor": 4,
|
||||
"factor": 16,
|
||||
"original_max_position_embeddings": 2048,
|
||||
"attention_factor": 1.0,
|
||||
"type": "yarn"
|
||||
} if self.inference_rope_scaling else None
|
||||
self.flash_attn = flash_attn
|
||||
@ -107,24 +108,23 @@ class RMSNorm(torch.nn.Module):
|
||||
|
||||
def precompute_freqs_cis(dim: int, end: int = int(32 * 1024), rope_base: float = 1e6,
|
||||
rope_scaling: Optional[dict] = None):
|
||||
freqs = 1.0 / (rope_base ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
||||
freqs, attn_factor = 1.0 / (rope_base ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)), 1.0
|
||||
if rope_scaling is not None:
|
||||
orig_max, factor, beta_fast, beta_slow = (
|
||||
rope_scaling.get("original_max_position_embeddings", 2048), rope_scaling.get("factor", 4),
|
||||
rope_scaling.get("beta_fast", 4.0), rope_scaling.get("beta_slow", 1.0)
|
||||
orig_max, factor, beta_fast, beta_slow, attn_factor = (
|
||||
rope_scaling.get("original_max_position_embeddings", 2048), rope_scaling.get("factor", 16),
|
||||
rope_scaling.get("beta_fast", 32.0), rope_scaling.get("beta_slow", 1.0), rope_scaling.get("attention_factor", 1.0)
|
||||
)
|
||||
if end / orig_max > 1.0:
|
||||
corr_dim = next((i for i in range(dim // 2) if 2 * math.pi / freqs[i] > orig_max), dim // 2)
|
||||
power = torch.arange(0, dim // 2, device=freqs.device).float() / max(dim // 2 - 1, 1)
|
||||
beta = beta_slow + (beta_fast - beta_slow) * power
|
||||
# λ = (β·α - β + 1)/(β·α) YaRN标准公式
|
||||
scale = torch.where(torch.arange(dim // 2, device=freqs.device) < corr_dim, (beta * factor - beta + 1) / (beta * factor), 1.0 / factor)
|
||||
freqs = freqs * scale
|
||||
# YaRN: f'(i) = f(i)((1-γ) + γ/s), where γ∈[0,1] is linear ramp
|
||||
inv_dim = lambda b: (dim * math.log(orig_max / (b * 2 * math.pi))) / (2 * math.log(rope_base))
|
||||
low, high = max(math.floor(inv_dim(beta_fast)), 0), min(math.ceil(inv_dim(beta_slow)), dim // 2 - 1)
|
||||
ramp = torch.clamp((torch.arange(dim // 2, device=freqs.device).float() - low) / max(high - low, 0.001), 0, 1)
|
||||
freqs = freqs * (1 - ramp + ramp / factor)
|
||||
|
||||
t = torch.arange(end, device=freqs.device)
|
||||
freqs = torch.outer(t, freqs).float()
|
||||
freqs_cos = torch.cat([torch.cos(freqs), torch.cos(freqs)], dim=-1)
|
||||
freqs_sin = torch.cat([torch.sin(freqs), torch.sin(freqs)], dim=-1)
|
||||
freqs_cos = torch.cat([torch.cos(freqs), torch.cos(freqs)], dim=-1) * attn_factor
|
||||
freqs_sin = torch.cat([torch.sin(freqs), torch.sin(freqs)], dim=-1) * attn_factor
|
||||
return freqs_cos, freqs_sin
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user