[feat] update yarn

This commit is contained in:
jingyaogong 2025-12-01 16:15:05 +08:00
parent 6b86ea399a
commit 151fdf7e76
4 changed files with 53 additions and 19 deletions

View File

@ -1580,13 +1580,28 @@ DPO和在线PPO的区别在于reject和chosen都是离线准备的和minimind
## Ⅳ RoPE长度外推 ## Ⅳ RoPE长度外推
MiniMind支持通过YaRN算法进行RoPE位置编码的长度外推使模型能够处理超出训练长度的文本序列。 MiniMind支持通过YaRN算法进行RoPE位置编码的长度外推使模型能够处理超出训练长度的文本序列。
在使用`eval_llm.py`进行推理时,只需添加`--inference_rope_scaling`参数即可启用RoPE外推
原生torch模型在使用`eval_llm.py`进行推理时,只需添加`--inference_rope_scaling`参数即可启用RoPE外推
```bash ```bash
python eval_llm.py --weight full_sft --inference_rope_scaling python eval_llm.py --weight full_sft --inference_rope_scaling
``` ```
下图展示了在不同文本「西游记」白话文小说长度下使用RoPE scaling前后的困惑度(PPL)对比。可以看出启用RoPE scaling后模型在长文本上的表现显著提升 对于Transformers格式的模型可以在config.json中添加以下配置实现长度外推
```json
"rope_scaling": {
"type": "yarn",
"factor": 16.0,
"original_max_position_embeddings": 2048,
"beta_fast": 32.0,
"beta_slow": 1.0,
"attention_factor": 1.0
}
```
在MiniMind-Small模型上测试输入不同长度的「西游记」白话文小说评估RoPE scaling前后的困惑度(PPL)对比。
可以看出启用YaRN外推后模型在长文本上的PPL表现显著下降
<div align="center"> <div align="center">
<img src="./images/rope_ppl.png"> <img src="./images/rope_ppl.png">

View File

@ -1547,13 +1547,32 @@ Personal subjective evaluation basically aligns with DeepSeek-R1, where:
## Ⅳ RoPE Long-text Extrapolation ## Ⅳ RoPE Long-text Extrapolation
MiniMind supports RoPE position encoding length extrapolation through YaRN algorithm, enabling models to handle text sequences exceeding training length. MiniMind supports RoPE position encoding length extrapolation through YaRN algorithm, enabling models to handle text sequences exceeding training length.
When using `eval_llm.py` for inference, just add `--inference_rope_scaling` parameter to enable RoPE extrapolation:
For native torch models, when using `eval_llm.py` for inference, just add `--inference_rope_scaling` parameter to enable RoPE extrapolation:
```bash ```bash
python eval_llm.py --weight full_sft --inference_rope_scaling python eval_llm.py --weight full_sft --inference_rope_scaling
``` ```
The chart below shows perplexity (PPL) comparison before and after RoPE scaling on different lengths of "Journey to the West" vernacular fiction text. You can see that after enabling RoPE scaling, model performance on long texts is significantly improved. For Transformers format models, add the following configuration to config.json to enable length extrapolation:
```json
"rope_scaling": {
"type": "yarn",
"factor": 16.0,
"original_max_position_embeddings": 2048,
"beta_fast": 32.0,
"beta_slow": 1.0,
"attention_factor": 1.0
}
```
Testing on MiniMind-Small model with different lengths of "Journey to the West" vernacular fiction text to evaluate perplexity (PPL) comparison before and after RoPE scaling.
You can see that after enabling YaRN extrapolation, the model's PPL performance on long texts significantly decreases:
<div align="center">
<img src="./images/rope_ppl.png">
</div>
## Objective Benchmarks ## Objective Benchmarks

Binary file not shown.

Before

Width:  |  Height:  |  Size: 144 KiB

After

Width:  |  Height:  |  Size: 79 KiB

View File

@ -54,12 +54,13 @@ class MiniMindConfig(PretrainedConfig):
self.rms_norm_eps = rms_norm_eps self.rms_norm_eps = rms_norm_eps
self.rope_theta = rope_theta self.rope_theta = rope_theta
self.inference_rope_scaling = inference_rope_scaling self.inference_rope_scaling = inference_rope_scaling
# 外推长度 = factor * original_max_position_embeddings # 外推长度 = factor * original_max_position_embeddings = 32768
self.rope_scaling = { self.rope_scaling = {
"beta_fast": 4, "beta_fast": 32,
"beta_slow": 1, "beta_slow": 1,
"factor": 4, "factor": 16,
"original_max_position_embeddings": 2048, "original_max_position_embeddings": 2048,
"attention_factor": 1.0,
"type": "yarn" "type": "yarn"
} if self.inference_rope_scaling else None } if self.inference_rope_scaling else None
self.flash_attn = flash_attn self.flash_attn = flash_attn
@ -107,24 +108,23 @@ class RMSNorm(torch.nn.Module):
def precompute_freqs_cis(dim: int, end: int = int(32 * 1024), rope_base: float = 1e6, def precompute_freqs_cis(dim: int, end: int = int(32 * 1024), rope_base: float = 1e6,
rope_scaling: Optional[dict] = None): rope_scaling: Optional[dict] = None):
freqs = 1.0 / (rope_base ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) freqs, attn_factor = 1.0 / (rope_base ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)), 1.0
if rope_scaling is not None: if rope_scaling is not None:
orig_max, factor, beta_fast, beta_slow = ( orig_max, factor, beta_fast, beta_slow, attn_factor = (
rope_scaling.get("original_max_position_embeddings", 2048), rope_scaling.get("factor", 4), rope_scaling.get("original_max_position_embeddings", 2048), rope_scaling.get("factor", 16),
rope_scaling.get("beta_fast", 4.0), rope_scaling.get("beta_slow", 1.0) rope_scaling.get("beta_fast", 32.0), rope_scaling.get("beta_slow", 1.0), rope_scaling.get("attention_factor", 1.0)
) )
if end / orig_max > 1.0: if end / orig_max > 1.0:
corr_dim = next((i for i in range(dim // 2) if 2 * math.pi / freqs[i] > orig_max), dim // 2) # YaRN: f'(i) = f(i)((1-γ) + γ/s), where γ∈[0,1] is linear ramp
power = torch.arange(0, dim // 2, device=freqs.device).float() / max(dim // 2 - 1, 1) inv_dim = lambda b: (dim * math.log(orig_max / (b * 2 * math.pi))) / (2 * math.log(rope_base))
beta = beta_slow + (beta_fast - beta_slow) * power low, high = max(math.floor(inv_dim(beta_fast)), 0), min(math.ceil(inv_dim(beta_slow)), dim // 2 - 1)
# λ = (β·α - β + 1)/(β·α) YaRN标准公式 ramp = torch.clamp((torch.arange(dim // 2, device=freqs.device).float() - low) / max(high - low, 0.001), 0, 1)
scale = torch.where(torch.arange(dim // 2, device=freqs.device) < corr_dim, (beta * factor - beta + 1) / (beta * factor), 1.0 / factor) freqs = freqs * (1 - ramp + ramp / factor)
freqs = freqs * scale
t = torch.arange(end, device=freqs.device) t = torch.arange(end, device=freqs.device)
freqs = torch.outer(t, freqs).float() freqs = torch.outer(t, freqs).float()
freqs_cos = torch.cat([torch.cos(freqs), torch.cos(freqs)], dim=-1) freqs_cos = torch.cat([torch.cos(freqs), torch.cos(freqs)], dim=-1) * attn_factor
freqs_sin = torch.cat([torch.sin(freqs), torch.sin(freqs)], dim=-1) freqs_sin = torch.cat([torch.sin(freqs), torch.sin(freqs)], dim=-1) * attn_factor
return freqs_cos, freqs_sin return freqs_cos, freqs_sin