mirror of
https://github.com/jingyaogong/minimind.git
synced 2026-01-13 19:57:20 +08:00
support ppo and Training web code refactoring
This commit is contained in:
parent
237744d58a
commit
d66a7945db
40
trainer_web/start_web_ui.sh
Executable file
40
trainer_web/start_web_ui.sh
Executable file
@ -0,0 +1,40 @@
|
||||
#!/bin/bash
|
||||
|
||||
# 获取脚本所在目录
|
||||
SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
|
||||
cd "$SCRIPT_DIR"
|
||||
|
||||
# 检查是否已经有实例在运行
|
||||
if [ -f "train_web_ui.pid" ]; then
|
||||
pid=$(cat "train_web_ui.pid")
|
||||
if ps -p "$pid" > /dev/null 2>&1; then
|
||||
echo "Web UI 服务已经在运行 (PID: $pid)"
|
||||
exit 1
|
||||
else
|
||||
echo "删除旧的PID文件"
|
||||
rm "train_web_ui.pid"
|
||||
fi
|
||||
fi
|
||||
|
||||
# 创建日志目录
|
||||
LOG_DIR="../logfile"
|
||||
mkdir -p "$LOG_DIR"
|
||||
|
||||
# 生成时间戳
|
||||
TIMESTAMP=$(date +"%Y%m%d_%H%M%S")
|
||||
LOG_FILE="$LOG_DIR/web_ui_$TIMESTAMP.log"
|
||||
|
||||
echo "启动 MiniMind Web UI 服务..."
|
||||
echo "日志文件: $LOG_FILE"
|
||||
|
||||
# 使用nohup启动服务
|
||||
nohup python -u train_web_ui.py > "$LOG_FILE" 2>&1 &
|
||||
|
||||
# 保存PID
|
||||
echo $! > "train_web_ui.pid"
|
||||
|
||||
sleep 2
|
||||
|
||||
echo "服务已启动! PID: $(cat "train_web_ui.pid")"
|
||||
echo "访问地址: http://localhost:5000"
|
||||
echo "停止命令: kill $(cat "train_web_ui.pid") or ./trainer_web/stop_web_ui.sh"
|
||||
|
Before Width: | Height: | Size: 615 KiB After Width: | Height: | Size: 615 KiB |
@ -32,12 +32,19 @@ function openTab(evt, tabName) {
|
||||
document.getElementById('train_type').addEventListener('change', function() {
|
||||
const trainType = this.value;
|
||||
const pretrainSftFields = document.querySelectorAll('.pretrain-sft');
|
||||
const fromWeightFields = document.querySelectorAll('.from-weight');
|
||||
const loraFields = document.querySelectorAll('.lora');
|
||||
const dpoFields = document.querySelectorAll('.dpo');
|
||||
const dpoParamCard = document.querySelector('.parameter-card.dpo');
|
||||
const ppoFields = document.querySelectorAll('.ppo');
|
||||
const ppoParamCard = document.querySelector('.parameter-card.ppo');
|
||||
|
||||
pretrainSftFields.forEach(field => {
|
||||
field.style.display = (trainType === 'pretrain' || trainType === 'sft') ? 'block' : 'none';
|
||||
field.style.display = (trainType === 'pretrain' || trainType === 'sft' || trainType === 'dpo' || trainType === 'ppo') ? 'block' : 'none';
|
||||
});
|
||||
|
||||
fromWeightFields.forEach(field => {
|
||||
field.style.display = (trainType !== 'ppo') ? 'block' : 'none';
|
||||
});
|
||||
|
||||
loraFields.forEach(field => {
|
||||
@ -48,10 +55,18 @@ document.getElementById('train_type').addEventListener('change', function() {
|
||||
field.style.display = trainType === 'dpo' ? 'block' : 'none';
|
||||
});
|
||||
|
||||
ppoFields.forEach(field => {
|
||||
field.style.display = trainType === 'ppo' ? 'block' : 'none';
|
||||
});
|
||||
|
||||
if (dpoParamCard) {
|
||||
dpoParamCard.style.display = trainType === 'dpo' ? 'block' : 'none';
|
||||
}
|
||||
|
||||
if (ppoParamCard) {
|
||||
ppoParamCard.style.display = trainType === 'ppo' ? 'block' : 'none';
|
||||
}
|
||||
|
||||
// 设置默认值
|
||||
if (trainType === 'pretrain') {
|
||||
document.getElementById('save_dir').value = '../out';
|
||||
@ -114,6 +129,27 @@ document.getElementById('train_type').addEventListener('change', function() {
|
||||
document.getElementById('num_hidden_layers').value = '8';
|
||||
document.getElementById('max_seq_len').value = '1024';
|
||||
document.getElementById('use_moe').value = '0';
|
||||
} else if (trainType === 'ppo') {
|
||||
document.getElementById('save_dir').value = '../out';
|
||||
document.getElementById('save_weight').value = 'ppo_actor';
|
||||
document.getElementById('epochs').value = '1';
|
||||
document.getElementById('batch_size').value = '2';
|
||||
document.getElementById('learning_rate').value = '8e-8';
|
||||
document.getElementById('data_path').value = '../dataset/rlaif-mini.jsonl';
|
||||
document.getElementById('log_interval').value = '1';
|
||||
document.getElementById('save_interval').value = '10';
|
||||
// PPO特有参数默认值
|
||||
document.getElementById('clip_epsilon').value = '0.1';
|
||||
document.getElementById('vf_coef').value = '0.5';
|
||||
document.getElementById('kl_coef').value = '0.02';
|
||||
document.getElementById('reasoning').value = '1';
|
||||
document.getElementById('update_old_actor_freq').value = '4';
|
||||
document.getElementById('reward_model_path').value = '../../internlm2-1_8b-reward';
|
||||
// 模型结构参数默认值
|
||||
document.getElementById('hidden_size').value = '512';
|
||||
document.getElementById('num_hidden_layers').value = '8';
|
||||
document.getElementById('max_seq_len').value = '66';
|
||||
document.getElementById('use_moe').value = '0';
|
||||
}
|
||||
});
|
||||
|
||||
24
trainer_web/stop_web_ui.sh
Executable file
24
trainer_web/stop_web_ui.sh
Executable file
@ -0,0 +1,24 @@
|
||||
#!/bin/bash
|
||||
|
||||
SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
|
||||
cd "$SCRIPT_DIR"
|
||||
|
||||
if [ -f "train_web_ui.pid" ]; then
|
||||
pid=$(cat "train_web_ui.pid")
|
||||
if ps -p "$pid" > /dev/null 2>&1; then
|
||||
echo "正在停止 Web UI 服务 (PID: $pid)"
|
||||
kill "$pid"
|
||||
sleep 2
|
||||
# 检查是否成功停止
|
||||
if ps -p "$pid" > /dev/null 2>&1; then
|
||||
echo "强制停止服务..."
|
||||
kill -9 "$pid"
|
||||
fi
|
||||
echo "服务已停止"
|
||||
else
|
||||
echo "服务未运行,但存在PID文件,已删除"
|
||||
rm "train_web_ui.pid"
|
||||
fi
|
||||
else
|
||||
echo "服务未运行(未找到PID文件)"
|
||||
fi
|
||||
@ -33,6 +33,7 @@
|
||||
<option value="sft">SFT - Full</option>
|
||||
<option value="lora">SFT - Lora</option>
|
||||
<option value="dpo">RL - DPO</option>
|
||||
<option value="ppo">RL - PPO</option>
|
||||
</select>
|
||||
</div>
|
||||
<div class="form-group">
|
||||
@ -60,7 +61,7 @@
|
||||
|
||||
<!-- 强化学习参数 -->
|
||||
<div class="parameter-card dpo" style="display: none;">
|
||||
<h3 class="card-title">强化学习参数</h3>
|
||||
<h3 class="card-title">强化学习参数 (DPO)</h3>
|
||||
<div class="parameter-content">
|
||||
<div class="form-group">
|
||||
<label for="beta">DPO Beta 参数:</label>
|
||||
@ -69,6 +70,40 @@
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- PPO强化学习参数 -->
|
||||
<div class="parameter-card ppo" style="display: none;">
|
||||
<h3 class="card-title">强化学习参数 (PPO)</h3>
|
||||
<div class="parameter-content">
|
||||
<div class="form-group">
|
||||
<label for="clip_epsilon">PPO剪切系数:</label>
|
||||
<input type="text" id="clip_epsilon" name="clip_epsilon" placeholder="0.2">
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<label for="vf_coef">价值函数系数:</label>
|
||||
<input type="text" id="vf_coef" name="vf_coef" placeholder="0.1">
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<label for="kl_coef">KL散度惩罚系数:</label>
|
||||
<input type="text" id="kl_coef" name="kl_coef" placeholder="0.01">
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<label for="reasoning">是否使用Reasoning模式:</label>
|
||||
<select id="reasoning" name="reasoning">
|
||||
<option value="0">否</option>
|
||||
<option value="1">是</option>
|
||||
</select>
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<label for="update_old_actor_freq">更新旧Actor频率:</label>
|
||||
<input type="number" id="update_old_actor_freq" name="update_old_actor_freq" placeholder="10" min="1">
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<label for="reward_model_path">奖励模型路径:</label>
|
||||
<input type="text" id="reward_model_path" name="reward_model_path" placeholder="path/to/reward/model">
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 模型结构参数 -->
|
||||
<div class="parameter-card">
|
||||
<h3 class="card-title">模型结构参数</h3>
|
||||
@ -115,7 +150,7 @@
|
||||
<label for="lora_name">LoRA权重名称:</label>
|
||||
<input type="text" id="lora_name" name="lora_name">
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<div class="form-group from-weight">
|
||||
<label for="from_weight">基于哪个权重训练:</label>
|
||||
<input type="text" id="from_weight" name="from_weight">
|
||||
</div>
|
||||
1
trainer_web/train_web_ui.pid
Normal file
1
trainer_web/train_web_ui.pid
Normal file
@ -0,0 +1 @@
|
||||
2746051
|
||||
@ -4,8 +4,11 @@ import subprocess
|
||||
import threading
|
||||
import json
|
||||
import socket
|
||||
import atexit
|
||||
import signal
|
||||
from flask import Flask, render_template, request, jsonify, redirect, url_for
|
||||
import time
|
||||
import psutil
|
||||
|
||||
# 尝试导入torch来检测GPU
|
||||
try:
|
||||
@ -46,6 +49,12 @@ app = Flask(__name__, template_folder='templates', static_folder='static')
|
||||
# 存储训练进程的信息
|
||||
training_processes = {}
|
||||
|
||||
# 进程信息持久化文件
|
||||
PROCESSES_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'training_processes.json')
|
||||
|
||||
# PID文件
|
||||
PID_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'train_web_ui.pid')
|
||||
|
||||
# 启动训练进程
|
||||
def start_training_process(train_type, params):
|
||||
# 获取脚本所在目录的绝对路径
|
||||
@ -102,17 +111,36 @@ def start_training_process(train_type, params):
|
||||
cmd.extend(['--accumulation_steps', params['accumulation_steps']])
|
||||
if 'grad_clip' in params and params['grad_clip']:
|
||||
cmd.extend(['--grad_clip', params['grad_clip']])
|
||||
elif train_type == 'ppo':
|
||||
script_path = '../trainer/train_ppo.py'
|
||||
if use_torchrun:
|
||||
cmd = ['torchrun', '--nproc_per_node', str(gpu_num), script_path]
|
||||
else:
|
||||
cmd = [sys.executable, script_path]
|
||||
# 添加PPO特定参数
|
||||
if 'clip_epsilon' in params and params['clip_epsilon']:
|
||||
cmd.extend(['--clip_epsilon', params['clip_epsilon']])
|
||||
if 'vf_coef' in params and params['vf_coef']:
|
||||
cmd.extend(['--vf_coef', params['vf_coef']])
|
||||
if 'kl_coef' in params and params['kl_coef']:
|
||||
cmd.extend(['--kl_coef', params['kl_coef']])
|
||||
if 'reasoning' in params and params['reasoning']:
|
||||
cmd.extend(['--reasoning', params['reasoning']])
|
||||
if 'update_old_actor_freq' in params and params['update_old_actor_freq']:
|
||||
cmd.extend(['--update_old_actor_freq', params['update_old_actor_freq']])
|
||||
if 'reward_model_path' in params and params['reward_model_path']:
|
||||
cmd.extend(['--reward_model_path', params['reward_model_path']])
|
||||
else:
|
||||
return None
|
||||
|
||||
# 添加通用参数
|
||||
for key, value in params.items():
|
||||
# 跳过特殊参数和DPO特有参数,以及gpu_num参数(因为已经在torchrun命令中使用)
|
||||
if key in ['train_type', 'save_weight', 'lora_name', 'train_monitor', 'beta', 'accumulation_steps', 'grad_clip', 'gpu_num']:
|
||||
# 跳过特殊参数和DPO、PPO特有参数,以及gpu_num参数(因为已经在torchrun命令中使用)
|
||||
# 对于PPO训练,跳过--from_weight参数
|
||||
if key in ['train_type', 'save_weight', 'lora_name', 'train_monitor', 'beta', 'accumulation_steps', 'grad_clip', 'gpu_num', 'clip_epsilon', 'vf_coef', 'kl_coef', 'reasoning', 'update_old_actor_freq', 'reward_model_path'] or (train_type == 'ppo' and key == 'from_weight'):
|
||||
continue
|
||||
|
||||
# 对于from_resume参数,需要正确传递参数值
|
||||
if key == 'from_resume':
|
||||
elif key == 'from_resume':
|
||||
# 确保传递参数名和参数值
|
||||
cmd.extend([f'--{key}', str(value)])
|
||||
else:
|
||||
@ -429,13 +457,99 @@ def find_available_port(start_port=5000, max_attempts=100):
|
||||
return port
|
||||
return None
|
||||
|
||||
def save_processes_info():
|
||||
"""保存训练进程信息到文件"""
|
||||
try:
|
||||
# 创建一个不包含进程对象的可序列化版本
|
||||
serializable_processes = {}
|
||||
for pid, info in training_processes.items():
|
||||
serializable_processes[pid] = {
|
||||
'pid': info.get('pid', info.get('process').pid) if isinstance(info.get('process'), subprocess.Popen) else info.get('pid'),
|
||||
'train_type': info['train_type'],
|
||||
'log_file': info['log_file'],
|
||||
'start_time': info['start_time'],
|
||||
'running': info['running'],
|
||||
'error': info.get('error', False),
|
||||
'manually_stopped': info.get('manually_stopped', False)
|
||||
}
|
||||
|
||||
with open(PROCESSES_FILE, 'w', encoding='utf-8') as f:
|
||||
json.dump(serializable_processes, f, ensure_ascii=False, indent=2)
|
||||
except Exception as e:
|
||||
print(f"保存进程信息失败: {str(e)}")
|
||||
|
||||
def load_processes_info():
|
||||
"""从文件加载训练进程信息"""
|
||||
global training_processes
|
||||
try:
|
||||
if os.path.exists(PROCESSES_FILE):
|
||||
with open(PROCESSES_FILE, 'r', encoding='utf-8') as f:
|
||||
loaded_processes = json.load(f)
|
||||
|
||||
# 检查每个进程是否还在运行
|
||||
for pid, info in loaded_processes.items():
|
||||
if info['running']:
|
||||
try:
|
||||
# 检查进程是否还在运行
|
||||
proc = psutil.Process(info['pid'])
|
||||
if proc.is_running() and proc.status() != 'zombie':
|
||||
# 进程仍在运行,恢复信息
|
||||
training_processes[pid] = info
|
||||
else:
|
||||
# 进程已停止
|
||||
info['running'] = False
|
||||
training_processes[pid] = info
|
||||
except (psutil.NoSuchProcess, psutil.AccessDenied):
|
||||
# 进程不存在或无权限访问
|
||||
info['running'] = False
|
||||
training_processes[pid] = info
|
||||
else:
|
||||
# 进程已停止,直接恢复
|
||||
training_processes[pid] = info
|
||||
except Exception as e:
|
||||
print(f"加载进程信息失败: {str(e)}")
|
||||
|
||||
def handle_exit(signum, frame):
|
||||
"""处理程序退出信号,保存进程信息"""
|
||||
print("正在保存进程信息...")
|
||||
save_processes_info()
|
||||
# 删除PID文件
|
||||
if os.path.exists(PID_FILE):
|
||||
try:
|
||||
os.remove(PID_FILE)
|
||||
except:
|
||||
pass
|
||||
sys.exit(0)
|
||||
|
||||
# 注册退出处理器
|
||||
signal.signal(signal.SIGINT, handle_exit) # Ctrl+C
|
||||
if hasattr(signal, 'SIGTERM'):
|
||||
signal.signal(signal.SIGTERM, handle_exit) # 终止信号
|
||||
|
||||
# 注册程序退出时的处理函数
|
||||
atexit.register(save_processes_info)
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 加载已保存的进程信息
|
||||
load_processes_info()
|
||||
|
||||
# 创建PID文件,用于标识web进程
|
||||
with open(PID_FILE, 'w') as f:
|
||||
f.write(str(os.getpid()))
|
||||
|
||||
# 尝试使用默认端口5000,如果被占用则自动寻找可用端口
|
||||
port = find_available_port(5000)
|
||||
if port is not None:
|
||||
print(f"启动Flask服务器在 http://0.0.0.0:{port}")
|
||||
print(f"使用nohup启动可保持服务持续运行: nohup python -u scripts/train_web_ui.py &")
|
||||
# 使用0.0.0.0作为host以兼容VSCode的端口转发功能
|
||||
app.run(host='0.0.0.0', port=port, debug=True)
|
||||
app.run(host='0.0.0.0', port=port, debug=False) # 生产环境关闭debug
|
||||
else:
|
||||
print("无法找到可用的端口,请检查系统端口占用情况")
|
||||
# 删除PID文件
|
||||
if os.path.exists(PID_FILE):
|
||||
try:
|
||||
os.remove(PID_FILE)
|
||||
except:
|
||||
pass
|
||||
sys.exit(1)
|
||||
29
trainer_web/training_processes.json
Normal file
29
trainer_web/training_processes.json
Normal file
@ -0,0 +1,29 @@
|
||||
{
|
||||
"20251111_030331": {
|
||||
"pid": 2327112,
|
||||
"train_type": "pretrain",
|
||||
"log_file": "/file_system/wyz/minimind/logfile/train_pretrain_20251111_030331.log",
|
||||
"start_time": "2025-11-11 03:03:31",
|
||||
"running": false,
|
||||
"error": false,
|
||||
"manually_stopped": false
|
||||
},
|
||||
"20251111_030507": {
|
||||
"pid": 2332304,
|
||||
"train_type": "sft",
|
||||
"log_file": "/file_system/wyz/minimind/logfile/train_sft_20251111_030507.log",
|
||||
"start_time": "2025-11-11 03:05:07",
|
||||
"running": false,
|
||||
"error": false,
|
||||
"manually_stopped": false
|
||||
},
|
||||
"20251111_035709": {
|
||||
"pid": 2626329,
|
||||
"train_type": "ppo",
|
||||
"log_file": "/file_system/wyz/minimind/logfile/train_ppo_20251111_035709.log",
|
||||
"start_time": "2025-11-11 03:57:09",
|
||||
"running": false,
|
||||
"error": true,
|
||||
"manually_stopped": false
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user