diff --git a/README.md b/README.md index 8ecc4c5..537c9c3 100644 --- a/README.md +++ b/README.md @@ -123,6 +123,43 @@ 希望此开源项目可以帮助LLM初学者快速入门! +## 📚 新手学习路线(推荐先看) + +如果你是 LLM 新手,建议先看这份路线图,先跑通训练闭环再读源码: +`docs/learning_path.md` + +## 🧪 动态神经元生长实验(论文主线) + +已提供动态神经元生长(FFN 级别)的实验指南与对照命令: +`docs/experiments_dynamic_growth.md` + +固定 prompt 评测脚本: +`scripts/eval_fixed_prompts.py` + +PPL/交叉熵评测脚本: +`scripts/eval_ppl.py` + +一键对照实验脚本(默认只打印命令,`--run` 才执行): +`scripts/run_growth_sweep.py` + +固定验证集生成脚本: +`scripts/make_val_split.py` + +实验结果模板: +`eval/results_template.csv` + +评测结果汇总脚本: +`scripts/aggregate_eval.py` + +快速绘图脚本: +`scripts/plot_growth.py` + +论文一键流水线(训练/评测/汇总/绘图): +`scripts/run_paper_pipeline.py`(配置文件:`eval/pipeline_config.json`) + +过夜一键脚本(自动拉取数据/强容错/自动清缓存): +`scripts/run_overnight_pipeline.py`(配置文件:`eval/pipeline_config_overnight.json`) + ### 👉**更新日志**
@@ -1965,5 +2002,3 @@ If you find MiniMind helpful in your research or work, please cite: # License This repository is licensed under the [Apache-2.0 License](LICENSE). - - diff --git a/docs/experiments_dynamic_growth.md b/docs/experiments_dynamic_growth.md new file mode 100644 index 0000000..cfa3b23 --- /dev/null +++ b/docs/experiments_dynamic_growth.md @@ -0,0 +1,225 @@ +# 动态神经元生长论文实验指南(MiniMind) + +这份文档帮助你把“动态神经元生长(FFN 级别)”做成**可复现实验**,并形成论文所需的对照与记录。 + +也可直接使用一键流水线: + +```bash +python scripts/run_paper_pipeline.py --run +``` + +配置文件在:`eval/pipeline_config.json` + +--- + +## 1. 实验主线与对照设置 + +**主方法** +- 活动 + 梯度驱动生长(`grow_method=act_grad`) + +**对照基线** +- Baseline:无生长(`neuron_growth=0`) +- Random:随机生长(`grow_method=random`) +- Grad-only:纯梯度生长(`grow_method=act_grad, grow_score_alpha=0, grow_score_beta=1`) +- Act-only:纯活动生长(`grow_method=act_grad, grow_score_alpha=1, grow_score_beta=0`) + +--- + +## 2. 推荐实验命令(预训练) + +> 下面命令以 `train_pretrain.py` 为例;SFT 使用 `train_full_sft.py` 同理。 + +**Baseline(无生长)** +```bash +python trainer/train_pretrain.py \ + --save_weight pretrain_baseline \ + --neuron_growth 0 +``` + +**Random 生长** +```bash +python trainer/train_pretrain.py \ + --save_weight pretrain_random \ + --neuron_growth 1 \ + --init_active_ratio 0.8 \ + --grow_method random \ + --grow_interval 100 \ + --grow_ratio 0.02 \ + --max_active_ratio 0.99 +``` + +**Grad-only 生长(只看梯度)** +```bash +python trainer/train_pretrain.py \ + --save_weight pretrain_grad \ + --neuron_growth 1 \ + --init_active_ratio 0.8 \ + --grow_method act_grad \ + --grow_interval 100 \ + --grow_ratio 0.02 \ + --max_active_ratio 0.99 \ + --grow_score_alpha 0.0 \ + --grow_score_beta 1.0 +``` + +**Act-only 生长(只看活动)** +```bash +python trainer/train_pretrain.py \ + --save_weight pretrain_act \ + --neuron_growth 1 \ + --init_active_ratio 0.8 \ + --grow_method act_grad \ + --grow_interval 100 \ + --grow_ratio 0.02 \ + --max_active_ratio 0.99 \ + --grow_score_alpha 1.0 \ + --grow_score_beta 0.0 +``` + +**Act+Grad(主方法)** +```bash +python trainer/train_pretrain.py \ + --save_weight pretrain_actgrad \ + --neuron_growth 1 \ + --init_active_ratio 0.8 \ + --grow_method act_grad \ + --grow_interval 100 \ + --grow_ratio 0.02 \ + --max_active_ratio 0.99 \ + --grow_score_alpha 1.0 \ + --grow_score_beta 1.0 +``` + +--- + +## 3. 推荐实验命令(SFT) + +将脚本换成 `trainer/train_full_sft.py`,其余参数相同。 + +```bash +python trainer/train_full_sft.py \ + --save_weight full_sft_actgrad \ + --neuron_growth 1 \ + --init_active_ratio 0.8 \ + --grow_method act_grad \ + --grow_interval 100 \ + --grow_ratio 0.02 \ + --max_active_ratio 0.99 \ + --grow_score_alpha 1.0 \ + --grow_score_beta 1.0 +``` + +--- + +## 4. 关键参数解释(建议先固定) + +- `init_active_ratio`:初始激活比例,推荐 0.8 +- `grow_interval`:每隔多少次“优化器更新步”生长一次(非 batch) +- `grow_ratio`:每次激活比例(建议 0.01~0.05) +- `max_active_ratio`:最大激活比例(0.95~1.0) +- `grow_score_alpha/beta`:活动/梯度的权重 +- `neuron_ema_beta`:活动 EMA 的系数(0.05~0.2) + +--- + +## 5. 记录模板(写论文时必备) + +建议你每次训练记录以下信息(最好表格化): + +| Run Name | Model Size | Data Version | Steps | LR | Batch | Seq Len | Growth Method | init_ratio | grow_interval | grow_ratio | max_ratio | PPL | Notes | +|---|---|---|---|---|---|---|---|---|---|---|---|---|---| + +**示例记录** +- Run Name: pretrain_actgrad_v1 +- Model Size: 26M (hidden=512, layers=8) +- Data Version: pretrain_hq_v1 +- Steps: 20k +- Growth Method: act_grad +- PPL: 18.3 +- Notes: stable, active_ratio≈0.97 + +训练脚本会自动在 `save_dir` 下生成 `*_config.json`,包含: +- 超参配置 +- 随机种子 +- git commit +- 训练耗时、总 tokens(训练结束后写入) + +--- + +## 6. 轻量评测(建议统一流程) + +建议使用固定 prompt 列表进行对比评测: + +```bash +python scripts/eval_fixed_prompts.py \ + --weight pretrain_actgrad \ + --prompts_file eval/prompts_minimal.jsonl \ + --out_dir eval_runs \ + --config eval/eval_config.json +``` + +建议先生成一个**固定验证集**(保证实验可复现): + +```bash +python scripts/make_val_split.py \ + --data_path dataset/pretrain_hq.jsonl \ + --out_path eval/val_pretrain.jsonl \ + --val_size 2000 \ + --seed 42 +``` + +建议再跑一次 PPL/交叉熵评测,作为论文中的定量指标: + +```bash +python scripts/eval_ppl.py \ + --weight pretrain_actgrad \ + --data_path eval/val_pretrain.jsonl \ + --max_seq_len 340 \ + --batch_size 8 +``` + +结果可填入模板:`eval/results_template.csv` + +可以用汇总脚本生成 CSV: + +```bash +python scripts/aggregate_eval.py --inputs eval_runs_ppl --out eval/summary_ppl.csv +``` + +并用绘图脚本快速画图: + +```bash +python scripts/plot_growth.py --csv eval/summary_ppl.csv --x weight --y ppl --out eval/plot_ppl.png +``` + +如果你想统计多 seed 的均值/方差: + +```bash +python scripts/aggregate_eval.py \ + --inputs eval_runs_ppl \ + --out eval/summary_ppl.csv \ + --group_by method \ + --out_grouped eval/summary_ppl_grouped.csv +``` + +--- + +## 7. 最小可发表的实验组合 + +1. Baseline +2. Random +3. Grad-only +4. Act-only +5. Act+Grad(主方法) + +只要主方法在相同预算下优于 1~4,即有论文价值。 + +建议每个设置至少跑 3 个随机种子,可用批量脚本: + +```bash +python scripts/run_growth_sweep.py \ + --script trainer/train_pretrain.py \ + --prefix exp \ + --seeds 42,123,2026 \ + --base_args \"--epochs 1 --batch_size 32\" +``` diff --git a/docs/learning_path.md b/docs/learning_path.md new file mode 100644 index 0000000..7993309 --- /dev/null +++ b/docs/learning_path.md @@ -0,0 +1,124 @@ +# MiniMind 新手学习路线(含基础知识速查) + +这份指南面向 **LLM 零基础/初学者**,目标是用最短路径理解并跑通 MiniMind 的训练闭环,然后逐步深入源码细节。 + +**你将得到什么** +- 一条清晰的学习路线(从“能跑通”到“看懂源码”) +- 训练前需要的最少基础知识 +- 常见报错与排错方向 +- 必要的概念速查(形状、loss、混合精度、DDP 等) + +--- + +## 1. 最短可行路线(先跑通再深入) + +**目标**:先完整跑通一次预训练,再回头读源码。 + +1. 确保能运行 `trainer/train_pretrain.py`(最小数据集也行) +2. 跑一次训练并看到 loss 输出 +3. 用 `eval_llm.py` 做一次推理 +4. 再开始读源码(按推荐顺序) + +> 经验:先跑通再读源码,比一开始就硬啃模型代码更高效。 + +--- + +## 2. 源码阅读顺序(强烈推荐) + +1. `trainer/train_pretrain.py` +2. `dataset/lm_dataset.py` +3. `model/model_minimind.py` +4. `trainer/trainer_utils.py` +5. `eval_llm.py` + +**原因**:这条顺序刚好对应“数据 → 模型 → 训练 → 保存/恢复 → 推理”。 + +--- + +## 3. 必备基础知识(最少集合) + +**Python 基础** +- 列表重复:`[0] * N` +- 切片:`a[:-1]` / `a[1:]` +- `...`(Ellipsis)用于多维切片 +- 类与 `self` 的含义 + +**PyTorch 最小闭环** +- forward → loss → backward → step +- `Dataset / DataLoader` +- `Tensor` 的 shape / dtype / device +- `torch.nn.functional.cross_entropy` + +**LLM 训练最小概念** +- “预测下一个 token”的 shift 逻辑 +- `labels` 中 `-100` 表示忽略 loss +- `batch_size / seq_len / vocab_size` 的形状关系 + +--- + +## 4. 形状速查(理解训练最关键) + +**常见张量形状** +- `input_ids`: `(batch_size, seq_len)` +- `logits`: `(batch_size, seq_len, vocab_size)` +- `shift_logits`: `(batch_size, seq_len-1, vocab_size)` +- `shift_labels`: `(batch_size, seq_len-1)` + +**loss 逻辑** +- 模型学的是“预测下一个 token” +- 所以要让 `logits[:, :-1]` 对齐 `labels[:, 1:]` + +--- + +## 5. 训练参数怎么理解(新手版) + +- `batch_size`: 越大越吃显存 +- `max_seq_len`: 越大显存增长很快 +- `hidden_size / num_hidden_layers`: 决定模型规模 +- `learning_rate`: 太大容易发散 +- `accumulation_steps`: 让“显存小”也能模拟大 batch +- `dtype`: `bfloat16` 稳,`float16` 快但更易不稳 + +--- + +## 6. 常见问题与排错顺序 + +**1) 找不到数据文件** +- 优先检查 `--data_path` 是否存在 + +**2) 报错提示 `../model` 或 `../dataset`** +- 说明运行路径不对 +- 建议在 `trainer/` 目录下运行脚本 + +**3) OOM(显存不够)** +- 优先降低 `batch_size` +- 再降低 `max_seq_len` +- 再减小 `hidden_size/num_hidden_layers` + +**4) loss 变 NaN** +- 先把 `learning_rate` 降 10 倍 +- 检查数据是否异常 + +--- + +## 7. 推荐练习(快速提升理解) + +1. 用 5 行数据过拟合(loss 应快速下降) +2. 把 `max_seq_len` 改大,观察显存变化 +3. 把 `learning_rate` 调大,观察 loss 发散 +4. 改 `hidden_size`,理解模型规模与速度关系 + +--- + +## 8. 下一步建议 + +当你能解释下面三件事,就说明已经入门: +- 文本是如何变成 `input_ids/labels` 的 +- loss 在哪里算、为什么要 shift +- 训练循环里参数是如何更新的 + +如果还不确定,可以回到: +- `trainer/train_pretrain.py` +- `dataset/lm_dataset.py` +- `model/model_minimind.py` + diff --git a/eval/eval_config.json b/eval/eval_config.json new file mode 100644 index 0000000..9e3be2a --- /dev/null +++ b/eval/eval_config.json @@ -0,0 +1,8 @@ +{ + "max_new_tokens": 256, + "temperature": 0.7, + "top_p": 0.9, + "do_sample": 0, + "seed": 2026, + "use_chat": -1 +} diff --git a/eval/pipeline_config.json b/eval/pipeline_config.json new file mode 100644 index 0000000..015d9fd --- /dev/null +++ b/eval/pipeline_config.json @@ -0,0 +1,71 @@ +{ + "stages": ["make_val", "train_pretrain", "eval_ppl", "eval_prompts", "aggregate", "plot"], + "paths": { + "pretrain_data": "dataset/pretrain_hq.jsonl", + "sft_data": "dataset/sft_mini_512.jsonl", + "val_data": "eval/val_pretrain.jsonl", + "prompts": "eval/prompts_minimal.jsonl", + "eval_config": "eval/eval_config.json", + "out_dir": "out", + "eval_runs": "eval_runs", + "eval_runs_ppl": "eval_runs_ppl", + "summary_csv": "eval/summary_ppl.csv", + "plot_png": "eval/plot_ppl.png" + }, + "pretrain": { + "script": "trainer/train_pretrain.py", + "save_prefix": "exp_pretrain", + "args": { + "epochs": 1, + "batch_size": 32, + "learning_rate": 0.0005, + "dtype": "bfloat16", + "num_workers": 8, + "accumulation_steps": 8, + "grad_clip": 1.0, + "max_seq_len": 340, + "hidden_size": 512, + "num_hidden_layers": 8, + "use_moe": 0, + "use_compile": 0 + }, + "flags": [] + }, + "sft": { + "enabled": false, + "script": "trainer/train_full_sft.py", + "save_prefix": "exp_sft", + "from_weight_mode": "fixed", + "from_weight": "pretrain", + "args": { + "epochs": 1, + "batch_size": 16, + "learning_rate": 1e-6, + "dtype": "bfloat16", + "num_workers": 8, + "accumulation_steps": 1, + "grad_clip": 1.0, + "max_seq_len": 340, + "hidden_size": 512, + "num_hidden_layers": 8, + "use_moe": 0, + "use_compile": 0 + }, + "flags": [] + }, + "growth": { + "init_active_ratio": 0.8, + "grow_interval": 100, + "grow_ratio": 0.02, + "max_active_ratio": 0.99, + "neuron_ema_beta": 0.1 + }, + "methods": ["baseline", "random", "grad", "act", "actgrad"], + "seeds": [42, 123, 2026], + "eval": { + "target": "pretrain", + "max_seq_len": 340, + "batch_size": 8, + "max_samples": 0 + } +} diff --git a/eval/pipeline_config_overnight.json b/eval/pipeline_config_overnight.json new file mode 100644 index 0000000..8301ad7 --- /dev/null +++ b/eval/pipeline_config_overnight.json @@ -0,0 +1,107 @@ +{ + "stages": [ + "preflight", + "prepare", + "download", + "make_val", + "train_pretrain", + "train_sft", + "eval_ppl", + "eval_prompts", + "aggregate", + "plot", + "cleanup" + ], + "dataset": { + "repo": "https://www.modelscope.cn/datasets/gongjy/minimind_dataset.git", + "target_dir": "minimind_dataset", + "download_mode": "selective", + "stage_files": { + "make_val": ["pretrain_hq.jsonl"], + "train_pretrain": ["pretrain_hq.jsonl"], + "train_sft": ["sft_512.jsonl"], + "eval_ppl": ["pretrain_hq.jsonl"] + } + }, + "paths": { + "pretrain_data": "minimind_dataset/pretrain_hq.jsonl", + "sft_data": "minimind_dataset/sft_512.jsonl", + "val_data": "eval/val_pretrain.jsonl", + "prompts": "eval/prompts_minimal.jsonl", + "eval_config": "eval/eval_config.json", + "out_dir": "out", + "eval_runs": "eval_runs", + "eval_runs_ppl": "eval_runs_ppl", + "summary_csv": "eval/summary_ppl.csv", + "plot_png": "eval/plot_ppl.png", + "log_dir": "logs" + }, + "cache": { + "hf_home": ".cache/hf_home", + "hf_datasets": ".cache/hf_datasets", + "transformers": ".cache/hf_transformers" + }, + "runtime": { + "continue_on_error": true, + "min_free_gb": 6, + "retry_download": 2, + "cleanup_after_stage": ["train_pretrain", "train_sft", "eval_ppl", "eval_prompts"], + "cleanup_paths": [".cache/hf_datasets", ".cache/hf_transformers", ".cache/hf_home", "__pycache__"] + }, + "pretrain": { + "script": "trainer/train_pretrain.py", + "save_prefix": "exp_pretrain", + "args": { + "epochs": 1, + "batch_size": 32, + "learning_rate": 0.0005, + "dtype": "bfloat16", + "num_workers": 8, + "accumulation_steps": 8, + "grad_clip": 1.0, + "max_seq_len": 340, + "hidden_size": 512, + "num_hidden_layers": 8, + "use_moe": 0, + "use_compile": 0 + }, + "flags": [] + }, + "sft": { + "enabled": true, + "script": "trainer/train_full_sft.py", + "save_prefix": "exp_sft", + "from_weight_mode": "match_pretrain", + "from_weight": "pretrain", + "args": { + "epochs": 1, + "batch_size": 16, + "learning_rate": 1e-6, + "dtype": "bfloat16", + "num_workers": 8, + "accumulation_steps": 1, + "grad_clip": 1.0, + "max_seq_len": 340, + "hidden_size": 512, + "num_hidden_layers": 8, + "use_moe": 0, + "use_compile": 0 + }, + "flags": [] + }, + "growth": { + "init_active_ratio": 0.8, + "grow_interval": 100, + "grow_ratio": 0.02, + "max_active_ratio": 0.99, + "neuron_ema_beta": 0.1 + }, + "methods": ["baseline", "random", "grad", "act", "actgrad"], + "seeds": [42, 123, 2026], + "eval": { + "targets": ["pretrain", "sft"], + "max_seq_len": 340, + "batch_size": 8, + "max_samples": 0 + } +} diff --git a/eval/pipeline_config_overnight_3090.json b/eval/pipeline_config_overnight_3090.json new file mode 100644 index 0000000..1f65ad1 --- /dev/null +++ b/eval/pipeline_config_overnight_3090.json @@ -0,0 +1,107 @@ +{ + "stages": [ + "preflight", + "prepare", + "download", + "make_val", + "train_pretrain", + "train_sft", + "eval_ppl", + "eval_prompts", + "aggregate", + "plot", + "cleanup" + ], + "dataset": { + "repo": "https://www.modelscope.cn/datasets/gongjy/minimind_dataset.git", + "target_dir": "minimind_dataset", + "download_mode": "selective", + "stage_files": { + "make_val": ["pretrain_hq.jsonl"], + "train_pretrain": ["pretrain_hq.jsonl"], + "train_sft": ["sft_512.jsonl"], + "eval_ppl": ["pretrain_hq.jsonl"] + } + }, + "paths": { + "pretrain_data": "minimind_dataset/pretrain_hq.jsonl", + "sft_data": "minimind_dataset/sft_512.jsonl", + "val_data": "eval/val_pretrain.jsonl", + "prompts": "eval/prompts_minimal.jsonl", + "eval_config": "eval/eval_config.json", + "out_dir": "out", + "eval_runs": "eval_runs", + "eval_runs_ppl": "eval_runs_ppl", + "summary_csv": "eval/summary_ppl.csv", + "plot_png": "eval/plot_ppl.png", + "log_dir": "logs" + }, + "cache": { + "hf_home": ".cache/hf_home", + "hf_datasets": ".cache/hf_datasets", + "transformers": ".cache/hf_transformers" + }, + "runtime": { + "continue_on_error": true, + "min_free_gb": 6, + "retry_download": 2, + "cleanup_after_stage": ["train_pretrain", "train_sft", "eval_ppl", "eval_prompts"], + "cleanup_paths": [".cache/hf_datasets", ".cache/hf_transformers", ".cache/hf_home", "__pycache__"] + }, + "pretrain": { + "script": "trainer/train_pretrain.py", + "save_prefix": "exp_pretrain", + "args": { + "epochs": 1, + "batch_size": 64, + "learning_rate": 0.0005, + "dtype": "float16", + "num_workers": 4, + "accumulation_steps": 4, + "grad_clip": 1.0, + "max_seq_len": 512, + "hidden_size": 512, + "num_hidden_layers": 8, + "use_moe": 0, + "use_compile": 0 + }, + "flags": [] + }, + "sft": { + "enabled": true, + "script": "trainer/train_full_sft.py", + "save_prefix": "exp_sft", + "from_weight_mode": "match_pretrain", + "from_weight": "pretrain", + "args": { + "epochs": 1, + "batch_size": 16, + "learning_rate": 1e-6, + "dtype": "float16", + "num_workers": 4, + "accumulation_steps": 2, + "grad_clip": 1.0, + "max_seq_len": 512, + "hidden_size": 512, + "num_hidden_layers": 8, + "use_moe": 0, + "use_compile": 0 + }, + "flags": [] + }, + "growth": { + "init_active_ratio": 0.8, + "grow_interval": 100, + "grow_ratio": 0.02, + "max_active_ratio": 0.99, + "neuron_ema_beta": 0.1 + }, + "methods": ["baseline", "random", "grad", "act", "actgrad"], + "seeds": [42, 123, 2026], + "eval": { + "targets": ["pretrain", "sft"], + "max_seq_len": 512, + "batch_size": 8, + "max_samples": 0 + } +} diff --git a/eval/prompts_minimal.jsonl b/eval/prompts_minimal.jsonl new file mode 100644 index 0000000..be7bfa7 --- /dev/null +++ b/eval/prompts_minimal.jsonl @@ -0,0 +1,20 @@ +{"prompt": "你是谁?请用一句话介绍自己。"} +{"prompt": "解释一下什么是梯度下降。"} +{"prompt": "写一个Python函数计算斐波那契数列。"} +{"prompt": "用通俗语言解释注意力机制(attention)。"} +{"prompt": "比较一下猫和狗作为宠物的优缺点。"} +{"prompt": "如果明天下雨,我应该如何出门?"} +{"prompt": "总结一下牛顿三大定律。"} +{"prompt": "给我一个三点清单,说明如何高效学习。"} +{"prompt": "用一句话解释什么是过拟合。"} +{"prompt": "请把‘大模型训练需要大量数据’翻译成英文。"} +{"prompt": "写一段关于春天的短诗(四行以内)。"} +{"prompt": "请解释为什么天空是蓝色的。"} +{"prompt": "简述 Transformer 的主要组成部分。"} +{"prompt": "写一个SQL查询示例:统计用户表中每个城市的人数。"} +{"prompt": "用三句话介绍机器学习的应用场景。"} +{"prompt": "给出一个常见的排序算法,并说明时间复杂度。"} +{"prompt": "什么是大语言模型的token?"} +{"prompt": "请用Python写一个判断素数的函数。"} +{"prompt": "解释一下什么是强化学习。"} +{"prompt": "用简短例子说明‘因果关系’与‘相关性’的区别。"} diff --git a/eval/results_template.csv b/eval/results_template.csv new file mode 100644 index 0000000..2d6b75a --- /dev/null +++ b/eval/results_template.csv @@ -0,0 +1 @@ +run_name,model_size,data_version,steps,learning_rate,batch_size,seq_len,growth_method,init_ratio,grow_interval,grow_ratio,max_ratio,seed,ppl,notes diff --git a/model/model_minimind.py b/model/model_minimind.py index b3910a8..312ba36 100755 --- a/model/model_minimind.py +++ b/model/model_minimind.py @@ -224,9 +224,39 @@ class FeedForward(nn.Module): self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) self.dropout = nn.Dropout(config.dropout) self.act_fn = ACT2FN[config.hidden_act] + # --- 动态神经元生长相关(默认关闭,不影响原行为) --- + # mask=1 表示该神经元激活;mask=0 表示该神经元屏蔽 + self.register_buffer("mask", torch.ones(config.intermediate_size), persistent=True) + # 记录神经元活动的 EMA(用于活动驱动的生长策略) + self.register_buffer("ema_act", torch.zeros(config.intermediate_size), persistent=True) + # 下面是控制开关(由训练脚本设置) + self.track_activity = False # 是否统计活动 + self.track_mask_grad = False # 是否追踪 mask 的梯度 + self.ema_beta = 0.1 # EMA 衰减系数(可被训练脚本覆盖) + self._mask_proxy = None # 临时存放可求梯度的 mask def forward(self, x): - return self.dropout(self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))) + # 1) 计算 FFN 中间激活(未加 mask) + h = self.act_fn(self.gate_proj(x)) * self.up_proj(x) + + # 2) 统计活动强度(可选) + if self.training and self.track_activity: + with torch.no_grad(): + # 对 batch 和序列维度做平均,得到每个神经元的活动强度 + act = h.abs().mean(dim=(0, 1)) + self.ema_act.mul_(1 - self.ema_beta).add_(self.ema_beta * act) + + # 3) 应用神经元 mask(可选追踪梯度) + mask = self.mask + if self.training and self.track_mask_grad: + # 注意:用 proxy 承接梯度,避免把 buffer 变成参数 + mask = self.mask.detach().clone().requires_grad_(True) + self._mask_proxy = mask + else: + self._mask_proxy = None + + h = h * mask # 广播到 (batch, seq_len, intermediate_size) + return self.dropout(self.down_proj(h)) class MoEGate(nn.Module): diff --git a/scripts/aggregate_eval.py b/scripts/aggregate_eval.py new file mode 100644 index 0000000..e499df3 --- /dev/null +++ b/scripts/aggregate_eval.py @@ -0,0 +1,99 @@ +import os +import json +import argparse +import csv + + +def collect_json_files(path): + files = [] + if os.path.isdir(path): + for name in sorted(os.listdir(path)): + if name.endswith(".json"): + files.append(os.path.join(path, name)) + elif os.path.isfile(path): + files.append(path) + return files + + +def load_json(path): + with open(path, "r", encoding="utf-8") as f: + return json.load(f) + + +def main(): + parser = argparse.ArgumentParser(description="汇总 PPL 评测结果为 CSV") + parser.add_argument("--inputs", required=True, type=str, help="JSON 文件或目录(eval_ppl 输出)") + parser.add_argument("--out", default="eval/summary_ppl.csv", type=str, help="输出 CSV 路径") + parser.add_argument("--group_by", default="", type=str, help="按某字段分组(可选)") + parser.add_argument("--out_grouped", default="", type=str, help="分组汇总 CSV 输出路径(可选)") + args = parser.parse_args() + + files = collect_json_files(args.inputs) + if not files: + raise ValueError(f"No json files found in {args.inputs}") + + rows = [] + for path in files: + data = load_json(path) + data["file"] = os.path.basename(path) + rows.append(data) + + os.makedirs(os.path.dirname(args.out), exist_ok=True) + # 写 CSV + keys = sorted({k for r in rows for k in r.keys()}) + with open(args.out, "w", newline="", encoding="utf-8") as f: + writer = csv.DictWriter(f, fieldnames=keys) + writer.writeheader() + for r in rows: + writer.writerow(r) + + print(f"[OK] Saved {len(rows)} rows to {args.out}") + + if args.group_by: + grouped = {} + for r in rows: + key = r.get(args.group_by) + if key is None: + continue + grouped.setdefault(key, []).append(r) + + def to_float(x): + try: + return float(x) + except Exception: + return None + + out_grouped = args.out_grouped or os.path.splitext(args.out)[0] + f"_grouped_{args.group_by}.csv" + # 收集数值列 + numeric_cols = set() + for r in rows: + for k, v in r.items(): + if isinstance(v, (int, float)) or to_float(v) is not None: + numeric_cols.add(k) + + base_cols = sorted(numeric_cols) + std_cols = [f"{c}_std" for c in base_cols] + fieldnames = [args.group_by, "count"] + base_cols + std_cols + with open(out_grouped, "w", newline="", encoding="utf-8") as f: + writer = csv.DictWriter(f, fieldnames=fieldnames) + writer.writeheader() + for key, rs in grouped.items(): + row = {args.group_by: key, "count": len(rs)} + for col in numeric_cols: + vals = [] + for r in rs: + v = r.get(col) + v = to_float(v) + if v is not None: + vals.append(v) + if vals: + mean = sum(vals) / len(vals) + var = sum((x - mean) ** 2 for x in vals) / len(vals) + row[col] = round(mean, 6) + row[f"{col}_std"] = round(var ** 0.5, 6) + writer.writerow(row) + print(f"[OK] Saved grouped summary to {out_grouped}") + + +if __name__ == "__main__": + main() diff --git a/scripts/eval_fixed_prompts.py b/scripts/eval_fixed_prompts.py new file mode 100644 index 0000000..311e7ea --- /dev/null +++ b/scripts/eval_fixed_prompts.py @@ -0,0 +1,140 @@ +import os +import json +import time +import argparse +import torch +from transformers import AutoTokenizer +from model.model_minimind import MiniMindConfig, MiniMindForCausalLM +from trainer.trainer_utils import setup_seed, get_model_params + + +def load_prompts(path): + prompts = [] + with open(path, "r", encoding="utf-8") as f: + for line in f: + if not line.strip(): + continue + item = json.loads(line) + prompt = item.get("prompt") or item.get("text") + if prompt: + prompts.append(prompt) + return prompts + + +def init_model(args): + tokenizer = AutoTokenizer.from_pretrained(args.load_from) + if "model" in args.load_from: + model = MiniMindForCausalLM(MiniMindConfig( + hidden_size=args.hidden_size, + num_hidden_layers=args.num_hidden_layers, + use_moe=bool(args.use_moe), + inference_rope_scaling=args.inference_rope_scaling + )) + moe_suffix = '_moe' if args.use_moe else '' + ckp = f'./{args.save_dir}/{args.weight}_{args.hidden_size}{moe_suffix}.pth' + state = torch.load(ckp, map_location=args.device) + missing, unexpected = model.load_state_dict(state, strict=False) + if missing: + print(f"[WARN] Missing keys: {missing}") + if unexpected: + print(f"[WARN] Unexpected keys: {unexpected}") + else: + from transformers import AutoModelForCausalLM + model = AutoModelForCausalLM.from_pretrained(args.load_from, trust_remote_code=True) + + get_model_params(model, model.config) + return model.eval().to(args.device), tokenizer + + +def main(): + parser = argparse.ArgumentParser(description="MiniMind 固定 Prompt 评测") + parser.add_argument('--load_from', default='model', type=str, help="模型加载路径(model=原生torch权重)") + parser.add_argument('--save_dir', default='out', type=str, help="模型权重目录") + parser.add_argument('--weight', default='full_sft', type=str, help="权重名称前缀") + parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度") + parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量") + parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构") + parser.add_argument('--inference_rope_scaling', default=False, action='store_true', help="启用RoPE位置编码外推") + parser.add_argument('--max_new_tokens', default=512, type=int, help="最大生成长度") + parser.add_argument('--temperature', default=0.7, type=float, help="生成温度") + parser.add_argument('--top_p', default=0.9, type=float, help="top_p 采样") + parser.add_argument('--do_sample', default=0, type=int, choices=[0, 1], help="是否采样(0=贪婪,1=采样)") + parser.add_argument('--seed', default=2026, type=int, help="随机种子") + parser.add_argument('--device', default='cuda' if torch.cuda.is_available() else 'cpu', type=str, help="运行设备") + parser.add_argument('--prompts_file', default='eval/prompts_minimal.jsonl', type=str, help="prompt 文件路径") + parser.add_argument('--out_dir', default='eval_runs', type=str, help="结果保存目录") + parser.add_argument('--run_name', default='', type=str, help="本次评测名称") + parser.add_argument('--use_chat', default=-1, type=int, choices=[-1, 0, 1], help="是否使用chat模板(-1=自动)") + parser.add_argument('--config', default='', type=str, help="评测配置文件(JSON,可覆盖生成参数)") + args = parser.parse_args() + + # 可选:从配置文件覆盖评测参数 + if args.config: + with open(args.config, "r", encoding="utf-8") as f: + cfg = json.load(f) + for k, v in cfg.items(): + if hasattr(args, k): + setattr(args, k, v) + + os.makedirs(args.out_dir, exist_ok=True) + prompts = load_prompts(args.prompts_file) + if not prompts: + raise ValueError(f"No prompts found in {args.prompts_file}") + + if args.use_chat == -1: + use_chat = (args.weight != 'pretrain') + else: + use_chat = bool(args.use_chat) + + setup_seed(args.seed) + model, tokenizer = init_model(args) + + timestamp = time.strftime("%Y%m%d_%H%M%S") + run_name = args.run_name or f"{args.weight}_{timestamp}" + out_path = os.path.join(args.out_dir, f"{run_name}.jsonl") + + print(f"[Eval] prompts: {len(prompts)} | use_chat={use_chat} | out={out_path}") + + with open(out_path, "w", encoding="utf-8") as f: + for i, prompt in enumerate(prompts): + setup_seed(args.seed + i) + if use_chat: + messages = [{"role": "user", "content": prompt}] + text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + else: + text = tokenizer.bos_token + prompt + + inputs = tokenizer(text, return_tensors="pt", truncation=True).to(args.device) + start = time.time() + with torch.no_grad(): + gen_ids = model.generate( + inputs=inputs["input_ids"], + attention_mask=inputs.get("attention_mask"), + max_new_tokens=args.max_new_tokens, + do_sample=bool(args.do_sample), + temperature=args.temperature, + top_p=args.top_p, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + repetition_penalty=1.0 + ) + elapsed = time.time() - start + gen_tokens = gen_ids.shape[-1] - inputs["input_ids"].shape[-1] + response = tokenizer.decode(gen_ids[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True) + + record = { + "id": i, + "prompt": prompt, + "response": response, + "gen_tokens": gen_tokens, + "time_sec": round(elapsed, 4), + "tokens_per_sec": round(gen_tokens / max(elapsed, 1e-6), 2) + } + f.write(json.dumps(record, ensure_ascii=False) + "\n") + print(f"[{i+1}/{len(prompts)}] tokens={gen_tokens} time={elapsed:.2f}s") + + print(f"[Done] saved to {out_path}") + + +if __name__ == "__main__": + main() diff --git a/scripts/eval_ppl.py b/scripts/eval_ppl.py new file mode 100644 index 0000000..1b2a315 --- /dev/null +++ b/scripts/eval_ppl.py @@ -0,0 +1,119 @@ +import os +import math +import time +import argparse +import json +import torch +import torch.nn.functional as F +from torch.utils.data import DataLoader +from transformers import AutoTokenizer +from model.model_minimind import MiniMindConfig, MiniMindForCausalLM +from dataset.lm_dataset import PretrainDataset +from trainer.trainer_utils import get_model_params + + +def init_model(args): + tokenizer = AutoTokenizer.from_pretrained(args.load_from) + if "model" in args.load_from: + model = MiniMindForCausalLM(MiniMindConfig( + hidden_size=args.hidden_size, + num_hidden_layers=args.num_hidden_layers, + use_moe=bool(args.use_moe), + inference_rope_scaling=args.inference_rope_scaling + )) + moe_suffix = '_moe' if args.use_moe else '' + ckp = f'./{args.save_dir}/{args.weight}_{args.hidden_size}{moe_suffix}.pth' + state = torch.load(ckp, map_location=args.device) + model.load_state_dict(state, strict=False) + else: + from transformers import AutoModelForCausalLM + model = AutoModelForCausalLM.from_pretrained(args.load_from, trust_remote_code=True) + get_model_params(model, model.config) + return model.eval().to(args.device), tokenizer + + +def main(): + parser = argparse.ArgumentParser(description="MiniMind PPL 评测") + parser.add_argument('--load_from', default='model', type=str, help="模型加载路径(model=原生torch权重)") + parser.add_argument('--save_dir', default='out', type=str, help="模型权重目录") + parser.add_argument('--weight', default='pretrain', type=str, help="权重名称前缀") + parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度") + parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量") + parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构") + parser.add_argument('--inference_rope_scaling', default=False, action='store_true', help="启用RoPE位置编码外推") + parser.add_argument('--device', default='cuda' if torch.cuda.is_available() else 'cpu', type=str, help="运行设备") + parser.add_argument('--data_path', default='dataset/pretrain_hq.jsonl', type=str, help="评测数据路径") + parser.add_argument('--max_seq_len', default=340, type=int, help="最大序列长度") + parser.add_argument('--batch_size', default=8, type=int, help="batch size") + parser.add_argument('--num_workers', default=4, type=int, help="dataloader workers") + parser.add_argument('--max_samples', default=0, type=int, help="最多评测样本数(0=全量)") + parser.add_argument('--out_path', default='', type=str, help="保存评测结果的 JSON 文件") + parser.add_argument('--method', default='', type=str, help="方法名称(用于汇总统计)") + args = parser.parse_args() + + model, tokenizer = init_model(args) + ds = PretrainDataset(args.data_path, tokenizer, max_length=args.max_seq_len) + if args.max_samples and args.max_samples > 0: + ds.samples = ds.samples.select(range(min(args.max_samples, len(ds.samples)))) + + loader = DataLoader(ds, batch_size=args.batch_size, num_workers=args.num_workers) + + total_loss = 0.0 + total_tokens = 0 + total_steps = 0 + start = time.time() + + with torch.no_grad(): + for input_ids, labels in loader: + input_ids = input_ids.to(args.device) + labels = labels.to(args.device) + + outputs = model(input_ids) + logits = outputs.logits + shift_logits = logits[:, :-1, :].contiguous() + shift_labels = labels[:, 1:].contiguous() + + loss = F.cross_entropy( + shift_logits.view(-1, shift_logits.size(-1)), + shift_labels.view(-1), + ignore_index=-100, + reduction='sum' + ) + token_count = (shift_labels != -100).sum().item() + total_loss += loss.item() + total_tokens += token_count + total_steps += 1 + + avg_loss = total_loss / max(total_tokens, 1) + ppl = math.exp(avg_loss) + elapsed = time.time() - start + tokens_per_sec = total_tokens / max(elapsed, 1e-6) + + print(f"[PPL] loss={avg_loss:.4f} ppl={ppl:.4f}") + print(f"[Info] tokens={total_tokens} steps={total_steps} time={elapsed:.2f}s tokens/s={tokens_per_sec:.2f}") + + if args.out_path: + os.makedirs(os.path.dirname(args.out_path), exist_ok=True) + result = { + "loss": avg_loss, + "ppl": ppl, + "tokens": total_tokens, + "steps": total_steps, + "time_sec": elapsed, + "tokens_per_sec": tokens_per_sec, + "data_path": args.data_path, + "weight": args.weight, + "hidden_size": args.hidden_size, + "num_hidden_layers": args.num_hidden_layers, + "use_moe": args.use_moe, + "max_seq_len": args.max_seq_len, + "batch_size": args.batch_size + } + if args.method: + result["method"] = args.method + with open(args.out_path, "w", encoding="utf-8") as f: + json.dump(result, f, ensure_ascii=False, indent=2) + + +if __name__ == '__main__': + main() diff --git a/scripts/make_val_split.py b/scripts/make_val_split.py new file mode 100644 index 0000000..694ff10 --- /dev/null +++ b/scripts/make_val_split.py @@ -0,0 +1,63 @@ +import os +import json +import argparse +import random + + +def count_lines(path): + with open(path, "r", encoding="utf-8") as f: + return sum(1 for _ in f) + + +def reservoir_sample(path, k, seed=42): + random.seed(seed) + reservoir = [] + with open(path, "r", encoding="utf-8") as f: + for i, line in enumerate(f): + if i < k: + reservoir.append(line) + else: + j = random.randint(0, i) + if j < k: + reservoir[j] = line + return reservoir + + +def main(): + parser = argparse.ArgumentParser(description="生成固定验证集(JSONL)") + parser.add_argument("--data_path", required=True, type=str, help="原始 jsonl 数据路径") + parser.add_argument("--out_path", default="eval/val_pretrain.jsonl", type=str, help="输出验证集路径") + parser.add_argument("--val_size", default=2000, type=int, help="验证集样本数(优先)") + parser.add_argument("--val_ratio", default=0.0, type=float, help="验证集比例(若 val_size=0 则使用)") + parser.add_argument("--seed", default=42, type=int, help="随机种子") + args = parser.parse_args() + + if not os.path.exists(args.data_path): + raise FileNotFoundError(f"data_path not found: {args.data_path}") + + if args.val_size <= 0: + if args.val_ratio <= 0: + raise ValueError("val_size <=0 时必须提供 val_ratio > 0") + total = count_lines(args.data_path) + args.val_size = max(1, int(total * args.val_ratio)) + + os.makedirs(os.path.dirname(args.out_path), exist_ok=True) + samples = reservoir_sample(args.data_path, args.val_size, seed=args.seed) + + with open(args.out_path, "w", encoding="utf-8") as f: + for line in samples: + line = line.strip() + if not line: + continue + # 简单校验 JSONL 格式 + try: + _ = json.loads(line) + except Exception: + continue + f.write(line + "\n") + + print(f"[OK] Saved {len(samples)} lines to {args.out_path}") + + +if __name__ == "__main__": + main() diff --git a/scripts/plot_growth.py b/scripts/plot_growth.py new file mode 100644 index 0000000..3b3a0b8 --- /dev/null +++ b/scripts/plot_growth.py @@ -0,0 +1,47 @@ +import os +import csv +import argparse +import matplotlib.pyplot as plt + + +def main(): + parser = argparse.ArgumentParser(description="绘制动态生长对照图") + parser.add_argument("--csv", default="eval/summary_ppl.csv", type=str, help="CSV 路径") + parser.add_argument("--x", default="weight", type=str, help="x 轴列名") + parser.add_argument("--y", default="ppl", type=str, help="y 轴列名") + parser.add_argument("--out", default="eval/plot_ppl.png", type=str, help="输出图片") + args = parser.parse_args() + + if not os.path.exists(args.csv): + raise FileNotFoundError(args.csv) + + x_vals = [] + y_vals = [] + with open(args.csv, "r", encoding="utf-8") as f: + reader = csv.DictReader(f) + for row in reader: + if args.x not in row or args.y not in row: + continue + x_vals.append(str(row[args.x])) + try: + y_vals.append(float(row[args.y])) + except Exception: + y_vals.append(float("nan")) + + if not x_vals: + raise ValueError("No valid rows found in CSV") + + plt.figure(figsize=(8, 4)) + plt.bar(x_vals, y_vals) + plt.xlabel(args.x) + plt.ylabel(args.y) + plt.xticks(rotation=45, ha='right') + plt.tight_layout() + + os.makedirs(os.path.dirname(args.out), exist_ok=True) + plt.savefig(args.out, dpi=200) + print(f"[OK] Saved plot to {args.out}") + + +if __name__ == "__main__": + main() diff --git a/scripts/run_growth_sweep.py b/scripts/run_growth_sweep.py new file mode 100644 index 0000000..94b0da6 --- /dev/null +++ b/scripts/run_growth_sweep.py @@ -0,0 +1,41 @@ +import argparse +import shlex +import subprocess + + +def build_commands(script, prefix, base_args, seeds): + runs = [ + ("baseline", "--neuron_growth 0"), + ("random", "--neuron_growth 1 --grow_method random"), + ("grad", "--neuron_growth 1 --grow_method act_grad --grow_score_alpha 0.0 --grow_score_beta 1.0"), + ("act", "--neuron_growth 1 --grow_method act_grad --grow_score_alpha 1.0 --grow_score_beta 0.0"), + ("actgrad", "--neuron_growth 1 --grow_method act_grad --grow_score_alpha 1.0 --grow_score_beta 1.0"), + ] + cmds = [] + for seed in seeds: + for name, extra in runs: + save_weight = f"{prefix}_{name}_s{seed}" + cmd = f"python {script} --save_weight {save_weight} --seed {seed} {base_args} {extra}".strip() + cmds.append(cmd) + return cmds + + +def main(): + parser = argparse.ArgumentParser(description="批量运行动态神经元生长对照实验") + parser.add_argument("--script", default="trainer/train_pretrain.py", type=str, help="训练脚本路径") + parser.add_argument("--prefix", default="exp", type=str, help="save_weight 前缀") + parser.add_argument("--base_args", default="", type=str, help="统一附加参数") + parser.add_argument("--seeds", default="42", type=str, help="随机种子列表(逗号分隔)") + parser.add_argument("--run", action="store_true", help="实际执行(不加则只打印命令)") + args = parser.parse_args() + + seeds = [int(s) for s in args.seeds.split(",") if s.strip()] + cmds = build_commands(args.script, args.prefix, args.base_args, seeds) + for c in cmds: + print(c) + if args.run: + subprocess.run(shlex.split(c), check=True) + + +if __name__ == "__main__": + main() diff --git a/scripts/run_overnight_pipeline.py b/scripts/run_overnight_pipeline.py new file mode 100644 index 0000000..dab4c23 --- /dev/null +++ b/scripts/run_overnight_pipeline.py @@ -0,0 +1,518 @@ +import os +import sys +import json +import time +import shlex +import shutil +import argparse +import subprocess + +ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +PYTHON = sys.executable + + +def load_config(path): + with open(path, "r", encoding="utf-8") as f: + return json.load(f) + + +def resolve_path(path): + if not path: + return path + return path if os.path.isabs(path) else os.path.join(ROOT, path) + + +def ensure_dir(path): + if not path: + return + os.makedirs(path, exist_ok=True) + + +def now_ts(): + return time.strftime("%Y%m%d_%H%M%S") + + +def dict_to_args(d, flags=None): + flags = set(flags or []) + args = [] + for k, v in d.items(): + key = f"--{k}" + if isinstance(v, bool): + if k in flags: + if v: + args.append(key) + else: + args.extend([key, str(int(v))]) + elif v is None: + continue + else: + args.extend([key, str(v)]) + return args + + +def build_method_args(method, growth_cfg): + if method == "baseline": + return {"neuron_growth": 0} + if method == "random": + return {"neuron_growth": 1, "grow_method": "random", **growth_cfg} + if method == "grad": + return { + "neuron_growth": 1, + "grow_method": "act_grad", + "grow_score_alpha": 0.0, + "grow_score_beta": 1.0, + **growth_cfg + } + if method == "act": + return { + "neuron_growth": 1, + "grow_method": "act_grad", + "grow_score_alpha": 1.0, + "grow_score_beta": 0.0, + **growth_cfg + } + if method == "actgrad": + return { + "neuron_growth": 1, + "grow_method": "act_grad", + "grow_score_alpha": 1.0, + "grow_score_beta": 1.0, + **growth_cfg + } + raise ValueError(f"Unknown method: {method}") + + +def run_cmd(cmd, log_path, env=None, cwd=None, retries=0, dry_run=False): + cmd_str = " ".join(shlex.quote(c) for c in cmd) + print("[CMD]", cmd_str) + if dry_run: + return True + for attempt in range(retries + 1): + stamp = time.strftime("%Y-%m-%d %H:%M:%S") + with open(log_path, "a", encoding="utf-8") as log: + log.write(f"\n[{stamp}] $ {cmd_str}\n") + proc = subprocess.run(cmd, cwd=cwd or ROOT, env=env, stdout=log, stderr=log, text=True) + if proc.returncode == 0: + return True + if attempt < retries: + time.sleep(5) + return False + + +def run_capture(cmd): + try: + proc = subprocess.run(cmd, cwd=ROOT, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True) + return proc.returncode == 0, proc.stdout.strip() + except FileNotFoundError: + return False, "" + except Exception as e: + return False, str(e) + + +def disk_free_gb(path): + usage = shutil.disk_usage(path) + return usage.free / (1024 ** 3) + + +def safe_remove(path): + abs_path = resolve_path(path) + if not abs_path: + return + if not abs_path.startswith(ROOT + os.sep): + print(f"[WARN] Skip remove outside repo: {abs_path}") + return + if os.path.isdir(abs_path): + shutil.rmtree(abs_path, ignore_errors=True) + elif os.path.isfile(abs_path): + try: + os.remove(abs_path) + except OSError: + pass + + +def cleanup_paths(paths): + for p in paths: + if p == "__pycache__": + for root, dirs, _ in os.walk(ROOT): + for d in dirs: + if d == "__pycache__": + safe_remove(os.path.join(root, d)) + else: + safe_remove(p) + + +def is_lfs_pointer(path): + try: + with open(path, "r", encoding="utf-8") as f: + head = f.readline().strip() + return head.startswith("version https://git-lfs.github.com/spec/v1") + except Exception: + return False + + +def preflight_checks(runtime, log_path): + checks = [] + ok_git, out_git = run_capture(["git", "--version"]) + checks.append(("git", ok_git, out_git or "not found")) + ok_lfs, out_lfs = run_capture(["git", "lfs", "version"]) + checks.append(("git_lfs", ok_lfs, out_lfs or "not found")) + + ok_smi, out_smi = run_capture([ + "nvidia-smi", + "--query-gpu=name,memory.total,memory.free,utilization.gpu", + "--format=csv,noheader,nounits" + ]) + checks.append(("nvidia_smi", ok_smi, out_smi or "not found")) + + min_free = runtime.get("min_free_gb", 0) + free_gb = disk_free_gb(ROOT) + disk_ok = free_gb >= min_free if min_free else True + checks.append(("disk_free_gb", disk_ok, f"{free_gb:.1f}GB (min {min_free}GB)")) + + # 写入日志 + stamp = time.strftime("%Y-%m-%d %H:%M:%S") + with open(log_path, "a", encoding="utf-8") as log: + log.write(f"\n[{stamp}] [PREFLIGHT]\n") + for name, ok, info in checks: + log.write(f"{name}: {'OK' if ok else 'FAIL'} | {info}\n") + return checks + + +def ensure_dataset_repo(dataset_cfg, log_path, env, dry_run, retries=0): + target = resolve_path(dataset_cfg.get("target_dir", "minimind_dataset")) + if os.path.isdir(target): + return True + repo = dataset_cfg.get("repo", "") + if not repo: + print("[WARN] dataset.repo not set") + return False + ok = run_cmd(["git", "lfs", "install"], log_path, env=env, dry_run=dry_run, retries=retries) + ok = run_cmd(["git", "clone", repo, target], log_path, env=env, dry_run=dry_run, retries=retries) and ok + return ok + + +def lfs_pull(dataset_cfg, files, log_path, env, dry_run, retries=0): + target = resolve_path(dataset_cfg.get("target_dir", "minimind_dataset")) + if not os.path.isdir(target): + return False + mode = dataset_cfg.get("download_mode", "selective") + if mode == "full": + return run_cmd(["git", "lfs", "pull"], log_path, env=env, cwd=target, dry_run=dry_run, retries=retries) + files = [f for f in files if f] + if not files: + return True + include = ",".join(files) + return run_cmd(["git", "lfs", "pull", "--include", include], log_path, env=env, cwd=target, dry_run=dry_run, retries=retries) + + +def ensure_dataset_files(dataset_cfg, stage, log_path, env, dry_run, retries=0): + if not ensure_dataset_repo(dataset_cfg, log_path, env, dry_run, retries=retries): + return False + stage_files = dataset_cfg.get("stage_files", {}).get(stage, []) + if not stage_files: + return True + # 缺失就尝试拉取 + missing = [] + for name in stage_files: + file_path = resolve_path(os.path.join(dataset_cfg.get("target_dir", "minimind_dataset"), name)) + if not os.path.exists(file_path) or is_lfs_pointer(file_path): + missing.append(name) + if not missing: + return True + ok = lfs_pull(dataset_cfg, missing, log_path, env, dry_run, retries=retries) + if not ok: + return False + # 再次检查 + for name in missing: + file_path = resolve_path(os.path.join(dataset_cfg.get("target_dir", "minimind_dataset"), name)) + if not os.path.exists(file_path) or is_lfs_pointer(file_path): + return False + return True + + +def weight_path(prefix, hidden_size, use_moe, save_dir): + moe_suffix = "_moe" if use_moe else "" + return os.path.join(resolve_path(save_dir), f"{prefix}_{hidden_size}{moe_suffix}.pth") + + +def main(): + parser = argparse.ArgumentParser(description="MiniMind 过夜一键训练/评测/汇总/绘图脚本(强纠错)") + parser.add_argument("--config", default="eval/pipeline_config_overnight.json", type=str, help="配置文件路径") + parser.add_argument("--stages", default="", type=str, help="仅运行的阶段(逗号分隔)") + parser.add_argument("--methods", default="", type=str, help="方法列表(逗号分隔)") + parser.add_argument("--seeds", default="", type=str, help="随机种子列表(逗号分隔)") + parser.add_argument("--dry_run", action="store_true", help="只打印命令,不执行") + parser.add_argument("--stop_on_error", action="store_true", help="遇到错误就停止") + args = parser.parse_args() + + cfg = load_config(resolve_path(args.config)) + runtime = cfg.get("runtime", {}) + retry_download = int(runtime.get("retry_download", 0)) + dataset_cfg = cfg.get("dataset", {}) + paths = cfg.get("paths", {}) + growth_cfg = cfg.get("growth", {}) + cache_cfg = cfg.get("cache", {}) + + stages = cfg.get("stages", []) + if args.stages: + stages = [s.strip() for s in args.stages.split(",") if s.strip()] + + methods = cfg.get("methods", ["baseline"]) + if args.methods: + methods = [m.strip() for m in args.methods.split(",") if m.strip()] + + seeds = cfg.get("seeds", [42]) + if args.seeds: + seeds = [int(s) for s in args.seeds.split(",") if s.strip()] + + log_dir = resolve_path(paths.get("log_dir", "logs")) + ensure_dir(log_dir) + log_path = os.path.join(log_dir, f"pipeline_{now_ts()}.log") + + # 统一缓存目录,方便清理 + env = os.environ.copy() + hf_home = resolve_path(cache_cfg.get("hf_home", ".cache/hf_home")) + hf_datasets = resolve_path(cache_cfg.get("hf_datasets", ".cache/hf_datasets")) + hf_transformers = resolve_path(cache_cfg.get("transformers", ".cache/hf_transformers")) + env["HF_HOME"] = hf_home + env["HF_DATASETS_CACHE"] = hf_datasets + env["TRANSFORMERS_CACHE"] = hf_transformers + env["HF_HUB_DISABLE_TELEMETRY"] = "1" + env["TOKENIZERS_PARALLELISM"] = "false" + + failures = [] + + def mark_fail(stage, msg): + failures.append(f"{stage}: {msg}") + print(f"[FAIL] {stage}: {msg}") + if args.stop_on_error or not runtime.get("continue_on_error", True): + raise RuntimeError(msg) + + def maybe_cleanup(stage_name): + cleanup_after = runtime.get("cleanup_after_stage", []) + if stage_name in cleanup_after: + cleanup_paths(runtime.get("cleanup_paths", [])) + + def disk_guard(stage_name): + min_free = runtime.get("min_free_gb", 0) + if not min_free: + return True + free_gb = disk_free_gb(ROOT) + if free_gb >= min_free: + return True + print(f"[WARN] Free disk {free_gb:.1f}GB < {min_free}GB, try cleanup") + cleanup_paths(runtime.get("cleanup_paths", [])) + free_gb = disk_free_gb(ROOT) + if free_gb < min_free: + mark_fail(stage_name, f"disk too low: {free_gb:.1f}GB") + return False + return True + + # ========== Stage: preflight ========== + if "preflight" in stages: + checks = preflight_checks(runtime, log_path) + for name, ok, info in checks: + status = "OK" if ok else "FAIL" + print(f"[PREFLIGHT] {name}: {status} | {info}") + if not ok: + mark_fail("preflight", f"{name} check failed") + maybe_cleanup("preflight") + + # ========== Stage: prepare ========== + if "prepare" in stages: + ensure_dir(resolve_path(paths.get("out_dir", "out"))) + ensure_dir(resolve_path(paths.get("eval_runs", "eval_runs"))) + ensure_dir(resolve_path(paths.get("eval_runs_ppl", "eval_runs_ppl"))) + ensure_dir(hf_home) + ensure_dir(hf_datasets) + ensure_dir(hf_transformers) + maybe_cleanup("prepare") + + # ========== Stage: download ========== + if "download" in stages: + if disk_guard("download"): + if not ensure_dataset_repo(dataset_cfg, log_path, env, args.dry_run, retries=retry_download): + mark_fail("download", "dataset repo init failed") + else: + stage_files = dataset_cfg.get("stage_files", {}) + union_files = [] + for s in stages: + union_files.extend(stage_files.get(s, [])) + ok = lfs_pull(dataset_cfg, sorted(set(union_files)), log_path, env, args.dry_run, retries=retry_download) + if not ok: + mark_fail("download", "git lfs pull failed") + maybe_cleanup("download") + + # ========== Stage: make_val ========== + if "make_val" in stages: + if disk_guard("make_val"): + if not ensure_dataset_files(dataset_cfg, "make_val", log_path, env, args.dry_run, retries=retry_download): + mark_fail("make_val", "missing pretrain data") + else: + cmd = [ + PYTHON, "scripts/make_val_split.py", + "--data_path", resolve_path(paths.get("pretrain_data", "")), + "--out_path", resolve_path(paths.get("val_data", "eval/val_pretrain.jsonl")), + "--val_size", "2000", + "--seed", "42" + ] + if not run_cmd(cmd, log_path, env=env, dry_run=args.dry_run): + mark_fail("make_val", "script failed") + maybe_cleanup("make_val") + + # ========== Stage: train_pretrain ========== + if "train_pretrain" in stages: + if disk_guard("train_pretrain"): + if not ensure_dataset_files(dataset_cfg, "train_pretrain", log_path, env, args.dry_run, retries=retry_download): + mark_fail("train_pretrain", "missing pretrain data") + else: + pre_cfg = cfg.get("pretrain", {}) + for seed in seeds: + for method in methods: + save_weight = f"{pre_cfg.get('save_prefix','pretrain')}_{method}_s{seed}" + base_args = dict_to_args(pre_cfg.get("args", {}), pre_cfg.get("flags", [])) + method_args = dict_to_args(build_method_args(method, growth_cfg)) + cmd = [ + PYTHON, pre_cfg.get("script", "trainer/train_pretrain.py"), + "--save_dir", resolve_path(paths.get("out_dir", "out")), + "--save_weight", save_weight, + "--data_path", resolve_path(paths.get("pretrain_data", "")), + "--seed", str(seed) + ] + base_args + method_args + if not run_cmd(cmd, log_path, env=env, retries=0, dry_run=args.dry_run): + mark_fail("train_pretrain", f"{save_weight} failed") + maybe_cleanup("train_pretrain") + + # ========== Stage: train_sft ========== + if "train_sft" in stages and cfg.get("sft", {}).get("enabled", False): + if disk_guard("train_sft"): + if not ensure_dataset_files(dataset_cfg, "train_sft", log_path, env, args.dry_run, retries=retry_download): + mark_fail("train_sft", "missing sft data") + else: + sft_cfg = cfg.get("sft", {}) + for seed in seeds: + for method in methods: + save_weight = f"{sft_cfg.get('save_prefix','sft')}_{method}_s{seed}" + base_args = dict_to_args(sft_cfg.get("args", {}), sft_cfg.get("flags", [])) + method_args = dict_to_args(build_method_args(method, growth_cfg)) + from_weight_mode = sft_cfg.get("from_weight_mode", "fixed") + if from_weight_mode == "match_pretrain": + from_weight = f"{cfg.get('pretrain', {}).get('save_prefix','pretrain')}_{method}_s{seed}" + else: + from_weight = sft_cfg.get("from_weight", "pretrain") + cmd = [ + PYTHON, sft_cfg.get("script", "trainer/train_full_sft.py"), + "--save_dir", resolve_path(paths.get("out_dir", "out")), + "--save_weight", save_weight, + "--data_path", resolve_path(paths.get("sft_data", "")), + "--from_weight", from_weight, + "--seed", str(seed) + ] + base_args + method_args + if not run_cmd(cmd, log_path, env=env, retries=0, dry_run=args.dry_run): + mark_fail("train_sft", f"{save_weight} failed") + maybe_cleanup("train_sft") + + # ========== Stage: eval_ppl ========== + if "eval_ppl" in stages: + if disk_guard("eval_ppl"): + if not ensure_dataset_files(dataset_cfg, "eval_ppl", log_path, env, args.dry_run, retries=retry_download): + mark_fail("eval_ppl", "missing val/pretrain data") + else: + eval_cfg = cfg.get("eval", {}) + targets = eval_cfg.get("targets") or [eval_cfg.get("target", "pretrain")] + for target in targets: + target_cfg = cfg.get(target, {}) + prefix = target_cfg.get("save_prefix", target) + use_moe = int(target_cfg.get("args", {}).get("use_moe", 0)) + hidden_size = int(target_cfg.get("args", {}).get("hidden_size", 512)) + for seed in seeds: + for method in methods: + weight = f"{prefix}_{method}_s{seed}" + if not os.path.exists(weight_path(weight, hidden_size, use_moe, paths.get("out_dir", "out"))): + mark_fail("eval_ppl", f"missing weight {weight}") + continue + out_dir = resolve_path(paths.get("eval_runs_ppl", "eval_runs_ppl")) + ensure_dir(out_dir) + out_path = os.path.join(out_dir, f"{weight}.json") + cmd = [ + PYTHON, "scripts/eval_ppl.py", + "--weight", weight, + "--save_dir", resolve_path(paths.get("out_dir", "out")), + "--data_path", resolve_path(paths.get("val_data", "")), + "--hidden_size", str(hidden_size), + "--max_seq_len", str(eval_cfg.get("max_seq_len", 340)), + "--batch_size", str(eval_cfg.get("batch_size", 8)), + "--max_samples", str(eval_cfg.get("max_samples", 0)), + "--method", method, + "--out_path", out_path + ] + if not run_cmd(cmd, log_path, env=env, dry_run=args.dry_run): + mark_fail("eval_ppl", f"{weight} failed") + maybe_cleanup("eval_ppl") + + # ========== Stage: eval_prompts ========== + if "eval_prompts" in stages: + if disk_guard("eval_prompts"): + prompts = resolve_path(paths.get("prompts", "")) + if not os.path.exists(prompts): + mark_fail("eval_prompts", "missing prompts file") + else: + eval_cfg = cfg.get("eval", {}) + targets = eval_cfg.get("targets") or [eval_cfg.get("target", "pretrain")] + for target in targets: + target_cfg = cfg.get(target, {}) + prefix = target_cfg.get("save_prefix", target) + use_moe = int(target_cfg.get("args", {}).get("use_moe", 0)) + hidden_size = int(target_cfg.get("args", {}).get("hidden_size", 512)) + for seed in seeds: + for method in methods: + weight = f"{prefix}_{method}_s{seed}" + if not os.path.exists(weight_path(weight, hidden_size, use_moe, paths.get("out_dir", "out"))): + mark_fail("eval_prompts", f"missing weight {weight}") + continue + cmd = [ + PYTHON, "scripts/eval_fixed_prompts.py", + "--weight", weight, + "--save_dir", resolve_path(paths.get("out_dir", "out")), + "--hidden_size", str(hidden_size), + "--prompts_file", prompts, + "--out_dir", resolve_path(paths.get("eval_runs", "eval_runs")), + "--config", resolve_path(paths.get("eval_config", "eval/eval_config.json")), + "--run_name", weight + ] + if not run_cmd(cmd, log_path, env=env, dry_run=args.dry_run): + mark_fail("eval_prompts", f"{weight} failed") + maybe_cleanup("eval_prompts") + + # ========== Stage: aggregate ========== + if "aggregate" in stages: + if disk_guard("aggregate"): + summary_csv = resolve_path(paths.get("summary_csv", "eval/summary_ppl.csv")) + cmd = [PYTHON, "scripts/aggregate_eval.py", "--inputs", resolve_path(paths.get("eval_runs_ppl", "eval_runs_ppl")), "--out", summary_csv] + if not run_cmd(cmd, log_path, env=env, dry_run=args.dry_run): + mark_fail("aggregate", "aggregate failed") + maybe_cleanup("aggregate") + + # ========== Stage: plot ========== + if "plot" in stages: + if disk_guard("plot"): + summary_csv = resolve_path(paths.get("summary_csv", "eval/summary_ppl.csv")) + plot_png = resolve_path(paths.get("plot_png", "eval/plot_ppl.png")) + cmd = [PYTHON, "scripts/plot_growth.py", "--csv", summary_csv, "--x", "weight", "--y", "ppl", "--out", plot_png] + if not run_cmd(cmd, log_path, env=env, dry_run=args.dry_run): + mark_fail("plot", "plot failed") + maybe_cleanup("plot") + + # ========== Stage: cleanup ========== + if "cleanup" in stages: + cleanup_paths(runtime.get("cleanup_paths", [])) + + summary = "\n".join(failures) if failures else "OK" + with open(log_path, "a", encoding="utf-8") as f: + f.write("\n[SUMMARY]\n") + f.write(summary + "\n") + print("[SUMMARY]", summary) + + +if __name__ == "__main__": + main() diff --git a/scripts/run_paper_pipeline.py b/scripts/run_paper_pipeline.py new file mode 100644 index 0000000..4cb93c7 --- /dev/null +++ b/scripts/run_paper_pipeline.py @@ -0,0 +1,231 @@ +import os +import sys +import json +import shlex +import argparse +import subprocess + +ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +PYTHON = sys.executable + + +def load_config(path): + with open(path, "r", encoding="utf-8") as f: + return json.load(f) + + +def dict_to_args(d, flags=None): + flags = set(flags or []) + args = [] + for k, v in d.items(): + key = f"--{k}" + if isinstance(v, bool): + if k in flags: + if v: + args.append(key) + else: + args.extend([key, str(int(v))]) + elif v is None: + continue + else: + args.extend([key, str(v)]) + return args + + +def run_cmd(cmd, run=False): + print("[CMD]", " ".join(cmd)) + if run: + subprocess.run(cmd, check=True, cwd=ROOT) + + +def ensure_exists(path, strict=False): + if os.path.exists(os.path.join(ROOT, path)): + return True + msg = f"[WARN] Missing path: {path}" + if strict: + raise FileNotFoundError(msg) + print(msg) + return False + + +def build_method_args(method, growth_cfg): + if method == "baseline": + return {"neuron_growth": 0} + if method == "random": + return { + "neuron_growth": 1, + "grow_method": "random", + **growth_cfg + } + if method == "grad": + return { + "neuron_growth": 1, + "grow_method": "act_grad", + "grow_score_alpha": 0.0, + "grow_score_beta": 1.0, + **growth_cfg + } + if method == "act": + return { + "neuron_growth": 1, + "grow_method": "act_grad", + "grow_score_alpha": 1.0, + "grow_score_beta": 0.0, + **growth_cfg + } + if method == "actgrad": + return { + "neuron_growth": 1, + "grow_method": "act_grad", + "grow_score_alpha": 1.0, + "grow_score_beta": 1.0, + **growth_cfg + } + raise ValueError(f"Unknown method: {method}") + + +def main(): + parser = argparse.ArgumentParser(description="MiniMind 论文实验一键流水线") + parser.add_argument("--config", default="eval/pipeline_config.json", type=str, help="配置文件路径") + parser.add_argument("--stages", default="", type=str, help="仅运行的阶段(逗号分隔)") + parser.add_argument("--methods", default="", type=str, help="方法列表(逗号分隔)") + parser.add_argument("--seeds", default="", type=str, help="随机种子列表(逗号分隔)") + parser.add_argument("--run", action="store_true", help="实际执行(不加则只打印命令)") + parser.add_argument("--strict", action="store_true", help="缺少数据路径时直接报错") + args = parser.parse_args() + + cfg = load_config(os.path.join(ROOT, args.config)) + + stages = cfg.get("stages", []) + if args.stages: + stages = [s.strip() for s in args.stages.split(",") if s.strip()] + + methods = cfg.get("methods", ["baseline"]) + if args.methods: + methods = [m.strip() for m in args.methods.split(",") if m.strip()] + + seeds = cfg.get("seeds", [42]) + if args.seeds: + seeds = [int(s) for s in args.seeds.split(",") if s.strip()] + + paths = cfg.get("paths", {}) + growth_cfg = cfg.get("growth", {}) + + # ========== Stage: make_val ========== + if "make_val" in stages: + if ensure_exists(paths.get("pretrain_data", ""), strict=args.strict): + cmd = [PYTHON, "scripts/make_val_split.py", + "--data_path", paths["pretrain_data"], + "--out_path", paths["val_data"], + "--val_size", "2000", + "--seed", "42"] + run_cmd(cmd, run=args.run) + else: + print("[SKIP] make_val (missing pretrain_data)") + + # ========== Stage: train_pretrain ========== + if "train_pretrain" in stages: + if ensure_exists(paths.get("pretrain_data", ""), strict=args.strict): + pre_cfg = cfg.get("pretrain", {}) + for seed in seeds: + for method in methods: + save_weight = f"{pre_cfg.get('save_prefix','pretrain')}_{method}_s{seed}" + base_args = dict_to_args(pre_cfg.get("args", {}), pre_cfg.get("flags", [])) + method_args = dict_to_args(build_method_args(method, growth_cfg)) + cmd = [PYTHON, pre_cfg.get("script", "trainer/train_pretrain.py"), + "--save_dir", paths.get("out_dir", "out"), + "--save_weight", save_weight, + "--data_path", paths.get("pretrain_data"), + "--seed", str(seed)] + base_args + method_args + run_cmd(cmd, run=args.run) + else: + print("[SKIP] train_pretrain (missing pretrain_data)") + + # ========== Stage: train_sft ========== + if "train_sft" in stages and cfg.get("sft", {}).get("enabled", False): + if ensure_exists(paths.get("sft_data", ""), strict=args.strict): + sft_cfg = cfg.get("sft", {}) + for seed in seeds: + for method in methods: + save_weight = f"{sft_cfg.get('save_prefix','sft')}_{method}_s{seed}" + base_args = dict_to_args(sft_cfg.get("args", {}), sft_cfg.get("flags", [])) + method_args = dict_to_args(build_method_args(method, growth_cfg)) + # from_weight 选择 + from_weight_mode = sft_cfg.get("from_weight_mode", "fixed") + if from_weight_mode == "match_pretrain": + from_weight = f"{cfg.get('pretrain', {}).get('save_prefix','pretrain')}_{method}_s{seed}" + else: + from_weight = sft_cfg.get("from_weight", "pretrain") + cmd = [PYTHON, sft_cfg.get("script", "trainer/train_full_sft.py"), + "--save_dir", paths.get("out_dir", "out"), + "--save_weight", save_weight, + "--data_path", paths.get("sft_data"), + "--from_weight", from_weight, + "--seed", str(seed)] + base_args + method_args + run_cmd(cmd, run=args.run) + else: + print("[SKIP] train_sft (missing sft_data)") + + # ========== Stage: eval_ppl ========== + if "eval_ppl" in stages: + val_path = paths.get("val_data", "") + if ensure_exists(val_path, strict=args.strict): + eval_cfg = cfg.get("eval", {}) + target = eval_cfg.get("target", "pretrain") + prefix = cfg.get(target, {}).get("save_prefix", target) + eval_runs_ppl = paths.get("eval_runs_ppl", "eval_runs_ppl") + for seed in seeds: + for method in methods: + weight = f"{prefix}_{method}_s{seed}" + out_path = os.path.join(eval_runs_ppl, f"{weight}.json") + cmd = [PYTHON, "scripts/eval_ppl.py", + "--weight", weight, + "--save_dir", paths.get("out_dir", "out"), + "--data_path", val_path, + "--max_seq_len", str(eval_cfg.get("max_seq_len", 340)), + "--batch_size", str(eval_cfg.get("batch_size", 8)), + "--max_samples", str(eval_cfg.get("max_samples", 0)), + "--method", method, + "--out_path", out_path] + run_cmd(cmd, run=args.run) + else: + print("[SKIP] eval_ppl (missing val_data)") + + # ========== Stage: eval_prompts ========== + if "eval_prompts" in stages: + prompts = paths.get("prompts", "") + if ensure_exists(prompts, strict=args.strict): + eval_cfg = cfg.get("eval", {}) + target = eval_cfg.get("target", "pretrain") + prefix = cfg.get(target, {}).get("save_prefix", target) + for seed in seeds: + for method in methods: + weight = f"{prefix}_{method}_s{seed}" + cmd = [PYTHON, "scripts/eval_fixed_prompts.py", + "--weight", weight, + "--save_dir", paths.get("out_dir", "out"), + "--prompts_file", prompts, + "--out_dir", paths.get("eval_runs", "eval_runs"), + "--config", paths.get("eval_config", "eval/eval_config.json"), + "--run_name", weight] + run_cmd(cmd, run=args.run) + else: + print("[SKIP] eval_prompts (missing prompts)") + + # ========== Stage: aggregate ========== + if "aggregate" in stages: + eval_runs_ppl = paths.get("eval_runs_ppl", "eval_runs_ppl") + summary_csv = paths.get("summary_csv", "eval/summary_ppl.csv") + cmd = [PYTHON, "scripts/aggregate_eval.py", "--inputs", eval_runs_ppl, "--out", summary_csv] + run_cmd(cmd, run=args.run) + + # ========== Stage: plot ========== + if "plot" in stages: + summary_csv = paths.get("summary_csv", "eval/summary_ppl.csv") + plot_png = paths.get("plot_png", "eval/plot_ppl.png") + cmd = [PYTHON, "scripts/plot_growth.py", "--csv", summary_csv, "--x", "weight", "--y", "ppl", "--out", plot_png] + run_cmd(cmd, run=args.run) + + +if __name__ == "__main__": + main() diff --git a/trainer/train_full_sft.py b/trainer/train_full_sft.py index cc59cc7..ffa2110 100644 --- a/trainer/train_full_sft.py +++ b/trainer/train_full_sft.py @@ -15,16 +15,53 @@ from torch.nn.parallel import DistributedDataParallel from torch.utils.data import DataLoader, DistributedSampler from model.model_minimind import MiniMindConfig from dataset.lm_dataset import SFTDataset -from trainer.trainer_utils import get_lr, Logger, is_main_process, lm_checkpoint, init_distributed_mode, setup_seed, init_model, SkipBatchSampler +from trainer.trainer_utils import ( + get_lr, + Logger, + is_main_process, + lm_checkpoint, + init_distributed_mode, + setup_seed, + init_model, + SkipBatchSampler, + init_neuron_mask, + set_neuron_tracking, + grow_neurons, + save_run_config, + update_run_config, + get_active_ratio_by_layer, + get_active_ratio_stats +) warnings.filterwarnings('ignore') +def get_neuron_active_ratio(model): + total = 0 + active = 0 + for m in model.modules(): + if hasattr(m, "mask"): + total += m.mask.numel() + active += int(m.mask.sum().item()) + return (active / total) if total > 0 else 1.0 + + def train_epoch(epoch, loader, iters, start_step=0, wandb=None): start_time = time.time() + tokens_seen = 0 for step, (input_ids, labels) in enumerate(loader, start=start_step + 1): input_ids = input_ids.to(args.device) labels = labels.to(args.device) + tokens_seen += input_ids.numel() + global_step = epoch * iters + step + should_update = ((step + 1) % args.accumulation_steps == 0) + update_step = global_step // args.accumulation_steps + should_grow = bool(args.neuron_growth) and should_update and (update_step > 0) and (update_step % args.grow_interval == 0) + + if args.neuron_growth: + track_activity = (args.grow_method != "random") + track_grad = (args.grow_method != "random") and should_grow + set_neuron_tracking(model, track_activity=track_activity, track_mask_grad=track_grad, ema_beta=args.neuron_ema_beta) lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate) for param_group in optimizer.param_groups: param_group['lr'] = lr @@ -43,6 +80,18 @@ def train_epoch(epoch, loader, iters, start_step=0, wandb=None): scaler.step(optimizer) scaler.update() + if should_grow: + grow_neurons( + model, + method=args.grow_method, + grow_ratio=args.grow_ratio, + max_active_ratio=args.max_active_ratio, + score_alpha=args.grow_score_alpha, + score_beta=args.grow_score_beta, + seed=global_step + ) + set_neuron_tracking(model, track_activity=(args.grow_method != "random"), track_mask_grad=False, ema_beta=args.neuron_ema_beta) + optimizer.zero_grad(set_to_none=True) if step % args.log_interval == 0 or step == iters - 1: @@ -52,8 +101,27 @@ def train_epoch(epoch, loader, iters, start_step=0, wandb=None): current_logits_loss = current_loss - current_aux_loss current_lr = optimizer.param_groups[-1]['lr'] eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60 - Logger(f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), loss: {current_loss:.4f}, logits_loss: {current_logits_loss:.4f}, aux_loss: {current_aux_loss:.4f}, lr: {current_lr:.8f}, epoch_time: {eta_min:.1f}min') - if wandb: wandb.log({"loss": current_loss, "logits_loss": current_logits_loss, "aux_loss": current_aux_loss, "learning_rate": current_lr, "epoch_time": eta_min}) + tokens_per_sec = tokens_seen / max(spend_time, 1e-6) + active_ratio = get_neuron_active_ratio(model) if args.neuron_growth else None + active_stats = get_active_ratio_stats(model) if args.neuron_growth else None + active_msg = f', active_ratio: {active_ratio:.3f}' if active_ratio is not None else '' + if active_stats: + active_msg += f", active_mean: {active_stats['mean']:.3f}, active_min: {active_stats['min']:.3f}, active_max: {active_stats['max']:.3f}" + Logger(f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), loss: {current_loss:.4f}, logits_loss: {current_logits_loss:.4f}, aux_loss: {current_aux_loss:.4f}, lr: {current_lr:.8f}, epoch_time: {eta_min:.1f}min, tok/s: {tokens_per_sec:.1f}{active_msg}') + if wandb: + log_data = {"loss": current_loss, "logits_loss": current_logits_loss, "aux_loss": current_aux_loss, "learning_rate": current_lr, "epoch_time": eta_min, "tokens_per_sec": tokens_per_sec} + if active_ratio is not None: + log_data["active_ratio"] = active_ratio + if active_stats: + log_data["active_ratio_mean"] = active_stats["mean"] + log_data["active_ratio_min"] = active_stats["min"] + log_data["active_ratio_max"] = active_stats["max"] + if args.neuron_growth: + layer_ratios = get_active_ratio_by_layer(model) + for name, ratio in layer_ratios.items(): + key = name.replace(".", "_") + log_data[f"active_{key}"] = ratio + wandb.log(log_data) if (step % args.save_interval == 0 or step == iters - 1) and is_main_process(): model.eval() @@ -70,6 +138,8 @@ def train_epoch(epoch, loader, iters, start_step=0, wandb=None): del input_ids, labels, res, loss + return tokens_seen + if __name__ == "__main__": parser = argparse.ArgumentParser(description="MiniMind Full SFT") @@ -85,6 +155,17 @@ if __name__ == "__main__": parser.add_argument("--grad_clip", type=float, default=1.0, help="梯度裁剪阈值") parser.add_argument("--log_interval", type=int, default=100, help="日志打印间隔") parser.add_argument("--save_interval", type=int, default=1000, help="模型保存间隔") + parser.add_argument("--seed", type=int, default=42, help="随机种子(复现实验用)") + # 动态神经元生长相关参数(可选) + parser.add_argument("--neuron_growth", default=0, type=int, choices=[0, 1], help="是否启用动态神经元生长") + parser.add_argument("--init_active_ratio", type=float, default=0.8, help="初始激活神经元比例") + parser.add_argument("--grow_method", type=str, default="random", choices=["random", "act_grad"], help="神经元生长方式") + parser.add_argument("--grow_interval", type=int, default=100, help="每隔多少次优化器更新触发生长") + parser.add_argument("--grow_ratio", type=float, default=0.02, help="每次生长激活比例") + parser.add_argument("--max_active_ratio", type=float, default=0.99, help="最多激活到多少比例") + parser.add_argument("--grow_score_alpha", type=float, default=1.0, help="活动分数权重(EMA)") + parser.add_argument("--grow_score_beta", type=float, default=1.0, help="梯度分数权重") + parser.add_argument("--neuron_ema_beta", type=float, default=0.1, help="活动EMA系数") parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度") parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量") parser.add_argument('--max_seq_len', default=340, type=int, help="训练的最大截断长度(中文1token≈1.5~1.7字符)") @@ -100,12 +181,16 @@ if __name__ == "__main__": # ========== 1. 初始化环境和随机种子 ========== local_rank = init_distributed_mode() if dist.is_initialized(): args.device = f"cuda:{local_rank}" - setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0)) + setup_seed(args.seed + (dist.get_rank() if dist.is_initialized() else 0)) # ========== 2. 配置目录、模型参数、检查ckp ========== os.makedirs(args.save_dir, exist_ok=True) lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=bool(args.use_moe)) ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None + run_config_path = None + if is_main_process(): + run_name = f"{args.save_weight}_{args.hidden_size}_{time.strftime('%Y%m%d_%H%M%S')}" + run_config_path = save_run_config(args, args.save_dir, run_name=run_name, extra={"resume": bool(ckp_data)}) # ========== 3. 设置混合精度 ========== device_type = "cuda" if "cuda" in args.device else "cpu" @@ -123,6 +208,15 @@ if __name__ == "__main__": # ========== 5. 定义模型、数据、优化器 ========== model, tokenizer = init_model(lm_config, args.from_weight, device=args.device) + if args.neuron_growth and not ckp_data: + init_neuron_mask(model, init_active_ratio=args.init_active_ratio, seed=42) + if args.neuron_growth: + set_neuron_tracking( + model, + track_activity=(args.grow_method != "random"), + track_mask_grad=False, + ema_beta=args.neuron_ema_beta + ) if args.use_compile == 1: model = torch.compile(model) Logger('torch.compile enabled') @@ -146,17 +240,27 @@ if __name__ == "__main__": model = DistributedDataParallel(model, device_ids=[local_rank]) # ========== 8. 开始训练 ========== + train_start = time.time() + total_tokens = 0 for epoch in range(start_epoch, args.epochs): train_sampler and train_sampler.set_epoch(epoch) - setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist() + setup_seed(args.seed + epoch); indices = torch.randperm(len(train_ds)).tolist() skip = start_step if (epoch == start_epoch and start_step > 0) else 0 batch_sampler = SkipBatchSampler(train_sampler or indices, args.batch_size, skip) loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True) if skip > 0: Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step,从step {start_step + 1}开始') - train_epoch(epoch, loader, len(loader) + skip, start_step, wandb) + epoch_tokens = train_epoch(epoch, loader, len(loader) + skip, start_step, wandb) else: - train_epoch(epoch, loader, len(loader), 0, wandb) + epoch_tokens = train_epoch(epoch, loader, len(loader), 0, wandb) + total_tokens += epoch_tokens + + if is_main_process() and run_config_path: + update_run_config(run_config_path, { + "train_time_sec": time.time() - train_start, + "total_tokens": total_tokens, + "final_active_ratio": get_neuron_active_ratio(model) if args.neuron_growth else None + }) # ========== 9. 清理分布进程 ========== - if dist.is_initialized(): dist.destroy_process_group() \ No newline at end of file + if dist.is_initialized(): dist.destroy_process_group() diff --git a/trainer/train_pretrain.py b/trainer/train_pretrain.py index a8ad97f..2291fbc 100644 --- a/trainer/train_pretrain.py +++ b/trainer/train_pretrain.py @@ -1,161 +1,398 @@ -import os -import sys +import os # 导入 os,用于路径/文件操作 +import sys # 导入 sys,用于修改模块搜索路径 -__package__ = "trainer" -sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) +__package__ = "trainer" # 指定当前脚本所属包,便于相对导入 +# 把项目根目录加入 Python 路径,方便直接运行本脚本时能找到上层模块 +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) # 将上级目录加入 sys.path -import argparse -import time -import warnings -import torch -import torch.distributed as dist -from contextlib import nullcontext -from torch import optim, nn -from torch.nn.parallel import DistributedDataParallel -from torch.utils.data import DataLoader, DistributedSampler -from model.model_minimind import MiniMindConfig -from dataset.lm_dataset import PretrainDataset -from trainer.trainer_utils import get_lr, Logger, is_main_process, lm_checkpoint, init_distributed_mode, setup_seed, init_model, SkipBatchSampler +import argparse # 命令行参数解析 +import time # 计时 +import warnings # 控制警告输出 +import torch # PyTorch 主库 +import torch.distributed as dist # 分布式训练工具 +from contextlib import nullcontext # 空上下文管理器(CPU 下不用 autocast) +from torch import optim, nn # 优化器与神经网络模块 +from torch.nn.parallel import DistributedDataParallel # DDP 封装 +from torch.utils.data import DataLoader, DistributedSampler # 数据加载与分布式采样 +from model.model_minimind import MiniMindConfig # 模型配置类 +from dataset.lm_dataset import PretrainDataset # 预训练数据集类 +from trainer.trainer_utils import ( # 训练工具函数集合 + get_lr, # 计算学习率 + Logger, # 日志打印(主进程) + is_main_process, # 判断是否主进程 + lm_checkpoint, # 保存/读取断点 + init_distributed_mode, # 初始化分布式环境 + setup_seed, # 设置随机种子 + init_model, # 初始化模型与分词器 + SkipBatchSampler, # 可跳过 batch 的采样器 + init_neuron_mask, # 初始化神经元 mask + set_neuron_tracking, # 控制活动/梯度统计 + grow_neurons, # 动态激活神经元 + save_run_config, # 保存实验配置 + update_run_config, # 更新实验配置 + get_active_ratio_by_layer, # 按层统计激活比例 + get_active_ratio_stats # 激活比例统计 +) # 结束导入列表 -warnings.filterwarnings('ignore') +warnings.filterwarnings('ignore') # 忽略警告信息 + +# ============================================================================= +# 新手必读(最重要的几个角度): +# 1) 训练是否能跑通,最常见的问题是 data_path 指向的文件不存在。 +# 2) OOM(显存不够)优先降低:batch_size -> max_seq_len -> hidden_size/num_hidden_layers。 +# 3) loss 不下降/变成 NaN:先把学习率降 10 倍试试,再检查数据质量。 +# 4) 单卡/多卡差别:多卡时用 DistributedSampler + set_epoch,且只主进程保存。 +# 5) 混合精度:bfloat16 更稳,float16 更快但更容易数值不稳。 +# ============================================================================= + +# ============================================================================= +# 本脚本的训练主线(给新手看的超简版): +# 1) 读取 jsonl 数据 -> token -> input_ids/labels +# 2) 模型前向得到 loss +# 3) loss.backward() 反向传播 +# 4) optimizer.step() 更新参数 +# 5) 周期性打印日志、保存权重、保存断点 +# ============================================================================= -def train_epoch(epoch, loader, iters, start_step=0, wandb=None): - start_time = time.time() - for step, (input_ids, labels) in enumerate(loader, start=start_step + 1): - input_ids = input_ids.to(args.device) - labels = labels.to(args.device) - lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate) - for param_group in optimizer.param_groups: - param_group['lr'] = lr - - with autocast_ctx: - res = model(input_ids, labels=labels) - loss = res.loss + res.aux_loss - loss = loss / args.accumulation_steps - - scaler.scale(loss).backward() - - if (step + 1) % args.accumulation_steps == 0: - scaler.unscale_(optimizer) - torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) - - scaler.step(optimizer) - scaler.update() - - optimizer.zero_grad(set_to_none=True) - - if step % args.log_interval == 0 or step == iters - 1: - spend_time = time.time() - start_time - current_loss = loss.item() * args.accumulation_steps - current_aux_loss = res.aux_loss.item() if res.aux_loss is not None else 0.0 - current_logits_loss = current_loss - current_aux_loss - current_lr = optimizer.param_groups[-1]['lr'] - eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60 - Logger(f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), loss: {current_loss:.4f}, logits_loss: {current_logits_loss:.4f}, aux_loss: {current_aux_loss:.4f}, lr: {current_lr:.8f}, epoch_time: {eta_min:.1f}min') - if wandb: wandb.log({"loss": current_loss, "logits_loss": current_logits_loss, "aux_loss": current_aux_loss, "learning_rate": current_lr, "epoch_time": eta_min}) - - if (step % args.save_interval == 0 or step == iters - 1) and is_main_process(): - model.eval() - moe_suffix = '_moe' if lm_config.use_moe else '' - ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth' - raw_model = model.module if isinstance(model, DistributedDataParallel) else model - raw_model = getattr(raw_model, '_orig_mod', raw_model) - state_dict = raw_model.state_dict() - torch.save({k: v.half().cpu() for k, v in state_dict.items()}, ckp) - lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints') - model.train() - del state_dict - - del input_ids, labels, res, loss +def get_neuron_active_ratio(model): # 统计当前激活神经元比例 + total = 0 # 总神经元数 + active = 0 # 已激活神经元数 + for m in model.modules(): # 遍历所有模块 + if hasattr(m, "mask"): # 只统计带 mask 的 FFN + total += m.mask.numel() + active += int(m.mask.sum().item()) + return (active / total) if total > 0 else 1.0 # 返回比例(若无 mask 则视为 1) -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="MiniMind Pretraining") - parser.add_argument("--save_dir", type=str, default="../out", help="模型保存目录") - parser.add_argument('--save_weight', default='pretrain', type=str, help="保存权重的前缀名") - parser.add_argument("--epochs", type=int, default=1, help="训练轮数(建议1轮zero或2-6轮充分训练)") - parser.add_argument("--batch_size", type=int, default=32, help="batch size") - parser.add_argument("--learning_rate", type=float, default=5e-4, help="初始学习率") - parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备") - parser.add_argument("--dtype", type=str, default="bfloat16", help="混合精度类型") - parser.add_argument("--num_workers", type=int, default=8, help="数据加载线程数") - parser.add_argument("--accumulation_steps", type=int, default=8, help="梯度累积步数") - parser.add_argument("--grad_clip", type=float, default=1.0, help="梯度裁剪阈值") - parser.add_argument("--log_interval", type=int, default=100, help="日志打印间隔") - parser.add_argument("--save_interval", type=int, default=1000, help="模型保存间隔") - parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度") - parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量") - parser.add_argument('--max_seq_len', default=340, type=int, help="训练的最大截断长度(中文1token≈1.5~1.7字符)") - parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构(0=否,1=是)") - parser.add_argument("--data_path", type=str, default="../dataset/pretrain_hq.jsonl", help="预训练数据路径") - parser.add_argument('--from_weight', default='none', type=str, help="基于哪个权重训练,为none则从头开始") - parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训(0=否,1=是)") - parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb") - parser.add_argument("--wandb_project", type=str, default="MiniMind-Pretrain", help="wandb项目名") - parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否使用torch.compile加速(0=否,1=是)") - args = parser.parse_args() +def train_epoch(epoch, loader, iters, start_step=0, wandb=None): # 训练一个 epoch 的函数 + """ + 训练一个 epoch + - epoch: 当前第几个 epoch(从 0 开始) + - loader: DataLoader(会产出 batch) + - iters: 这个 epoch 里总的 step 数 + - start_step: 断点续训时,跳过前面已经训练过的 step 数 + - wandb: 记录日志用(可选) + """ + # 记录本 epoch 开始时间,用于估算剩余时间 + start_time = time.time() # 当前时间戳 + tokens_seen = 0 # 统计已处理 token 数,用于吞吐率 + # 这里的 step 从 start_step + 1 开始计数,日志更直观(人类习惯从 1 开始) + for step, (input_ids, labels) in enumerate(loader, start=start_step + 1): # 遍历 DataLoader 的每个 batch + # 1) 把数据搬到指定设备(GPU/CPU) + # input_ids/labels 形状通常是 (batch_size, seq_len) + # 例:batch_size=2, seq_len=128 -> shape=(2, 128) + input_ids = input_ids.to(args.device) # 把输入 token 放到 GPU/CPU + labels = labels.to(args.device) # 把监督标签放到 GPU/CPU + tokens_seen += input_ids.numel() # 统计 tokens(含 padding,用于吞吐率) + + # 计算全局步数与是否需要“生长”神经元 + global_step = epoch * iters + step # 全局 step(从 1 开始) + should_update = ((step + 1) % args.accumulation_steps == 0) # 本步是否会更新参数 + update_step = global_step // args.accumulation_steps # 优化器更新步数(从 0 开始) + should_grow = bool(args.neuron_growth) and should_update and (update_step > 0) and (update_step % args.grow_interval == 0) + + # 根据需要开启/关闭活动与 mask 梯度追踪 + if args.neuron_growth: + track_activity = (args.grow_method != "random") # 随机增长不需要活动统计 + track_grad = (args.grow_method != "random") and should_grow # 只有在增长步才追踪 mask 梯度 + set_neuron_tracking(model, track_activity=track_activity, track_mask_grad=track_grad, ema_beta=args.neuron_ema_beta) + + # 2) 计算当前 step 的学习率(这里用余弦衰减) + # 说明:学习率不是固定的,会随训练进度变化 + # current_step = epoch * iters + step 是“全局步数” + lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate) # 计算当前学习率 + for param_group in optimizer.param_groups: # 遍历优化器的参数组 + param_group['lr'] = lr # 设置该参数组的学习率 + + # 3) 前向 + 计算 loss(混合精度可选) + with autocast_ctx: # GPU 下启用混合精度,CPU 下为空上下文 + # res 是模型输出对象,包含 loss / logits / aux_loss 等 + res = model(input_ids, labels=labels) # 前向计算(内部会算 loss) + # res.logits 形状通常是 (batch_size, seq_len, vocab_size) + # 例:batch_size=2, seq_len=128, vocab=6400 -> shape=(2, 128, 6400) + # loss 在模型内部计算,核心是“预测下一个 token” + # shift_logits: (batch_size, seq_len-1, vocab_size) + # shift_labels: (batch_size, seq_len-1) + # res.loss 是语言模型的交叉熵损失(预测下一个 token) + # res.aux_loss 仅在 MoE 模型中存在,用于专家负载均衡 + # 主损失 + MoE 的辅助损失(如果有) + loss = res.loss + res.aux_loss # 总损失 + # 梯度累积:把 loss 平均分摊到多次小步 + # 这样累积 N 次的梯度,等价于“大 batch”的效果 + # 等效总 batch_size = batch_size * accumulation_steps * world_size + loss = loss / args.accumulation_steps # 按累积步数缩放 loss + + # 4) 反向传播:把 loss 的梯度累积到参数上 + # 注意:这里不会立刻更新参数,只是把梯度累积在参数上 + scaler.scale(loss).backward() # 反向传播(混合精度下用 scaler) + + # 5) 每积累一定步数才更新一次参数 + # accumulation_steps=8 表示每 8 个小步更新一次参数 + if (step + 1) % args.accumulation_steps == 0: # 到达累积步数时更新参数 + # 反缩放后再裁剪梯度,避免梯度爆炸 + # clip_grad_norm_ 会限制梯度范数不超过 args.grad_clip + scaler.unscale_(optimizer) # 先取消缩放,得到真实梯度 + torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) # 裁剪梯度 + + # 更新参数 + scaler.step(optimizer) # 执行一步优化器更新 + scaler.update() # 更新缩放比例 + + # 动态激活更多神经元(仅在指定步数触发) + if should_grow: + grow_neurons( + model, + method=args.grow_method, + grow_ratio=args.grow_ratio, + max_active_ratio=args.max_active_ratio, + score_alpha=args.grow_score_alpha, + score_beta=args.grow_score_beta, + seed=global_step + ) + # 生长完成后关闭 mask 梯度追踪,避免额外开销 + set_neuron_tracking(model, track_activity=(args.grow_method != "random"), track_mask_grad=False, ema_beta=args.neuron_ema_beta) + + # 清空梯度,进入下一轮 + optimizer.zero_grad(set_to_none=True) # 清零梯度,节省显存 + + # 6) 日志打印 + if step % args.log_interval == 0 or step == iters - 1: # 按间隔或最后一步打印 + spend_time = time.time() - start_time # 已花时间(秒) + # 注意:这里把 loss 乘回来,恢复到“真实的单步损失” + # 因为前面为了梯度累积把 loss 除过 + current_loss = loss.item() * args.accumulation_steps # 真实损失值 + current_aux_loss = res.aux_loss.item() if res.aux_loss is not None else 0.0 # MoE 辅助损失 + current_logits_loss = current_loss - current_aux_loss # 语言模型主损失 + current_lr = optimizer.param_groups[-1]['lr'] # 当前学习率 + # 估算当前 epoch 剩余时间(分钟) + # eta = 已花时间 / 已完成步数 * 总步数 - 已花时间 + eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60 # 估算剩余分钟数 + tokens_per_sec = tokens_seen / max(spend_time, 1e-6) # 吞吐率(tokens/s) + # 如果启用了动态生长,额外记录当前激活比例 + active_ratio = get_neuron_active_ratio(model) if args.neuron_growth else None + active_stats = get_active_ratio_stats(model) if args.neuron_growth else None + active_msg = f', active_ratio: {active_ratio:.3f}' if active_ratio is not None else '' + if active_stats: + active_msg += f", active_mean: {active_stats['mean']:.3f}, active_min: {active_stats['min']:.3f}, active_max: {active_stats['max']:.3f}" + Logger(f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), loss: {current_loss:.4f}, logits_loss: {current_logits_loss:.4f}, aux_loss: {current_aux_loss:.4f}, lr: {current_lr:.8f}, epoch_time: {eta_min:.1f}min, tok/s: {tokens_per_sec:.1f}{active_msg}') # 打印日志 + if wandb: # 如果启用 wandb + log_data = { + "loss": current_loss, + "logits_loss": current_logits_loss, + "aux_loss": current_aux_loss, + "learning_rate": current_lr, + "epoch_time": eta_min, + "tokens_per_sec": tokens_per_sec + } + if active_ratio is not None: + log_data["active_ratio"] = active_ratio + if active_stats: + log_data["active_ratio_mean"] = active_stats["mean"] + log_data["active_ratio_min"] = active_stats["min"] + log_data["active_ratio_max"] = active_stats["max"] + # 逐层激活比例(便于画曲线) + if args.neuron_growth: + layer_ratios = get_active_ratio_by_layer(model) + for name, ratio in layer_ratios.items(): + key = name.replace(".", "_") + log_data[f"active_{key}"] = ratio + wandb.log(log_data) # 记录到 wandb + + # 7) 保存模型(只在主进程保存,避免多卡重复写文件) + if (step % args.save_interval == 0 or step == iters - 1) and is_main_process(): # 满足保存条件且是主进程 + # eval() 主要是关闭 dropout / 避免不一致 + model.eval() # 切换到评估模式 + moe_suffix = '_moe' if lm_config.use_moe else '' # MoE 模型在文件名加后缀 + ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth' # 权重保存路径 + # raw_model 兼容两种情况: + # - DDP 包裹时,真实模型在 model.module 里 + # - torch.compile 包裹时,真实模型在 _orig_mod 里 + raw_model = model.module if isinstance(model, DistributedDataParallel) else model # 取出真实模型 + raw_model = getattr(raw_model, '_orig_mod', raw_model) # 兼容 torch.compile + state_dict = raw_model.state_dict() # 获取模型参数字典 + # 只保存半精度权重,节省空间(推理时再转回) + torch.save({k: v.half().cpu() for k, v in state_dict.items()}, ckp) # 保存权重到硬盘 + # 保存断点信息(可用于恢复训练) + lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints') # 保存断点 + model.train() # 切回训练模式 + del state_dict # 释放权重变量 + + # 释放中间变量,减少显存占用 + # 对新手来说:这不是必须,但有助于长时间训练稳定 + del input_ids, labels, res, loss # 删除临时变量 + + return tokens_seen # 返回本 epoch 处理的 token 数 + + +if __name__ == "__main__": # 仅在直接运行本脚本时执行以下代码 + # ------------------------- + # 0. 解析训练参数 + # ------------------------- + parser = argparse.ArgumentParser(description="MiniMind Pretraining") # 创建参数解析器 + # 保存相关参数 + parser.add_argument("--save_dir", type=str, default="../out", help="模型保存目录") # 输出权重目录 + parser.add_argument('--save_weight', default='pretrain', type=str, help="保存权重的前缀名") # 权重前缀 + # 训练相关参数 + # batch_size:越大越快但越吃显存 + parser.add_argument("--epochs", type=int, default=1, help="训练轮数(建议1轮zero或2-6轮充分训练)") # 训练轮数 + parser.add_argument("--batch_size", type=int, default=32, help="batch size") # 每步样本数 + # learning_rate:影响收敛速度与稳定性,过大会震荡/发散 + parser.add_argument("--learning_rate", type=float, default=5e-4, help="初始学习率") # 初始学习率 + parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备") # 训练设备 + # dtype:bfloat16 更稳,float16 更快但风险更高 + parser.add_argument("--dtype", type=str, default="bfloat16", help="混合精度类型") # 混合精度类型 + # num_workers:数据加载并行数,太大可能反而拖慢 + parser.add_argument("--num_workers", type=int, default=8, help="数据加载线程数") # 数据加载线程数 + # accumulation_steps:梯度累积,等效扩大 batch_size(但会更慢) + parser.add_argument("--accumulation_steps", type=int, default=8, help="梯度累积步数") # 梯度累积步数 + # grad_clip:限制梯度范数,避免梯度爆炸 + parser.add_argument("--grad_clip", type=float, default=1.0, help="梯度裁剪阈值") # 梯度裁剪阈值 + parser.add_argument("--log_interval", type=int, default=100, help="日志打印间隔") # 日志间隔 + parser.add_argument("--save_interval", type=int, default=1000, help="模型保存间隔") # 保存间隔 + parser.add_argument("--seed", type=int, default=42, help="随机种子(复现实验用)") # 随机种子 + # 动态神经元生长相关参数(可选) + parser.add_argument("--neuron_growth", default=0, type=int, choices=[0, 1], help="是否启用动态神经元生长") # 是否开启 + parser.add_argument("--init_active_ratio", type=float, default=0.8, help="初始激活神经元比例") # 初始激活比例 + parser.add_argument("--grow_method", type=str, default="random", choices=["random", "act_grad"], help="神经元生长方式") # 生长方法 + parser.add_argument("--grow_interval", type=int, default=100, help="每隔多少次优化器更新触发生长") # 生长间隔 + parser.add_argument("--grow_ratio", type=float, default=0.02, help="每次生长激活比例") # 每次激活比例 + parser.add_argument("--max_active_ratio", type=float, default=0.99, help="最多激活到多少比例") # 最大激活比例 + parser.add_argument("--grow_score_alpha", type=float, default=1.0, help="活动分数权重(EMA)") # 活动权重 + parser.add_argument("--grow_score_beta", type=float, default=1.0, help="梯度分数权重") # 梯度权重 + parser.add_argument("--neuron_ema_beta", type=float, default=0.1, help="活动EMA系数") # EMA 衰减系数 + # 模型结构参数 + # hidden_size/num_hidden_layers 越大:参数越多、训练越慢、显存越高 + parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度") # 隐藏层维度 + parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量") # 层数 + # max_seq_len 越大:上下文更长,但显存/算力增长很快 + parser.add_argument('--max_seq_len', default=340, type=int, help="训练的最大截断长度(中文1token≈1.5~1.7字符)") # 最大序列长度 + parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构(0=否,1=是)") # 是否启用 MoE + # 数据与权重加载 + # data_path 必须是 jsonl,且每行包含 {"text": "..."} + parser.add_argument("--data_path", type=str, default="../dataset/pretrain_hq.jsonl", help="预训练数据路径") # 数据文件路径 + # from_weight:从已有权重继续训练;none 表示从头开始 + parser.add_argument('--from_weight', default='none', type=str, help="基于哪个权重训练,为none则从头开始") # 加载已有权重 + # from_resume:是否自动检测断点续训(保存/恢复 optimizer 状态) + parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训(0=否,1=是)") # 是否断点续训 + # 日志与加速 + parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb") # 是否启用 wandb + parser.add_argument("--wandb_project", type=str, default="MiniMind-Pretrain", help="wandb项目名") # wandb 项目名 + parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否使用torch.compile加速(0=否,1=是)") # 是否启用 torch.compile + args = parser.parse_args() # 解析命令行参数 # ========== 1. 初始化环境和随机种子 ========== - local_rank = init_distributed_mode() - if dist.is_initialized(): args.device = f"cuda:{local_rank}" - setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0)) + local_rank = init_distributed_mode() # 初始化分布式环境并获取本地 rank + if dist.is_initialized(): # 如果启用了分布式 + args.device = f"cuda:{local_rank}" # 让当前进程绑定到对应 GPU + # 每个进程使用不同随机种子,保证多卡可复现且不重复 + # 这样多卡不会完全“学到相同的 batch” + setup_seed(args.seed + (dist.get_rank() if dist.is_initialized() else 0)) # 设置随机种子 # ========== 2. 配置目录、模型参数、检查ckp ========== - os.makedirs(args.save_dir, exist_ok=True) - lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=bool(args.use_moe)) - ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None + os.makedirs(args.save_dir, exist_ok=True) # 创建模型保存目录 + # 构建模型配置(这里是最小的 MiniMind 配置) + # hidden_size / num_hidden_layers 决定模型规模(大=更慢=更耗显存) + lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=bool(args.use_moe)) # 创建模型配置 + # 若启用断点续训,就尝试从 checkpoints 里读取 + ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume == 1 else None # 读取断点 + # 保存本次实验配置(仅主进程),便于论文复现实验 + run_config_path = None + if is_main_process(): + run_name = f"{args.save_weight}_{args.hidden_size}_{time.strftime('%Y%m%d_%H%M%S')}" + run_config_path = save_run_config(args, args.save_dir, run_name=run_name, extra={"resume": bool(ckp_data)}) # 写入配置文件 # ========== 3. 设置混合精度 ========== - device_type = "cuda" if "cuda" in args.device else "cpu" - dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16 - autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype) + device_type = "cuda" if "cuda" in args.device else "cpu" # 判断设备类型 + dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16 # 选择混合精度类型 + # CPU 不使用混合精度;GPU 才启用 autocast + # bfloat16 相对更稳定;float16 更快但可能不稳定 + autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype) # 设置 autocast 上下文 # ========== 4. 配wandb ========== - wandb = None - if args.use_wandb and is_main_process(): - import swanlab as wandb - wandb_id = ckp_data.get('wandb_id') if ckp_data else None - resume = 'must' if wandb_id else None - wandb_run_name = f"MiniMind-Pretrain-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}" - wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume) + wandb = None # 默认不启用 wandb + if args.use_wandb and is_main_process(): # 仅主进程初始化 wandb + import swanlab as wandb # swanlab 与 wandb API 兼容 + wandb_id = ckp_data.get('wandb_id') if ckp_data else None # 续训时复用 run id + resume = 'must' if wandb_id else None # 若有 id 则强制恢复 + wandb_run_name = f"MiniMind-Pretrain-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}" # 生成 run 名 + wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume) # 初始化 wandb # ========== 5. 定义模型、数据、优化器 ========== - model, tokenizer = init_model(lm_config, args.from_weight, device=args.device) - if args.use_compile == 1: - model = torch.compile(model) - Logger('torch.compile enabled') - train_ds = PretrainDataset(args.data_path, tokenizer, max_length=args.max_seq_len) - train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None - scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16')) - optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate) + # 载入模型与分词器(若指定 from_weight 则加载权重) + model, tokenizer = init_model(lm_config, args.from_weight, device=args.device) # 创建模型与 tokenizer + # 若启用动态神经元生长,先初始化 mask(仅在非断点恢复时) + if args.neuron_growth and not ckp_data: + init_neuron_mask(model, init_active_ratio=args.init_active_ratio, seed=42) + # 设置活动/梯度统计开关(随机增长不需要活动统计) + if args.neuron_growth: + set_neuron_tracking( + model, + track_activity=(args.grow_method != "random"), + track_mask_grad=False, + ema_beta=args.neuron_ema_beta + ) + if args.use_compile == 1: # 是否启用 torch.compile + model = torch.compile(model) # 编译模型以加速 + Logger('torch.compile enabled') # 打印提示 + # 读取预训练数据(jsonl) + # 每一行都是 {"text": "..."} 的格式 + train_ds = PretrainDataset(args.data_path, tokenizer, max_length=args.max_seq_len) # 创建数据集 + train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None # 多卡时使用分布式采样 + # 仅 float16 时启用 GradScaler;bfloat16 不需要 + scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16')) # 混合精度缩放器 + # AdamW 是语言模型最常用的优化器 + optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate) # 创建优化器 # ========== 6. 从ckp恢复状态 ========== - start_epoch, start_step = 0, 0 - if ckp_data: - model.load_state_dict(ckp_data['model']) - optimizer.load_state_dict(ckp_data['optimizer']) - scaler.load_state_dict(ckp_data['scaler']) - start_epoch = ckp_data['epoch'] - start_step = ckp_data.get('step', 0) + start_epoch, start_step = 0, 0 # 默认从头开始 + if ckp_data: # 如果找到断点 + # 恢复模型、优化器、混合精度状态 + model.load_state_dict(ckp_data['model']) # 恢复模型权重 + optimizer.load_state_dict(ckp_data['optimizer']) # 恢复优化器状态 + scaler.load_state_dict(ckp_data['scaler']) # 恢复 scaler 状态 + start_epoch = ckp_data['epoch'] # 继续的 epoch + start_step = ckp_data.get('step', 0) # 继续的 step # ========== 7. DDP包模型 ========== - if dist.is_initialized(): - model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"} - model = DistributedDataParallel(model, device_ids=[local_rank]) + if dist.is_initialized(): # 如果启用了分布式 + # 这两个 buffer 在 DDP 下不需要同步 + model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"} # 忽略这些 buffer + # DDP 会自动帮我们做多卡梯度同步 + model = DistributedDataParallel(model, device_ids=[local_rank]) # 包装成 DDP 模型 # ========== 8. 开始训练 ========== - for epoch in range(start_epoch, args.epochs): - train_sampler and train_sampler.set_epoch(epoch) - setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist() - skip = start_step if (epoch == start_epoch and start_step > 0) else 0 - batch_sampler = SkipBatchSampler(train_sampler or indices, args.batch_size, skip) - loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True) - if skip > 0: - Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step,从step {start_step + 1}开始') - train_epoch(epoch, loader, len(loader) + skip, start_step, wandb) - else: - train_epoch(epoch, loader, len(loader), 0, wandb) + train_start = time.time() + total_tokens = 0 + for epoch in range(start_epoch, args.epochs): # 遍历 epoch + # 多卡时每个 epoch 都要 set_epoch,保证采样不同 + train_sampler and train_sampler.set_epoch(epoch) # 设置采样器的 epoch + # 打乱索引(单卡) + setup_seed(args.seed + epoch) # 每个 epoch 设置不同随机种子 + indices = torch.randperm(len(train_ds)).tolist() # 生成随机索引列表 + # 断点续训时跳过已经训练过的 step + skip = start_step if (epoch == start_epoch and start_step > 0) else 0 # 需要跳过的 step 数 + # 自定义 batch_sampler,用于跳过前面若干 batch + # 注意:这里用 batch_sampler 时,DataLoader 不再传 batch_size + batch_sampler = SkipBatchSampler(train_sampler or indices, args.batch_size, skip) # 创建可跳过的采样器 + # pin_memory=True 会让 CPU->GPU 拷贝更快 + # loader 每次返回的 input_ids/labels 形状是 (batch_size, max_seq_len) + loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True) # 构建 DataLoader + if skip > 0: # 如果需要跳过 + Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step,从step {start_step + 1}开始') # 打印提示 + epoch_tokens = train_epoch(epoch, loader, len(loader) + skip, start_step, wandb) # 从指定 step 继续训练 + else: # 正常从头开始 + epoch_tokens = train_epoch(epoch, loader, len(loader), 0, wandb) # 正常训练 + total_tokens += epoch_tokens + + # 训练结束后更新配置文件(记录总 tokens 和耗时) + if is_main_process() and run_config_path: + update_run_config(run_config_path, { + "train_time_sec": time.time() - train_start, + "total_tokens": total_tokens, + "final_active_ratio": get_neuron_active_ratio(model) if args.neuron_growth else None + }) # ========== 9. 清理分布进程 ========== - if dist.is_initialized(): dist.destroy_process_group() \ No newline at end of file + if dist.is_initialized(): # 如果启用了分布式 + dist.destroy_process_group() # 关闭分布式进程组 diff --git a/trainer/trainer_utils.py b/trainer/trainer_utils.py index 3ec1e44..443ca10 100644 --- a/trainer/trainer_utils.py +++ b/trainer/trainer_utils.py @@ -7,6 +7,10 @@ __package__ = "trainer" sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) import random import math +import json +import time +import subprocess +import platform import numpy as np import torch import torch.distributed as dist @@ -60,6 +64,88 @@ def setup_seed(seed: int): torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False + +def _get_git_commit(): + """尽量获取当前 git commit(失败则返回 None)""" + try: + res = subprocess.run( + ["git", "rev-parse", "--short", "HEAD"], + stdout=subprocess.PIPE, + stderr=subprocess.DEVNULL, + check=False, + text=True + ) + commit = res.stdout.strip() + return commit if commit else None + except Exception: + return None + + +def save_run_config(args, save_dir, run_name="run", extra=None): + """ + 保存训练配置,便于复现实验 + - args: argparse.Namespace + - save_dir: 保存目录 + - run_name: 文件名前缀 + - extra: 额外信息 dict + """ + os.makedirs(save_dir, exist_ok=True) + config = { + "run_name": run_name, + "timestamp": time.strftime("%Y%m%d_%H%M%S"), + "args": vars(args) if hasattr(args, "__dict__") else args, + "torch": torch.__version__, + "python": platform.python_version(), + "git_commit": _get_git_commit() + } + if extra: + config.update(extra) + path = os.path.join(save_dir, f"{run_name}_config.json") + with open(path, "w", encoding="utf-8") as f: + json.dump(config, f, ensure_ascii=False, indent=2) + return path + + +def update_run_config(path, extra): + """更新 run_config.json(追加训练完成后的统计信息)""" + if not path or not os.path.exists(path): + return + try: + with open(path, "r", encoding="utf-8") as f: + config = json.load(f) + except Exception: + config = {} + config.update(extra or {}) + with open(path, "w", encoding="utf-8") as f: + json.dump(config, f, ensure_ascii=False, indent=2) + + +def get_active_ratio_by_layer(model): + """ + 返回每个带 mask 的 FFN 模块激活比例 + - key: 模块名称(如 model.layers.0.mlp) + - value: 激活比例(0~1) + """ + ratios = {} + for name, m in model.named_modules(): + if hasattr(m, "mask"): + total = m.mask.numel() + if total > 0: + ratios[name] = float(m.mask.sum().item() / total) + return ratios + + +def get_active_ratio_stats(model): + """返回激活比例的统计值(均值/最小/最大)""" + ratios = list(get_active_ratio_by_layer(model).values()) + if not ratios: + return None + return { + "mean": float(sum(ratios) / len(ratios)), + "min": float(min(ratios)), + "max": float(max(ratios)) + } + def lm_checkpoint(lm_config, weight='full_sft', model=None, optimizer=None, epoch=0, step=0, wandb=None, save_dir='../checkpoints', **kwargs): os.makedirs(save_dir, exist_ok=True) moe_path = '_moe' if lm_config.use_moe else '' @@ -131,6 +217,121 @@ def init_model(lm_config, from_weight='pretrain', tokenizer_path='../model', sav return model.to(device), tokenizer +# ===================== 动态神经元生长相关工具函数 ===================== +def _iter_ffn_modules(model): + """遍历所有带 mask 的 FFN 模块(Dense + MoE Experts 都会被包含)""" + for m in model.modules(): + if hasattr(m, "mask") and hasattr(m, "ema_act"): + yield m + + +def set_neuron_tracking(model, track_activity=False, track_mask_grad=False, ema_beta=0.1): + """开启/关闭神经元活动与 mask 梯度的统计""" + for m in _iter_ffn_modules(model): + m.track_activity = track_activity + m.track_mask_grad = track_mask_grad + m.ema_beta = ema_beta + + +def init_neuron_mask(model, init_active_ratio=0.8, seed=42): + """初始化神经元 mask(随机激活部分神经元)""" + # 多卡时只在主进程生成,再广播 + if dist.is_initialized(): + if is_main_process(): + _init_neuron_mask_impl(model, init_active_ratio, seed) + # 广播到所有进程,确保 mask 一致 + for m in _iter_ffn_modules(model): + dist.broadcast(m.mask, src=0) + else: + _init_neuron_mask_impl(model, init_active_ratio, seed) + + +def _init_neuron_mask_impl(model, init_active_ratio, seed): + g = torch.Generator() + g.manual_seed(seed) + for m in _iter_ffn_modules(model): + total = m.mask.numel() + n_active = max(1, int(total * init_active_ratio)) + # 先清零,再随机选一部分置 1 + m.mask.zero_() + idx = torch.randperm(total, generator=g)[:n_active].to(m.mask.device) + m.mask[idx] = 1.0 + + +def grow_neurons( + model, + method="random", + grow_ratio=0.02, + max_active_ratio=0.99, + score_alpha=1.0, + score_beta=1.0, + seed=None +): + """ + 动态激活更多神经元 + - method: "random" 或 "act_grad" + - grow_ratio: 每次激活的比例(相对于总神经元数) + - max_active_ratio: 最多激活到多少比例 + - score_alpha/score_beta: 活动/梯度的权重 + - seed: 随机种子(用于 random) + """ + # 多卡:所有进程同时计算(内部会 all_reduce),再统一广播以保证一致 + if dist.is_initialized(): + _grow_neurons_impl(model, method, grow_ratio, max_active_ratio, score_alpha, score_beta, seed) + for m in _iter_ffn_modules(model): + dist.broadcast(m.mask, src=0) + else: + _grow_neurons_impl(model, method, grow_ratio, max_active_ratio, score_alpha, score_beta, seed) + + +def _grow_neurons_impl(model, method, grow_ratio, max_active_ratio, score_alpha, score_beta, seed): + # 用于随机选择的 generator + g = None + if seed is not None: + g = torch.Generator() + g.manual_seed(seed) + + for m in _iter_ffn_modules(model): + mask = m.mask + total = mask.numel() + active = int(mask.sum().item()) + max_active = int(total * max_active_ratio) + if active >= max_active: + continue + + n_add = max(1, int(total * grow_ratio)) + n_add = min(n_add, max_active - active) + if n_add <= 0: + continue + + if method == "random": + inactive_idx = (mask == 0).nonzero(as_tuple=False).flatten() + if inactive_idx.numel() == 0: + continue + # 从未激活的神经元中随机选 + perm = torch.randperm(inactive_idx.numel(), generator=g)[:n_add].to(inactive_idx.device) + chosen = inactive_idx[perm] + else: + # 活动 + 梯度加权得分 + score = torch.zeros_like(mask) + if score_alpha > 0: + score += score_alpha * m.ema_act + if score_beta > 0 and m._mask_proxy is not None and m._mask_proxy.grad is not None: + grad = m._mask_proxy.grad.detach().abs() + # 多卡时梯度求平均(主进程决策) + if dist.is_initialized(): + dist.all_reduce(grad, op=dist.ReduceOp.SUM) + grad /= dist.get_world_size() + score += score_beta * grad + + # 已激活的神经元设为极小,保证不会被选中 + score = score.clone() + score[mask > 0] = -1e9 + chosen = torch.topk(score, k=n_add).indices + + mask[chosen] = 1.0 + + class SkipBatchSampler(Sampler): def __init__(self, sampler, batch_size, skip_batches=0): self.sampler = sampler @@ -154,4 +355,4 @@ class SkipBatchSampler(Sampler): def __len__(self): total_batches = (len(self.sampler) + self.batch_size - 1) // self.batch_size - return max(0, total_batches - self.skip_batches) \ No newline at end of file + return max(0, total_batches - self.skip_batches)