mirror of
https://github.com/jingyaogong/minimind.git
synced 2026-04-25 08:48:16 +08:00
Init Web UI
This commit is contained in:
parent
81e869fc3e
commit
9bfb58ad1e
6
.gitignore
vendored
6
.gitignore
vendored
@ -1,4 +1,8 @@
|
||||
__pycache__
|
||||
model/__pycache__
|
||||
out
|
||||
website/
|
||||
docs-minimind/
|
||||
docs-minimind/
|
||||
logfile
|
||||
dataset
|
||||
checkpoints
|
||||
338
scripts/static/css/style.css
Normal file
338
scripts/static/css/style.css
Normal file
@ -0,0 +1,338 @@
|
||||
body {
|
||||
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, 'Helvetica Neue', Arial, sans-serif;
|
||||
line-height: 1.6;
|
||||
color: #e0e0e0;
|
||||
max-width: 1200px;
|
||||
margin: 0 auto;
|
||||
padding: 20px;
|
||||
background-color: #121212;
|
||||
min-height: 100vh;
|
||||
}
|
||||
h1 {
|
||||
color: #ffffff;
|
||||
text-align: center;
|
||||
margin-bottom: 30px;
|
||||
text-shadow: 0 2px 10px rgba(0, 0, 0, 0.5);
|
||||
font-size: 2.5em;
|
||||
font-weight: 700;
|
||||
}
|
||||
.tabs {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
margin-bottom: 20px;
|
||||
border-radius: 10px;
|
||||
overflow: hidden;
|
||||
box-shadow: 0 5px 15px rgba(0, 0, 0, 0.1);
|
||||
}
|
||||
.tab {
|
||||
padding: 12px 0;
|
||||
cursor: pointer;
|
||||
background-color: #2a2a2a;
|
||||
border: none;
|
||||
font-size: 16px;
|
||||
font-weight: 500;
|
||||
transition: all 0.3s ease;
|
||||
position: relative;
|
||||
color: #cccccc;
|
||||
width: 30%;
|
||||
text-align: center;
|
||||
}
|
||||
.tab.active {
|
||||
background: linear-gradient(135deg, #4a148c 0%, #8e24aa 100%);
|
||||
color: white;
|
||||
transform: translateY(-2px);
|
||||
box-shadow: 0 5px 15px rgba(0, 0, 0, 0.3);
|
||||
}
|
||||
.form-container {
|
||||
background: rgba(30, 30, 30, 0.9);
|
||||
padding: 30px;
|
||||
border-radius: 15px;
|
||||
box-shadow: 0 10px 30px rgba(0, 0, 0, 0.3);
|
||||
margin-bottom: 30px;
|
||||
border: 1px solid #333;
|
||||
}
|
||||
|
||||
/* 参数卡片样式 */
|
||||
.parameter-card {
|
||||
background-color: #2d2d2d;
|
||||
border-radius: 8px;
|
||||
padding: 20px;
|
||||
margin-bottom: 20px;
|
||||
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.3);
|
||||
transition: transform 0.2s ease, box-shadow 0.2s ease;
|
||||
padding-left: 5%;
|
||||
padding-right: 5%;
|
||||
}
|
||||
|
||||
.parameter-card:hover {
|
||||
transform: translateY(-2px);
|
||||
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.4);
|
||||
}
|
||||
|
||||
/* 卡片标题样式 */
|
||||
.card-title {
|
||||
color: #e0e0e0;
|
||||
font-size: 1.4rem;
|
||||
font-weight: bold;
|
||||
margin-top: 0;
|
||||
margin-bottom: 20px;
|
||||
padding-bottom: 10px;
|
||||
border-bottom: 1px solid #404040;
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
/* 提交按钮容器 */
|
||||
.submit-container {
|
||||
text-align: center;
|
||||
margin-top: 30px;
|
||||
padding-top: 20px;
|
||||
border-top: 1px solid #4d4d4d;
|
||||
}
|
||||
|
||||
/* 参数内容容器 */
|
||||
.parameter-content {
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
.form-group {
|
||||
width: 40%;
|
||||
float: left;
|
||||
margin-bottom: 15px;
|
||||
margin-right: 10%;
|
||||
box-sizing: border-box;
|
||||
}
|
||||
|
||||
.form-group:nth-child(2n) {
|
||||
margin-right: 0;
|
||||
}
|
||||
|
||||
/* 确保复选框组占满整行 */
|
||||
.form-group.checkbox-group {
|
||||
width: 100%;
|
||||
margin-right: 0;
|
||||
}
|
||||
|
||||
.form-group.pretrain-sft, .form-group.lora {
|
||||
/* 移除100%宽度设置,让这些参数也遵循每行两个的布局 */
|
||||
}
|
||||
|
||||
.parameter-content::after {
|
||||
content: "";
|
||||
display: table;
|
||||
clear: both;
|
||||
}
|
||||
label {
|
||||
display: block;
|
||||
margin-bottom: 5px;
|
||||
color: #ffffff;
|
||||
font-weight: 600;
|
||||
font-size: 0.9rem;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.5px;
|
||||
}
|
||||
input[type="text"], input[type="number"], select {
|
||||
width: 100%;
|
||||
padding: 12px 15px;
|
||||
border: 2px solid #444;
|
||||
border-radius: 8px;
|
||||
font-size: 0.9rem;
|
||||
transition: all 0.3s ease;
|
||||
background-color: #2a2a2a;
|
||||
color: #ffffff;
|
||||
}
|
||||
|
||||
/* 确保textarea也适应两列布局 */
|
||||
textarea {
|
||||
width: 100%;
|
||||
padding: 12px 15px;
|
||||
border: 2px solid #444;
|
||||
border-radius: 8px;
|
||||
font-size: 0.9rem;
|
||||
transition: all 0.3s ease;
|
||||
background-color: #2a2a2a;
|
||||
color: #ffffff;
|
||||
resize: vertical;
|
||||
}
|
||||
|
||||
input[type="text"]:focus, input[type="number"]:focus, select:focus {
|
||||
outline: none;
|
||||
border-color: #8e24aa;
|
||||
box-shadow: 0 0 0 3px rgba(142, 36, 170, 0.2);
|
||||
background-color: #333;
|
||||
}
|
||||
.checkbox-group {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
}
|
||||
.checkbox-group input[type="checkbox"] {
|
||||
width: auto;
|
||||
margin-right: 10px;
|
||||
}
|
||||
button {
|
||||
background: linear-gradient(135deg, #4a148c 0%, #8e24aa 100%);
|
||||
color: white;
|
||||
border: none;
|
||||
padding: 12px 25px;
|
||||
border-radius: 8px;
|
||||
font-size: 16px;
|
||||
font-weight: 600;
|
||||
cursor: pointer;
|
||||
transition: all 0.3s ease;
|
||||
box-shadow: 0 4px 15px rgba(0, 0, 0, 0.3);
|
||||
position: relative;
|
||||
overflow: hidden;
|
||||
}
|
||||
button:hover {
|
||||
transform: translateY(-2px);
|
||||
box-shadow: 0 6px 20px rgba(0, 0, 0, 0.15);
|
||||
}
|
||||
|
||||
button:active {
|
||||
transform: translateY(0);
|
||||
}
|
||||
.logs-container {
|
||||
background-color: #0d0d0d;
|
||||
color: #e0e0e0;
|
||||
padding: 20px;
|
||||
border-radius: 10px;
|
||||
max-height: 300px;
|
||||
overflow-y: auto;
|
||||
margin-top: 15px;
|
||||
font-family: 'Courier New', monospace;
|
||||
box-shadow: 0 5px 15px rgba(0, 0, 0, 0.4);
|
||||
transition: all 0.3s ease;
|
||||
border: 1px solid #333;
|
||||
}
|
||||
|
||||
.logs-container:hover {
|
||||
box-shadow: 0 8px 25px rgba(0, 0, 0, 0.3);
|
||||
}
|
||||
.process-item {
|
||||
background: rgba(30, 30, 30, 0.9);
|
||||
padding: 20px;
|
||||
margin-bottom: 15px;
|
||||
border-radius: 12px;
|
||||
box-shadow: 0 4px 15px rgba(0, 0, 0, 0.3);
|
||||
transition: all 0.3s ease;
|
||||
border: 1px solid #444;
|
||||
}
|
||||
|
||||
.process-item:hover {
|
||||
transform: translateY(-2px);
|
||||
box-shadow: 0 6px 20px rgba(0, 0, 0, 0.15);
|
||||
}
|
||||
.process-info {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
margin-bottom: 10px;
|
||||
}
|
||||
.process-status {
|
||||
padding: 5px 12px;
|
||||
border-radius: 20px;
|
||||
font-size: 12px;
|
||||
font-weight: bold;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.5px;
|
||||
}
|
||||
.status-running {
|
||||
background: linear-gradient(135deg, #11998e 0%, #38ef7d 100%);
|
||||
color: white;
|
||||
}
|
||||
.status-completed {
|
||||
background: linear-gradient(135deg, #f093fb 0%, #f5576c 100%);
|
||||
color: white;
|
||||
}
|
||||
.status-error {
|
||||
background: linear-gradient(135deg, #ff416c 0%, #ff4b2b 100%);
|
||||
color: white;
|
||||
}
|
||||
.btn-stop {
|
||||
background: linear-gradient(135deg, #ff416c 0%, #ff4b2b 100%);
|
||||
padding: 8px 15px;
|
||||
font-size: 14px;
|
||||
border-radius: 6px;
|
||||
}
|
||||
.btn-stop:hover {
|
||||
transform: translateY(-1px);
|
||||
box-shadow: 0 4px 10px rgba(255, 65, 108, 0.3);
|
||||
}
|
||||
.btn-logs {
|
||||
background: linear-gradient(135deg, #4facfe 0%, #00f2fe 100%);
|
||||
padding: 8px 15px;
|
||||
font-size: 14px;
|
||||
margin-right: 10px;
|
||||
border-radius: 6px;
|
||||
}
|
||||
.btn-logs:hover {
|
||||
transform: translateY(-1px);
|
||||
box-shadow: 0 4px 10px rgba(79, 172, 254, 0.3);
|
||||
}
|
||||
.hidden {
|
||||
display: none;
|
||||
}
|
||||
.section-title {
|
||||
color: #ffffff;
|
||||
font-size: 20px;
|
||||
margin-bottom: 25px;
|
||||
text-shadow: 0 2px 5px rgba(0, 0, 0, 0.5);
|
||||
font-weight: 600;
|
||||
padding-bottom: 10px;
|
||||
border-bottom: 3px solid rgba(142, 36, 170, 0.3);
|
||||
}
|
||||
|
||||
/* 添加滚动条样式 */
|
||||
::-webkit-scrollbar {
|
||||
width: 8px;
|
||||
}
|
||||
|
||||
::-webkit-scrollbar-track {
|
||||
background: rgba(255, 255, 255, 0.05);
|
||||
border-radius: 10px;
|
||||
}
|
||||
|
||||
::-webkit-scrollbar-thumb {
|
||||
background: linear-gradient(135deg, #4a148c 0%, #8e24aa 100%);
|
||||
border-radius: 10px;
|
||||
}
|
||||
|
||||
::-webkit-scrollbar-thumb:hover {
|
||||
background: linear-gradient(135deg, #6a1b9a 0%, #ab47bc 100%);
|
||||
}
|
||||
|
||||
/* 添加动画效果 */
|
||||
@keyframes fadeIn {
|
||||
from {
|
||||
opacity: 0;
|
||||
transform: translateY(10px);
|
||||
}
|
||||
to {
|
||||
opacity: 1;
|
||||
transform: translateY(0);
|
||||
}
|
||||
}
|
||||
|
||||
.tab-content {
|
||||
animation: fadeIn 0.5s ease-out;
|
||||
}
|
||||
|
||||
/* 响应式设计 */
|
||||
@media (max-width: 768px) {
|
||||
body {
|
||||
padding: 10px;
|
||||
}
|
||||
|
||||
.tabs {
|
||||
flex-direction: column;
|
||||
}
|
||||
|
||||
.tab {
|
||||
margin-right: 0;
|
||||
margin-bottom: 5px;
|
||||
border-radius: 5px;
|
||||
}
|
||||
|
||||
.form-container {
|
||||
padding: 20px;
|
||||
}
|
||||
}
|
||||
242
scripts/static/js/script.js
Normal file
242
scripts/static/js/script.js
Normal file
@ -0,0 +1,242 @@
|
||||
// 标签切换功能
|
||||
function openTab(evt, tabName) {
|
||||
var i, tabContent, tabLinks;
|
||||
tabContent = document.getElementsByClassName("tab-content");
|
||||
for (i = 0; i < tabContent.length; i++) {
|
||||
tabContent[i].classList.add("hidden");
|
||||
}
|
||||
tabLinks = document.getElementsByClassName("tab");
|
||||
for (i = 0; i < tabLinks.length; i++) {
|
||||
tabLinks[i].classList.remove("active");
|
||||
}
|
||||
document.getElementById(tabName).classList.remove("hidden");
|
||||
evt.currentTarget.classList.add("active");
|
||||
|
||||
// 如果切换到进程页面,刷新进程列表
|
||||
if (tabName === 'processes') {
|
||||
loadProcesses();
|
||||
} else if (tabName === 'logfiles') {
|
||||
loadLogFiles();
|
||||
}
|
||||
}
|
||||
|
||||
// 根据训练类型显示/隐藏特定参数
|
||||
document.getElementById('train_type').addEventListener('change', function() {
|
||||
const trainType = this.value;
|
||||
const pretrainSftFields = document.querySelectorAll('.pretrain-sft');
|
||||
const loraFields = document.querySelectorAll('.lora');
|
||||
|
||||
pretrainSftFields.forEach(field => {
|
||||
field.style.display = (trainType === 'pretrain' || trainType === 'sft') ? 'block' : 'none';
|
||||
});
|
||||
|
||||
loraFields.forEach(field => {
|
||||
field.style.display = trainType === 'lora' ? 'block' : 'none';
|
||||
});
|
||||
|
||||
// 设置默认值
|
||||
if (trainType === 'pretrain') {
|
||||
document.getElementById('save_dir').value = '../out';
|
||||
document.getElementById('save_weight').value = 'pretrain';
|
||||
document.getElementById('epochs').value = '1';
|
||||
document.getElementById('batch_size').value = '32';
|
||||
document.getElementById('learning_rate').value = '5e-4';
|
||||
document.getElementById('data_path').value = '../dataset/pretrain_hq.jsonl';
|
||||
document.getElementById('from_weight').value = 'none';
|
||||
document.getElementById('log_interval').value = '100';
|
||||
document.getElementById('save_interval').value = '100';
|
||||
} else if (trainType === 'sft') {
|
||||
document.getElementById('save_dir').value = '../out';
|
||||
document.getElementById('save_weight').value = 'full_sft';
|
||||
document.getElementById('epochs').value = '2';
|
||||
document.getElementById('batch_size').value = '16';
|
||||
document.getElementById('learning_rate').value = '5e-7';
|
||||
document.getElementById('data_path').value = '../dataset/sft_mini_512.jsonl';
|
||||
document.getElementById('from_weight').value = 'pretrain';
|
||||
document.getElementById('log_interval').value = '100';
|
||||
document.getElementById('save_interval').value = '100';
|
||||
} else if (trainType === 'lora') {
|
||||
document.getElementById('save_dir').value = '../out/lora';
|
||||
document.getElementById('lora_name').value = 'lora_identity';
|
||||
document.getElementById('epochs').value = '50';
|
||||
document.getElementById('batch_size').value = '32';
|
||||
document.getElementById('learning_rate').value = '1e-4';
|
||||
document.getElementById('data_path').value = '../dataset/lora_identity.jsonl';
|
||||
document.getElementById('from_weight').value = 'full_sft';
|
||||
document.getElementById('log_interval').value = '10';
|
||||
document.getElementById('save_interval').value = '1';
|
||||
}
|
||||
});
|
||||
|
||||
// 初始触发一次change事件以设置默认值
|
||||
document.getElementById('train_type').dispatchEvent(new Event('change'));
|
||||
|
||||
// 加载进程列表
|
||||
function loadProcesses() {
|
||||
fetch('/processes')
|
||||
.then(response => response.json())
|
||||
.then(data => {
|
||||
const processList = document.getElementById('process-list');
|
||||
processList.innerHTML = '';
|
||||
|
||||
if (data.length === 0) {
|
||||
processList.innerHTML = '<p>暂无训练进程</p>';
|
||||
return;
|
||||
}
|
||||
|
||||
data.forEach(process => {
|
||||
const processItem = document.createElement('div');
|
||||
processItem.className = 'process-item';
|
||||
|
||||
let statusClass = '';
|
||||
let statusText = '';
|
||||
if (process.running) {
|
||||
statusClass = 'status-running';
|
||||
statusText = '运行中';
|
||||
} else if (process.error) {
|
||||
statusClass = 'status-error';
|
||||
statusText = '出错';
|
||||
} else {
|
||||
statusClass = 'status-completed';
|
||||
statusText = '已完成';
|
||||
}
|
||||
|
||||
processItem.innerHTML = `
|
||||
<div class="process-info">
|
||||
<div>
|
||||
<strong>${process.train_type}</strong> - ${process.start_time}
|
||||
</div>
|
||||
<div>
|
||||
<span class="process-status ${statusClass}">${statusText}</span>
|
||||
</div>
|
||||
</div>
|
||||
<div>
|
||||
<button class="btn-logs" onclick="showLogs('${process.id}')">查看日志</button>
|
||||
${process.running ? `<button class="btn-stop" onclick="stopProcess('${process.id}')">停止训练</button>` : ''}
|
||||
</div>
|
||||
<div id="logs-${process.id}" class="logs-container hidden"></div>
|
||||
`;
|
||||
|
||||
processList.appendChild(processItem);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// 显示日志
|
||||
function showLogs(processId) {
|
||||
const logsContainer = document.getElementById(`logs-${processId}`);
|
||||
logsContainer.classList.toggle('hidden');
|
||||
|
||||
if (!logsContainer.classList.contains('hidden')) {
|
||||
fetch(`/logs/${processId}`)
|
||||
.then(response => response.text())
|
||||
.then(logs => {
|
||||
logsContainer.textContent = logs;
|
||||
logsContainer.scrollTop = logsContainer.scrollHeight;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// 停止进程
|
||||
function stopProcess(processId) {
|
||||
if (confirm('确定要停止这个训练进程吗?')) {
|
||||
fetch(`/stop/${processId}`, {
|
||||
method: 'POST'
|
||||
})
|
||||
.then(() => {
|
||||
loadProcesses();
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// 表单提交处理
|
||||
document.getElementById('train-form').addEventListener('submit', function(e) {
|
||||
e.preventDefault();
|
||||
const formData = new FormData(this);
|
||||
const data = Object.fromEntries(formData.entries());
|
||||
|
||||
fetch('/train', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json'
|
||||
},
|
||||
body: JSON.stringify(data)
|
||||
})
|
||||
.then(response => response.json())
|
||||
.then(result => {
|
||||
if (result.success) {
|
||||
alert('训练已开始!');
|
||||
openTab(event, 'processes');
|
||||
} else {
|
||||
alert('训练启动失败:' + result.error);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
// 加载日志文件列表
|
||||
function loadLogFiles() {
|
||||
fetch('/logfiles')
|
||||
.then(response => response.json())
|
||||
.then(data => {
|
||||
const logfilesList = document.getElementById('logfiles-list');
|
||||
logfilesList.innerHTML = '';
|
||||
|
||||
if (data.length === 0) {
|
||||
logfilesList.innerHTML = '<p>暂无日志文件</p>';
|
||||
return;
|
||||
}
|
||||
|
||||
// 按日期倒序排序
|
||||
data.sort((a, b) => new Date(b.modified_time) - new Date(a.modified_time));
|
||||
|
||||
data.forEach(logfile => {
|
||||
const fileItem = document.createElement('div');
|
||||
fileItem.className = 'process-item';
|
||||
|
||||
// 从文件名提取训练类型
|
||||
let trainType = '未知';
|
||||
if (logfile.filename.includes('train_pretrain_')) {
|
||||
trainType = 'pretrain';
|
||||
} else if (logfile.filename.includes('train_sft_')) {
|
||||
trainType = 'sft';
|
||||
} else if (logfile.filename.includes('train_lora_')) {
|
||||
trainType = 'lora';
|
||||
}
|
||||
|
||||
fileItem.innerHTML = `
|
||||
<div class="process-info">
|
||||
<div>
|
||||
<strong>${trainType}</strong> - ${logfile.modified_time}
|
||||
</div>
|
||||
<div>
|
||||
<span class="process-status status-completed">已保存</span>
|
||||
</div>
|
||||
</div>
|
||||
<div>
|
||||
<button class="btn-logs" onclick="viewLogFile('${logfile.filename}', this)">查看日志</button>
|
||||
</div>
|
||||
<div id="log-content-${logfile.filename.replace(/\./g, '-')}" class="logs-container hidden"></div>
|
||||
`;
|
||||
|
||||
logfilesList.appendChild(fileItem);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// 查看日志文件内容
|
||||
function viewLogFile(filename, button) {
|
||||
const safeFilename = filename.replace(/\./g, '-');
|
||||
const logContainer = button.closest('.process-item').querySelector(`#log-content-${safeFilename}`);
|
||||
logContainer.classList.toggle('hidden');
|
||||
|
||||
if (!logContainer.classList.contains('hidden') && logContainer.textContent === '') {
|
||||
logContainer.textContent = '加载日志中...';
|
||||
|
||||
fetch(`/logfile-content/${encodeURIComponent(filename)}`)
|
||||
.then(response => response.text())
|
||||
.then(logs => {
|
||||
logContainer.textContent = logs;
|
||||
logContainer.scrollTop = 0;
|
||||
});
|
||||
}
|
||||
}
|
||||
163
scripts/templates/index.html
Normal file
163
scripts/templates/index.html
Normal file
@ -0,0 +1,163 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="zh-CN">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>MiniMind 训练 Web UI</title>
|
||||
<link rel="stylesheet" href="/static/css/style.css">
|
||||
</head>
|
||||
<body>
|
||||
<h1>MiniMind 训练 Web UI</h1>
|
||||
|
||||
<div class="tabs">
|
||||
<button class="tab active" onclick="openTab(event, 'train')">开始训练</button>
|
||||
<button class="tab" onclick="openTab(event, 'processes')">训练进程</button>
|
||||
<button class="tab" onclick="openTab(event, 'logfiles')">日志文件</button>
|
||||
</div>
|
||||
|
||||
<div id="train" class="tab-content">
|
||||
<div class="form-container">
|
||||
<h2 class="section-title">选择训练类型并配置参数</h2>
|
||||
<form id="train-form" method="post" action="/train">
|
||||
<!-- 基础训练参数 -->
|
||||
<div class="parameter-card">
|
||||
<h3 class="card-title">基础训练参数</h3>
|
||||
<div class="parameter-content">
|
||||
<div class="form-group">
|
||||
<label for="train_type">训练类型:</label>
|
||||
<select id="train_type" name="train_type" required>
|
||||
<option value="pretrain">预训练 (Pretrain)</option>
|
||||
<option value="sft">监督微调 (Full SFT)</option>
|
||||
<option value="lora">LoRA 微调</option>
|
||||
</select>
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<label for="epochs">训练轮数:</label>
|
||||
<input type="number" id="epochs" name="epochs" min="1" required>
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<label for="batch_size">Batch Size:</label>
|
||||
<input type="number" id="batch_size" name="batch_size" min="1" required>
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<label for="learning_rate">学习率:</label>
|
||||
<input type="text" id="learning_rate" name="learning_rate" required>
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<label for="log_interval">日志打印间隔:</label>
|
||||
<input type="number" id="log_interval" name="log_interval" min="1" required>
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<label for="data_path">数据路径:</label>
|
||||
<input type="text" id="data_path" name="data_path" required>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 模型结构参数 -->
|
||||
<div class="parameter-card">
|
||||
<h3 class="card-title">模型结构参数</h3>
|
||||
<div class="parameter-content">
|
||||
<div class="form-group">
|
||||
<label for="hidden_size">隐藏层维度:</label>
|
||||
<input type="number" id="hidden_size" name="hidden_size" min="128" required>
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<label for="num_hidden_layers">隐藏层数量:</label>
|
||||
<input type="number" id="num_hidden_layers" name="num_hidden_layers" min="1" required>
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<label for="max_seq_len">最大序列长度:</label>
|
||||
<input type="number" id="max_seq_len" name="max_seq_len" min="64" required>
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<label for="use_moe">是否使用MoE架构:</label>
|
||||
<select id="use_moe" name="use_moe">
|
||||
<option value="0">否</option>
|
||||
<option value="1">是</option>
|
||||
</select>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 硬件配置 -->
|
||||
<div class="parameter-card">
|
||||
<h3 class="card-title">硬件配置</h3>
|
||||
<div class="parameter-content">
|
||||
<div class="form-group">
|
||||
<label for="device">训练设备:</label>
|
||||
<input type="text" id="device" name="device" value="cuda:0" required>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 模型保存与恢复 -->
|
||||
<div class="parameter-card">
|
||||
<h3 class="card-title">模型保存与恢复</h3>
|
||||
<div class="parameter-content">
|
||||
<div class="form-group">
|
||||
<label for="save_dir">模型保存目录:</label>
|
||||
<input type="text" id="save_dir" name="save_dir" required>
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<label for="save_interval">模型保存间隔:</label>
|
||||
<input type="number" id="save_interval" name="save_interval" min="1" required>
|
||||
</div>
|
||||
<div class="form-group pretrain-sft">
|
||||
<label for="save_weight">保存权重前缀名:</label>
|
||||
<input type="text" id="save_weight" name="save_weight">
|
||||
</div>
|
||||
<div class="form-group lora">
|
||||
<label for="lora_name">LoRA权重名称:</label>
|
||||
<input type="text" id="lora_name" name="lora_name">
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<label for="from_weight">基于哪个权重训练:</label>
|
||||
<input type="text" id="from_weight" name="from_weight">
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<div class="checkbox-group">
|
||||
<input type="checkbox" id="from_resume" name="from_resume" value="1">
|
||||
<label for="from_resume">是否自动检测&续训</label>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 其他设置 -->
|
||||
<div class="parameter-card">
|
||||
<h3 class="card-title">其他设置</h3>
|
||||
<div class="parameter-content">
|
||||
<div class="form-group">
|
||||
<div class="checkbox-group">
|
||||
<input type="checkbox" id="use_wandb" name="use_wandb" value="1">
|
||||
<label for="use_wandb">是否使用wandb</label>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="submit-container">
|
||||
<button type="submit">开始训练</button>
|
||||
</div>
|
||||
</form>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div id="processes" class="tab-content hidden">
|
||||
<h2 class="section-title">训练进程列表</h2>
|
||||
<div id="process-list">
|
||||
<!-- 进程列表将通过JavaScript动态加载 -->
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div id="logfiles" class="tab-content hidden">
|
||||
<h2 class="section-title">日志文件列表</h2>
|
||||
<div id="logfiles-list">
|
||||
<!-- 日志文件列表将通过JavaScript动态加载 -->
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script src="/static/js/script.js"></script>
|
||||
</body>
|
||||
</html>
|
||||
231
scripts/train_web_ui.py
Normal file
231
scripts/train_web_ui.py
Normal file
@ -0,0 +1,231 @@
|
||||
import os
|
||||
import sys
|
||||
import subprocess
|
||||
import threading
|
||||
import json
|
||||
from flask import Flask, render_template, request, jsonify, redirect, url_for
|
||||
import time
|
||||
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||
|
||||
app = Flask(__name__, template_folder='templates', static_folder='static')
|
||||
|
||||
# 存储训练进程的信息
|
||||
training_processes = {}
|
||||
|
||||
# 启动训练进程
|
||||
def start_training_process(train_type, params):
|
||||
# 获取脚本所在目录的绝对路径
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
# 使用详细的时间戳作为进程ID和日志文件名
|
||||
process_id = time.strftime('%Y%m%d_%H%M%S')
|
||||
# 构建logfile目录的绝对路径
|
||||
log_dir = os.path.join(script_dir, '../logfile')
|
||||
log_dir = os.path.abspath(log_dir)
|
||||
log_file = os.path.join(log_dir, f"train_{train_type}_{process_id}.log")
|
||||
|
||||
# 确保日志目录存在
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
|
||||
# 构建命令
|
||||
if train_type == 'pretrain':
|
||||
script_path = '../trainer/train_pretrain.py'
|
||||
cmd = [sys.executable, script_path]
|
||||
if 'save_weight' in params:
|
||||
cmd.extend(['--save_weight', params['save_weight']])
|
||||
elif train_type == 'sft':
|
||||
script_path = '../trainer/train_full_sft.py'
|
||||
cmd = [sys.executable, script_path]
|
||||
if 'save_weight' in params:
|
||||
cmd.extend(['--save_weight', params['save_weight']])
|
||||
elif train_type == 'lora':
|
||||
script_path = '../trainer/train_lora.py'
|
||||
cmd = [sys.executable, script_path]
|
||||
if 'lora_name' in params:
|
||||
cmd.extend(['--lora_name', params['lora_name']])
|
||||
else:
|
||||
return None
|
||||
|
||||
# 添加通用参数
|
||||
for key, value in params.items():
|
||||
if key not in ['train_type', 'save_weight', 'lora_name']:
|
||||
# 特殊处理布尔标志参数
|
||||
if key in ['use_wandb', 'from_resume']:
|
||||
if value == '1': # 只有当值为1时才添加这个标志
|
||||
cmd.append(f'--{key}')
|
||||
else:
|
||||
# 确保log_interval和save_interval参数正确传递
|
||||
cmd.extend([f'--{key}', str(value)])
|
||||
|
||||
# 创建日志文件
|
||||
with open(log_file, 'w') as f:
|
||||
f.write(f"开始训练 {train_type} 进程\n")
|
||||
f.write(f"命令: {' '.join(cmd)}\n\n")
|
||||
|
||||
# 启动进程
|
||||
process = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True,
|
||||
cwd=os.path.dirname(os.path.abspath(__file__))
|
||||
)
|
||||
|
||||
# 存储进程信息
|
||||
training_processes[process_id] = {
|
||||
'process': process,
|
||||
'train_type': train_type,
|
||||
'log_file': log_file,
|
||||
'start_time': time.strftime('%Y-%m-%d %H:%M:%S'),
|
||||
'running': True,
|
||||
'error': False
|
||||
}
|
||||
|
||||
# 开始读取输出
|
||||
def read_output():
|
||||
try:
|
||||
while True:
|
||||
output = process.stdout.readline()
|
||||
if output == '' and process.poll() is not None:
|
||||
break
|
||||
if output:
|
||||
with open(log_file, 'a') as f:
|
||||
f.write(output)
|
||||
# 检查进程是否成功结束
|
||||
if process.returncode != 0:
|
||||
training_processes[process_id]['error'] = True
|
||||
finally:
|
||||
training_processes[process_id]['running'] = False
|
||||
|
||||
# 启动线程读取输出
|
||||
threading.Thread(target=read_output, daemon=True).start()
|
||||
|
||||
return process_id
|
||||
|
||||
# Flask路由
|
||||
@app.route('/')
|
||||
def index():
|
||||
return render_template('index.html')
|
||||
|
||||
@app.route('/train', methods=['POST'])
|
||||
def train():
|
||||
data = request.json
|
||||
train_type = data.get('train_type')
|
||||
|
||||
# 移除不相关的参数
|
||||
params = data.copy()
|
||||
|
||||
# 处理复选框参数
|
||||
if 'from_resume' not in params:
|
||||
params['from_resume'] = '0'
|
||||
if 'use_wandb' not in params:
|
||||
params['use_wandb'] = '0'
|
||||
|
||||
# 启动训练进程
|
||||
process_id = start_training_process(train_type, params)
|
||||
|
||||
if process_id:
|
||||
return jsonify({'success': True, 'process_id': process_id})
|
||||
else:
|
||||
return jsonify({'success': False, 'error': '无效的训练类型'})
|
||||
|
||||
@app.route('/processes')
|
||||
def processes():
|
||||
result = []
|
||||
for process_id, info in training_processes.items():
|
||||
result.append({
|
||||
'id': process_id,
|
||||
'train_type': info['train_type'],
|
||||
'start_time': info['start_time'],
|
||||
'running': info['running'],
|
||||
'error': info['error']
|
||||
})
|
||||
return jsonify(result)
|
||||
|
||||
@app.route('/logs/<process_id>')
|
||||
def logs(process_id):
|
||||
# 直接从本地logfile目录读取日志文件
|
||||
# 获取脚本所在目录的绝对路径
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
# 构建logfile目录的绝对路径
|
||||
log_dir = os.path.join(script_dir, '../logfile')
|
||||
log_dir = os.path.abspath(log_dir)
|
||||
|
||||
if os.path.exists(log_dir):
|
||||
for filename in os.listdir(log_dir):
|
||||
if filename.endswith(f'{process_id}.log'):
|
||||
log_file = os.path.join(log_dir, filename)
|
||||
if os.path.exists(log_file):
|
||||
try:
|
||||
with open(log_file, 'r', encoding='utf-8') as f:
|
||||
return f.read()
|
||||
except Exception as e:
|
||||
return f'读取日志失败: {str(e)}'
|
||||
return '日志文件不存在或已被删除'
|
||||
|
||||
@app.route('/logfiles')
|
||||
def get_logfiles():
|
||||
# 获取脚本所在目录的绝对路径
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
# 构建logfile目录的绝对路径
|
||||
log_dir = os.path.join(script_dir, '../logfile')
|
||||
log_dir = os.path.abspath(log_dir)
|
||||
|
||||
logfiles = []
|
||||
if os.path.exists(log_dir):
|
||||
for filename in os.listdir(log_dir):
|
||||
if filename.endswith('.log') and filename.startswith('train_'):
|
||||
file_path = os.path.join(log_dir, filename)
|
||||
try:
|
||||
modified_time = os.path.getmtime(file_path)
|
||||
formatted_time = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(modified_time))
|
||||
logfiles.append({
|
||||
'filename': filename,
|
||||
'modified_time': formatted_time,
|
||||
'size': os.path.getsize(file_path)
|
||||
})
|
||||
except Exception as e:
|
||||
continue
|
||||
return jsonify(logfiles)
|
||||
|
||||
@app.route('/logfile-content/<filename>')
|
||||
def get_logfile_content(filename):
|
||||
# 获取脚本所在目录的绝对路径
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
# 构建logfile目录的绝对路径
|
||||
log_dir = os.path.join(script_dir, '../logfile')
|
||||
log_dir = os.path.abspath(log_dir)
|
||||
|
||||
# 安全检查:防止路径遍历攻击
|
||||
if '..' in filename or '/' in filename or '\\' in filename:
|
||||
return '非法的文件名'
|
||||
|
||||
log_file = os.path.join(log_dir, filename)
|
||||
if os.path.exists(log_file) and os.path.isfile(log_file):
|
||||
try:
|
||||
with open(log_file, 'r', encoding='utf-8') as f:
|
||||
return f.read()
|
||||
except Exception as e:
|
||||
return f'读取日志失败: {str(e)}'
|
||||
return '日志文件不存在'
|
||||
|
||||
|
||||
@app.route('/stop/<process_id>', methods=['POST'])
|
||||
def stop(process_id):
|
||||
if process_id in training_processes and training_processes[process_id]['running']:
|
||||
process = training_processes[process_id]['process']
|
||||
# 在Windows上使用terminate,在Unix上尝试优雅终止
|
||||
try:
|
||||
process.terminate()
|
||||
# 等待进程结束
|
||||
process.wait(timeout=5)
|
||||
except subprocess.TimeoutExpired:
|
||||
# 如果超时,强制杀死
|
||||
process.kill()
|
||||
|
||||
training_processes[process_id]['running'] = False
|
||||
return jsonify({'success': True})
|
||||
return jsonify({'success': False})
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(host='0.0.0.0', port=5000, debug=True)
|
||||
Loading…
Reference in New Issue
Block a user