diff --git a/.gitignore b/.gitignore index aee52b1..8675a0e 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,8 @@ +__pycache__ model/__pycache__ out website/ -docs-minimind/ \ No newline at end of file +docs-minimind/ +logfile +dataset +checkpoints \ No newline at end of file diff --git a/minimind_sdk/__init__.py b/minimind_sdk/__init__.py new file mode 100644 index 0000000..4f9b7ed --- /dev/null +++ b/minimind_sdk/__init__.py @@ -0,0 +1 @@ +from .client import MinimindClient \ No newline at end of file diff --git a/minimind_sdk/client.py b/minimind_sdk/client.py new file mode 100644 index 0000000..d104e1c --- /dev/null +++ b/minimind_sdk/client.py @@ -0,0 +1,65 @@ +import json +import urllib.request +import urllib.error + +class MinimindClient: + def __init__(self, base_url, api_key=None, timeout=10): + self.base_url = base_url.rstrip('/') + self.api_key = api_key or '' + self.timeout = timeout + + def _request(self, method, path, body=None, expect_text=False): + url = f"{self.base_url}{path}" + headers = { + 'Content-Type': 'application/json', + 'Cache-Control': 'no-cache' + } + if self.api_key: + headers['Authorization'] = f"Bearer {self.api_key}" + data = None + if body is not None: + data = json.dumps(body).encode('utf-8') + req = urllib.request.Request(url, data=data, headers=headers, method=method) + try: + with urllib.request.urlopen(req, timeout=self.timeout) as resp: + raw = resp.read() + if expect_text: + return raw.decode('utf-8', errors='replace') + return json.loads(raw.decode('utf-8')) + except urllib.error.HTTPError as e: + msg = e.read().decode('utf-8', errors='replace') + raise RuntimeError(f"HTTP {e.code}: {msg}") + except urllib.error.URLError as e: + raise RuntimeError(str(e)) + + def register(self, name, email): + res = self._request('POST', '/api/register', {'name': name, 'email': email}) + self.api_key = res.get('api_key', self.api_key) + return res + + def start_training(self, train_type, **params): + payload = {'train_type': train_type} + payload.update(params or {}) + res = self._request('POST', '/train', payload) + return res + + def get_processes(self): + return self._request('GET', '/processes', None) + + def get_logs(self, process_id): + return self._request('GET', f"/logs/{process_id}", None, expect_text=True) + + def stop(self, process_id): + return self._request('POST', f"/stop/{process_id}", None) + + def delete(self, process_id): + return self._request('POST', f"/delete/{process_id}", None) + + def get_logfiles(self): + return self._request('GET', '/logfiles', None) + + def get_logfile_content(self, filename): + return self._request('GET', f"/logfile-content/{filename}", None, expect_text=True) + + def delete_logfile(self, filename): + return self._request('DELETE', f"/delete-logfile/{filename}", None) \ No newline at end of file diff --git a/trainer_web/dispatcher.py b/trainer_web/dispatcher.py new file mode 100644 index 0000000..092f7b5 --- /dev/null +++ b/trainer_web/dispatcher.py @@ -0,0 +1,81 @@ +import sys +import os + +def build_command(train_type, params, gpu_num, use_torchrun): + if train_type == 'pretrain': + script_path = '../trainer/train_pretrain.py' + cmd = ['torchrun', '--nproc_per_node', str(gpu_num), script_path] if use_torchrun else [sys.executable, script_path] + if 'save_weight' in params: + cmd.extend(['--save_weight', params['save_weight']]) + elif train_type == 'sft': + script_path = '../trainer/train_full_sft.py' + cmd = ['torchrun', '--nproc_per_node', str(gpu_num), script_path] if use_torchrun else [sys.executable, script_path] + if 'save_weight' in params: + cmd.extend(['--save_weight', params['save_weight']]) + elif train_type == 'lora': + script_path = '../trainer/train_lora.py' + cmd = ['torchrun', '--nproc_per_node', str(gpu_num), script_path] if use_torchrun else [sys.executable, script_path] + if 'lora_name' in params: + cmd.extend(['--lora_name', params['lora_name']]) + elif train_type == 'dpo': + script_path = '../trainer/train_dpo.py' + cmd = ['torchrun', '--nproc_per_node', str(gpu_num), script_path] if use_torchrun else [sys.executable, script_path] + if 'beta' in params and params['beta']: + cmd.extend(['--beta', params['beta']]) + if 'accumulation_steps' in params and params['accumulation_steps']: + cmd.extend(['--accumulation_steps', params['accumulation_steps']]) + if 'grad_clip' in params and params['grad_clip']: + cmd.extend(['--grad_clip', params['grad_clip']]) + elif train_type == 'ppo': + script_path = '../trainer/train_ppo.py' + cmd = ['torchrun', '--nproc_per_node', str(gpu_num), script_path] if use_torchrun else [sys.executable, script_path] + if 'clip_epsilon' in params and params['clip_epsilon']: + cmd.extend(['--clip_epsilon', params['clip_epsilon']]) + if 'vf_coef' in params and params['vf_coef']: + cmd.extend(['--vf_coef', params['vf_coef']]) + if 'kl_coef' in params and params['kl_coef']: + cmd.extend(['--kl_coef', params['kl_coef']]) + if 'reasoning' in params and params['reasoning']: + cmd.extend(['--reasoning', params['reasoning']]) + if 'update_old_actor_freq' in params and params['update_old_actor_freq']: + cmd.extend(['--update_old_actor_freq', params['update_old_actor_freq']]) + if 'reward_model_path' in params and params['reward_model_path']: + cmd.extend(['--reward_model_path', params['reward_model_path']]) + elif train_type == 'grpo': + script_path = '../trainer/train_grpo.py' + cmd = ['torchrun', '--nproc_per_node', str(gpu_num), script_path] if use_torchrun else [sys.executable, script_path] + if 'beta' in params and params['beta']: + cmd.extend(['--beta', params['beta']]) + if 'num_generations' in params and params['num_generations']: + cmd.extend(['--num_generations', params['num_generations']]) + if 'reasoning' in params and params['reasoning']: + cmd.extend(['--reasoning', params['reasoning']]) + if 'reward_model_path' in params and params['reward_model_path']: + cmd.extend(['--reward_model_path', params['reward_model_path']]) + elif train_type == 'spo': + script_path = '../trainer/train_spo.py' + cmd = ['torchrun', '--nproc_per_node', str(gpu_num), script_path] if use_torchrun else [sys.executable, script_path] + if 'beta' in params and params['beta']: + cmd.extend(['--beta', params['beta']]) + if 'reasoning' in params and params['reasoning']: + cmd.extend(['--reasoning', params['reasoning']]) + if 'reward_model_path' in params and params['reward_model_path']: + cmd.extend(['--reward_model_path', params['reward_model_path']]) + else: + return None + + for key, value in params.items(): + if key in ['train_type', 'save_weight', 'lora_name', 'train_monitor', 'beta', 'accumulation_steps', 'grad_clip', 'gpu_num', 'clip_epsilon', 'vf_coef', 'kl_coef', 'reasoning', 'update_old_actor_freq', 'reward_model_path', 'num_generations'] or ((train_type == 'ppo' or train_type == 'grpo' or train_type == 'spo') and key == 'from_weight'): + continue + elif key == 'from_resume': + cmd.extend([f'--{key}', str(value)]) + else: + cmd.extend([f'--{key}', str(value)]) + + if 'train_monitor' in params: + if params['train_monitor'] == 'wandb' or params['train_monitor'] == 'swanlab': + cmd.append('--use_wandb') + if params['train_monitor'] == 'wandb': + cmd.extend(['--wandb_project', 'minimind_training']) + + return cmd \ No newline at end of file diff --git a/trainer_web/start_web_ui.sh b/trainer_web/start_web_ui.sh new file mode 100755 index 0000000..ef0741e --- /dev/null +++ b/trainer_web/start_web_ui.sh @@ -0,0 +1,91 @@ +#!/bin/bash + +# 获取脚本所在目录(兼容 macOS) +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +cd "$SCRIPT_DIR" + +# 检查是否已经有实例在运行 +if [ -f "train_web_ui.pid" ]; then + pid=$(cat "train_web_ui.pid") + if ps -p "$pid" > /dev/null 2>&1; then + echo "Web UI 服务已经在运行 (PID: $pid)" + exit 1 + else + echo "删除旧的PID文件" + rm "train_web_ui.pid" + fi +fi + +# 创建日志目录 +LOG_DIR="../logfile" +mkdir -p "$LOG_DIR" + +# 生成时间戳 +TIMESTAMP=$(date +"%Y%m%d_%H%M%S") +LOG_FILE="$LOG_DIR/web_ui_$TIMESTAMP.log" + +echo "启动 MiniMind Web UI 服务..." +echo "日志文件: $LOG_FILE" + +# 依赖预检 +python - <<'PY' +import sys +missing = [] +for m in ('flask', 'psutil'): + try: + __import__(m) + except Exception as e: + missing.append(f"{m}: {e.__class__.__name__} {e}") +if missing: + print("依赖缺失或不可用:\n" + "\n".join(missing)) + sys.exit(1) +PY +if [ $? -ne 0 ]; then + echo "启动失败:请先安装缺失依赖,例如 'pip install flask psutil'" + exit 1 +fi + +# 使用nohup启动服务 +nohup python -u train_web_ui.py > "$LOG_FILE" 2>&1 & + +# 保存PID +echo $! > "train_web_ui.pid" + +# 轮询日志以获取实际端口号(最多等待10秒) +PORT="" +for i in {1..20}; do + # 提取形如 http://0.0.0.0:12345 的地址,再截取端口 + PORT=$(grep -Eo 'http://0\.0\.0\.0:[0-9]+' "$LOG_FILE" | tail -n1 | awk -F: '{print $NF}') + if [ -n "$PORT" ]; then + break + fi + sleep 0.5 +done + +# 如果仍未获取到端口,回退为默认提示端口(与后端起始端口一致) +# 健康检查:验证端口响应(最多等待10秒) +if [ -n "$PORT" ]; then + for i in {1..20}; do + if curl -s "http://localhost:$PORT/healthz" | grep -Eq '"status"[[:space:]]*:[[:space:]]*"ok"'; then + echo "服务已启动! PID: $(cat "train_web_ui.pid")" + echo "访问地址: http://localhost:$PORT" + echo "停止命令: kill $(cat "train_web_ui.pid") or bash trainer_web/stop_web_ui.sh" + exit 0 + fi + sleep 0.5 + done +fi + +# 启动失败处理:打印日志并退出非零 +echo "服务启动失败,请查看日志" +tail -n 50 "$LOG_FILE" || true + +if [ -f "train_web_ui.pid" ]; then + pid=$(cat "train_web_ui.pid") + if ps -p "$pid" > /dev/null 2>&1; then + kill "$pid" >/dev/null 2>&1 || true + fi + rm -f "train_web_ui.pid" +fi + +exit 1 diff --git a/trainer_web/static/css/style.css b/trainer_web/static/css/style.css new file mode 100644 index 0000000..7ec69bd --- /dev/null +++ b/trainer_web/static/css/style.css @@ -0,0 +1,1365 @@ +:root { + --bg: #0b0b0b; + --card-bg: #000000; + --panel-bg: rgba(0, 0, 0, 0.95); + --text: #e2e8f0; + --text-secondary: #94a3b8; + --accent: #8b5cf6; + --accent-grad-start: #7c3aed; + --accent-grad-end: #a855f7; + --danger-grad-start: #ef4444; + --danger-grad-end: #dc2626; + --info-grad-start: #3b82f6; + --info-grad-end: #2563eb; + --success-grad-start: #10b981; + --success-grad-end: #059669; + --warning-grad-start: #f59e0b; + --warning-grad-end: #d97706; + --border: #2d3748; + --border-light: #4a5568; + --shadow-sm: 0 1px 2px 0 rgba(0, 0, 0, 0.05); + --shadow-md: 0 4px 6px -1px rgba(0, 0, 0, 0.1), 0 2px 4px -1px rgba(0, 0, 0, 0.06); + --shadow-lg: 0 10px 15px -3px rgba(0, 0, 0, 0.1), 0 4px 6px -2px rgba(0, 0, 0, 0.05); + --shadow-xl: 0 20px 25px -5px rgba(0, 0, 0, 0.1), 0 10px 10px -5px rgba(0, 0, 0, 0.04); + --radius-sm: 0.375rem; + --radius-md: 0.5rem; + --radius-lg: 0.75rem; + --radius-xl: 1rem; +} + +body { + font-family: 'Inter', -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, 'Helvetica Neue', Arial, sans-serif; + line-height: 1.6; + color: var(--text); + max-width: 1400px; + margin: 0 auto; + padding: 0; + background: linear-gradient(135deg, var(--bg) 0%, #000000 100%); + min-height: 100vh; + background-attachment: fixed; + font-size: 14px; +} + +/* 头部样式 */ +.header { + display: flex; + align-items: center; + justify-content: center; + padding: 2rem 0; + margin-bottom: 2rem; + background: rgba(0, 0, 0, 0.8); + backdrop-filter: blur(10px); + border-bottom: 1px solid var(--border); + position: sticky; + top: 0; + z-index: 100; +} + +.logo { + height: 48px; + margin-right: 1rem; + vertical-align: middle; + filter: drop-shadow(0 2px 4px rgba(0, 0, 0, 0.3)); + transition: transform 0.3s ease; +} + +.logo:hover { + transform: scale(1.05); +} + +h1 { + color: var(--text); + font-size: 2.25rem; + font-weight: 700; + margin: 0; + vertical-align: middle; + background: linear-gradient(135deg, var(--accent-grad-start) 0%, var(--accent-grad-end) 100%); + -webkit-background-clip: text; + -webkit-text-fill-color: transparent; + background-clip: text; + text-shadow: none; +} +.tabs { + display: flex; + justify-content: center; + margin: 0 auto 2rem; + max-width: 800px; + background: var(--panel-bg); + border-radius: var(--radius-lg); + padding: 0.5rem; + box-shadow: var(--shadow-lg); + border: 1px solid var(--border); + backdrop-filter: blur(10px); +} +.tab { + padding: 0.75rem 1.5rem; + cursor: pointer; + background: transparent; + border: none; + font-size: 0.9rem; + font-weight: 500; + transition: all 0.3s ease; + position: relative; + color: var(--text-secondary); + text-align: center; + border-radius: var(--radius-md); + flex: 1; + margin: 0 0.25rem; + position: relative; + overflow: hidden; +} +.tab.active { + background: linear-gradient(135deg, var(--accent-grad-start) 0%, var(--accent-grad-end) 100%); + color: white; + font-weight: 600; + box-shadow: var(--shadow-md); + transform: translateY(-1px); +} + +.tab:hover:not(.active) { + color: var(--text); + background: rgba(139, 92, 246, 0.1); + transform: translateY(-1px); +} +.form-container { + background: var(--panel-bg); + padding: 2rem; + border-radius: var(--radius-xl); + box-shadow: var(--shadow-xl); + margin: 0 auto 2rem; + max-width: 1200px; + border: 1px solid var(--border); + backdrop-filter: blur(10px); +} + +/* 参数卡片样式 */ +.parameter-card { + background: linear-gradient(135deg, var(--card-bg) 0%, rgba(26, 26, 36, 0.8) 100%); + border-radius: var(--radius-lg); + padding: 1.5rem; + margin-bottom: 1rem; + box-shadow: var(--shadow-md); + transition: all 0.3s ease; + border: 1px solid var(--border); + position: relative; + overflow: hidden; +} + +.parameter-card::before { + content: ''; + position: absolute; + top: 0; + left: 0; + right: 0; + height: 2px; + background: linear-gradient(90deg, var(--accent-grad-start) 0%, var(--accent-grad-end) 100%); + opacity: 0; + transition: opacity 0.3s ease; +} + +.parameter-card:hover::before { + opacity: 1; +} + +.parameter-card:hover { + transform: translateY(-4px); + box-shadow: var(--shadow-lg); + border-color: var(--border-light); +} + +/* 卡片标题样式 */ +.card-title { + color: var(--text); + font-size: 1.1rem; + font-weight: 600; + margin: 0 0 1rem 0; + padding-bottom: 0.5rem; + border-bottom: 2px solid var(--accent); + width: 100%; + letter-spacing: 0.025em; +} + +/* 提交按钮容器 */ +.submit-container { + text-align: center; + margin-top: 30px; + padding-top: 20px; + border-top: 1px solid #4d4d4d; +} + +/* 参数内容容器 - 使用flex布局替代float */ +.parameter-content { + width: 100%; + display: grid; + grid-template-columns: repeat(auto-fit, minmax(280px, 1fr)); + gap: 1rem; + align-items: start; +} + +.form-group { + margin-bottom: 0; + box-sizing: border-box; + position: relative; +} + +/* 确保复选框组占满整行 */ +.form-group.checkbox-group { + width: 100%; +} + +.form-group.pretrain-sft, .form-group.lora { + /* 保持默认宽度,遵循每行两个的布局 */ + width: calc(40% - 8px); +} +label { + display: block; + margin-bottom: 0.5rem; + color: var(--text); + font-weight: 500; + font-size: 0.8rem; + opacity: 0.9; + text-transform: uppercase; + letter-spacing: 0.05em; + transition: color 0.3s ease; +} +input[type="text"], input[type="number"], select, textarea { + width: 100%; + padding: 0.75rem 1rem; + border: 1px solid var(--border); + border-radius: var(--radius-md); + font-size: 0.9rem; + transition: all 0.3s ease; + background: rgba(45, 55, 72, 0.5); + color: var(--text); + font-family: inherit; + box-sizing: border-box; +} + +input[type="text"]:hover, input[type="number"]:hover, select:hover, textarea:hover { + border-color: var(--border-light); + background: rgba(45, 55, 72, 0.7); +} + +/* 确保textarea也适应两列布局 */ +textarea { + resize: vertical; + min-height: 80px; +} + +input[type="text"]:focus, input[type="number"]:focus, select:focus, textarea:focus { + outline: none; + border-color: var(--accent); + box-shadow: 0 0 0 3px rgba(139, 92, 246, 0.2); + background: rgba(45, 55, 72, 0.9); + transform: translateY(-1px); +} + +/* 文件夹选择器样式 */ +.input-with-picker { + display: flex; + gap: 0.5rem; + align-items: center; +} + +.input-with-picker input { + flex: 1; +} + +.btn-picker { + background: linear-gradient(135deg, var(--info-grad-start) 0%, var(--info-grad-end) 100%); + color: white; + border: none; + padding: 0.75rem; + border-radius: var(--radius-md); + font-size: 1rem; + cursor: pointer; + transition: all 0.3s ease; + box-shadow: var(--shadow-sm); + min-width: 40px; + height: 40px; + display: flex; + align-items: center; + justify-content: center; +} + +.btn-picker:hover { + transform: translateY(-1px); + box-shadow: var(--shadow-md); + filter: brightness(1.1); +} + +/* 进度条样式 */ +.progress-container { + margin: 0.5rem 0; + background: rgba(45, 55, 72, 0.3); + border-radius: var(--radius-lg); + padding: 0.5rem; + border: 1px solid var(--border); +} + +.progress-bar { + width: 100%; + height: 8px; + background: rgba(45, 55, 72, 0.5); + border-radius: var(--radius-sm); + overflow: hidden; + margin: 0.5rem 0; +} + +.progress-fill { + height: 100%; + background: linear-gradient(90deg, var(--accent-grad-start) 0%, var(--accent-grad-end) 100%); + border-radius: var(--radius-sm); + transition: width 0.3s ease; + position: relative; +} + +.progress-fill::after { + content: ''; + position: absolute; + top: 0; + left: 0; + right: 0; + bottom: 0; + background: linear-gradient(90deg, transparent, rgba(255,255,255,0.3), transparent); + animation: progress-shine 2s infinite; +} + +@keyframes progress-shine { + 0% { transform: translateX(-100%); } + 100% { transform: translateX(100%); } +} + +.progress-info { + display: flex; + justify-content: space-between; + align-items: center; + font-size: 0.8rem; + color: var(--text-secondary); + margin-top: 0.25rem; +} + +.progress-metrics { + display: flex; + gap: 1rem; + flex-wrap: wrap; + font-size: 0.85rem; + margin-top: 0.5rem; +} + +.metric-item { + display: flex; + align-items: center; + gap: 0.25rem; + padding: 0.25rem 0.5rem; + background: rgba(139, 92, 246, 0.1); + border-radius: var(--radius-sm); + border: 1px solid rgba(139, 92, 246, 0.2); +} + +.metric-label { + font-weight: 500; + color: var(--text-secondary); +} + +.metric-value { + font-weight: 600; + color: var(--accent); +} + +/* 文件浏览器模态框样式 */ +.modal { + position: fixed; + top: 0; + left: 0; + width: 100%; + height: 100%; + background: rgba(0, 0, 0, 0.8); + backdrop-filter: blur(5px); + display: flex; + align-items: center; + justify-content: center; + z-index: 1000; + transition: opacity 0.3s ease; +} + +.modal.hidden { + display: none; +} + +.modal-content { + background: var(--panel-bg); + border-radius: var(--radius-lg); + border: 1px solid var(--border); + box-shadow: var(--shadow-xl); + width: 90%; + max-width: 600px; + max-height: 80vh; + display: flex; + flex-direction: column; + overflow: hidden; +} + +.modal-header { + padding: 1rem 1.5rem; + border-bottom: 1px solid var(--border); + display: flex; + justify-content: space-between; + align-items: center; + background: rgba(139, 92, 246, 0.1); +} + +.modal-header h3 { + margin: 0; + color: var(--text); + font-size: 1.1rem; +} + +.modal-close { + background: none; + border: none; + color: var(--text-secondary); + font-size: 1.5rem; + cursor: pointer; + padding: 0; + width: 30px; + height: 30px; + display: flex; + align-items: center; + justify-content: center; + border-radius: var(--radius-sm); + transition: all 0.3s ease; +} + +.modal-close:hover { + background: rgba(239, 68, 68, 0.2); + color: #ef4444; +} + +.modal-body { + flex: 1; + padding: 1rem; + overflow-y: auto; +} + +.modal-footer { + padding: 1rem 1.5rem; + border-top: 1px solid var(--border); + display: flex; + gap: 0.5rem; + align-items: center; + background: rgba(22, 22, 32, 0.8); +} + +.modal-footer input { + flex: 1; +} + +.btn-secondary { + background: linear-gradient(135deg, var(--border-light) 0%, var(--border) 100%); + color: var(--text); + border: none; + padding: 0.5rem 1rem; + border-radius: var(--radius-md); + cursor: pointer; + transition: all 0.3s ease; + font-size: 0.8rem; + font-weight: 500; +} + +.btn-secondary:hover { + transform: translateY(-1px); + filter: brightness(1.1); +} + +/* 模态框底部样式 */ +.modal-footer { + display: flex; + gap: 0.75rem; + align-items: center; + padding-top: 1rem; + border-top: 1px solid var(--border); +} + +.modal-footer input { + flex: 1; + margin-right: 0.5rem; +} + +/* 改进模态框关闭按钮 */ +.modal-close { + background: none; + border: none; + font-size: 1.5rem; + color: var(--text-secondary); + cursor: pointer; + padding: 0.25rem; + border-radius: var(--radius-sm); + transition: all 0.3s ease; +} + +.modal-close:hover { + color: var(--text); + background: rgba(239, 68, 68, 0.1); +} + +/* 文件浏览器导航 */ +.file-browser-nav { + display: flex; + justify-content: space-between; + align-items: center; + margin-bottom: 1rem; + padding: 0.5rem; + background: rgba(45, 55, 72, 0.5); + border-radius: var(--radius-md); +} + +.nav-buttons { + display: flex; + gap: 0.5rem; +} + +.current-path { + font-family: 'Courier New', monospace; + font-size: 0.85rem; + color: var(--text-secondary); + flex: 1; + margin-right: 1rem; + padding: 0.25rem 0.5rem; + background: rgba(22, 22, 32, 0.5); + border-radius: var(--radius-sm); + border: 1px solid var(--border); +} + +.btn-navigate { + background: linear-gradient(135deg, var(--info-grad-start) 0%, var(--info-grad-end) 100%); + color: white; + border: none; + padding: 0.5rem; + border-radius: var(--radius-sm); + cursor: pointer; + font-size: 0.8rem; + transition: all 0.3s ease; +} + +.btn-navigate:hover { + transform: translateY(-1px); + filter: brightness(1.1); +} + +/* 快捷路径 */ +.quick-paths { + display: flex; + gap: 0.5rem; + margin-bottom: 1rem; + flex-wrap: wrap; +} + +.file-browser-help { + background: rgba(59, 130, 246, 0.1); + border: 1px solid rgba(59, 130, 246, 0.3); + border-radius: var(--radius-sm); + padding: 0.75rem; + margin-bottom: 1rem; + font-size: 0.8rem; + color: var(--info-grad-start); + display: flex; + align-items: center; + gap: 0.5rem; +} + +.quick-path-btn { + background: rgba(139, 92, 246, 0.1); + color: var(--accent); + border: 1px solid rgba(139, 92, 246, 0.3); + padding: 0.4rem 0.8rem; + border-radius: var(--radius-sm); + cursor: pointer; + font-size: 0.75rem; + transition: all 0.3s ease; +} + +.quick-path-btn:hover { + background: rgba(139, 92, 246, 0.2); + transform: translateY(-1px); +} + +/* 文件列表 */ +.file-list { + max-height: 300px; + overflow-y: auto; + border: 1px solid var(--border); + border-radius: var(--radius-md); + background: rgba(22, 22, 32, 0.3); + margin-bottom: 1rem; +} + +.file-item { + display: flex; + align-items: center; + padding: 0.75rem; + cursor: pointer; + transition: all 0.3s ease; + border-bottom: 1px solid rgba(45, 55, 72, 0.3); +} + +.file-item:last-child { + border-bottom: none; +} + +.file-item:hover { + background: rgba(139, 92, 246, 0.1); +} + +.file-item.selected { + background: rgba(139, 92, 246, 0.2); + border-left: 3px solid var(--accent); +} + +.file-item.disabled { + opacity: 0.5; + cursor: not-allowed; +} + +.file-item.disabled:hover { + background: none; +} + +.file-icon { + margin-right: 0.75rem; + font-size: 1.2rem; + width: 20px; + text-align: center; +} + +.file-name { + flex: 1; + font-size: 0.9rem; + color: var(--text); +} + +.file-info { + font-size: 0.75rem; + color: var(--text-secondary); + text-align: right; +} + +.directory { + color: var(--info-grad-start); +} + +.file { + color: var(--text-secondary); +} +.checkbox-group { + display: flex; + align-items: center; +} +.checkbox-group input[type="checkbox"] { + width: auto; + margin-right: 10px; +} +button { + background: linear-gradient(135deg, var(--accent-grad-start) 0%, var(--accent-grad-end) 100%); + color: white; + border: none; + padding: 0.75rem 1.5rem; + border-radius: var(--radius-md); + font-size: 0.9rem; + font-weight: 600; + cursor: pointer; + transition: all 0.3s ease; + box-shadow: var(--shadow-md); + position: relative; + overflow: hidden; + letter-spacing: 0.025em; + text-transform: uppercase; + font-size: 0.8rem; +} +button:hover { + transform: translateY(-2px); + box-shadow: var(--shadow-lg); + filter: brightness(1.1); +} + +button:active { + transform: translateY(0); + filter: brightness(0.95); +} + +.section-title { + color: var(--text); + font-size: 1.25rem; + margin-bottom: 1.5rem; + font-weight: 600; + padding-bottom: 0.5rem; + border-bottom: 2px solid var(--accent); + letter-spacing: 0.025em; +} +.logs-container { + background: linear-gradient(135deg, #0f0f15 0%, #1a1a24 100%); + color: var(--text); + padding: 1.5rem; + border-radius: var(--radius-lg); + max-height: 400px; + overflow-y: auto; + margin-top: 1rem; + font-family: 'SF Mono', 'Monaco', 'Inconsolata', 'Roboto Mono', monospace; + font-size: 0.8rem; + line-height: 1.4; + box-shadow: var(--shadow-md), inset 0 1px 0 rgba(255, 255, 255, 0.05); + transition: all 0.3s ease; + border: 1px solid var(--border); + position: relative; +} + +.logs-container::before { + content: ''; + position: absolute; + top: 0; + left: 0; + right: 0; + height: 1px; + background: linear-gradient(90deg, transparent 0%, var(--accent) 50%, transparent 100%); +} + +.logs-container:hover { + box-shadow: 0 8px 25px rgba(0, 0, 0, 0.3); +} +.process-type-group { + margin-bottom: 30px; + background-color: var(--panel-bg); + border-radius: 15px; + border: 1px solid #444; + box-shadow: 0 5px 15px rgba(0, 0, 0, 0.2); + overflow: hidden; +} + +/* 标题容器样式 */ +.process-type-header { + display: flex; + justify-content: space-between; + align-items: center; + padding: 20px; + cursor: pointer; + user-select: none; +} + +.process-type-title { + margin: 0; + color: var(--text); + font-size: 1.2em; + font-weight: 600; + border-bottom: 2px solid var(--accent); + padding-bottom: 8px; + flex-grow: 1; +} + +/* 切换按钮样式 */ +.toggle-btn { + background: none; + border: none; + color: #e0e0e0; + font-size: 0.8em; + cursor: pointer; + padding: 5px 10px; + border-radius: 5px; + transition: background-color 0.3s, transform 0.2s; + margin-left: 15px; +} + +.toggle-btn:hover { + background-color: rgba(255, 255, 255, 0.1); + transform: scale(1.1); +} + +/* 内容容器样式 */ +.process-type-content { + max-height: none; + overflow: visible; + transition: max-height 0.3s ease-in-out; + padding: 0 20px 20px 20px; +} + +.process-item { + background: linear-gradient(135deg, var(--card-bg) 0%, rgba(26, 26, 36, 0.8) 100%); + padding: 1.5rem; + margin-bottom: 1rem; + border-radius: var(--radius-lg); + box-shadow: var(--shadow-md); + transition: all 0.3s ease; + border: 1px solid var(--border); + position: relative; + overflow: hidden; +} + +.process-item::before { + content: ''; + position: absolute; + top: 0; + left: 0; + right: 0; + height: 2px; + background: linear-gradient(90deg, var(--accent-grad-start) 0%, var(--accent-grad-end) 100%); + opacity: 0; + transition: opacity 0.3s ease; +} + +.process-item:hover::before { + opacity: 1; +} + +.process-item:hover { + transform: translateY(-4px); + box-shadow: var(--shadow-lg); + border-color: var(--border-light); +} +.process-info { + display: flex; + justify-content: space-between; + align-items: center; + margin-bottom: 10px; +} +.process-status { + padding: 0.375rem 0.75rem; + border-radius: var(--radius-lg); + font-size: 0.7rem; + font-weight: 600; + text-transform: uppercase; + letter-spacing: 0.05em; + backdrop-filter: blur(10px); +} +.status-running { + background: linear-gradient(135deg, var(--success-grad-start) 0%, var(--success-grad-end) 100%); + color: white; +} +.status-completed { + background: linear-gradient(135deg, #f093fb 0%, #f5576c 100%); + color: white; +} +.status-error { + background: linear-gradient(135deg, #ff416c 0%, #ff4b2b 100%); + color: white; +} +.status-manual-stop { + background: linear-gradient(135deg, #4e54c8 0%, #8f94fb 100%); + color: white; +} +.btn-stop { + background: linear-gradient(135deg, var(--danger-grad-start) 0%, var(--danger-grad-end) 100%); + padding: 8px 15px; + font-size: 14px; + border-radius: 6px; +} +.btn-stop:hover { + transform: translateY(-1px); + box-shadow: 0 4px 10px rgba(255, 65, 108, 0.3); +} +.btn-logs { + background: linear-gradient(135deg, var(--info-grad-start) 0%, var(--info-grad-end) 100%); + padding: 8px 15px; + font-size: 14px; + margin-right: 10px; + border-radius: 6px; +} + +.btn-swanlab { + background: linear-gradient(135deg, #007bff 0%, #00bfff 100%); + padding: 8px 15px; + font-size: 14px; + margin-right: 10px; + border-radius: 6px; + color: white; +} + +.btn-swanlab:hover { + transform: translateY(-1px); + box-shadow: 0 4px 10px rgba(0, 123, 255, 0.3); +} +.btn-delete { + background: linear-gradient(135deg, #f44336 0%, #d32f2f 100%); + padding: 8px 15px; + font-size: 14px; + margin-right: 10px; + border-radius: 6px; + color: white; +} +.btn-delete:hover { + transform: translateY(-1px); + box-shadow: 0 4px 10px rgba(244, 67, 54, 0.3); +} +.btn-delete:disabled { + background: linear-gradient(135deg, #cccccc 0%, #aaaaaa 100%); + cursor: not-allowed; + transform: none; + box-shadow: none; +} +.btn-logs:hover { + transform: translateY(-1px); + box-shadow: 0 4px 10px rgba(79, 172, 254, 0.3); +} +.hidden { + display: none; +} +.section-title { + color: #ffffff; + font-size: 18px; + margin-bottom: 25px; + text-shadow: 0 2px 5px rgba(0, 0, 0, 0.5); + font-weight: 700; + padding-bottom: 10px; + border-bottom: 1px solid #e040fb; + margin-top: 16px; +} + +/* 添加滚动条样式 */ +::-webkit-scrollbar { + width: 8px; +} + +::-webkit-scrollbar-track { + background: rgba(255, 255, 255, 0.02); + border-radius: 4px; +} + +::-webkit-scrollbar-thumb { + background: linear-gradient(135deg, var(--accent-grad-start) 0%, var(--accent-grad-end) 100%); + border-radius: 4px; +} + +::-webkit-scrollbar-thumb:hover { + background: linear-gradient(135deg, #6d28d9 0%, #8b5cf6 100%); +} + +/* 添加动画效果 */ +@keyframes fadeIn { + from { + opacity: 0; + transform: translateY(20px); + } + to { + opacity: 1; + transform: translateY(0); + } +} + +@keyframes slideIn { + from { + opacity: 0; + transform: translateX(-20px); + } + to { + opacity: 1; + transform: translateX(0); + } +} + +@keyframes pulse { + 0%, 100% { + opacity: 1; + } + 50% { + opacity: 0.7; + } +} + +/* 加载动画 */ +@keyframes shimmer { + 0% { + background-position: -1000px 0; + } + 100% { + background-position: 1000px 0; + } +} + +.loading-shimmer { + background: linear-gradient(90deg, transparent 0%, rgba(139, 92, 246, 0.1) 50%, transparent 100%); + background-size: 1000px 100%; + animation: shimmer 2s infinite; +} + +/* 状态指示器 */ +.status-indicator { + display: inline-block; + width: 8px; + height: 8px; + border-radius: 50%; + margin-right: 0.5rem; + animation: pulse 2s infinite; +} + +.status-indicator.running { + background: linear-gradient(135deg, var(--success-grad-start) 0%, var(--success-grad-end) 100%); +} + +.status-indicator.stopped { + background: linear-gradient(135deg, var(--danger-grad-start) 0%, var(--danger-grad-end) 100%); +} + +.status-indicator.pending { + background: linear-gradient(135deg, var(--warning-grad-start) 0%, var(--warning-grad-end) 100%); +} + +/* 自定义确认对话框样式 */ +.dialog-overlay { + position: fixed; + top: 0; + left: 0; + right: 0; + bottom: 0; + background-color: rgba(0, 0, 0, 0.7); + display: flex; + align-items: center; + justify-content: center; + z-index: 1000; + opacity: 0; + transition: opacity 0.3s ease; +} + +.dialog-overlay.show { + opacity: 1; +} + +.custom-dialog { + background: linear-gradient(135deg, #1a1a2e 0%, #16213e 100%); + border-radius: 8px; + box-shadow: 0 4px 20px rgba(0, 0, 0, 0.5); + max-width: 400px; + width: 90%; + transform: translateY(-20px); + opacity: 0; + transition: transform 0.3s ease, opacity 0.3s ease; + border: 1px solid rgba(255, 255, 255, 0.1); +} + +.custom-dialog.show { + transform: translateY(0); + opacity: 1; +} + +.dialog-content { + padding: 20px; +} + +.dialog-message { + color: #ffffff; + font-size: 16px; + margin-bottom: 20px; + text-align: center; + line-height: 1.5; +} + +.dialog-actions { + display: flex; + justify-content: flex-end; + gap: 12px; +} + +.dialog-button { + padding: 10px 20px; + border: none; + border-radius: 6px; + cursor: pointer; + font-size: 14px; + font-weight: 500; + transition: all 0.2s ease; +} + +.dialog-cancel { + background: linear-gradient(135deg, #4a4a4a 0%, #333333 100%); + color: #ffffff; +} + +.dialog-cancel:hover { + background: linear-gradient(135deg, #5a5a5a 0%, #444444 100%); + transform: translateY(-1px); +} + +.dialog-confirm { + background: linear-gradient(135deg, #6a11cb 0%, #2575fc 100%); + color: #ffffff; +} + +.dialog-confirm:hover { + background: linear-gradient(135deg, #7a21db 0%, #3585ff 100%); + transform: translateY(-1px); +} + +/* 消息弹窗样式 */ +.notification { + position: fixed; + top: 20px; + right: 20px; + padding: 15px 25px; + border-radius: 10px; + color: white; + font-weight: 600; + font-size: 16px; + z-index: 1000; + box-shadow: 0 5px 20px rgba(0, 0, 0, 0.3); + opacity: 0; + transform: translateX(100%); + transition: all 0.3s ease; + max-width: 400px; + word-wrap: break-word; +} + +/* 显示状态 */ +.notification.show { + opacity: 1; + transform: translateX(0); +} + +/* 成功通知样式 */ +.notification-success { + background: linear-gradient(135deg, #4caf50 0%, #81c784 100%); +} + +/* 错误通知样式 */ +.notification-error { + background: linear-gradient(135deg, #f44336 0%, #ef5350 100%); +} + +/* 信息通知样式 */ +.notification-info { + background: linear-gradient(135deg, #2196f3 0%, #64b5f6 100%); +} + +.tab-content { + animation: fadeIn 0.5s ease-out; + padding: 0 1rem; +} + +/* 增强现有样式 */ +.section-header { + display: flex; + justify-content: space-between; + align-items: center; + margin-bottom: 1.5rem; + padding: 0 1rem; +} + +.section-actions { + display: flex; + gap: 0.5rem; +} + +.btn-primary, .btn-refresh { + background: linear-gradient(135deg, var(--accent-grad-start) 0%, var(--accent-grad-end) 100%); + color: white; + border: none; + padding: 0.75rem 1.5rem; + border-radius: var(--radius-md); + font-size: 0.8rem; + font-weight: 600; + cursor: pointer; + transition: all 0.3s ease; + box-shadow: var(--shadow-md); + letter-spacing: 0.025em; + text-transform: uppercase; + display: inline-flex; + align-items: center; + gap: 0.5rem; +} + +.btn-primary:hover, .btn-refresh:hover { + transform: translateY(-2px); + box-shadow: var(--shadow-lg); + filter: brightness(1.1); +} + +.btn-icon { + font-size: 1rem; + line-height: 1; +} + +.process-type-group { + margin-bottom: 2rem; + background: var(--panel-bg); + border-radius: var(--radius-lg); + border: 1px solid var(--border); + box-shadow: var(--shadow-md); + overflow: hidden; + backdrop-filter: blur(10px); +} + +.process-type-header { + display: flex; + justify-content: space-between; + align-items: center; + padding: 1.5rem; + cursor: pointer; + user-select: none; + background: rgba(0, 0, 0, 0.3); + transition: background-color 0.3s ease; +} + +.process-type-header:hover { + background: rgba(45, 55, 72, 0.5); +} + +.process-type-title { + margin: 0; + color: var(--text); + font-size: 1.1rem; + font-weight: 600; + border-bottom: 2px solid var(--accent); + padding-bottom: 0.5rem; + flex-grow: 1; +} + +.toggle-btn { + background: none; + border: none; + color: var(--text-secondary); + font-size: 0.8rem; + cursor: pointer; + padding: 0.5rem; + border-radius: var(--radius-sm); + transition: all 0.3s ease; + margin-left: 1rem; +} + +.toggle-btn:hover { + background: rgba(139, 92, 246, 0.1); + color: var(--text); + transform: scale(1.1); +} + +.process-type-content { + max-height: none; + overflow: visible; + transition: max-height 0.3s ease-in-out; + padding: 0 1.5rem 1.5rem; +} + +/* 按钮样式增强 */ +.btn-stop, .btn-logs, .btn-swanlab, .btn-delete { + padding: 0.5rem 1rem; + font-size: 0.75rem; + border-radius: var(--radius-md); + font-weight: 500; + letter-spacing: 0.025em; + transition: all 0.3s ease; + margin-right: 0.5rem; + margin-bottom: 0.25rem; + border: none; + cursor: pointer; + display: inline-flex; + align-items: center; + gap: 0.25rem; +} + +.btn-stop:hover, .btn-logs:hover, .btn-swanlab:hover, .btn-delete:hover { + transform: translateY(-1px); + filter: brightness(1.1); +} + +.btn-stop { + background: linear-gradient(135deg, var(--danger-grad-start) 0%, var(--danger-grad-end) 100%); +} + +.btn-logs { + background: linear-gradient(135deg, var(--info-grad-start) 0%, var(--info-grad-end) 100%); +} + +.btn-swanlab { + background: linear-gradient(135deg, #3b82f6 0%, #1d4ed8 100%); +} + +.btn-delete { + background: linear-gradient(135deg, #dc2626 0%, #991b1b 100%); +} + +.btn-delete:disabled { + background: linear-gradient(135deg, #4b5563 0%, #374151 100%); + cursor: not-allowed; + transform: none; + filter: none; +} + +/* 响应式设计 */ +@media (max-width: 768px) { + body { + padding: 0; + font-size: 13px; + } + + .header { + flex-direction: column; + text-align: center; + padding: 1.5rem 1rem; + position: relative; + } + + .logo { + height: 40px; + margin-right: 0; + margin-bottom: 0.5rem; + } + + h1 { + font-size: 1.75rem; + } + + .tabs { + flex-direction: column; + margin: 0 1rem 1.5rem; + max-width: none; + } + + .tab { + margin: 0.25rem 0; + padding: 0.75rem 1rem; + border-radius: var(--radius-md); + } + + .form-container { + padding: 1.5rem; + margin: 0 1rem 1.5rem; + max-width: none; + } + + .parameter-content { + grid-template-columns: 1fr; + gap: 0.75rem; + } + + .process-item { + padding: 1rem; + } + + .process-info { + flex-direction: column; + align-items: flex-start; + gap: 0.5rem; + } + + .logs-container { + padding: 1rem; + max-height: 300px; + } +} + +@media (max-width: 480px) { + .header { + padding: 1rem; + } + + h1 { + font-size: 1.5rem; + } + + .form-container { + padding: 1rem; + } + + .parameter-card { + padding: 1rem; + } + + .card-title { + font-size: 1rem; + } +} diff --git a/trainer_web/static/images/logo2.png b/trainer_web/static/images/logo2.png new file mode 100644 index 0000000..9a0b3e2 Binary files /dev/null and b/trainer_web/static/images/logo2.png differ diff --git a/trainer_web/static/js/app.js b/trainer_web/static/js/app.js new file mode 100644 index 0000000..c2dcdb9 --- /dev/null +++ b/trainer_web/static/js/app.js @@ -0,0 +1,362 @@ +import { openTab as _openTab } from './ui/tabs.js'; +import { initTrainForm } from './train/form.js'; +import { startProcessPolling, stopProcessPolling, loadProcesses } from './processes/list.js'; +import { loadLogFiles } from './logfiles/list.js'; +import { refreshLog } from './processes/logs.js'; + +const hooks = { + onEnterProcesses: () => { + // 当切换到进程标签页时,立即加载一次,然后开始轮询 + loadProcesses().then(() => { + startProcessPolling(); + }); + }, + onLeaveProcesses: () => { + stopProcessPolling(); + }, + onEnterLogfiles: () => { + loadLogFiles(); + }, +}; + +window.openTab = (evt, tabName) => _openTab(evt, tabName, hooks); + +// 文件夹选择器功能 - 直接显示服务器端文件浏览器 +window.selectFolder = (inputId) => { + // 直接使用远程文件浏览器,不尝试本地文件系统访问 + openRemoteFileBrowser(inputId); +}; + +// 远程文件浏览器 - 支持文件和文件夹选择 +let currentFileBrowserTarget = null; +let currentBrowsePath = './'; +let selectedFilePath = null; +let currentSelectionMode = 'auto'; // 'file', 'folder', or 'auto' + +function openRemoteFileBrowser(inputId) { + console.log('openRemoteFileBrowser called with:', inputId); + currentFileBrowserTarget = inputId; + + // 根据输入框ID确定选择模式 + if (inputId === 'data_path') { + currentSelectionMode = 'file'; // 数据路径需要文件选择 + console.log('Mode set to: FILE selection'); + } else if (inputId === 'save_dir' || inputId.includes('reward_model_path')) { + currentSelectionMode = 'folder'; // 保存目录和奖励模型路径需要文件夹选择 + console.log('Mode set to: FOLDER selection'); + } else { + currentSelectionMode = 'auto'; // 自动模式 + console.log('Mode set to: AUTO selection'); + } + + const modal = document.getElementById('file-browser-modal'); + if (modal) { + modal.classList.remove('hidden'); + console.log('Modal opened successfully'); + } else { + console.error('Modal element not found!'); + return; + } + + // 重置选择状态 + selectedFilePath = null; + const selectedPathInput = document.getElementById('selected-path'); + if (selectedPathInput) { + selectedPathInput.value = ''; + console.log('Selected path input cleared'); + } + + // 加载初始路径 + loadQuickPaths(); + browsePath('./'); +} + +function closeFileBrowser() { + document.getElementById('file-browser-modal').classList.add('hidden'); + currentFileBrowserTarget = null; + currentBrowsePath = './'; + selectedFilePath = null; + currentSelectionMode = 'auto'; +} + +function confirmFileSelection() { + console.log('confirmFileSelection called'); + console.log('selectedFilePath:', selectedFilePath); + console.log('currentFileBrowserTarget:', currentFileBrowserTarget); + + if (selectedFilePath && currentFileBrowserTarget) { + const targetElement = document.getElementById(currentFileBrowserTarget); + console.log('targetElement:', targetElement); + + if (targetElement) { + targetElement.value = selectedFilePath; + console.log('Value set successfully'); + closeFileBrowser(); + } else { + console.error('Target element not found:', currentFileBrowserTarget); + alert('错误:无法找到目标输入框'); + } + } else { + console.log('Missing selection or target'); + alert('请先选择文件或文件夹'); + } +} + +function navigateToParent() { + if (window.currentParentPath) { + // 使用后端提供的父目录路径(绝对路径) + browsePath(window.currentParentPath); + } else if (currentBrowsePath && currentBrowsePath !== './') { + // 回退到基于当前路径的计算 + const parentPath = currentBrowsePath.includes('/') ? + currentBrowsePath.substring(0, currentBrowsePath.lastIndexOf('/')) : './'; + browsePath(parentPath || './'); + } +} + +function selectCurrentDirectory() { + // 选择当前目录 + selectedFilePath = currentBrowsePath; + document.getElementById('selected-path').value = currentBrowsePath; + // 可以关闭模态框或让用户继续浏览 +} + +async function loadQuickPaths() { + try { + const response = await fetch('/api/quick-paths'); + const data = await response.json(); + + const quickPathsContainer = document.getElementById('quick-paths'); + quickPathsContainer.innerHTML = ''; + + if (data.paths && data.paths.length > 0) { + data.paths.forEach(path => { + const btn = document.createElement('button'); + btn.className = 'quick-path-btn'; + btn.textContent = path.name; + btn.onclick = () => browsePath(path.path); + btn.title = path.path; + quickPathsContainer.appendChild(btn); + }); + } + } catch (error) { + console.warn('加载快捷路径失败:', error); + } +} + +async function browsePath(path) { + console.log('browsePath called with:', path); + try { + currentBrowsePath = path; + selectedFilePath = null; // 重置选中的文件路径 + document.getElementById('selected-path').value = ''; // 清空显示 + + // 更新帮助文本 + updateHelpText(); + + const response = await fetch(`/api/browse?path=${encodeURIComponent(path)}`); + const data = await response.json(); + + if (data.error) { + alert(`浏览失败: ${data.error}`); + return; + } + + renderFileList(data); + console.log('File list rendered successfully'); + } catch (error) { + console.error('浏览路径失败:', error); + alert('浏览路径失败,请检查网络连接'); + } +} + +function renderFileList(data) { + const fileList = document.getElementById('file-list'); + fileList.innerHTML = ''; + + if (!data.items || data.items.length === 0) { + fileList.innerHTML = '
此目录为空
'; + return; + } + + // 更新当前路径显示(使用相对路径用于显示) + document.getElementById('current-path').textContent = data.relative_path || data.current_path; + + // 存储父目录路径供导航使用 + window.currentParentPath = data.parent; + + // 先显示目录,再显示文件 + const directories = data.items.filter(item => item.type === 'directory'); + const files = data.items.filter(item => item.type === 'file'); + + // 渲染目录 + directories.forEach(item => { + const div = createFileItem(item, '📁'); + fileList.appendChild(div); + }); + + // 渲染文件(仅在文件选择模式或自动模式下显示) + if (currentSelectionMode !== 'folder') { + files.forEach(item => { + const div = createFileItem(item, '📄'); + fileList.appendChild(div); + }); + } +} + +function createFileItem(item, icon) { + const div = document.createElement('div'); + div.className = 'file-item'; + + // 根据选择模式添加适当的CSS类 + if (currentSelectionMode === 'file' && item.type === 'directory') { + // 文件选择模式下,文件夹只用于导航,不能选择 + div.classList.add('navigable'); + } else if (currentSelectionMode === 'folder' && item.type === 'file') { + // 文件夹选择模式下,文件不能被选择 + div.classList.add('disabled'); + } + + div.onclick = (event) => selectFileItem(item, event); + + div.innerHTML = ` + ${icon} + ${item.name} + ${item.type === 'file' ? formatFileSize(item.size) : '文件夹'} + `; + + return div; +} + +function selectFileItem(item, event) { + console.log('selectFileItem called with:', item); + console.log('currentSelectionMode:', currentSelectionMode); + console.log('event:', event); + + // 检查是否点击了被禁用的项目 + if (event && event.currentTarget && event.currentTarget.classList.contains('disabled')) { + console.log('Clicked disabled item, ignoring'); + return; + } + + if (item.type === 'directory') { + // 文件夹:根据选择模式决定行为 + if (currentSelectionMode === 'file') { + // 文件选择模式:只能选择文件,点击进入目录 + console.log('File mode: navigating into directory'); + browsePath(item.path); + } else if (currentSelectionMode === 'folder') { + // 文件夹选择模式:可以选择文件夹 + console.log('Folder mode: selecting directory'); + selectedFilePath = item.path; + document.getElementById('selected-path').value = item.path; + // 高亮显示选中的文件夹 + document.querySelectorAll('.file-item').forEach(el => el.classList.remove('selected')); + if (event && event.currentTarget) { + event.currentTarget.classList.add('selected'); + } + console.log('Directory selected:', selectedFilePath); + } else { + // 自动模式:点击进入目录 + console.log('Auto mode: navigating into directory'); + browsePath(item.path); + } + } else { + // 文件:选中文件路径(仅在选择文件或自动模式下) + if (currentSelectionMode !== 'folder') { + console.log('Selecting file:', item.path); + selectedFilePath = item.path; + document.getElementById('selected-path').value = item.path; + // 高亮显示选中的文件 + document.querySelectorAll('.file-item').forEach(el => el.classList.remove('selected')); + if (event && event.currentTarget) { + event.currentTarget.classList.add('selected'); + } + console.log('File selected:', selectedFilePath); + } else { + console.log('File clicked in folder mode, ignoring'); + } + } +} + +function formatFileSize(bytes) { + if (bytes === 0) return '0 B'; + const k = 1024; + const sizes = ['B', 'KB', 'MB', 'GB']; + const i = Math.floor(Math.log(bytes) / Math.log(k)); + return parseFloat((bytes / Math.pow(k, i)).toFixed(1)) + ' ' + sizes[i]; +} + +function updateHelpText() { + const helpText = document.querySelector('.file-browser-help'); + const modalTitle = document.getElementById('modal-title'); + + if (!helpText) return; + + let text = ''; + let title = ''; + switch (currentSelectionMode) { + case 'file': + text = '💡 请选择文件:点击文件选择,点击文件夹进入目录,使用📍选择当前目录'; + title = '选择文件'; + break; + case 'folder': + text = '💡 请选择文件夹:点击文件夹选择,点击文件无效,使用📍选择当前目录'; + title = '选择文件夹'; + break; + default: + text = '💡 点击文件夹进入目录,点击文件选择文件,使用📍选择当前目录'; + title = '选择文件或文件夹'; + } + helpText.textContent = text; + if (modalTitle) { + modalTitle.textContent = title; + } +} + +// 添加模态框键盘事件监听 +document.addEventListener('keydown', function(event) { + if (event.key === 'Escape') { + closeFileBrowser(); + } +}); + +// 添加模态框点击外部关闭功能 +document.addEventListener('DOMContentLoaded', function() { + const modal = document.getElementById('file-browser-modal'); + if (modal) { + modal.addEventListener('click', function(event) { + if (event.target === modal) { + closeFileBrowser(); + } + }); + } +}); + +// 将文件浏览器函数暴露到全局作用域 +window.selectFolder = selectFolder; +window.openRemoteFileBrowser = openRemoteFileBrowser; +window.closeFileBrowser = closeFileBrowser; +window.confirmFileSelection = confirmFileSelection; +window.navigateToParent = navigateToParent; +window.selectCurrentDirectory = selectCurrentDirectory; + +// 将进程管理函数暴露到全局作用域 +window.refreshProcesses = () => { + // 立即刷新进程数据,然后重置轮询计时器 + return loadProcesses().then(() => { + // 重置轮询计时器以确保平滑的更新间隔 + stopProcessPolling(); + startProcessPolling(); + }); +}; +window.refreshLogs = loadLogFiles; +window.refreshLog = refreshLog; + +window.addEventListener('load', () => { + initTrainForm(); + // 不再立即开始轮询,而是等待用户切换到进程标签页 + // startProcessPolling(); // 移动到钩子函数中 + loadProcesses(); // 仍然加载初始进程数据 +}); + diff --git a/trainer_web/static/js/logfiles/list.js b/trainer_web/static/js/logfiles/list.js new file mode 100644 index 0000000..a20d24b --- /dev/null +++ b/trainer_web/static/js/logfiles/list.js @@ -0,0 +1,194 @@ +import { getLogFiles, getLogFileContent, deleteLogFile as apiDeleteLogFile } from '../services/apiClient.js'; +import { el } from '../utils/dom.js'; +import { showNotification } from '../ui/notify.js'; +import { showConfirmDialog } from '../ui/dialog.js'; + +export function loadLogFiles() { + return getLogFiles().then((data) => { + const list = document.getElementById('logfiles-list'); + list.innerHTML = ''; + if (data.length === 0) { + list.innerHTML = '

暂无日志文件

'; + return; + } + data.sort((a, b) => new Date(b.modified_time) - new Date(a.modified_time)); + const groups = {}; + data.forEach((f) => { + let type = '自定义训练'; + const n = f.filename; + if (n.includes('train_pretrain_')) type = 'pretrain'; + else if (n.includes('train_sft_')) type = 'sft'; + else if (n.includes('train_lora_')) type = 'lora'; + else if (n.includes('train_dpo_')) type = 'dpo'; + else if (n.includes('train_ppo_')) type = 'ppo'; + else if (n.includes('train_grpo_')) type = 'grpo'; + else if (n.includes('train_spo_')) type = 'spo'; + f.train_type = type; + if (!groups[type]) groups[type] = []; + groups[type].push(f); + }); + const order = ['pretrain', 'sft', 'lora', 'dpo', 'ppo', 'grpo', 'spo', '未知']; + [...order.filter((t) => groups[t]), ...Object.keys(groups).filter((t) => !order.includes(t))].forEach((t) => { + list.appendChild(createTypeGroupWithToggle(t, groups[t])); + }); + }); +} + +function createTypeGroupWithToggle(trainType, files) { + const group = el('div', { class: 'process-type-group' }); + const header = el('div', { class: 'process-type-header' }); + header.dataset.expanded = 'true'; + const title = el('h3', { class: 'process-type-title', text: getTrainTypeDisplayName(trainType) }); + const toggle = el('button', { class: 'toggle-btn' }); + toggle.innerHTML = '▼'; + toggle.onclick = (e) => { + e.stopPropagation(); + toggleGroup(header); + }; + header.appendChild(title); + header.appendChild(toggle); + header.onclick = () => toggleGroup(header); + const content = el('div', { class: 'process-type-content' }); + files.forEach((f) => addLogFileItemToGroup(content, f)); + group.appendChild(header); + group.appendChild(content); + return group; +} + +function toggleGroup(header) { + const expanded = header.dataset.expanded === 'true'; + const content = header.nextElementSibling; + const toggle = header.querySelector('.toggle-btn'); + if (expanded) { + header.dataset.expanded = 'false'; + content.style.maxHeight = '0'; + content.style.overflow = 'hidden'; + toggle.innerHTML = '▶'; + } else { + content.style.overflow = 'hidden'; + content.style.maxHeight = 'none'; + const h = content.scrollHeight; + content.style.maxHeight = '0'; + content.offsetHeight; + header.dataset.expanded = 'true'; + content.style.maxHeight = h + 'px'; + setTimeout(() => { + content.style.maxHeight = 'none'; + content.style.overflow = 'visible'; + }, 300); + toggle.innerHTML = '▼'; + } +} + +function getTrainTypeDisplayName(trainType) { + const names = { + pretrain: '预训练 (Pretrain)', + sft: '全参数监督微调 (SFT - Full)', + lora: 'LoRA监督微调 (SFT - Lora)', + dpo: '直接偏好优化 (RL - DPO)', + ppo: 'PPO', + grpo: 'GRPO', + spo: 'SPO', + }; + return names[trainType] || trainType; +} + +function addLogFileItemToGroup(parent, logfile) { + const item = el('div', { class: 'process-item' }); + item.innerHTML = ` +
+
${logfile.filename}
+
+ 已保存 + ${logfile.modified_time} +
+
+
+ + +
+ + `; + parent.appendChild(item); + bindItemButtons(item, logfile); +} + +function bindItemButtons(item, logfile) { + const viewBtn = item.querySelector('[data-view]'); + if (viewBtn) viewBtn.addEventListener('click', () => viewLogFile(logfile.filename, viewBtn)); + const delBtn = item.querySelector('[data-del]'); + if (delBtn) delBtn.addEventListener('click', () => deleteLogFile(logfile.filename, delBtn)); +} + +function deleteLogFile(filename, button) { + showConfirmDialog(`确定要删除日志文件 "${filename}" 吗?此操作无法恢复。`, () => { + const item = button.closest('.process-item'); + const content = item.closest('.process-type-content'); + const group = content.closest('.process-type-group'); + const original = button.textContent; + button.textContent = '删除中...'; + button.disabled = true; + apiDeleteLogFile(filename) + .then((data) => { + if (data.success) { + item.remove(); + if (content.children.length === 0) group.remove(); + else { + const header = content.previousElementSibling; + if (header && header.dataset.expanded === 'true') { + content.style.maxHeight = 'none'; + const h = content.scrollHeight; + content.style.maxHeight = h + 'px'; + } + } + showNotification(`日志文件 "${filename}" 已成功删除`); + } else throw new Error(data.message || '删除失败'); + }) + .catch((e) => { + showNotification(`删除失败: ${e.message}`, 'error'); + button.textContent = original; + button.disabled = false; + }); + }); +} + +function viewLogFile(filename, button) { + const safe = filename.replace(/[^a-zA-Z0-9_.-]/g, '_').replace(/\./g, '-'); + const item = button.closest('.process-item'); + const container = item.querySelector(`#log-content-${safe}`); + const content = item.closest('.process-type-content'); + const header = content ? content.previousElementSibling : null; + if (content && header && header.dataset.expanded !== 'true') toggleGroup(header); + if (container.classList.contains('hidden')) { + container.classList.remove('hidden'); + container.textContent = '加载中...'; + getLogFileContent(filename) + .then((logs) => { + container.textContent = logs; + container.scrollTop = 0; + updateContentHeight(content, header); + }) + .catch((e) => { + container.textContent = `获取日志失败: ${e.message}`; + updateContentHeight(content, header); + }); + } else { + container.classList.add('hidden'); + updateContentHeight(content, header); + } +} + +function updateContentHeight(content, header) { + if (content && header && header.dataset.expanded === 'true') { + const current = content.style.maxHeight; + content.style.maxHeight = 'none'; + const h = content.scrollHeight; + if (current === 'none' || parseInt(current) !== h) { + content.style.maxHeight = h + 'px'; + setTimeout(() => { + if (header.dataset.expanded === 'true') content.style.maxHeight = 'none'; + }, 300); + } else content.style.maxHeight = current; + } +} + diff --git a/trainer_web/static/js/processes/list.js b/trainer_web/static/js/processes/list.js new file mode 100644 index 0000000..551c1ad --- /dev/null +++ b/trainer_web/static/js/processes/list.js @@ -0,0 +1,634 @@ +import { getProcesses, stopProcess as apiStop, deleteProcess as apiDelete } from '../services/apiClient.js'; +import { showNotification } from '../ui/notify.js'; +import { showConfirmDialog } from '../ui/dialog.js'; +import { el, clearChildren } from '../utils/dom.js'; +import { showLogs, refreshLog, clearLogTimerFor } from './logs.js'; + +// 计算训练进度信息 +function calculateRemainingTime(current, total, logText) { + // 尝试从日志中提取时间信息 + const timePatterns = [ + /remaining[\s:=]\s*(\d+)[\s:]?(\d+)?[\s:]?(\d+)?/i, // remaining: 1:30:45 or remaining: 90 + /ETA[\s:=]\s*(\d+):(\d+):(\d+)/i, // ETA: 1:30:45 + /预计剩余[\s:=]\s*(\d+)[\s小时]*[\s:]?(\d+)?[\s分钟]*/i, // 预计剩余: 1小时30分钟 + /剩余时间[\s:=]\s*(\d+)[\s小时]*[\s:]?(\d+)?[\s分钟]*/i, // 剩余时间: 1小时30分钟 + /time left[\s:=]\s*(\d+)[\s:]?(\d+)?[\s:]?(\d+)?/i, // time left: 1:30:45 + /还需[\s:=]\s*(\d+)[\s小时]*[\s:]?(\d+)?[\s分钟]*/i // 还需: 1小时30分钟 + ]; + + for (const pattern of timePatterns) { + const match = logText.match(pattern); + if (match) { + const hours = parseInt(match[1]) || 0; + const minutes = parseInt(match[2]) || 0; + const seconds = parseInt(match[3]) || 0; + + if (hours > 0 || minutes > 0 || seconds > 0) { + const parts = []; + if (hours > 0) parts.push(`${hours}小时`); + if (minutes > 0) parts.push(`${minutes}分钟`); + if (seconds > 0 && hours === 0 && minutes === 0) parts.push(`${seconds}秒`); + + return parts.join(''); + } + } + } + + // 如果没有找到时间信息,根据进度估算 + if (current > 0 && current < total) { + const remainingEpochs = total - current; + // 假设每个epoch大约需要一定时间,这里使用简单的线性估算 + // 实际应用中可以根据历史数据更准确地估算 + return `约${remainingEpochs}个epoch`; + } + + return '计算中...'; +} + +function calculateProgress(process) { + const defaultProgress = { + percentage: 0, + current: 0, + total: 0, + remaining: '计算中...', + loss: null, + epoch: null, + lr: null + }; + + // 如果进程不在运行,返回默认进度 + if (!process.running) return defaultProgress; + + // 从进程数据中提取进度信息 + if (process.progress) { + return { + percentage: process.progress.percentage || 0, + current: process.progress.current_epoch || 0, + total: process.progress.total_epochs || 0, + remaining: process.progress.remaining_time || '计算中...', + loss: process.progress.current_loss || null, + epoch: process.progress.current_epoch ? `${process.progress.current_epoch}/${process.progress.total_epochs}` : null, + lr: process.progress.current_lr || null, + step: process.progress.current_step && process.progress.total_steps ? + `${process.progress.current_step}/${process.progress.total_steps}` : null, + currentStep: process.progress.current_step || 0, + totalSteps: process.progress.total_steps || 0 + }; + } + + // 尝试从日志中提取进度信息(增强版本) + if (process.logs) { + const logText = process.logs.slice(-2000); // 取最近2000字符以获取更多上下文 + + // 提取epoch信息 - 支持多种格式 + const epochPatterns = [ + /epoch\s+(\d+)\s*\/\s*(\d+)/i, // epoch 3/10 + /Epoch\s+(\d+)\s*of\s*(\d+)/i, // Epoch 3 of 10 + /\[(\d+)\/(\d+)\]/i, // [3/10] + /epoch\s*[::]\s*(\d+)\s*\/\s*(\d+)/i, // epoch: 3/10 + /第\s*(\d+)\s*轮\s*\/\s*共\s*(\d+)\s*轮/i // 第3轮/共10轮 + ]; + + let current = 0; + let total = 0; + let percentage = 0; + let currentStep = 0; + let totalSteps = 0; + let stepInfo = null; + + for (const pattern of epochPatterns) { + const match = logText.match(pattern); + if (match) { + current = parseInt(match[1]); + total = parseInt(match[2]); + percentage = total > 0 ? Math.round((current / total) * 100) : 0; + break; + } + } + + // 提取step信息 - 支持多种格式 + const stepPatterns = [ + /step\s+(\d+)\s*\/\s*(\d+)/i, // step 150/1000 + /Step\s+(\d+)\s*of\s*(\d+)/i, // Step 150 of 1000 + /\[(\d+)\/(\d+)\]/i, // [150/1000] + /step\s*[::]\s*(\d+)\s*\/\s*(\d+)/i, // step: 150/1000 + /第\s*(\d+)\s*步\s*\/\s*共\s*(\d+)\s*步/i, // 第150步/共1000步 + /步数\s*(\d+)\s*\/\s*(\d+)/i, // 步数 150/1000 + /batch\s+(\d+)\s*\/\s*(\d+)/i, // batch 150/1000 + /Batch\s+(\d+)\s*of\s*(\d+)/i // Batch 150 of 1000 + ]; + + for (const pattern of stepPatterns) { + const match = logText.match(pattern); + if (match) { + currentStep = parseInt(match[1]); + totalSteps = parseInt(match[2]); + stepInfo = `${currentStep}/${totalSteps}`; + break; + } + } + + // 提取loss信息 - 支持多种格式 + const lossPatterns = [ + /loss[\s:=]\s*([\d.]+(?:e[+-]?\d+)?)/i, // loss: 4.32 or loss = 4.32 + /training_loss[\s:=]\s*([\d.]+(?:e[+-]?\d+)?)/i, // training_loss: 4.32 + /train_loss[\s:=]\s*([\d.]+(?:e[+-]?\d+)?)/i, // train_loss: 4.32 + /Loss[\s:=]\s*([\d.]+(?:e[+-]?\d+)?)/i, // Loss: 4.32 + /训练损失[\s:=]\s*([\d.]+(?:e[+-]?\d+)?)/i, // 训练损失: 4.32 + /损失[\s:=]\s*([\d.]+(?:e[+-]?\d+)?)/i, // 损失: 4.32 + /\s+([\d.]+(?:e[+-]?\d+)?)\s*loss/i, // 4.32 loss + /\s+([\d.]+(?:e[+-]?\d+)?)\s*训练损失/i, // 4.32 训练损失 + /(?:loss|损失|training_loss|train_loss)\s*=\s*([\d.]+(?:e[+-]?\d+)?)/i // loss = 4.32 + ]; + + let currentLoss = null; + for (const pattern of lossPatterns) { + const matches = [...logText.matchAll(pattern)]; + if (matches.length > 0) { + // 取最后一个匹配的loss值 + const lastMatch = matches[matches.length - 1]; + const lossValue = parseFloat(lastMatch[1]); + if (!isNaN(lossValue) && lossValue > 0 && lossValue < 100) { // 合理的loss范围 + currentLoss = lossValue.toFixed(4); + break; + } + } + } + + // 提取学习率信息 + const lrPatterns = [ + /lr[\s:=]\s*([\d.e+-]+)/i, // lr: 1e-4 + /learning_rate[\s:=]\s*([\d.e+-]+)/i, // learning_rate: 1e-4 + /LR[\s:=]\s*([\d.e+-]+)/i, // LR: 1e-4 + /学习率[\s:=]\s*([\d.e+-]+)/i // 学习率: 1e-4 + ]; + + let currentLr = null; + for (const pattern of lrPatterns) { + const matches = [...logText.matchAll(pattern)]; + if (matches.length > 0) { + const lastMatch = matches[matches.length - 1]; + const lrValue = parseFloat(lastMatch[1]); + if (!isNaN(lrValue) && lrValue > 0 && lrValue < 1) { // 合理的lr范围 + currentLr = lrValue.toExponential(2); + break; + } + } + } + + // 如果找到了有效的epoch信息,返回进度 + if (total > 0) { + // 重新计算百分比 - 支持epoch和step双重进度 + let finalPercentage = percentage; + if (totalSteps > 0 && currentStep > 0) { + // 基础epoch进度 + const epochPercentage = (current / total) * 100; + // 当前epoch内的step进度 + const stepPercentageInEpoch = (currentStep / totalSteps) * 100; + // 将step进度加到epoch进度上(每个epoch占总进度的1/total) + const stepContribution = stepPercentageInEpoch / total; + finalPercentage = Math.min(100, Math.max(0, Math.round(epochPercentage + stepContribution))); + } + + return { + percentage: finalPercentage, + current, + total, + remaining: calculateRemainingTime(current, total, logText), + loss: currentLoss, + epoch: `${current}/${total}`, + lr: currentLr, + step: stepInfo, + currentStep, + totalSteps + }; + } + } + + return defaultProgress; +} + +let processPollingTimer = null; + +export function startProcessPolling() { + if (processPollingTimer) clearInterval(processPollingTimer); + // 使用2秒间隔进行实时更新,确保进度信息及时刷新 + processPollingTimer = setInterval(() => { + const tab = document.querySelector('.tab.active'); + if (tab && tab.textContent.includes('进程')) { + checkProcessStatusChanges(); + } + }, 2000); +} + +export function stopProcessPolling() { + if (processPollingTimer) { + clearInterval(processPollingTimer); + processPollingTimer = null; + } +} + +export function checkProcessStatusChanges() { + return getProcesses() + .then((data) => { + let updatedCount = 0; + data.forEach((p) => { + const item = document.querySelector(`[data-process-id="${p.id}"]`); + if (!item) return; + const cur = item.dataset.processStatus; + const next = p.status; + + // 如果状态发生变化,更新整个项目 + if (cur !== next) { + updateProcessItem(item, p); + if (next === '出错') showNotification(`进程 ${p.train_type} 已出错`, 'error'); + updatedCount++; + } + // 如果进程正在运行,即使状态没变也要更新进度信息 + else if (p.running) { + updateProcessProgress(item, p); + updatedCount++; + } + }); + + // 调试用:在控制台显示更新信息(生产环境中可以移除) + if (updatedCount > 0) { + console.log(`[${new Date().toLocaleTimeString()}] 更新了 ${updatedCount} 个进程的进度信息`); + } + }) + .catch(() => { + showNotification('连接服务器失败,请刷新页面重试', 'error'); + }); +} + +export function loadProcesses() { + return getProcesses().then((data) => { + const list = document.getElementById('process-list'); + clearChildren(list); + if (data.length === 0) { + list.innerHTML = '

暂无训练进程

'; + return; + } + data.sort((a, b) => new Date(b.start_time) - new Date(a.start_time)); + const groups = {}; + data.forEach((p) => { + if (!groups[p.train_type]) groups[p.train_type] = []; + groups[p.train_type].push(p); + }); + const order = ['pretrain', 'sft', 'lora', 'dpo']; + const types = [...order.filter((t) => groups[t]), ...Object.keys(groups).filter((t) => !order.includes(t))]; + types.forEach((t) => { + const g = createTypeGroupWithToggle(t, groups[t]); + list.appendChild(g); + }); + }); +} + +function createTypeGroupWithToggle(trainType, processes) { + const group = el('div', { class: 'process-type-group' }); + const header = el('div', { class: 'process-type-header' }); + header.dataset.expanded = 'true'; + const title = el('h3', { class: 'process-type-title', text: getTrainTypeDisplayName(trainType) }); + const toggle = el('button', { class: 'toggle-btn' }); + toggle.innerHTML = '▼'; + toggle.onclick = (e) => { + e.stopPropagation(); + toggleGroup(header); + }; + header.appendChild(title); + header.appendChild(toggle); + header.onclick = () => toggleGroup(header); + const content = el('div', { class: 'process-type-content' }); + processes.forEach((p) => addProcessItemToGroup(content, p)); + group.appendChild(header); + group.appendChild(content); + return group; +} + +function toggleGroup(header) { + const expanded = header.dataset.expanded === 'true'; + const content = header.nextElementSibling; + const toggle = header.querySelector('.toggle-btn'); + if (expanded) { + header.dataset.expanded = 'false'; + content.style.maxHeight = '0'; + content.style.overflow = 'hidden'; + toggle.innerHTML = '▶'; + } else { + content.style.overflow = 'hidden'; + content.style.maxHeight = 'none'; + const h = content.scrollHeight; + content.style.maxHeight = '0'; + content.offsetHeight; + header.dataset.expanded = 'true'; + content.style.maxHeight = h + 'px'; + setTimeout(() => { + content.style.maxHeight = 'none'; + content.style.overflow = 'visible'; + }, 300); + toggle.innerHTML = '▼'; + } +} + +function getTrainTypeDisplayName(trainType) { + const names = { + pretrain: '预训练 (Pretrain)', + sft: '全参数监督微调 (SFT - Full)', + lora: 'LoRA监督微调 (SFT - Lora)', + dpo: '直接偏好优化 (RL - DPO)', + ppo: 'PPO', + grpo: 'GRPO', + spo: 'SPO', + }; + return names[trainType] || trainType; +} + +export function addProcessItemToGroup(parent, process) { + const item = el('div', { class: 'process-item' }); + let statusClass = 'status-completed'; + if (process.status === '运行中') statusClass = 'status-running'; + else if (process.status === '手动停止') statusClass = 'status-manual-stop'; + else if (process.status === '出错') statusClass = 'status-error'; + item.dataset.processId = process.id; + item.dataset.processStatus = process.status; + item.dataset.trainMonitor = process.train_monitor || 'none'; + item.dataset.swanlabUrl = process.swanlab_url || ''; + const showDelete = !process.running; + const showSwanlab = process.train_monitor !== 'none'; + const swanBtn = showSwanlab ? `` : ''; + + // 计算进度信息 + const progressInfo = calculateProgress(process); + const progressBar = process.running ? ` +
+
+
+
+
+ 进度: ${progressInfo.current}/${progressInfo.total}${progressInfo.step ? ` (${progressInfo.step})` : ''} + 剩余时间: ${progressInfo.remaining} +
+
+ ${progressInfo.loss ? `
Loss:${progressInfo.loss}
` : ''} + ${progressInfo.epoch ? `
Epoch:${progressInfo.epoch}
` : ''} + ${progressInfo.step ? `
Step:${progressInfo.step}
` : ''} + ${progressInfo.lr ? `
LR:${progressInfo.lr}
` : ''} +
+
+ ` : ''; + + item.innerHTML = ` +
+
${process.start_time}
+
${process.status}
+
+ ${progressBar} +
+ + + ${swanBtn} + ${process.running ? `` : ''} + ${showDelete ? `` : ''} +
+ + `; + parent.appendChild(item); + bindItemButtons(item, process); +} + +function bindItemButtons(item, process) { + const showBtn = item.querySelector('[data-show]'); + if (showBtn) showBtn.addEventListener('click', () => showLogs(process.id)); + const refreshBtn = item.querySelector('[data-refresh]'); + if (refreshBtn) refreshBtn.addEventListener('click', () => refreshLog(process.id)); + const swanBtn = item.querySelector('[data-swan]'); + if (swanBtn) swanBtn.addEventListener('click', () => checkAndOpenSwanlab(process.id)); + const stopBtn = item.querySelector('[data-stop]'); + if (stopBtn) stopBtn.addEventListener('click', () => stopProcess(process.id)); + const delBtn = item.querySelector('[data-del]'); + if (delBtn) delBtn.addEventListener('click', () => deleteProcess(process.id)); +} + +export function updateProcessProgress(item, process) { + // 只更新进度信息,不更新整个项目 + const progressInfo = calculateProgress(process); + + // 更新进度条 + const progressFill = item.querySelector('.progress-fill'); + const progressText = item.querySelector('.progress-info span:first-child'); + const remainingText = item.querySelector('.progress-info span:last-child'); + const metricsContainer = item.querySelector('.progress-metrics'); + + if (progressFill) { + progressFill.style.width = `${progressInfo.percentage}%`; + } + + if (progressText) { + const stepText = progressInfo.step ? ` (${progressInfo.step})` : ''; + progressText.textContent = `进度: ${progressInfo.current}/${progressInfo.total}${stepText}`; + } + + if (remainingText) { + remainingText.textContent = `剩余时间: ${progressInfo.remaining}`; + } + + if (metricsContainer) { + // 更新指标 - 只更新有变化的值来减少DOM操作 + const lossItem = metricsContainer.querySelector('.metric-item:nth-child(1) .metric-value'); + const epochItem = metricsContainer.querySelector('.metric-item:nth-child(2) .metric-value'); + const stepItem = metricsContainer.querySelector('.metric-item:nth-child(3) .metric-value'); + const lrItem = metricsContainer.querySelector('.metric-item:nth-child(4) .metric-value'); + + if (progressInfo.loss && lossItem) { + lossItem.textContent = progressInfo.loss; + } + if (progressInfo.epoch && epochItem) { + epochItem.textContent = progressInfo.epoch; + } + if (progressInfo.step && stepItem) { + stepItem.textContent = progressInfo.step; + } + if (progressInfo.lr && lrItem) { + lrItem.textContent = progressInfo.lr; + } + } +} + +export function updateProcessItem(item, process) { + item.dataset.processStatus = process.status; + item.dataset.trainMonitor = process.train_monitor || 'none'; + if (process.swanlab_url) item.dataset.swanlabUrl = process.swanlab_url; + const statusEl = item.querySelector('.process-status'); + if (statusEl) { + statusEl.classList.remove('status-running', 'status-manual-stop', 'status-error', 'status-completed'); + let cls = 'status-completed'; + if (process.status === '运行中') cls = 'status-running'; + else if (process.status === '手动停止') cls = 'status-manual-stop'; + else if (process.status === '出错') cls = 'status-error'; + statusEl.classList.add(cls); + statusEl.textContent = process.status; + } + const btnContainer = item.querySelector('div:nth-child(2)'); + const existingSwan = item.querySelector('.btn-swanlab'); + const showSwan = process.train_monitor !== 'none'; + if (showSwan && !existingSwan && btnContainer) { + const b = el('button', { class: 'btn-swanlab' }); + b.textContent = 'SwanLab'; + b.onclick = () => checkAndOpenSwanlab(process.id); + const stop = btnContainer.querySelector('.btn-stop'); + if (stop) btnContainer.insertBefore(b, stop); + else btnContainer.appendChild(b); + } else if (!showSwan && existingSwan) existingSwan.remove(); + const stopBtn = item.querySelector('.btn-stop'); + if (stopBtn) { + if (!process.running) stopBtn.remove(); + } else if (process.running && btnContainer) { + const n = el('button', { class: 'btn-stop' }); + n.textContent = '停止训练'; + n.onclick = () => stopProcess(process.id); + btnContainer.appendChild(n); + } + const delBtn = item.querySelector('.btn-delete'); + if (!process.running) { + if (!delBtn) { + const c = item.querySelector('div:last-child'); + if (c) { + const d = el('button', { class: 'btn-delete' }); + d.textContent = '删除'; + d.onclick = () => deleteProcess(process.id); + c.appendChild(d); + } + } + } else if (delBtn) delBtn.remove(); + if (!process.running) clearLogTimerFor(process.id); +} + +export function deleteProcess(processId) { + showConfirmDialog('确定要删除这个训练进程吗?此操作不可恢复。', () => { + apiDelete(processId) + .then(() => { + const item = document.querySelector(`[data-process-id="${processId}"]`); + if (item && item.parentNode) { + item.style.transition = 'opacity 0.3s, transform 0.3s'; + item.style.opacity = '0'; + item.style.transform = 'translateX(-20px)'; + setTimeout(() => { + const content = item.closest('.process-type-content'); + const group = content ? content.closest('.process-type-group') : null; + item.parentNode.removeChild(item); + if (content) { + const remain = content.querySelectorAll('.process-item'); + if (remain.length === 0 && group) { + setTimeout(() => { + group.style.transition = 'opacity 0.3s, transform 0.3s'; + group.style.opacity = '0'; + group.style.transform = 'translateY(-10px)'; + setTimeout(() => { + if (group.parentNode) group.parentNode.removeChild(group); + const left = document.querySelectorAll('.process-item'); + if (left.length === 0) { + const list = document.getElementById('process-list'); + list.innerHTML = '

暂无训练进程

'; + } + }, 300); + }, 100); + } else { + const header = content.previousElementSibling; + if (header && header.dataset.expanded === 'true') content.style.maxHeight = content.scrollHeight + 'px'; + const left = document.querySelectorAll('.process-item'); + if (left.length === 0) { + const list = document.getElementById('process-list'); + list.innerHTML = '

暂无训练进程

'; + } + } + } + }, 300); + } + clearLogTimerFor(processId); + showNotification('训练进程已删除', 'success'); + }) + .catch(() => { + showNotification('删除进程失败,请刷新页面重试', 'error'); + }); + }); +} + +export function stopProcess(processId) { + showConfirmDialog('确定要停止这个训练进程吗?', () => { + apiStop(processId) + .then(() => { + const item = document.querySelector(`[data-process-id="${processId}"]`); + if (item) { + item.dataset.processStatus = '手动停止'; + const statusEl = item.querySelector('.process-status'); + if (statusEl) { + statusEl.classList.remove('status-running', 'status-error', 'status-completed'); + statusEl.classList.add('status-manual-stop'); + statusEl.textContent = '手动停止'; + } + const stopBtn = item.querySelector('.btn-stop'); + if (stopBtn) stopBtn.remove(); + clearLogTimerFor(processId); + } + showNotification('训练进程已停止', 'info'); + getProcesses() + .then((data) => { + const updated = data.find((p) => p.id === processId); + if (updated && item) updateProcessItem(item, updated); + }) + .catch(() => {}); + }) + .catch(() => { + showNotification('停止进程失败', 'error'); + }); + }, () => { + showNotification('已取消停止操作', 'info'); + }); +} + +export function checkAndOpenSwanlab(processId) { + const item = document.querySelector(`[data-process-id="${processId}"]`); + const monitor = item ? item.dataset.trainMonitor : 'none'; + if (monitor === 'none') { + showNotification('此训练未启用监控功能', 'info'); + return; + } + let url = item ? item.dataset.swanlabUrl : ''; + if (!url || url.trim() === '') { + getProcesses() + .then((data) => { + const p = data.find((x) => x.id === processId); + if (p && p.swanlab_url) { + url = p.swanlab_url; + if (item) item.dataset.swanlabUrl = url; + openSwanlab(url); + } else { + showNotification('SwanLab链接尚未生成,请稍后再试', 'info'); + } + }) + .catch(() => { + showNotification('获取SwanLab链接失败,请稍后再试', 'error'); + }); + } else openSwanlab(url); +} + +function openSwanlab(url) { + if (!isValidUrl(url)) { + showNotification('SwanLab链接无效或尚未生成', 'info'); + return; + } + const w = window.open(url, '_blank'); + if (w) showNotification('正在打开SwanLab页面', 'info'); + else showNotification('无法打开新窗口,请检查浏览器设置', 'error'); +} + +function isValidUrl(url) { + try { + new URL(url); + return true; + } catch { + const u = String(url).toLowerCase(); + return u.startsWith('http://') || u.startsWith('https://'); + } +} + diff --git a/trainer_web/static/js/processes/logs.js b/trainer_web/static/js/processes/logs.js new file mode 100644 index 0000000..3c63c0c --- /dev/null +++ b/trainer_web/static/js/processes/logs.js @@ -0,0 +1,73 @@ +import { getLogs } from '../services/apiClient.js'; +import { setHidden } from '../utils/dom.js'; + +const logTimers = new Map(); + +export function showLogs(processId) { + const container = document.getElementById(`logs-${processId}`); + if (!container) return; + const wasHidden = container.classList.contains('hidden'); + setHidden(container, false); + if (wasHidden) { + loadLogContent(processId, container); + resetTimer(processId, container); + } else { + setHidden(container, true); + clearTimer(processId); + } +} + +export function refreshLog(processId) { + const container = document.getElementById(`logs-${processId}`); + if (!container || container.classList.contains('hidden')) return; + loadLogContent(processId, container); + resetTimer(processId, container); +} + +export function clearLogTimerFor(processId) { + clearTimer(processId); +} + +export function isLogTimerActive(processId) { + return logTimers.has(processId); +} + +function resetTimer(processId, container) { + clearTimer(processId); + const item = document.querySelector(`[data-process-id="${processId}"]`); + const running = item && item.dataset.processStatus === '运行中'; + if (!running) return; + const id = setInterval(() => { + if (container.classList.contains('hidden')) { + clearTimer(processId); + return; + } + const current = document.querySelector(`[data-process-id="${processId}"]`); + const stillRunning = current && current.dataset.processStatus === '运行中'; + if (stillRunning) loadLogContent(processId, container); + else clearTimer(processId); + }, 1000); + logTimers.set(processId, id); +} + +function clearTimer(processId) { + const id = logTimers.get(processId); + if (id) { + clearInterval(id); + logTimers.delete(processId); + } +} + +function loadLogContent(processId, container) { + const old = container.textContent; + const stickBottom = container.scrollHeight - container.scrollTop <= container.clientHeight + 10; + return getLogs(processId) + .then((logs) => { + container.textContent = logs; + if (stickBottom || old === container.textContent) container.scrollTop = container.scrollHeight; + }) + .catch((err) => { + if (!container.textContent.includes('加载失败')) container.textContent = `加载日志失败: ${err.message}`; + }); +} + diff --git a/trainer_web/static/js/services/apiClient.js b/trainer_web/static/js/services/apiClient.js new file mode 100644 index 0000000..9652661 --- /dev/null +++ b/trainer_web/static/js/services/apiClient.js @@ -0,0 +1,75 @@ +const defaultTimeout = 10000; + +export function fetchWithTimeoutAndRetry(url, options = {}, timeout = defaultTimeout, retries = 3) { + const controller = new AbortController(); + const timeoutId = setTimeout(() => controller.abort(), timeout); + const fetchOptions = { + ...options, + headers: { + ...options.headers, + 'Cache-Control': 'no-cache, no-store, must-revalidate', + Pragma: 'no-cache', + Expires: '0', + }, + signal: controller.signal, + }; + + return fetch(url, fetchOptions) + .then((response) => { + clearTimeout(timeoutId); + if (!response.ok) throw new Error(`HTTP ${response.status}`); + return response; + }) + .catch((error) => { + clearTimeout(timeoutId); + if (error.name === 'AbortError') throw new Error('请求超时'); + if (retries > 0) { + return new Promise((resolve) => { + setTimeout(() => { + resolve(fetchWithTimeoutAndRetry(url, options, timeout, retries - 1)); + }, timeout / 2); + }); + } + throw error; + }); +} + +export function getProcesses() { + return fetchWithTimeoutAndRetry('/processes').then((r) => r.json()); +} + +export function getLogs(processId) { + return fetchWithTimeoutAndRetry(`/logs/${processId}`).then((r) => r.text()); +} + +export function startTrain(payload) { + return fetchWithTimeoutAndRetry('/train', { + method: 'POST', + headers: { 'Content-Type': 'application/json', 'Cache-Control': 'no-cache' }, + body: JSON.stringify(payload), + }).then((r) => r.json()); +} + +export function stopProcess(processId) { + return fetchWithTimeoutAndRetry(`/stop/${processId}`, { method: 'POST' }).then((r) => r.json().catch(() => ({}))); +} + +export function deleteProcess(processId) { + return fetchWithTimeoutAndRetry(`/delete/${processId}`, { method: 'POST' }).then((r) => r.json().catch(() => ({}))); +} + +export function getLogFiles() { + return fetchWithTimeoutAndRetry('/logfiles').then((r) => r.json()); +} + +export function getLogFileContent(filename) { + return fetchWithTimeoutAndRetry(`/logfile-content/${encodeURIComponent(filename)}`).then((r) => r.text()); +} + +export function deleteLogFile(filename) { + return fetchWithTimeoutAndRetry(`/delete-logfile/${encodeURIComponent(filename)}`, { + method: 'DELETE', + headers: { 'Cache-Control': 'no-cache' }, + }).then((r) => r.json()); +} + diff --git a/trainer_web/static/js/services/authClient.js b/trainer_web/static/js/services/authClient.js new file mode 100644 index 0000000..18470c5 --- /dev/null +++ b/trainer_web/static/js/services/authClient.js @@ -0,0 +1,29 @@ +const KEY = 'minimind_api_key'; + +export function getApiKey() { + try { + return localStorage.getItem(KEY) || ''; + } catch (_) { + return ''; + } +} + +export function setApiKey(k) { + try { + localStorage.setItem(KEY, k || ''); + } catch (_) {} +} + +export function registerClient(payload) { + return fetch('/api/register', { + method: 'POST', + headers: { 'Content-Type': 'application/json', 'Cache-Control': 'no-cache' }, + body: JSON.stringify(payload || {}), + }).then((r) => { + if (!r.ok) throw new Error('register_failed'); + return r.json(); + }).then((res) => { + if (res && res.api_key) setApiKey(res.api_key); + return res; + }); +} \ No newline at end of file diff --git a/trainer_web/static/js/train/form.js b/trainer_web/static/js/train/form.js new file mode 100644 index 0000000..6e9f893 --- /dev/null +++ b/trainer_web/static/js/train/form.js @@ -0,0 +1,128 @@ +import { startTrain } from '../services/apiClient.js'; +import { showNotification } from '../ui/notify.js'; + +export function initTrainForm() { + const typeSel = document.getElementById('train_type'); + if (typeSel) { + typeSel.addEventListener('change', onTrainTypeChange); + typeSel.dispatchEvent(new Event('change')); + } + initGpuSelectors(); + const form = document.getElementById('train-form'); + if (form) form.addEventListener('submit', onSubmit); +} + +function onTrainTypeChange() { + const v = this.value; + const pretrainSft = document.querySelectorAll('.pretrain-sft'); + const fromWeightFields = document.querySelectorAll('.from-weight'); + const loraFields = document.querySelectorAll('.lora'); + const dpoFields = document.querySelectorAll('.dpo'); + const dpoCard = document.querySelector('.parameter-card.dpo'); + const ppoFields = document.querySelectorAll('.ppo'); + const ppoCard = document.querySelector('.parameter-card.ppo'); + const grpoFields = document.querySelectorAll('.grpo'); + const grpoCard = document.querySelector('.parameter-card.grpo'); + const spoFields = document.querySelectorAll('.spo'); + const spoCard = document.querySelector('.parameter-card.spo'); + pretrainSft.forEach((f) => (f.style.display = v === 'pretrain' || v === 'sft' || v === 'dpo' || v === 'ppo' || v === 'grpo' || v === 'spo' ? 'block' : 'none')); + fromWeightFields.forEach((f) => (f.style.display = v !== 'ppo' && v !== 'grpo' && v !== 'spo' ? 'block' : 'none')); + loraFields.forEach((f) => (f.style.display = v === 'lora' ? 'block' : 'none')); + dpoFields.forEach((f) => (f.style.display = v === 'dpo' ? 'block' : 'none')); + ppoFields.forEach((f) => (f.style.display = v === 'ppo' ? 'block' : 'none')); + if (dpoCard) dpoCard.style.display = v === 'dpo' ? 'block' : 'none'; + if (ppoCard) ppoCard.style.display = v === 'ppo' ? 'block' : 'none'; + grpoFields.forEach((f) => (f.style.display = v === 'grpo' ? 'block' : 'none')); + spoFields.forEach((f) => (f.style.display = v === 'spo' ? 'block' : 'none')); + if (grpoCard) grpoCard.style.display = v === 'grpo' ? 'block' : 'none'; + if (spoCard) spoCard.style.display = v === 'spo' ? 'block' : 'none'; + if (v === 'pretrain') setDefaults({ save_dir: '../out', save_weight: 'pretrain', epochs: '1', batch_size: '32', learning_rate: '5e-4', data_path: '../dataset/pretrain_hq.jsonl', from_weight: 'none', log_interval: '100', save_interval: '100', hidden_size: '512', num_hidden_layers: '8', max_seq_len: '512', use_moe: '0' }); + else if (v === 'sft') setDefaults({ save_dir: '../out', save_weight: 'full_sft', epochs: '2', batch_size: '16', learning_rate: '5e-7', data_path: '../dataset/sft_mini_512.jsonl', from_weight: 'pretrain', log_interval: '100', save_interval: '100', hidden_size: '512', num_hidden_layers: '8', max_seq_len: '512', use_moe: '0' }); + else if (v === 'lora') setDefaults({ save_dir: '../out/lora', lora_name: 'lora_identity', epochs: '50', batch_size: '32', learning_rate: '1e-4', data_path: '../dataset/lora_identity.jsonl', from_weight: 'full_sft', log_interval: '10', save_interval: '1', hidden_size: '512', num_hidden_layers: '8', max_seq_len: '512', use_moe: '0' }); + else if (v === 'dpo') setDefaults({ save_dir: '../out', save_weight: 'dpo', epochs: '1', batch_size: '4', learning_rate: '4e-8', data_path: '../dataset/dpo.jsonl', from_weight: 'full_sft', log_interval: '100', save_interval: '100', beta: '0.1', hidden_size: '512', num_hidden_layers: '8', max_seq_len: '1024', use_moe: '0' }); + else if (v === 'ppo') setDefaults({ save_dir: '../out', save_weight: 'ppo_actor', epochs: '1', batch_size: '2', learning_rate: '8e-8', data_path: '../dataset/rlaif-mini.jsonl', log_interval: '1', save_interval: '10', clip_epsilon: '0.1', vf_coef: '0.5', kl_coef: '0.02', reasoning: '1', update_old_actor_freq: '4', reward_model_path: '../../internlm2-1_8b-reward', hidden_size: '512', num_hidden_layers: '8', max_seq_len: '66', use_moe: '0' }); + else if (v === 'grpo') setDefaults({ save_dir: '../out', save_weight: 'grpo', epochs: '1', batch_size: '2', learning_rate: '8e-8', data_path: '../dataset/rlaif-mini.jsonl', log_interval: '1', save_interval: '10', beta: '0.02', num_generations: '8', reasoning: '1', reward_model_path: '../../internlm2-1_8b-reward', hidden_size: '512', num_hidden_layers: '8', max_seq_len: '66', use_moe: '0' }); + else if (v === 'spo') setDefaults({ save_dir: '../out', save_weight: 'spo', epochs: '1', batch_size: '2', learning_rate: '1e-7', data_path: '../dataset/rlaif-mini.jsonl', log_interval: '1', save_interval: '10', beta: '0.02', reasoning: '1', reward_model_path: '../../internlm2-1_8b-reward', hidden_size: '512', num_hidden_layers: '8', max_seq_len: '66', use_moe: '0' }); +} + +function setDefaults(map) { + Object.entries(map).forEach(([name, val]) => { + const nodes = document.querySelectorAll(`[name="${name}"]`); + nodes.forEach((node) => { + const card = node.closest('.parameter-card'); + const visible = !card || card.style.display !== 'none'; + if (visible) node.value = val; + }); + }); +} + +function initGpuSelectors() { + const hasGpu = window.hasGpu === true; + const gpuCount = Number(window.gpuCount || 0); + const modeSel = document.getElementById('training_mode'); + const single = document.getElementById('single-gpu-selection'); + const multi = document.getElementById('multi-gpu-selection'); + if (!modeSel) return; + function updateVisibility() { + const mode = modeSel.value; + if (single) single.style.display = mode === 'single_gpu' ? 'block' : 'none'; + if (multi) multi.style.display = mode === 'multi_gpu' ? 'block' : 'none'; + } + if (!hasGpu) { + modeSel.value = 'cpu'; + if (single) single.style.display = 'none'; + if (multi) multi.style.display = 'none'; + } else { + const gpuNumInput = document.getElementById('gpu_num'); + if (gpuNumInput && gpuCount > 0) gpuNumInput.value = gpuCount; + } + updateVisibility(); + modeSel.addEventListener('change', updateVisibility); +} + +function onSubmit(e) { + e.preventDefault(); + const form = e.currentTarget; + const data = {}; + const trainingModeSel = form.querySelector('#training_mode'); + const trainingMode = trainingModeSel ? trainingModeSel.value : 'cpu'; + const inputs = form.querySelectorAll('input, select, textarea'); + inputs.forEach((el) => { + const name = el.name; + if (!name || name === 'training_mode') return; + const card = el.closest('.parameter-card'); + const visible = !card || card.style.display !== 'none'; + if (!visible) return; + let value = el.value; + if (el.type === 'checkbox') { + if (!el.checked) return; + } + if (name === 'gpu_num') { + const multi = document.getElementById('multi-gpu-selection'); + if (!(multi && multi.style.display !== 'none')) return; + } + if (name === 'device') { + if (trainingMode === 'single_gpu') value = `cuda:${value}`; + else if (trainingMode === 'cpu') value = 'cpu'; + else return; + } + data[name] = value; + }); + showNotification('正在启动训练...', 'info'); + setTimeout(() => { + startTrain(data) + .then((result) => { + if (result.success) { + showNotification('训练已开始!', 'success'); + setTimeout(() => { + const processTab = document.querySelector('.tab[onclick*="processes"]'); + if (processTab) processTab.click(); + }, 1000); + } else showNotification('训练启动失败:' + result.error, 'error'); + }) + .catch(() => { + showNotification('启动训练中,请耐心等待...', 'info'); + }); + }, 1000); +} + diff --git a/trainer_web/static/js/ui/dialog.js b/trainer_web/static/js/ui/dialog.js new file mode 100644 index 0000000..f9ebca5 --- /dev/null +++ b/trainer_web/static/js/ui/dialog.js @@ -0,0 +1,51 @@ +export function showConfirmDialog(message, onConfirm, onCancel = null) { + const existing = document.querySelector('.custom-dialog'); + if (existing && existing.parentNode && existing.parentNode.classList.contains('dialog-overlay')) { + document.body.removeChild(existing.parentNode); + } + const overlay = document.createElement('div'); + overlay.className = 'dialog-overlay'; + const container = document.createElement('div'); + container.className = 'custom-dialog'; + container.innerHTML = ` +
+
${message}
+
+ + +
+
+ `; + overlay.appendChild(container); + document.body.appendChild(overlay); + setTimeout(() => { + overlay.classList.add('show'); + container.classList.add('show'); + }, 10); + const confirmBtn = container.querySelector('.dialog-confirm'); + confirmBtn.addEventListener('click', () => { + if (onConfirm) onConfirm(); + closeDialog(overlay); + }); + const cancelBtn = container.querySelector('.dialog-cancel'); + cancelBtn.addEventListener('click', () => { + if (onCancel) onCancel(); + closeDialog(overlay); + }); + overlay.addEventListener('click', (e) => { + if (e.target === overlay) { + if (onCancel) onCancel(); + closeDialog(overlay); + } + }); +} + +export function closeDialog(overlay) { + overlay.classList.remove('show'); + const container = overlay.querySelector('.custom-dialog'); + if (container) container.classList.remove('show'); + setTimeout(() => { + if (overlay.parentNode) document.body.removeChild(overlay); + }, 300); +} + diff --git a/trainer_web/static/js/ui/notify.js b/trainer_web/static/js/ui/notify.js new file mode 100644 index 0000000..33ad846 --- /dev/null +++ b/trainer_web/static/js/ui/notify.js @@ -0,0 +1,16 @@ +export function showNotification(message, type = 'success') { + const n = document.createElement('div'); + n.className = `notification notification-${type}`; + n.textContent = message; + document.body.appendChild(n); + setTimeout(() => { + n.classList.add('show'); + }, 10); + setTimeout(() => { + n.classList.remove('show'); + setTimeout(() => { + if (n.parentNode) document.body.removeChild(n); + }, 300); + }, 3000); +} + diff --git a/trainer_web/static/js/ui/tabs.js b/trainer_web/static/js/ui/tabs.js new file mode 100644 index 0000000..b4746d2 --- /dev/null +++ b/trainer_web/static/js/ui/tabs.js @@ -0,0 +1,15 @@ +import { qsa } from '../utils/dom.js'; + +export function openTab(evt, tabName, hooks = {}) { + const contents = qsa('.tab-content'); + contents.forEach((c) => c.classList.add('hidden')); + const tabs = qsa('.tab'); + tabs.forEach((t) => t.classList.remove('active')); + const target = document.getElementById(tabName); + if (target) target.classList.remove('hidden'); + if (evt && evt.currentTarget) evt.currentTarget.classList.add('active'); + if (tabName !== 'processes' && hooks.onLeaveProcesses) hooks.onLeaveProcesses(); + if (tabName === 'processes' && hooks.onEnterProcesses) hooks.onEnterProcesses(); + if (tabName === 'logfiles' && hooks.onEnterLogfiles) hooks.onEnterLogfiles(); +} + diff --git a/trainer_web/static/js/utils/dom.js b/trainer_web/static/js/utils/dom.js new file mode 100644 index 0000000..f7e0380 --- /dev/null +++ b/trainer_web/static/js/utils/dom.js @@ -0,0 +1,34 @@ +export function qs(selector, scope = document) { + return scope.querySelector(selector); +} + +export function qsa(selector, scope = document) { + return Array.from(scope.querySelectorAll(selector)); +} + +export function el(tag, attrs = {}) { + const node = document.createElement(tag); + for (const [k, v] of Object.entries(attrs)) { + if (k === 'class') node.className = v; + else if (k === 'text') node.textContent = v; + else node.setAttribute(k, v); + } + return node; +} + +export function setHidden(node, hidden) { + if (!node) return; + if (hidden) node.classList.add('hidden'); + else node.classList.remove('hidden'); +} + +export function setText(node, text) { + if (!node) return; + node.textContent = text; +} + +export function clearChildren(node) { + if (!node) return; + while (node.firstChild) node.removeChild(node.firstChild); +} + diff --git a/trainer_web/stop_web_ui.sh b/trainer_web/stop_web_ui.sh new file mode 100755 index 0000000..9ed66ae --- /dev/null +++ b/trainer_web/stop_web_ui.sh @@ -0,0 +1,27 @@ +#!/bin/bash + +SCRIPT_DIR=$(dirname "$(readlink -f "$0")") +cd "$SCRIPT_DIR" + +if [ -f "train_web_ui.pid" ]; then + pid=$(cat "train_web_ui.pid") + if ps -p "$pid" > /dev/null 2>&1; then + echo "正在停止 Web UI 服务 (PID: $pid)" + kill "$pid" + sleep 2 + # 检查是否成功停止 + if ps -p "$pid" > /dev/null 2>&1; then + echo "强制停止服务..." + kill -9 "$pid" + fi + echo "正在保存进程信息..." + echo "已保存到 'trainer_web/training_processes.json'" + sleep 1 + echo "服务已停止" + else + echo "服务未运行,但存在PID文件,已删除" + rm "train_web_ui.pid" + fi +else + echo "服务未运行(未找到PID文件)" +fi \ No newline at end of file diff --git a/trainer_web/templates/index.html b/trainer_web/templates/index.html new file mode 100644 index 0000000..301874c --- /dev/null +++ b/trainer_web/templates/index.html @@ -0,0 +1,357 @@ + + + + + + MiniMind Training Lab + + + + + + +
+ +

MiniMind Training Lab

+
+ +
+ + + +
+ +
+
+

选择训练类型并配置参数

+
+ +
+

基础训练参数

+
+
+ + +
+
+ + +
+
+ + +
+
+ + +
+
+ + +
+
+ +
+ + +
+
+
+
+ + + + + + + + + + + + + + +
+

模型结构参数

+
+
+ + +
+
+ + +
+
+ + +
+
+ + +
+
+
+ + +
+

模型保存与恢复

+
+
+ +
+ + +
+
+
+ + +
+
+ + +
+
+ + +
+
+ + +
+
+
+ + +
+
+
+
+ + +
+

其他设置

+
+
+ + +
+
+ + +
+ +
+ + +
+
+
+ + + +
+ +
+
+
+
+ + + + + + + + + + + diff --git a/trainer_web/train_web_ui.py b/trainer_web/train_web_ui.py new file mode 100644 index 0000000..d145a84 --- /dev/null +++ b/trainer_web/train_web_ui.py @@ -0,0 +1,1091 @@ +import os +import sys +import subprocess +import threading +import json +import socket +import atexit +import signal +import re +from flask import Flask, render_template, request, jsonify, redirect, url_for +from flask import g +import time +import psutil +import glob +import pathlib + +# 尝试导入torch来检测GPU +try: + import torch + HAS_TORCH = True + # 检测可用的GPU数量和设备信息 + if torch.cuda.is_available(): + GPU_COUNT = torch.cuda.device_count() + # 获取GPU设备名称 + GPU_NAMES = [torch.cuda.get_device_name(i) for i in range(GPU_COUNT)] + else: + GPU_COUNT = 0 + GPU_NAMES = [] +except ImportError: + HAS_TORCH = False + GPU_COUNT = 0 + GPU_NAMES = [] + +def calculate_training_progress(process_id, process_info): + """ + 计算训练进度信息 + 从日志文件中提取训练进度、loss、epoch等信息 + """ + progress = { + 'percentage': 0, + 'current_epoch': 0, + 'total_epochs': 0, + 'current_step': 0, + 'total_steps': 0, + 'remaining_time': '计算中...', + 'current_loss': None, + 'current_lr': None + } + + # 如果进程不在运行且没有日志文件,返回空进度 + if not process_info.get('running', False): + # 检查是否有日志文件,如果有则继续解析 + script_dir = os.path.dirname(os.path.abspath(__file__)) + log_dir = os.path.join(script_dir, '../logfile') + log_dir = os.path.abspath(log_dir) + + log_file_exists = False + if os.path.exists(log_dir): + for filename in os.listdir(log_dir): + if filename.endswith(f'{process_id}.log'): + log_file_exists = True + break + + # 如果没有日志文件且进程不在运行,返回空进度 + if not log_file_exists: + return progress + + try: + # 获取日志文件路径 + script_dir = os.path.dirname(os.path.abspath(__file__)) + log_dir = os.path.join(script_dir, '../logfile') + log_dir = os.path.abspath(log_dir) + + log_file = None + if os.path.exists(log_dir): + for filename in os.listdir(log_dir): + if filename.endswith(f'{process_id}.log'): + log_file = os.path.join(log_dir, filename) + break + + if not log_file or not os.path.exists(log_file): + return progress + + # 读取日志文件的最后1000行 + def read_last_lines(file_path, n=1000): + try: + with open(file_path, 'r', encoding='utf-8', errors='ignore') as f: + # 使用更高效的方法读取最后n行 + lines = [] + for line in f: + lines.append(line.strip()) + if len(lines) > n: + lines.pop(0) + return lines + except Exception: + return [] + + lines = read_last_lines(log_file, 1000) + + # 从日志中提取进度信息 + current_epoch = 0 + total_epochs = 0 + current_loss = None + current_lr = None + + for line in reversed(lines): # 从最新日志开始 + line = line.strip() + if not line: + continue + + # 提取epoch信息 - 支持多种格式 + if not total_epochs: + # 格式: epoch 3/10, Epoch 3 of 10, [3/10], 第3轮/共10轮, Epoch:[1/1] + epoch_patterns = [ + r'Epoch:\[(\d+)/(\d+)\]', # Epoch:[1/1] - 新格式 + r'epoch\s+(\d+)\s*/\s*(\d+)', + r'Epoch\s+(\d+)\s*of\s*(\d+)', + r'\[(\d+)/(\d+)\]', + r'epoch\s*[::]\s*(\d+)\s*/\s*(\d+)', + r'第\s*(\d+)\s*轮\s*/\s*共\s*(\d+)\s*轮' + ] + + for pattern in epoch_patterns: + match = re.search(pattern, line, re.IGNORECASE) + if match: + if 'Epoch:\[' in pattern: + current_epoch = int(match.group(1)) + total_epochs = int(match.group(2)) + else: + current_epoch = int(match.group(1)) + total_epochs = int(match.group(2)) + break + + # 提取step信息 - 支持多种格式 + # 格式: (74/44160), step 150/1000, Step 150 of 1000, [150/1000], step: 150/1000 + step_patterns = [ + r'\((\d+)/(\d+)\)', # (74/44160) - 新格式 + r'step\s+(\d+)\s*/\s*(\d+)', + r'Step\s+(\d+)\s*of\s*(\d+)', + r'\[(\d+)/(\d+)\]', + r'step\s*[::]\s*(\d+)\s*/\s*(\d+)', + r'第\s*(\d+)\s*步\s*/\s*共\s*(\d+)\s*步', + r'步数\s*(\d+)\s*/\s*(\d+)', + r'batch\s+(\d+)\s*/\s*(\d+)', # 也支持batch格式 + r'Batch\s+(\d+)\s*of\s*(\d+)' + ] + + for pattern in step_patterns: + match = re.search(pattern, line, re.IGNORECASE) + if match: + progress['current_step'] = int(match.group(1)) + progress['total_steps'] = int(match.group(2)) + break + + # 提取loss信息 - 支持多种格式 + if not current_loss: + # 格式: loss:8.896761, loss: 4.32, training_loss: 4.32, train_loss: 4.32, Loss: 4.32, 训练损失: 4.32 + loss_patterns = [ + r'loss:([\d.]+(?:e[+-]?\d+)?)', # loss:8.896761 - 新格式 + r'loss[\s:=]\s*([\d.]+(?:e[+-]?\d+)?)', # loss: 4.32 + r'training_loss[\s:=]\s*([\d.]+(?:e[+-]?\d+)?)', # training_loss: 4.32 + r'train_loss[\s:=]\s*([\d.]+(?:e[+-]?\d+)?)', # train_loss: 4.32 + r'Loss[\s:=]\s*([\d.]+(?:e[+-]?\d+)?)', # Loss: 4.32 + r'训练损失[\s:=]\s*([\d.]+(?:e[+-]?\d+)?)', # 训练损失: 4.32 + r'损失[\s:=]\s*([\d.]+(?:e[+-]?\d+)?)', # 损失: 4.32 + r'\s+([\d.]+(?:e[+-]?\d+)?)\s*loss', # 4.32 loss + r'\s+([\d.]+(?:e[+-]?\d+)?)\s*训练损失', # 4.32 训练损失 + r'(?:loss|损失|training_loss|train_loss)\s*=\s*([\d.]+(?:e[+-]?\d+)?)' # loss = 4.32 + ] + + for pattern in loss_patterns: + matches = re.findall(pattern, line, re.IGNORECASE) + if matches: + # 取最后一个匹配的loss值 + loss_value = float(matches[-1]) + if 0 < loss_value < 100: # 合理的loss范围 + current_loss = loss_value + break + + # 提取学习率信息 - 支持多种格式 + if not current_lr: + # 格式: lr:0.000549999999, lr: 1e-4, learning_rate: 1e-4, LR: 1e-4, 学习率: 1e-4 + lr_patterns = [ + r'lr:([\d.e+-]+)', # lr:0.000549999999 - 新格式 + r'lr[\s:=]\s*([\d.e+-]+)', + r'learning_rate[\s:=]\s*([\d.e+-]+)', + r'LR[\s:=]\s*([\d.e+-]+)', + r'学习率[\s:=]\s*([\d.e+-]+)' + ] + + for pattern in lr_patterns: + matches = re.findall(pattern, line, re.IGNORECASE) + if matches: + # 取最后一个匹配的lr值 + lr_value = float(matches[-1]) + if 0 < lr_value < 1: # 合理的lr范围 + current_lr = f"{lr_value:.2e}" + break + + # 如果已经收集到足够信息,提前退出 + if total_epochs and current_loss and current_lr: + break + + # 计算进度百分比 - 支持epoch和step双重进度 + percentage = 0 + if total_epochs > 0: + # 基础epoch进度 + epoch_percentage = (current_epoch / total_epochs) * 100 + + # 如果有step信息,在当前epoch内计算step进度 + if progress['total_steps'] > 0 and progress['current_step'] > 0: + # 计算当前epoch内的step进度 + step_percentage_in_epoch = (progress['current_step'] / progress['total_steps']) * 100 + # 将step进度加到epoch进度上(每个epoch占总进度的1/total_epochs) + step_contribution = step_percentage_in_epoch / total_epochs + percentage = min(100, max(0, int(epoch_percentage + step_contribution))) + else: + # 只有epoch信息的传统计算方式 + percentage = min(100, max(0, int(epoch_percentage))) + + # 更新进度字典 + progress['percentage'] = percentage + progress['current_epoch'] = current_epoch + progress['total_epochs'] = total_epochs + progress['current_loss'] = current_loss + progress['current_lr'] = current_lr + + # 估算剩余时间(增强计算) + remaining_time = '计算中...' + if current_epoch > 0 and total_epochs > current_epoch: + # 从日志中提取时间信息 + for line in reversed(lines): + # 格式: remaining: 1:30:45, ETA: 1:30:45, 预计剩余: 1小时30分钟, epoch_Time:332.0min: + time_patterns = [ + r'epoch_Time:([\d.]+)min:', # epoch_Time:332.0min: - 新格式 + r'remaining[\s:=]\s*(\d+):(\d+):(\d+)', # remaining: 1:30:45 + r'ETA[\s:=]\s*(\d+):(\d+):(\d+)', # ETA: 1:30:45 + r'预计剩余[\s:=]\s*(\d+)[\s小时]*[\s:]?(\d+)?[\s分钟]*', # 预计剩余: 1小时30分钟 + r'剩余时间[\s:=]\s*(\d+)[\s小时]*[\s:]?(\d+)?[\s分钟]*', # 剩余时间: 1小时30分钟 + r'time left[\s:=]\s*(\d+)[\s:]?(\d+)?[\s:]?(\d+)?', # time left: 1:30:45 + r'还需[\s:=]\s*(\d+)[\s小时]*[\s:]?(\d+)?[\s分钟]*' # 还需: 1小时30分钟 + ] + + for pattern in time_patterns: + match = re.search(pattern, line, re.IGNORECASE) + if match: + # 处理epoch_Time格式 + if 'epoch_Time:' in pattern: + minutes = float(match.group(1)) + if minutes > 0: + if minutes >= 60: + hours = int(minutes // 60) + remaining_minutes = int(minutes % 60) + if hours > 0: + remaining_time = f"{hours}小时{remaining_minutes}分钟" + else: + remaining_time = f"{remaining_minutes}分钟" + else: + remaining_time = f"{int(minutes)}分钟" + break + else: + groups = match.groups() + if len(groups) >= 3 and all(groups[:3]): + # 小时:分钟:秒格式 + hours = int(groups[0]) + minutes = int(groups[1]) + seconds = int(groups[2]) + if hours > 0 or minutes > 0 or seconds > 0: + parts = [] + if hours > 0: parts.append(f"{hours}小时") + if minutes > 0: parts.append(f"{minutes}分钟") + if seconds > 0 and hours == 0 and minutes == 0: + parts.append(f"{seconds}秒") + remaining_time = ''.join(parts) + break + elif len(groups) >= 2: + # 小时和分钟格式 + hours = int(groups[0]) + minutes = int(groups[1]) if groups[1] else 0 + if hours > 0 or minutes > 0: + parts = [] + if hours > 0: parts.append(f"{hours}小时") + if minutes > 0: parts.append(f"{minutes}分钟") + remaining_time = ''.join(parts) + break + + if remaining_time != '计算中...': + break + + # 如果没有找到时间信息,根据进度估算 + if remaining_time == '计算中...': + # 假设每epoch时间大致相同 + elapsed_time = time.time() - process_info.get('start_timestamp', time.time()) + if current_epoch > 0: + time_per_epoch = elapsed_time / current_epoch + remaining_epochs = total_epochs - current_epoch + remaining_seconds = remaining_epochs * time_per_epoch + + if remaining_seconds > 3600: + remaining_time = f"{remaining_seconds / 3600:.1f}小时" + elif remaining_seconds > 60: + remaining_time = f"{remaining_seconds / 60:.1f}分钟" + else: + remaining_time = f"{int(remaining_seconds)}秒" + + return { + 'percentage': percentage, + 'current_epoch': current_epoch, + 'total_epochs': total_epochs, + 'current_step': progress['current_step'], + 'total_steps': progress['total_steps'], + 'remaining_time': remaining_time, + 'current_loss': f"{current_loss:.4f}" if current_loss else None, + 'current_lr': current_lr + } + + except Exception as e: + print(f"计算进度时出错: {e}") + return progress + +# 训练方式支持检测 +def get_supported_training_methods(): + """获取当前环境支持的训练方法""" + methods = { + 'pretrain': True, # 预训练总是支持 + 'sft': True, # SFT总是支持 + 'lora': True, # LoRA总是支持 + 'dpo': True, # DPO总是支持 + 'multi_gpu': HAS_TORCH and GPU_COUNT > 1 # 多GPU训练需要PyTorch和多个GPU + } + return methods + +# 获取当前环境支持的训练方法 +SUPPORTED_METHODS = get_supported_training_methods() + +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + +app = Flask(__name__, template_folder='templates', static_folder='static') + +# 存储训练进程的信息 +training_processes = {} + +# 进程信息持久化文件 +PROCESSES_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'training_processes.json') + +# PID文件 +PID_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'train_web_ui.pid') + +# Authentication removed - allow anonymous training + +# 启动训练进程 +def start_training_process(train_type, params, client_id=None): + # 获取脚本所在目录的绝对路径 + script_dir = os.path.dirname(os.path.abspath(__file__)) + # 使用详细的时间戳作为进程ID和日志文件名 + process_id = time.strftime('%Y%m%d_%H%M%S') + # 构建logfile目录的绝对路径 + log_dir = os.path.join(script_dir, '../logfile') + log_dir = os.path.abspath(log_dir) + log_file = os.path.join(log_dir, f"train_{train_type}_{process_id}.log") + + # 确保日志目录存在 + os.makedirs(log_dir, exist_ok=True) + + # 获取GPU数量参数,如果存在且大于1,则使用torchrun启动多卡训练 + gpu_num = int(params.get('gpu_num', 0)) if 'gpu_num' in params else 0 + use_torchrun = HAS_TORCH and GPU_COUNT > 0 and gpu_num > 1 + + try: + from .dispatcher import build_command + except ImportError: + import sys as _sys + import os as _os + _sys.path.append(_os.path.dirname(_os.path.abspath(__file__))) + from dispatcher import build_command + cmd = build_command(train_type, params, gpu_num, use_torchrun) + if cmd is None: + return None + + # 创建日志文件 + with open(log_file, 'w') as f: + f.write(f"开始训练 {train_type} 进程\n") + f.write(f"命令: {' '.join(cmd)}\n\n") + + # 启动进程 + process = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + cwd=os.path.dirname(os.path.abspath(__file__)) + ) + + # 存储进程信息 + training_processes[process_id] = { + 'process': process, + 'train_type': train_type, + 'log_file': log_file, + 'start_time': time.strftime('%Y-%m-%d %H:%M:%S'), + 'start_timestamp': time.time(), # 添加时间戳用于进度计算 + 'running': True, + 'error': False, + 'train_monitor': params.get('train_monitor', 'none'), # 保存训练监控设置 + 'swanlab_url': None, + 'next_line_is_swanlab_url': False, + 'client_id': client_id + } + + # 开始读取输出 + def read_output(): + try: + while True: + output = process.stdout.readline() + if output == '' and process.poll() is not None: + break + if output: + # 检查是否是swanlab链接的行 + output_stripped = output.strip() + if training_processes[process_id]['next_line_is_swanlab_url']: + # 保存swanlab链接 + training_processes[process_id]['swanlab_url'] = output_stripped + training_processes[process_id]['next_line_is_swanlab_url'] = False + elif 'swanlab: 🚀 View run at' in output_stripped: + # 标记下一行是swanlab链接 + training_processes[process_id]['next_line_is_swanlab_url'] = True + + with open(log_file, 'a') as f: + f.write(output) + # 检查进程是否成功结束 + if process.returncode != 0: + training_processes[process_id]['error'] = True + finally: + training_processes[process_id]['running'] = False + + # 启动线程读取输出 + threading.Thread(target=read_output, daemon=True).start() + + return process_id + +# Flask路由 +@app.route('/') +def index(): + # 传递GPU信息到前端 + return render_template('index.html', has_gpu=HAS_TORCH and GPU_COUNT > 0, gpu_count=GPU_COUNT) + +@app.route('/healthz') +def healthz(): + try: + return jsonify({'status': 'ok', 'gpu': GPU_COUNT, 'methods': SUPPORTED_METHODS}), 200 + except Exception as e: + return jsonify({'status': 'error', 'message': str(e)}), 500 + +@app.route('/train', methods=['POST']) +def train(): + data = request.json + train_type = data.get('train_type') + + # 移除不相关的参数 + params = data.copy() + + # 处理复选框参数 + if 'from_resume' not in params: + params['from_resume'] = '0' + + # 启动训练进程 - 允许匿名训练,不传入client_id + process_id = start_training_process(train_type, params) + + if process_id: + return jsonify({'success': True, 'process_id': process_id}) + else: + return jsonify({'success': False, 'error': '无效的训练类型'}) + +# 测试端点 - 添加模拟训练进程 +@app.route('/test/add_process', methods=['POST']) +def add_test_process(): + """添加一个测试进程用于验证自动更新功能""" + import subprocess + import threading + + process_id = f"test_process_{int(time.time())}" + + # 创建测试训练命令 - 包含step进度和新的log格式 + test_command = [ + 'python', '-c', ''' +import time +import sys + +print("2024-11-21 14:30:00 - Starting pretrain training") +sys.stdout.flush() +time.sleep(1) + +print("2024-11-21 14:30:01 - Loading dataset from ../dataset/pretrain_hq.jsonl") +sys.stdout.flush() +time.sleep(1) + +print("2024-11-21 14:30:02 - Model initialized with 108M parameters") +sys.stdout.flush() +time.sleep(2) + +# 测试单epoch但多step的情况,使用新的log格式 +print("2024-11-21 14:30:03 - Epoch:[1/1] Starting training") +sys.stdout.flush() +time.sleep(1) + +total_steps = 20 +for step in range(1, total_steps + 1): + # 模拟step进度,使用新的格式 + if step % 5 == 0 or step == total_steps: + print(f"2024-11-21 14:30:{4 + step} - Epoch:[1/1]({step}/{total_steps}) Processing") + sys.stdout.flush() + + # 模拟训练过程,使用新的格式 + loss = 4.5 - step * 0.1 + lr = 1e-4 * (0.95 ** step) + if step % 3 == 0: + print(f"2024-11-21 14:30:{4 + step} - Epoch:[1/1]({step}/{total_steps}) loss:{loss:.6f} lr:{lr:.2e} epoch_Time:{step * 5.5:.1f}min:") + sys.stdout.flush() + + time.sleep(0.5) + +print("2024-11-21 14:30:25 - Training completed successfully") +sys.stdout.flush() + ''' + ] + + # 启动进程 + process = subprocess.Popen( + test_command, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + universal_newlines=True, + bufsize=1 + ) + + # 保存进程信息 + log_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), '../logfile') + log_dir = os.path.abspath(log_dir) + os.makedirs(log_dir, exist_ok=True) + + training_processes[process_id] = { + 'process': process, + 'train_type': 'pretrain', + 'log_file': os.path.join(log_dir, f'{process_id}.log'), + 'start_time': time.strftime('%Y-%m-%d %H:%M:%S'), + 'start_timestamp': time.time(), + 'running': True, + 'error': False, + 'train_monitor': 'none', + 'swanlab_url': None + } + + # 启动线程读取输出并写入日志文件 + def read_output(): + try: + log_file = training_processes[process_id]['log_file'] + with open(log_file, 'w') as f: + for line in iter(process.stdout.readline, ''): + if line: + f.write(line) + f.flush() + process.wait() + training_processes[process_id]['running'] = False + if process.returncode != 0: + training_processes[process_id]['error'] = True + except Exception as e: + print(f"读取测试进程输出时出错: {e}") + training_processes[process_id]['running'] = False + training_processes[process_id]['error'] = True + + threading.Thread(target=read_output, daemon=True).start() + + return jsonify({ + 'success': True, + 'process_id': process_id, + 'message': '测试进程已添加' + }) + +@app.route('/processes') +def processes(): + result = [] + for process_id, info in training_processes.items(): + # 确定状态 + status = '运行中' if info['running'] else \ + '手动停止' if 'manually_stopped' in info and info['manually_stopped'] else \ + '出错' if info['error'] else '已完成' + + # 计算训练进度信息 + progress = calculate_training_progress(process_id, info) + + result.append({ + 'id': process_id, + 'train_type': info['train_type'], + 'start_time': info['start_time'], + 'running': info['running'], + 'error': info['error'], + 'status': status, + 'train_monitor': info.get('train_monitor', 'none'), # 添加train_monitor字段 + 'swanlab_url': info.get('swanlab_url'), # 添加swanlab_url字段 + 'progress': progress # 添加进度信息 + }) + return jsonify(result) + +@app.route('/api/browse') +def browse_files(): + """ + 浏览服务器文件系统 + 支持远程文件选择功能 + """ + try: + # 获取请求的路径参数 + path = request.args.get('path', './') + + # 安全检查:限制访问范围 + script_dir = os.path.dirname(os.path.abspath(__file__)) + project_root = os.path.abspath(os.path.join(script_dir, '..')) + + # 解析请求的路径 + if path.startswith('./'): + # 相对路径,基于项目根目录 + full_path = os.path.abspath(os.path.join(project_root, path[2:])) + elif path.startswith('/'): + # 绝对路径,检查是否在项目目录内 + full_path = os.path.abspath(path) + else: + # 相对路径,基于项目根目录 + full_path = os.path.abspath(os.path.join(project_root, path)) + + # 安全检查:确保路径在项目目录内 + if not full_path.startswith(project_root): + full_path = project_root + + # 检查路径是否存在 + if not os.path.exists(full_path): + return jsonify({'error': '路径不存在', 'path': path}) + + # 获取目录内容 + if os.path.isdir(full_path): + items = [] + try: + # 列出目录内容 + for item in sorted(os.listdir(full_path)): + item_path = os.path.join(full_path, item) + + # 跳过隐藏文件和系统文件 + if item.startswith('.') or item.startswith('__'): + continue + + try: + stat = os.stat(item_path) + items.append({ + 'name': item, + 'path': item_path, # 返回绝对路径 + 'relative_path': os.path.relpath(item_path, project_root), # 同时返回相对路径用于显示 + 'type': 'directory' if os.path.isdir(item_path) else 'file', + 'size': stat.st_size if os.path.isfile(item_path) else 0, + 'modified': time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(stat.st_mtime)) + }) + except (OSError, PermissionError): + # 跳过无法访问的项目 + continue + + return jsonify({ + 'current_path': full_path, # 返回绝对路径 + 'relative_path': os.path.relpath(full_path, project_root), # 相对路径用于显示 + 'absolute_path': full_path, + 'items': items, + 'parent': os.path.dirname(full_path) if full_path != project_root else None + }) + except (OSError, PermissionError) as e: + return jsonify({'error': f'无法访问目录: {str(e)}', 'path': path}) + + else: + # 如果是文件,返回文件信息 + stat = os.stat(full_path) + return jsonify({ + 'name': os.path.basename(full_path), + 'path': full_path, # 返回绝对路径 + 'relative_path': os.path.relpath(full_path, project_root), # 相对路径用于显示 + 'type': 'file', + 'size': stat.st_size, + 'modified': time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(stat.st_mtime)) + }) + + except Exception as e: + return jsonify({'error': f'浏览文件时出错: {str(e)}'}) + +@app.route('/api/quick-paths') +def quick_paths(): + """ + 返回常用路径快捷方式 + """ + try: + script_dir = os.path.dirname(os.path.abspath(__file__)) + project_root = os.path.abspath(os.path.join(script_dir, '..')) + + quick_paths = [ + {'name': '项目根目录', 'path': './', 'type': 'directory'}, + {'name': '数据集目录', 'path': './dataset', 'type': 'directory'}, + {'name': '模型检查点', 'path': './checkpoints', 'type': 'directory'}, + {'name': '日志文件', 'path': './logfile', 'type': 'directory'} + ] + + # 验证路径是否存在 + valid_paths = [] + for item in quick_paths: + full_path = os.path.join(project_root, item['path'][2:] if item['path'].startswith('./') else item['path']) + if os.path.exists(full_path): + valid_paths.append(item) + + return jsonify({'paths': valid_paths}) + + except Exception as e: + return jsonify({'error': f'获取快捷路径时出错: {str(e)}'}) + +@app.route('/logs/') +def logs(process_id): + # 直接从本地logfile目录读取日志文件 + # 获取脚本所在目录的绝对路径 + script_dir = os.path.dirname(os.path.abspath(__file__)) + # 构建logfile目录的绝对路径 + log_dir = os.path.join(script_dir, '../logfile') + log_dir = os.path.abspath(log_dir) + + # 查找匹配的日志文件 + log_file = None + if os.path.exists(log_dir): + for filename in os.listdir(log_dir): + if filename.endswith(f'{process_id}.log'): + log_file = os.path.join(log_dir, filename) + break + + if not log_file or not os.path.exists(log_file): + return '日志文件不存在或已被删除' + + try: + # 使用高效且健壮的方法读取文件的最后200行 + def read_last_n_lines(file_path, n=200): + # 使用二进制模式读取文件,避免编码问题 + with open(file_path, 'rb') as f: + # 获取文件大小 + f.seek(0, os.SEEK_END) + file_size = f.tell() + + # 如果文件很小,直接读取整个文件 + if file_size < 1024 * 1024: # 小于1MB的文件直接读取 + f.seek(0) + content = f.read() + return process_content(content) + + # 对于大文件,使用缓冲读取末尾部分 + # 估计需要读取的字节数(假设每行平均100字节) + buffer_size = n * 200 # 为了保险,读取更多字节 + + # 定位到适当的位置 + position = max(0, file_size - buffer_size) + f.seek(position) + + # 读取缓冲区内容 + buffer = f.read(file_size - position) + + # 处理缓冲区内容 + lines = process_content(buffer) + + # 确保我们获取到完整的行 + # 如果缓冲区不是从文件开头开始,第一个行可能不完整 + if position > 0: + # 跳过第一个可能不完整的行 + if len(lines) > 1: + lines = lines[1:] + else: + # 如果只有一行且不在文件开头,可能需要读取更多 + # 这里简单处理,直接读取整个文件(罕见情况) + f.seek(0) + content = f.read() + lines = process_content(content) + + # 返回最后n行 + return lines[-n:] if len(lines) > n else lines + + def process_content(content): + # 尝试多种编码方式解码内容 + encodings = ['utf-8', 'latin-1', 'gbk', 'gb2312'] + for encoding in encodings: + try: + text = content.decode(encoding) + # 使用True参数保留换行符,确保行分隔符正确 + return text.splitlines(True) + except UnicodeDecodeError: + continue + # 如果所有编码都失败,使用错误替换模式 + text = content.decode('utf-8', errors='replace') + return text.splitlines(True) + + # 读取最后200行 + last_200_lines = read_last_n_lines(log_file, 200) + + # 确保返回的内容顺序正确,并且不包含空行 + return ''.join(last_200_lines) + except Exception as e: + return f'读取日志失败: {str(e)}' + +@app.route('/logfiles') +def get_logfiles(): + # 获取脚本所在目录的绝对路径 + script_dir = os.path.dirname(os.path.abspath(__file__)) + # 构建logfile目录的绝对路径 + log_dir = os.path.join(script_dir, '../logfile') + log_dir = os.path.abspath(log_dir) + + logfiles = [] + # 获取所有进程ID用于关联 + process_pids = set(training_processes.keys()) + + if os.path.exists(log_dir): + for filename in os.listdir(log_dir): + if filename.endswith('.log') and filename.startswith('train_'): + file_path = os.path.join(log_dir, filename) + try: + modified_time = os.path.getmtime(file_path) + formatted_time = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(modified_time)) + # 提取进程ID + pid = filename.split('.')[-2].split('_')[-1] if filename.endswith('.log') else None + logfiles.append({ + 'filename': filename, + 'modified_time': formatted_time, + 'size': os.path.getsize(file_path), + 'process_id': pid, + 'has_process': pid in process_pids + }) + except Exception as e: + continue + # 按修改时间倒序排序,最新的在前面 + logfiles.sort(key=lambda x: x['modified_time'], reverse=True) + return jsonify(logfiles) + +@app.route('/logfile-content/') +def get_logfile_content(filename): + # 安全检查:确保文件名不包含路径遍历字符 + if '..' in filename or '/' in filename or '\\' in filename: + return jsonify({'error': 'Invalid filename'}), 400 + + # 获取脚本所在目录的绝对路径 + script_dir = os.path.dirname(os.path.abspath(__file__)) + # 构建logfile目录的绝对路径,train_web_ui.py在scripts目录下 + log_dir = os.path.join(script_dir, '../logfile') + log_dir = os.path.abspath(log_dir) + log_file = os.path.join(log_dir, filename) + + try: + # 使用二进制模式读取文件,可以更可靠地保留原始换行符 + with open(log_file, 'rb') as f: + content_bytes = f.read() + + # 尝试多种编码方式解码,确保正确处理换行符 + encodings = ['utf-8', 'latin-1', 'gbk', 'gb2312'] + content = None + + for encoding in encodings: + try: + # 解码文件内容,保留原始换行符 + content = content_bytes.decode(encoding) + break + except UnicodeDecodeError: + continue + + # 如果所有编码都失败,使用errors='replace'参数处理不可解码的字符 + if content is None: + content = content_bytes.decode('utf-8', errors='replace') + + # 确保返回的内容正确保留所有换行符 + return content + except FileNotFoundError: + return jsonify({'error': 'Log file not found'}), 404 + except Exception as e: + return jsonify({'error': str(e)}), 500 + +@app.route('/delete-logfile/', methods=['DELETE']) +def delete_logfile(filename): + # 获取脚本所在目录的绝对路径 + script_dir = os.path.dirname(os.path.abspath(__file__)) + # 构建logfile目录的绝对路径 + log_dir = os.path.join(script_dir, '../logfile') + log_dir = os.path.abspath(log_dir) + + # 安全检查:防止路径遍历攻击 + if '..' in filename or '/' in filename or '\\' in filename: + return jsonify({'success': False, 'message': '非法的文件名'}) + + log_file = os.path.join(log_dir, filename) + if os.path.exists(log_file) and os.path.isfile(log_file): + try: + os.remove(log_file) + return jsonify({'success': True, 'message': '日志文件删除成功'}) + except Exception as e: + print(f"删除日志文件失败: {str(e)}") + return jsonify({'success': False, 'message': f'删除失败: {str(e)}'}) + return jsonify({'success': False, 'message': '日志文件不存在'}) + + +@app.route('/stop/', methods=['POST']) +def stop(process_id): + if process_id in training_processes and training_processes[process_id]['running']: + process = training_processes[process_id]['process'] + # 在Windows上使用terminate,在Unix上尝试优雅终止 + try: + process.terminate() + # 等待进程结束 + process.wait(timeout=5) + # 标记为手动停止 + training_processes[process_id]['running'] = False + training_processes[process_id]['manually_stopped'] = True + except subprocess.TimeoutExpired: + # 如果超时,强制杀死 + process.kill() + # 标记为手动停止 + training_processes[process_id]['running'] = False + training_processes[process_id]['manually_stopped'] = True + return jsonify({'success': True}) + return jsonify({'success': False}) + +@app.route('/delete/', methods=['POST']) +def delete(process_id): + if process_id in training_processes: + # 确保进程已经停止 + if training_processes[process_id]['running']: + # 如果进程还在运行,先停止它 + try: + process = training_processes[process_id]['process'] + process.terminate() + try: + process.wait(timeout=3) + except subprocess.TimeoutExpired: + process.kill() + except Exception as e: + print(f"停止进程失败: {str(e)}") + + # 从进程字典中删除 + del training_processes[process_id] + + # 可选:删除对应的日志文件 + try: + script_dir = os.path.dirname(os.path.abspath(__file__)) + log_dir = os.path.join(script_dir, '../logfile') + log_dir = os.path.abspath(log_dir) + + if os.path.exists(log_dir): + for filename in os.listdir(log_dir): + if filename.endswith(f'{process_id}.log'): + os.remove(os.path.join(log_dir, filename)) + except Exception as e: + print(f"删除日志文件失败: {str(e)}") + + return jsonify({'success': True}) + return jsonify({'success': False}) + +def find_available_port(start_port=12581, max_attempts=100): + """查找可用的端口号 + + Args: + start_port: 起始端口号 + max_attempts: 最大尝试次数 + + Returns: + 可用的端口号,如果没有找到则返回None + """ + for port in range(start_port, start_port + max_attempts): + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + result = sock.connect_ex(('localhost', port)) + sock.close() + if result != 0: # 端口可用 + return port + return None + +def save_processes_info(): + """保存训练进程信息到文件""" + try: + # 创建一个不包含进程对象的可序列化版本 + serializable_processes = {} + for pid, info in training_processes.items(): + serializable_processes[pid] = { + 'pid': info.get('pid', info.get('process').pid) if isinstance(info.get('process'), subprocess.Popen) else info.get('pid'), + 'train_type': info['train_type'], + 'log_file': info['log_file'], + 'start_time': info['start_time'], + 'running': info['running'], + 'error': info.get('error', False), + 'manually_stopped': info.get('manually_stopped', False), + 'train_monitor': info.get('train_monitor', 'none'), # 保存train_monitor + 'swanlab_url': info.get('swanlab_url'), + 'client_id': info.get('client_id') + } + + with open(PROCESSES_FILE, 'w', encoding='utf-8') as f: + json.dump(serializable_processes, f, ensure_ascii=False, indent=2) + except Exception as e: + print(f"保存进程信息失败: {str(e)}") + +def load_processes_info(): + """从文件加载训练进程信息""" + global training_processes + try: + if os.path.exists(PROCESSES_FILE): + with open(PROCESSES_FILE, 'r', encoding='utf-8') as f: + loaded_processes = json.load(f) + + # 检查每个进程是否还在运行 + for pid, info in loaded_processes.items(): + # 确保所有需要的字段都存在 + if 'swanlab_url' not in info: + info['swanlab_url'] = None + if 'manually_stopped' not in info: + info['manually_stopped'] = False + if 'error' not in info: + info['error'] = False + if 'train_monitor' not in info: + info['train_monitor'] = 'none' + if 'client_id' not in info: + info['client_id'] = None + + if info['running']: + try: + # 检查进程是否还在运行 + proc = psutil.Process(info['pid']) + if proc.is_running() and proc.status() != 'zombie': + # 进程仍在运行,恢复信息 + training_processes[pid] = info + else: + # 进程已停止 + info['running'] = False + # 如果进程未被明确标记为完成或出错,则默认为手动停止 + if not info['error']: + info['manually_stopped'] = True + training_processes[pid] = info + except (psutil.NoSuchProcess, psutil.AccessDenied): + # 进程不存在或无权限访问 + info['running'] = False + # 如果进程未被明确标记为完成或出错,则默认为手动停止 + if not info['error']: + info['manually_stopped'] = True + training_processes[pid] = info + else: + # 进程已停止,直接恢复 + training_processes[pid] = info + except Exception as e: + print(f"加载进程信息失败: {str(e)}") + +def handle_exit(signum, frame): + """处理程序退出信号,保存进程信息""" + print("正在保存进程信息... save at 'trainer_web/training_processes.json'...") + save_processes_info() + # 删除PID文件 + if os.path.exists(PID_FILE): + try: + os.remove(PID_FILE) + except: + pass + sys.exit(0) + +# 注册退出处理器 +signal.signal(signal.SIGINT, handle_exit) # Ctrl+C +if hasattr(signal, 'SIGTERM'): + signal.signal(signal.SIGTERM, handle_exit) # 终止信号 + +# 注册程序退出时的处理函数 +atexit.register(save_processes_info) + +if __name__ == '__main__': + # 加载已保存的进程信息 + load_processes_info() + + # 创建PID文件,用于标识web进程 + with open(PID_FILE, 'w') as f: + f.write(str(os.getpid())) + + # 尝试使用默认端口12581,如果被占用则自动寻找可用端口 + port = find_available_port(12581) + if port is not None: + print(f"启动Flask服务器在 http://0.0.0.0:{port}") + print(f"使用nohup启动可保持服务持续运行: nohup python -u scripts/train_web_ui.py &") + # 使用0.0.0.0作为host以兼容VSCode的端口转发功能 + app.run(host='0.0.0.0', port=port, debug=False) # 生产环境关闭debug + else: + print("无法找到可用的端口,请检查系统端口占用情况") + # 删除PID文件 + if os.path.exists(PID_FILE): + try: + os.remove(PID_FILE) + except: + pass + sys.exit(1) +# Registration endpoint removed - allow anonymous training \ No newline at end of file diff --git a/trainer_web/training_processes.json b/trainer_web/training_processes.json new file mode 100644 index 0000000..9e26dfe --- /dev/null +++ b/trainer_web/training_processes.json @@ -0,0 +1 @@ +{} \ No newline at end of file