diff --git a/README.md b/README.md index 2316c33..8973fd8 100644 --- a/README.md +++ b/README.md @@ -1580,13 +1580,28 @@ DPO和在线PPO的区别在于reject和chosen都是离线准备的,和minimind ## Ⅳ RoPE长度外推 MiniMind支持通过YaRN算法进行RoPE位置编码的长度外推,使模型能够处理超出训练长度的文本序列。 -在使用`eval_llm.py`进行推理时,只需添加`--inference_rope_scaling`参数即可启用RoPE外推: + +原生torch模型在使用`eval_llm.py`进行推理时,只需添加`--inference_rope_scaling`参数即可启用RoPE外推: ```bash python eval_llm.py --weight full_sft --inference_rope_scaling ``` -下图展示了在不同文本「西游记」白话文小说长度下,使用RoPE scaling前后的困惑度(PPL)对比。可以看出,启用RoPE scaling后,模型在长文本上的表现显著提升: +对于Transformers格式的模型,可以在config.json中添加以下配置实现长度外推: + +```json +"rope_scaling": { + "type": "yarn", + "factor": 16.0, + "original_max_position_embeddings": 2048, + "beta_fast": 32.0, + "beta_slow": 1.0, + "attention_factor": 1.0 +} +``` + +在MiniMind-Small模型上,测试输入不同长度的「西游记」白话文小说,评估RoPE scaling前后的困惑度(PPL)对比。 +可以看出,启用YaRN外推后,模型在长文本上的PPL表现显著下降:
diff --git a/README_en.md b/README_en.md index d4e2d3d..8c492b8 100644 --- a/README_en.md +++ b/README_en.md @@ -1547,13 +1547,32 @@ Personal subjective evaluation basically aligns with DeepSeek-R1, where: ## Ⅳ RoPE Long-text Extrapolation MiniMind supports RoPE position encoding length extrapolation through YaRN algorithm, enabling models to handle text sequences exceeding training length. -When using `eval_llm.py` for inference, just add `--inference_rope_scaling` parameter to enable RoPE extrapolation: + +For native torch models, when using `eval_llm.py` for inference, just add `--inference_rope_scaling` parameter to enable RoPE extrapolation: ```bash python eval_llm.py --weight full_sft --inference_rope_scaling ``` -The chart below shows perplexity (PPL) comparison before and after RoPE scaling on different lengths of "Journey to the West" vernacular fiction text. You can see that after enabling RoPE scaling, model performance on long texts is significantly improved. +For Transformers format models, add the following configuration to config.json to enable length extrapolation: + +```json +"rope_scaling": { + "type": "yarn", + "factor": 16.0, + "original_max_position_embeddings": 2048, + "beta_fast": 32.0, + "beta_slow": 1.0, + "attention_factor": 1.0 +} +``` + +Testing on MiniMind-Small model with different lengths of "Journey to the West" vernacular fiction text to evaluate perplexity (PPL) comparison before and after RoPE scaling. +You can see that after enabling YaRN extrapolation, the model's PPL performance on long texts significantly decreases: + +
+ +
## Ⅴ Objective Benchmarks diff --git a/images/rope_ppl.png b/images/rope_ppl.png index 223292e..ea07260 100644 Binary files a/images/rope_ppl.png and b/images/rope_ppl.png differ diff --git a/model/model_minimind.py b/model/model_minimind.py index 259f0af..4245000 100755 --- a/model/model_minimind.py +++ b/model/model_minimind.py @@ -54,12 +54,13 @@ class MiniMindConfig(PretrainedConfig): self.rms_norm_eps = rms_norm_eps self.rope_theta = rope_theta self.inference_rope_scaling = inference_rope_scaling - # 外推长度 = factor * original_max_position_embeddings + # 外推长度 = factor * original_max_position_embeddings = 32768 self.rope_scaling = { - "beta_fast": 4, + "beta_fast": 32, "beta_slow": 1, - "factor": 4, + "factor": 16, "original_max_position_embeddings": 2048, + "attention_factor": 1.0, "type": "yarn" } if self.inference_rope_scaling else None self.flash_attn = flash_attn @@ -107,24 +108,23 @@ class RMSNorm(torch.nn.Module): def precompute_freqs_cis(dim: int, end: int = int(32 * 1024), rope_base: float = 1e6, rope_scaling: Optional[dict] = None): - freqs = 1.0 / (rope_base ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + freqs, attn_factor = 1.0 / (rope_base ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)), 1.0 if rope_scaling is not None: - orig_max, factor, beta_fast, beta_slow = ( - rope_scaling.get("original_max_position_embeddings", 2048), rope_scaling.get("factor", 4), - rope_scaling.get("beta_fast", 4.0), rope_scaling.get("beta_slow", 1.0) + orig_max, factor, beta_fast, beta_slow, attn_factor = ( + rope_scaling.get("original_max_position_embeddings", 2048), rope_scaling.get("factor", 16), + rope_scaling.get("beta_fast", 32.0), rope_scaling.get("beta_slow", 1.0), rope_scaling.get("attention_factor", 1.0) ) if end / orig_max > 1.0: - corr_dim = next((i for i in range(dim // 2) if 2 * math.pi / freqs[i] > orig_max), dim // 2) - power = torch.arange(0, dim // 2, device=freqs.device).float() / max(dim // 2 - 1, 1) - beta = beta_slow + (beta_fast - beta_slow) * power - # λ = (β·α - β + 1)/(β·α) YaRN标准公式 - scale = torch.where(torch.arange(dim // 2, device=freqs.device) < corr_dim, (beta * factor - beta + 1) / (beta * factor), 1.0 / factor) - freqs = freqs * scale + # YaRN: f'(i) = f(i)((1-γ) + γ/s), where γ∈[0,1] is linear ramp + inv_dim = lambda b: (dim * math.log(orig_max / (b * 2 * math.pi))) / (2 * math.log(rope_base)) + low, high = max(math.floor(inv_dim(beta_fast)), 0), min(math.ceil(inv_dim(beta_slow)), dim // 2 - 1) + ramp = torch.clamp((torch.arange(dim // 2, device=freqs.device).float() - low) / max(high - low, 0.001), 0, 1) + freqs = freqs * (1 - ramp + ramp / factor) t = torch.arange(end, device=freqs.device) freqs = torch.outer(t, freqs).float() - freqs_cos = torch.cat([torch.cos(freqs), torch.cos(freqs)], dim=-1) - freqs_sin = torch.cat([torch.sin(freqs), torch.sin(freqs)], dim=-1) + freqs_cos = torch.cat([torch.cos(freqs), torch.cos(freqs)], dim=-1) * attn_factor + freqs_sin = torch.cat([torch.sin(freqs), torch.sin(freqs)], dim=-1) * attn_factor return freqs_cos, freqs_sin