support ppo and Training web code refactoring

This commit is contained in:
yuyu5333 2025-11-11 07:40:38 +00:00
parent 237744d58a
commit d66a7945db
9 changed files with 287 additions and 8 deletions

40
trainer_web/start_web_ui.sh Executable file
View File

@ -0,0 +1,40 @@
#!/bin/bash
# 获取脚本所在目录
SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
cd "$SCRIPT_DIR"
# 检查是否已经有实例在运行
if [ -f "train_web_ui.pid" ]; then
pid=$(cat "train_web_ui.pid")
if ps -p "$pid" > /dev/null 2>&1; then
echo "Web UI 服务已经在运行 (PID: $pid)"
exit 1
else
echo "删除旧的PID文件"
rm "train_web_ui.pid"
fi
fi
# 创建日志目录
LOG_DIR="../logfile"
mkdir -p "$LOG_DIR"
# 生成时间戳
TIMESTAMP=$(date +"%Y%m%d_%H%M%S")
LOG_FILE="$LOG_DIR/web_ui_$TIMESTAMP.log"
echo "启动 MiniMind Web UI 服务..."
echo "日志文件: $LOG_FILE"
# 使用nohup启动服务
nohup python -u train_web_ui.py > "$LOG_FILE" 2>&1 &
# 保存PID
echo $! > "train_web_ui.pid"
sleep 2
echo "服务已启动! PID: $(cat "train_web_ui.pid")"
echo "访问地址: http://localhost:5000"
echo "停止命令: kill $(cat "train_web_ui.pid") or ./trainer_web/stop_web_ui.sh"

View File

Before

Width:  |  Height:  |  Size: 615 KiB

After

Width:  |  Height:  |  Size: 615 KiB

View File

@ -32,12 +32,19 @@ function openTab(evt, tabName) {
document.getElementById('train_type').addEventListener('change', function() {
const trainType = this.value;
const pretrainSftFields = document.querySelectorAll('.pretrain-sft');
const fromWeightFields = document.querySelectorAll('.from-weight');
const loraFields = document.querySelectorAll('.lora');
const dpoFields = document.querySelectorAll('.dpo');
const dpoParamCard = document.querySelector('.parameter-card.dpo');
const ppoFields = document.querySelectorAll('.ppo');
const ppoParamCard = document.querySelector('.parameter-card.ppo');
pretrainSftFields.forEach(field => {
field.style.display = (trainType === 'pretrain' || trainType === 'sft') ? 'block' : 'none';
field.style.display = (trainType === 'pretrain' || trainType === 'sft' || trainType === 'dpo' || trainType === 'ppo') ? 'block' : 'none';
});
fromWeightFields.forEach(field => {
field.style.display = (trainType !== 'ppo') ? 'block' : 'none';
});
loraFields.forEach(field => {
@ -48,10 +55,18 @@ document.getElementById('train_type').addEventListener('change', function() {
field.style.display = trainType === 'dpo' ? 'block' : 'none';
});
ppoFields.forEach(field => {
field.style.display = trainType === 'ppo' ? 'block' : 'none';
});
if (dpoParamCard) {
dpoParamCard.style.display = trainType === 'dpo' ? 'block' : 'none';
}
if (ppoParamCard) {
ppoParamCard.style.display = trainType === 'ppo' ? 'block' : 'none';
}
// 设置默认值
if (trainType === 'pretrain') {
document.getElementById('save_dir').value = '../out';
@ -114,6 +129,27 @@ document.getElementById('train_type').addEventListener('change', function() {
document.getElementById('num_hidden_layers').value = '8';
document.getElementById('max_seq_len').value = '1024';
document.getElementById('use_moe').value = '0';
} else if (trainType === 'ppo') {
document.getElementById('save_dir').value = '../out';
document.getElementById('save_weight').value = 'ppo_actor';
document.getElementById('epochs').value = '1';
document.getElementById('batch_size').value = '2';
document.getElementById('learning_rate').value = '8e-8';
document.getElementById('data_path').value = '../dataset/rlaif-mini.jsonl';
document.getElementById('log_interval').value = '1';
document.getElementById('save_interval').value = '10';
// PPO特有参数默认值
document.getElementById('clip_epsilon').value = '0.1';
document.getElementById('vf_coef').value = '0.5';
document.getElementById('kl_coef').value = '0.02';
document.getElementById('reasoning').value = '1';
document.getElementById('update_old_actor_freq').value = '4';
document.getElementById('reward_model_path').value = '../../internlm2-1_8b-reward';
// 模型结构参数默认值
document.getElementById('hidden_size').value = '512';
document.getElementById('num_hidden_layers').value = '8';
document.getElementById('max_seq_len').value = '66';
document.getElementById('use_moe').value = '0';
}
});

24
trainer_web/stop_web_ui.sh Executable file
View File

@ -0,0 +1,24 @@
#!/bin/bash
SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
cd "$SCRIPT_DIR"
if [ -f "train_web_ui.pid" ]; then
pid=$(cat "train_web_ui.pid")
if ps -p "$pid" > /dev/null 2>&1; then
echo "正在停止 Web UI 服务 (PID: $pid)"
kill "$pid"
sleep 2
# 检查是否成功停止
if ps -p "$pid" > /dev/null 2>&1; then
echo "强制停止服务..."
kill -9 "$pid"
fi
echo "服务已停止"
else
echo "服务未运行但存在PID文件已删除"
rm "train_web_ui.pid"
fi
else
echo "服务未运行未找到PID文件"
fi

View File

@ -33,6 +33,7 @@
<option value="sft">SFT - Full</option>
<option value="lora">SFT - Lora</option>
<option value="dpo">RL - DPO</option>
<option value="ppo">RL - PPO</option>
</select>
</div>
<div class="form-group">
@ -60,7 +61,7 @@
<!-- 强化学习参数 -->
<div class="parameter-card dpo" style="display: none;">
<h3 class="card-title">强化学习参数</h3>
<h3 class="card-title">强化学习参数 (DPO)</h3>
<div class="parameter-content">
<div class="form-group">
<label for="beta">DPO Beta 参数:</label>
@ -69,6 +70,40 @@
</div>
</div>
<!-- PPO强化学习参数 -->
<div class="parameter-card ppo" style="display: none;">
<h3 class="card-title">强化学习参数 (PPO)</h3>
<div class="parameter-content">
<div class="form-group">
<label for="clip_epsilon">PPO剪切系数</label>
<input type="text" id="clip_epsilon" name="clip_epsilon" placeholder="0.2">
</div>
<div class="form-group">
<label for="vf_coef">价值函数系数:</label>
<input type="text" id="vf_coef" name="vf_coef" placeholder="0.1">
</div>
<div class="form-group">
<label for="kl_coef">KL散度惩罚系数</label>
<input type="text" id="kl_coef" name="kl_coef" placeholder="0.01">
</div>
<div class="form-group">
<label for="reasoning">是否使用Reasoning模式</label>
<select id="reasoning" name="reasoning">
<option value="0"></option>
<option value="1"></option>
</select>
</div>
<div class="form-group">
<label for="update_old_actor_freq">更新旧Actor频率</label>
<input type="number" id="update_old_actor_freq" name="update_old_actor_freq" placeholder="10" min="1">
</div>
<div class="form-group">
<label for="reward_model_path">奖励模型路径:</label>
<input type="text" id="reward_model_path" name="reward_model_path" placeholder="path/to/reward/model">
</div>
</div>
</div>
<!-- 模型结构参数 -->
<div class="parameter-card">
<h3 class="card-title">模型结构参数</h3>
@ -115,7 +150,7 @@
<label for="lora_name">LoRA权重名称</label>
<input type="text" id="lora_name" name="lora_name">
</div>
<div class="form-group">
<div class="form-group from-weight">
<label for="from_weight">基于哪个权重训练:</label>
<input type="text" id="from_weight" name="from_weight">
</div>

View File

@ -0,0 +1 @@
2746051

View File

@ -4,8 +4,11 @@ import subprocess
import threading
import json
import socket
import atexit
import signal
from flask import Flask, render_template, request, jsonify, redirect, url_for
import time
import psutil
# 尝试导入torch来检测GPU
try:
@ -46,6 +49,12 @@ app = Flask(__name__, template_folder='templates', static_folder='static')
# 存储训练进程的信息
training_processes = {}
# 进程信息持久化文件
PROCESSES_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'training_processes.json')
# PID文件
PID_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'train_web_ui.pid')
# 启动训练进程
def start_training_process(train_type, params):
# 获取脚本所在目录的绝对路径
@ -102,17 +111,36 @@ def start_training_process(train_type, params):
cmd.extend(['--accumulation_steps', params['accumulation_steps']])
if 'grad_clip' in params and params['grad_clip']:
cmd.extend(['--grad_clip', params['grad_clip']])
elif train_type == 'ppo':
script_path = '../trainer/train_ppo.py'
if use_torchrun:
cmd = ['torchrun', '--nproc_per_node', str(gpu_num), script_path]
else:
cmd = [sys.executable, script_path]
# 添加PPO特定参数
if 'clip_epsilon' in params and params['clip_epsilon']:
cmd.extend(['--clip_epsilon', params['clip_epsilon']])
if 'vf_coef' in params and params['vf_coef']:
cmd.extend(['--vf_coef', params['vf_coef']])
if 'kl_coef' in params and params['kl_coef']:
cmd.extend(['--kl_coef', params['kl_coef']])
if 'reasoning' in params and params['reasoning']:
cmd.extend(['--reasoning', params['reasoning']])
if 'update_old_actor_freq' in params and params['update_old_actor_freq']:
cmd.extend(['--update_old_actor_freq', params['update_old_actor_freq']])
if 'reward_model_path' in params and params['reward_model_path']:
cmd.extend(['--reward_model_path', params['reward_model_path']])
else:
return None
# 添加通用参数
for key, value in params.items():
# 跳过特殊参数和DPO特有参数以及gpu_num参数因为已经在torchrun命令中使用
if key in ['train_type', 'save_weight', 'lora_name', 'train_monitor', 'beta', 'accumulation_steps', 'grad_clip', 'gpu_num']:
# 跳过特殊参数和DPO、PPO特有参数以及gpu_num参数因为已经在torchrun命令中使用
# 对于PPO训练跳过--from_weight参数
if key in ['train_type', 'save_weight', 'lora_name', 'train_monitor', 'beta', 'accumulation_steps', 'grad_clip', 'gpu_num', 'clip_epsilon', 'vf_coef', 'kl_coef', 'reasoning', 'update_old_actor_freq', 'reward_model_path'] or (train_type == 'ppo' and key == 'from_weight'):
continue
# 对于from_resume参数需要正确传递参数值
if key == 'from_resume':
elif key == 'from_resume':
# 确保传递参数名和参数值
cmd.extend([f'--{key}', str(value)])
else:
@ -429,13 +457,99 @@ def find_available_port(start_port=5000, max_attempts=100):
return port
return None
def save_processes_info():
"""保存训练进程信息到文件"""
try:
# 创建一个不包含进程对象的可序列化版本
serializable_processes = {}
for pid, info in training_processes.items():
serializable_processes[pid] = {
'pid': info.get('pid', info.get('process').pid) if isinstance(info.get('process'), subprocess.Popen) else info.get('pid'),
'train_type': info['train_type'],
'log_file': info['log_file'],
'start_time': info['start_time'],
'running': info['running'],
'error': info.get('error', False),
'manually_stopped': info.get('manually_stopped', False)
}
with open(PROCESSES_FILE, 'w', encoding='utf-8') as f:
json.dump(serializable_processes, f, ensure_ascii=False, indent=2)
except Exception as e:
print(f"保存进程信息失败: {str(e)}")
def load_processes_info():
"""从文件加载训练进程信息"""
global training_processes
try:
if os.path.exists(PROCESSES_FILE):
with open(PROCESSES_FILE, 'r', encoding='utf-8') as f:
loaded_processes = json.load(f)
# 检查每个进程是否还在运行
for pid, info in loaded_processes.items():
if info['running']:
try:
# 检查进程是否还在运行
proc = psutil.Process(info['pid'])
if proc.is_running() and proc.status() != 'zombie':
# 进程仍在运行,恢复信息
training_processes[pid] = info
else:
# 进程已停止
info['running'] = False
training_processes[pid] = info
except (psutil.NoSuchProcess, psutil.AccessDenied):
# 进程不存在或无权限访问
info['running'] = False
training_processes[pid] = info
else:
# 进程已停止,直接恢复
training_processes[pid] = info
except Exception as e:
print(f"加载进程信息失败: {str(e)}")
def handle_exit(signum, frame):
"""处理程序退出信号,保存进程信息"""
print("正在保存进程信息...")
save_processes_info()
# 删除PID文件
if os.path.exists(PID_FILE):
try:
os.remove(PID_FILE)
except:
pass
sys.exit(0)
# 注册退出处理器
signal.signal(signal.SIGINT, handle_exit) # Ctrl+C
if hasattr(signal, 'SIGTERM'):
signal.signal(signal.SIGTERM, handle_exit) # 终止信号
# 注册程序退出时的处理函数
atexit.register(save_processes_info)
if __name__ == '__main__':
# 加载已保存的进程信息
load_processes_info()
# 创建PID文件用于标识web进程
with open(PID_FILE, 'w') as f:
f.write(str(os.getpid()))
# 尝试使用默认端口5000如果被占用则自动寻找可用端口
port = find_available_port(5000)
if port is not None:
print(f"启动Flask服务器在 http://0.0.0.0:{port}")
print(f"使用nohup启动可保持服务持续运行: nohup python -u scripts/train_web_ui.py &")
# 使用0.0.0.0作为host以兼容VSCode的端口转发功能
app.run(host='0.0.0.0', port=port, debug=True)
app.run(host='0.0.0.0', port=port, debug=False) # 生产环境关闭debug
else:
print("无法找到可用的端口,请检查系统端口占用情况")
# 删除PID文件
if os.path.exists(PID_FILE):
try:
os.remove(PID_FILE)
except:
pass
sys.exit(1)

View File

@ -0,0 +1,29 @@
{
"20251111_030331": {
"pid": 2327112,
"train_type": "pretrain",
"log_file": "/file_system/wyz/minimind/logfile/train_pretrain_20251111_030331.log",
"start_time": "2025-11-11 03:03:31",
"running": false,
"error": false,
"manually_stopped": false
},
"20251111_030507": {
"pid": 2332304,
"train_type": "sft",
"log_file": "/file_system/wyz/minimind/logfile/train_sft_20251111_030507.log",
"start_time": "2025-11-11 03:05:07",
"running": false,
"error": false,
"manually_stopped": false
},
"20251111_035709": {
"pid": 2626329,
"train_type": "ppo",
"log_file": "/file_system/wyz/minimind/logfile/train_ppo_20251111_035709.log",
"start_time": "2025-11-11 03:57:09",
"running": false,
"error": true,
"manually_stopped": false
}
}