From 9bfb58ad1e482c4d122beb74a0632159d72f370b Mon Sep 17 00:00:00 2001 From: yuyu5333 <1812107659@qq.com> Date: Thu, 6 Nov 2025 22:34:22 +0800 Subject: [PATCH] Init Web UI --- .gitignore | 6 +- scripts/static/css/style.css | 338 +++++++++++++++++++++++++++++++++++ scripts/static/js/script.js | 242 +++++++++++++++++++++++++ scripts/templates/index.html | 163 +++++++++++++++++ scripts/train_web_ui.py | 231 ++++++++++++++++++++++++ 5 files changed, 979 insertions(+), 1 deletion(-) create mode 100644 scripts/static/css/style.css create mode 100644 scripts/static/js/script.js create mode 100644 scripts/templates/index.html create mode 100644 scripts/train_web_ui.py diff --git a/.gitignore b/.gitignore index aee52b1..8675a0e 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,8 @@ +__pycache__ model/__pycache__ out website/ -docs-minimind/ \ No newline at end of file +docs-minimind/ +logfile +dataset +checkpoints \ No newline at end of file diff --git a/scripts/static/css/style.css b/scripts/static/css/style.css new file mode 100644 index 0000000..167b5b0 --- /dev/null +++ b/scripts/static/css/style.css @@ -0,0 +1,338 @@ +body { + font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, 'Helvetica Neue', Arial, sans-serif; + line-height: 1.6; + color: #e0e0e0; + max-width: 1200px; + margin: 0 auto; + padding: 20px; + background-color: #121212; + min-height: 100vh; +} +h1 { + color: #ffffff; + text-align: center; + margin-bottom: 30px; + text-shadow: 0 2px 10px rgba(0, 0, 0, 0.5); + font-size: 2.5em; + font-weight: 700; +} +.tabs { + display: flex; + justify-content: space-between; + margin-bottom: 20px; + border-radius: 10px; + overflow: hidden; + box-shadow: 0 5px 15px rgba(0, 0, 0, 0.1); +} +.tab { + padding: 12px 0; + cursor: pointer; + background-color: #2a2a2a; + border: none; + font-size: 16px; + font-weight: 500; + transition: all 0.3s ease; + position: relative; + color: #cccccc; + width: 30%; + text-align: center; +} +.tab.active { + background: linear-gradient(135deg, #4a148c 0%, #8e24aa 100%); + color: white; + transform: translateY(-2px); + box-shadow: 0 5px 15px rgba(0, 0, 0, 0.3); +} +.form-container { + background: rgba(30, 30, 30, 0.9); + padding: 30px; + border-radius: 15px; + box-shadow: 0 10px 30px rgba(0, 0, 0, 0.3); + margin-bottom: 30px; + border: 1px solid #333; +} + +/* 参数卡片样式 */ +.parameter-card { + background-color: #2d2d2d; + border-radius: 8px; + padding: 20px; + margin-bottom: 20px; + box-shadow: 0 2px 4px rgba(0, 0, 0, 0.3); + transition: transform 0.2s ease, box-shadow 0.2s ease; + padding-left: 5%; + padding-right: 5%; +} + +.parameter-card:hover { + transform: translateY(-2px); + box-shadow: 0 4px 8px rgba(0, 0, 0, 0.4); +} + +/* 卡片标题样式 */ +.card-title { + color: #e0e0e0; + font-size: 1.4rem; + font-weight: bold; + margin-top: 0; + margin-bottom: 20px; + padding-bottom: 10px; + border-bottom: 1px solid #404040; + width: 100%; +} + +/* 提交按钮容器 */ +.submit-container { + text-align: center; + margin-top: 30px; + padding-top: 20px; + border-top: 1px solid #4d4d4d; +} + +/* 参数内容容器 */ +.parameter-content { + width: 100%; +} + +.form-group { + width: 40%; + float: left; + margin-bottom: 15px; + margin-right: 10%; + box-sizing: border-box; +} + +.form-group:nth-child(2n) { + margin-right: 0; +} + +/* 确保复选框组占满整行 */ +.form-group.checkbox-group { + width: 100%; + margin-right: 0; +} + +.form-group.pretrain-sft, .form-group.lora { + /* 移除100%宽度设置,让这些参数也遵循每行两个的布局 */ +} + +.parameter-content::after { + content: ""; + display: table; + clear: both; +} +label { + display: block; + margin-bottom: 5px; + color: #ffffff; + font-weight: 600; + font-size: 0.9rem; + text-transform: uppercase; + letter-spacing: 0.5px; +} +input[type="text"], input[type="number"], select { + width: 100%; + padding: 12px 15px; + border: 2px solid #444; + border-radius: 8px; + font-size: 0.9rem; + transition: all 0.3s ease; + background-color: #2a2a2a; + color: #ffffff; +} + +/* 确保textarea也适应两列布局 */ +textarea { + width: 100%; + padding: 12px 15px; + border: 2px solid #444; + border-radius: 8px; + font-size: 0.9rem; + transition: all 0.3s ease; + background-color: #2a2a2a; + color: #ffffff; + resize: vertical; +} + +input[type="text"]:focus, input[type="number"]:focus, select:focus { + outline: none; + border-color: #8e24aa; + box-shadow: 0 0 0 3px rgba(142, 36, 170, 0.2); + background-color: #333; +} +.checkbox-group { + display: flex; + align-items: center; +} +.checkbox-group input[type="checkbox"] { + width: auto; + margin-right: 10px; +} +button { + background: linear-gradient(135deg, #4a148c 0%, #8e24aa 100%); + color: white; + border: none; + padding: 12px 25px; + border-radius: 8px; + font-size: 16px; + font-weight: 600; + cursor: pointer; + transition: all 0.3s ease; + box-shadow: 0 4px 15px rgba(0, 0, 0, 0.3); + position: relative; + overflow: hidden; +} +button:hover { + transform: translateY(-2px); + box-shadow: 0 6px 20px rgba(0, 0, 0, 0.15); +} + +button:active { + transform: translateY(0); +} +.logs-container { + background-color: #0d0d0d; + color: #e0e0e0; + padding: 20px; + border-radius: 10px; + max-height: 300px; + overflow-y: auto; + margin-top: 15px; + font-family: 'Courier New', monospace; + box-shadow: 0 5px 15px rgba(0, 0, 0, 0.4); + transition: all 0.3s ease; + border: 1px solid #333; +} + +.logs-container:hover { + box-shadow: 0 8px 25px rgba(0, 0, 0, 0.3); +} +.process-item { + background: rgba(30, 30, 30, 0.9); + padding: 20px; + margin-bottom: 15px; + border-radius: 12px; + box-shadow: 0 4px 15px rgba(0, 0, 0, 0.3); + transition: all 0.3s ease; + border: 1px solid #444; +} + +.process-item:hover { + transform: translateY(-2px); + box-shadow: 0 6px 20px rgba(0, 0, 0, 0.15); +} +.process-info { + display: flex; + justify-content: space-between; + align-items: center; + margin-bottom: 10px; +} +.process-status { + padding: 5px 12px; + border-radius: 20px; + font-size: 12px; + font-weight: bold; + text-transform: uppercase; + letter-spacing: 0.5px; +} +.status-running { + background: linear-gradient(135deg, #11998e 0%, #38ef7d 100%); + color: white; +} +.status-completed { + background: linear-gradient(135deg, #f093fb 0%, #f5576c 100%); + color: white; +} +.status-error { + background: linear-gradient(135deg, #ff416c 0%, #ff4b2b 100%); + color: white; +} +.btn-stop { + background: linear-gradient(135deg, #ff416c 0%, #ff4b2b 100%); + padding: 8px 15px; + font-size: 14px; + border-radius: 6px; +} +.btn-stop:hover { + transform: translateY(-1px); + box-shadow: 0 4px 10px rgba(255, 65, 108, 0.3); +} +.btn-logs { + background: linear-gradient(135deg, #4facfe 0%, #00f2fe 100%); + padding: 8px 15px; + font-size: 14px; + margin-right: 10px; + border-radius: 6px; +} +.btn-logs:hover { + transform: translateY(-1px); + box-shadow: 0 4px 10px rgba(79, 172, 254, 0.3); +} +.hidden { + display: none; +} +.section-title { + color: #ffffff; + font-size: 20px; + margin-bottom: 25px; + text-shadow: 0 2px 5px rgba(0, 0, 0, 0.5); + font-weight: 600; + padding-bottom: 10px; + border-bottom: 3px solid rgba(142, 36, 170, 0.3); +} + +/* 添加滚动条样式 */ +::-webkit-scrollbar { + width: 8px; +} + +::-webkit-scrollbar-track { + background: rgba(255, 255, 255, 0.05); + border-radius: 10px; +} + +::-webkit-scrollbar-thumb { + background: linear-gradient(135deg, #4a148c 0%, #8e24aa 100%); + border-radius: 10px; +} + +::-webkit-scrollbar-thumb:hover { + background: linear-gradient(135deg, #6a1b9a 0%, #ab47bc 100%); +} + +/* 添加动画效果 */ +@keyframes fadeIn { + from { + opacity: 0; + transform: translateY(10px); + } + to { + opacity: 1; + transform: translateY(0); + } +} + +.tab-content { + animation: fadeIn 0.5s ease-out; +} + +/* 响应式设计 */ +@media (max-width: 768px) { + body { + padding: 10px; + } + + .tabs { + flex-direction: column; + } + + .tab { + margin-right: 0; + margin-bottom: 5px; + border-radius: 5px; + } + + .form-container { + padding: 20px; + } +} \ No newline at end of file diff --git a/scripts/static/js/script.js b/scripts/static/js/script.js new file mode 100644 index 0000000..9d1756e --- /dev/null +++ b/scripts/static/js/script.js @@ -0,0 +1,242 @@ +// 标签切换功能 +function openTab(evt, tabName) { + var i, tabContent, tabLinks; + tabContent = document.getElementsByClassName("tab-content"); + for (i = 0; i < tabContent.length; i++) { + tabContent[i].classList.add("hidden"); + } + tabLinks = document.getElementsByClassName("tab"); + for (i = 0; i < tabLinks.length; i++) { + tabLinks[i].classList.remove("active"); + } + document.getElementById(tabName).classList.remove("hidden"); + evt.currentTarget.classList.add("active"); + + // 如果切换到进程页面,刷新进程列表 + if (tabName === 'processes') { + loadProcesses(); + } else if (tabName === 'logfiles') { + loadLogFiles(); + } +} + +// 根据训练类型显示/隐藏特定参数 +document.getElementById('train_type').addEventListener('change', function() { + const trainType = this.value; + const pretrainSftFields = document.querySelectorAll('.pretrain-sft'); + const loraFields = document.querySelectorAll('.lora'); + + pretrainSftFields.forEach(field => { + field.style.display = (trainType === 'pretrain' || trainType === 'sft') ? 'block' : 'none'; + }); + + loraFields.forEach(field => { + field.style.display = trainType === 'lora' ? 'block' : 'none'; + }); + + // 设置默认值 + if (trainType === 'pretrain') { + document.getElementById('save_dir').value = '../out'; + document.getElementById('save_weight').value = 'pretrain'; + document.getElementById('epochs').value = '1'; + document.getElementById('batch_size').value = '32'; + document.getElementById('learning_rate').value = '5e-4'; + document.getElementById('data_path').value = '../dataset/pretrain_hq.jsonl'; + document.getElementById('from_weight').value = 'none'; + document.getElementById('log_interval').value = '100'; + document.getElementById('save_interval').value = '100'; + } else if (trainType === 'sft') { + document.getElementById('save_dir').value = '../out'; + document.getElementById('save_weight').value = 'full_sft'; + document.getElementById('epochs').value = '2'; + document.getElementById('batch_size').value = '16'; + document.getElementById('learning_rate').value = '5e-7'; + document.getElementById('data_path').value = '../dataset/sft_mini_512.jsonl'; + document.getElementById('from_weight').value = 'pretrain'; + document.getElementById('log_interval').value = '100'; + document.getElementById('save_interval').value = '100'; + } else if (trainType === 'lora') { + document.getElementById('save_dir').value = '../out/lora'; + document.getElementById('lora_name').value = 'lora_identity'; + document.getElementById('epochs').value = '50'; + document.getElementById('batch_size').value = '32'; + document.getElementById('learning_rate').value = '1e-4'; + document.getElementById('data_path').value = '../dataset/lora_identity.jsonl'; + document.getElementById('from_weight').value = 'full_sft'; + document.getElementById('log_interval').value = '10'; + document.getElementById('save_interval').value = '1'; + } +}); + +// 初始触发一次change事件以设置默认值 +document.getElementById('train_type').dispatchEvent(new Event('change')); + +// 加载进程列表 +function loadProcesses() { + fetch('/processes') + .then(response => response.json()) + .then(data => { + const processList = document.getElementById('process-list'); + processList.innerHTML = ''; + + if (data.length === 0) { + processList.innerHTML = '

暂无训练进程

'; + return; + } + + data.forEach(process => { + const processItem = document.createElement('div'); + processItem.className = 'process-item'; + + let statusClass = ''; + let statusText = ''; + if (process.running) { + statusClass = 'status-running'; + statusText = '运行中'; + } else if (process.error) { + statusClass = 'status-error'; + statusText = '出错'; + } else { + statusClass = 'status-completed'; + statusText = '已完成'; + } + + processItem.innerHTML = ` +
+
+ ${process.train_type} - ${process.start_time} +
+
+ ${statusText} +
+
+
+ + ${process.running ? `` : ''} +
+ + `; + + processList.appendChild(processItem); + }); + }); +} + +// 显示日志 +function showLogs(processId) { + const logsContainer = document.getElementById(`logs-${processId}`); + logsContainer.classList.toggle('hidden'); + + if (!logsContainer.classList.contains('hidden')) { + fetch(`/logs/${processId}`) + .then(response => response.text()) + .then(logs => { + logsContainer.textContent = logs; + logsContainer.scrollTop = logsContainer.scrollHeight; + }); + } +} + +// 停止进程 +function stopProcess(processId) { + if (confirm('确定要停止这个训练进程吗?')) { + fetch(`/stop/${processId}`, { + method: 'POST' + }) + .then(() => { + loadProcesses(); + }); + } +} + +// 表单提交处理 +document.getElementById('train-form').addEventListener('submit', function(e) { + e.preventDefault(); + const formData = new FormData(this); + const data = Object.fromEntries(formData.entries()); + + fetch('/train', { + method: 'POST', + headers: { + 'Content-Type': 'application/json' + }, + body: JSON.stringify(data) + }) + .then(response => response.json()) + .then(result => { + if (result.success) { + alert('训练已开始!'); + openTab(event, 'processes'); + } else { + alert('训练启动失败:' + result.error); + } + }); +}); + +// 加载日志文件列表 +function loadLogFiles() { + fetch('/logfiles') + .then(response => response.json()) + .then(data => { + const logfilesList = document.getElementById('logfiles-list'); + logfilesList.innerHTML = ''; + + if (data.length === 0) { + logfilesList.innerHTML = '

暂无日志文件

'; + return; + } + + // 按日期倒序排序 + data.sort((a, b) => new Date(b.modified_time) - new Date(a.modified_time)); + + data.forEach(logfile => { + const fileItem = document.createElement('div'); + fileItem.className = 'process-item'; + + // 从文件名提取训练类型 + let trainType = '未知'; + if (logfile.filename.includes('train_pretrain_')) { + trainType = 'pretrain'; + } else if (logfile.filename.includes('train_sft_')) { + trainType = 'sft'; + } else if (logfile.filename.includes('train_lora_')) { + trainType = 'lora'; + } + + fileItem.innerHTML = ` +
+
+ ${trainType} - ${logfile.modified_time} +
+
+ 已保存 +
+
+
+ +
+ + `; + + logfilesList.appendChild(fileItem); + }); + }); +} + +// 查看日志文件内容 +function viewLogFile(filename, button) { + const safeFilename = filename.replace(/\./g, '-'); + const logContainer = button.closest('.process-item').querySelector(`#log-content-${safeFilename}`); + logContainer.classList.toggle('hidden'); + + if (!logContainer.classList.contains('hidden') && logContainer.textContent === '') { + logContainer.textContent = '加载日志中...'; + + fetch(`/logfile-content/${encodeURIComponent(filename)}`) + .then(response => response.text()) + .then(logs => { + logContainer.textContent = logs; + logContainer.scrollTop = 0; + }); + } +} \ No newline at end of file diff --git a/scripts/templates/index.html b/scripts/templates/index.html new file mode 100644 index 0000000..06a1afe --- /dev/null +++ b/scripts/templates/index.html @@ -0,0 +1,163 @@ + + + + + + MiniMind 训练 Web UI + + + +

MiniMind 训练 Web UI

+ +
+ + + +
+ +
+
+

选择训练类型并配置参数

+
+ +
+

基础训练参数

+
+
+ + +
+
+ + +
+
+ + +
+
+ + +
+
+ + +
+
+ + +
+
+
+ + +
+

模型结构参数

+
+
+ + +
+
+ + +
+
+ + +
+
+ + +
+
+
+ + +
+

硬件配置

+
+
+ + +
+
+
+ + +
+

模型保存与恢复

+
+
+ + +
+
+ + +
+
+ + +
+
+ + +
+
+ + +
+
+
+ + +
+
+
+
+ + +
+

其他设置

+
+
+
+ + +
+
+
+
+ +
+ +
+
+
+
+ + + + + + + + \ No newline at end of file diff --git a/scripts/train_web_ui.py b/scripts/train_web_ui.py new file mode 100644 index 0000000..8feaaf1 --- /dev/null +++ b/scripts/train_web_ui.py @@ -0,0 +1,231 @@ +import os +import sys +import subprocess +import threading +import json +from flask import Flask, render_template, request, jsonify, redirect, url_for +import time + +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + +app = Flask(__name__, template_folder='templates', static_folder='static') + +# 存储训练进程的信息 +training_processes = {} + +# 启动训练进程 +def start_training_process(train_type, params): + # 获取脚本所在目录的绝对路径 + script_dir = os.path.dirname(os.path.abspath(__file__)) + # 使用详细的时间戳作为进程ID和日志文件名 + process_id = time.strftime('%Y%m%d_%H%M%S') + # 构建logfile目录的绝对路径 + log_dir = os.path.join(script_dir, '../logfile') + log_dir = os.path.abspath(log_dir) + log_file = os.path.join(log_dir, f"train_{train_type}_{process_id}.log") + + # 确保日志目录存在 + os.makedirs(log_dir, exist_ok=True) + + # 构建命令 + if train_type == 'pretrain': + script_path = '../trainer/train_pretrain.py' + cmd = [sys.executable, script_path] + if 'save_weight' in params: + cmd.extend(['--save_weight', params['save_weight']]) + elif train_type == 'sft': + script_path = '../trainer/train_full_sft.py' + cmd = [sys.executable, script_path] + if 'save_weight' in params: + cmd.extend(['--save_weight', params['save_weight']]) + elif train_type == 'lora': + script_path = '../trainer/train_lora.py' + cmd = [sys.executable, script_path] + if 'lora_name' in params: + cmd.extend(['--lora_name', params['lora_name']]) + else: + return None + + # 添加通用参数 + for key, value in params.items(): + if key not in ['train_type', 'save_weight', 'lora_name']: + # 特殊处理布尔标志参数 + if key in ['use_wandb', 'from_resume']: + if value == '1': # 只有当值为1时才添加这个标志 + cmd.append(f'--{key}') + else: + # 确保log_interval和save_interval参数正确传递 + cmd.extend([f'--{key}', str(value)]) + + # 创建日志文件 + with open(log_file, 'w') as f: + f.write(f"开始训练 {train_type} 进程\n") + f.write(f"命令: {' '.join(cmd)}\n\n") + + # 启动进程 + process = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + cwd=os.path.dirname(os.path.abspath(__file__)) + ) + + # 存储进程信息 + training_processes[process_id] = { + 'process': process, + 'train_type': train_type, + 'log_file': log_file, + 'start_time': time.strftime('%Y-%m-%d %H:%M:%S'), + 'running': True, + 'error': False + } + + # 开始读取输出 + def read_output(): + try: + while True: + output = process.stdout.readline() + if output == '' and process.poll() is not None: + break + if output: + with open(log_file, 'a') as f: + f.write(output) + # 检查进程是否成功结束 + if process.returncode != 0: + training_processes[process_id]['error'] = True + finally: + training_processes[process_id]['running'] = False + + # 启动线程读取输出 + threading.Thread(target=read_output, daemon=True).start() + + return process_id + +# Flask路由 +@app.route('/') +def index(): + return render_template('index.html') + +@app.route('/train', methods=['POST']) +def train(): + data = request.json + train_type = data.get('train_type') + + # 移除不相关的参数 + params = data.copy() + + # 处理复选框参数 + if 'from_resume' not in params: + params['from_resume'] = '0' + if 'use_wandb' not in params: + params['use_wandb'] = '0' + + # 启动训练进程 + process_id = start_training_process(train_type, params) + + if process_id: + return jsonify({'success': True, 'process_id': process_id}) + else: + return jsonify({'success': False, 'error': '无效的训练类型'}) + +@app.route('/processes') +def processes(): + result = [] + for process_id, info in training_processes.items(): + result.append({ + 'id': process_id, + 'train_type': info['train_type'], + 'start_time': info['start_time'], + 'running': info['running'], + 'error': info['error'] + }) + return jsonify(result) + +@app.route('/logs/') +def logs(process_id): + # 直接从本地logfile目录读取日志文件 + # 获取脚本所在目录的绝对路径 + script_dir = os.path.dirname(os.path.abspath(__file__)) + # 构建logfile目录的绝对路径 + log_dir = os.path.join(script_dir, '../logfile') + log_dir = os.path.abspath(log_dir) + + if os.path.exists(log_dir): + for filename in os.listdir(log_dir): + if filename.endswith(f'{process_id}.log'): + log_file = os.path.join(log_dir, filename) + if os.path.exists(log_file): + try: + with open(log_file, 'r', encoding='utf-8') as f: + return f.read() + except Exception as e: + return f'读取日志失败: {str(e)}' + return '日志文件不存在或已被删除' + +@app.route('/logfiles') +def get_logfiles(): + # 获取脚本所在目录的绝对路径 + script_dir = os.path.dirname(os.path.abspath(__file__)) + # 构建logfile目录的绝对路径 + log_dir = os.path.join(script_dir, '../logfile') + log_dir = os.path.abspath(log_dir) + + logfiles = [] + if os.path.exists(log_dir): + for filename in os.listdir(log_dir): + if filename.endswith('.log') and filename.startswith('train_'): + file_path = os.path.join(log_dir, filename) + try: + modified_time = os.path.getmtime(file_path) + formatted_time = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(modified_time)) + logfiles.append({ + 'filename': filename, + 'modified_time': formatted_time, + 'size': os.path.getsize(file_path) + }) + except Exception as e: + continue + return jsonify(logfiles) + +@app.route('/logfile-content/') +def get_logfile_content(filename): + # 获取脚本所在目录的绝对路径 + script_dir = os.path.dirname(os.path.abspath(__file__)) + # 构建logfile目录的绝对路径 + log_dir = os.path.join(script_dir, '../logfile') + log_dir = os.path.abspath(log_dir) + + # 安全检查:防止路径遍历攻击 + if '..' in filename or '/' in filename or '\\' in filename: + return '非法的文件名' + + log_file = os.path.join(log_dir, filename) + if os.path.exists(log_file) and os.path.isfile(log_file): + try: + with open(log_file, 'r', encoding='utf-8') as f: + return f.read() + except Exception as e: + return f'读取日志失败: {str(e)}' + return '日志文件不存在' + + +@app.route('/stop/', methods=['POST']) +def stop(process_id): + if process_id in training_processes and training_processes[process_id]['running']: + process = training_processes[process_id]['process'] + # 在Windows上使用terminate,在Unix上尝试优雅终止 + try: + process.terminate() + # 等待进程结束 + process.wait(timeout=5) + except subprocess.TimeoutExpired: + # 如果超时,强制杀死 + process.kill() + + training_processes[process_id]['running'] = False + return jsonify({'success': True}) + return jsonify({'success': False}) + +if __name__ == '__main__': + app.run(host='0.0.0.0', port=5000, debug=True) \ No newline at end of file