diff --git a/trainer/train_distill_reason.py b/trainer/train_distill_reason.py
index 3bb15a8..f1e3526 100644
--- a/trainer/train_distill_reason.py
+++ b/trainer/train_distill_reason.py
@@ -6,7 +6,6 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import argparse
import time
-import math
import warnings
import torch
import torch.distributed as dist
@@ -14,23 +13,14 @@ from contextlib import nullcontext
from torch import optim, nn
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader, DistributedSampler
-from transformers import AutoTokenizer, AutoModelForCausalLM
-from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
+from model.model_minimind import MiniMindConfig
from dataset.lm_dataset import SFTDataset
+from trainer.trainer_utils import *
warnings.filterwarnings('ignore')
-def Logger(content):
- if not ddp or dist.get_rank() == 0:
- print(content)
-
-
-def get_lr(current_step, total_steps, lr):
- return lr / 10 + 0.5 * lr * (1 + math.cos(math.pi * current_step / total_steps))
-
-
-def train_epoch(epoch, wandb):
+def train_epoch(epoch, loader, iters, tokenizer, lm_config, start_step=0, wandb=None):
# 思考标签占位符
start_of_think_ids = tokenizer('').input_ids
end_of_think_ids = tokenizer('').input_ids
@@ -38,28 +28,30 @@ def train_epoch(epoch, wandb):
end_of_answer_ids = tokenizer('').input_ids
loss_fct = nn.CrossEntropyLoss(reduction='none')
start_time = time.time()
- for step, (X, Y, loss_mask) in enumerate(train_loader):
+
+ for step, (X, Y, loss_mask) in enumerate(loader, start=start_step + 1):
X = X.to(args.device)
Y = Y.to(args.device)
loss_mask = loss_mask.to(args.device)
- lr = get_lr(epoch * iter_per_epoch + step, args.epochs * iter_per_epoch, args.learning_rate)
+ lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
- with ctx:
+ with autocast_ctx:
res = model(X)
loss = loss_fct(
res.logits.view(-1, res.logits.size(-1)),
Y.view(-1)
).view(Y.size())
+
+ # 特殊标签位置增加权重(推理蒸馏特有)
sp_ids = torch.isin(Y.view(-1),
torch.tensor(start_of_think_ids + end_of_think_ids
+ start_of_answer_ids + end_of_answer_ids
).to(args.device))
- # 在 sp_ids 对应的位置增加额外的惩罚
loss_mask = loss_mask.view(-1)
loss_mask_sum = loss_mask.sum()
- loss_mask[sp_ids] = 10
+ loss_mask[sp_ids] = 10 # 对思考标签增加10倍权重
loss_mask = loss_mask.view(Y.size())
loss = (loss * loss_mask).sum() / loss_mask_sum
loss += res.aux_loss
@@ -70,148 +62,112 @@ def train_epoch(epoch, wandb):
if (step + 1) % args.accumulation_steps == 0:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
-
scaler.step(optimizer)
scaler.update()
-
optimizer.zero_grad(set_to_none=True)
- if step % args.log_interval == 0 or step == iter_per_epoch - 1:
+ if step % args.log_interval == 0 or step == iters - 1:
spend_time = time.time() - start_time
- Logger(
- 'Epoch:[{}/{}]({}/{}) loss:{:.6f} lr:{:.12f} epoch_Time:{}min:'.format(
- epoch + 1,
- args.epochs,
- step,
- iter_per_epoch,
- loss.item() * args.accumulation_steps,
- optimizer.param_groups[-1]['lr'],
- spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60))
+ current_loss = loss.item() * args.accumulation_steps
+ current_lr = optimizer.param_groups[-1]['lr']
+ eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60
+
+ Logger(f'Epoch:[{epoch+1}/{args.epochs}]({step}/{iters}) loss:{current_loss:.6f} lr:{current_lr:.12f} epoch_Time:{eta_min}min:')
+
+ if wandb: wandb.log({"loss": current_loss, "lr": current_lr, "epoch_Time": eta_min})
- if (wandb is not None) and (not ddp or dist.get_rank() == 0):
- wandb.log({"loss": loss.item() * args.accumulation_steps,
- "lr": optimizer.param_groups[-1]['lr'],
- "epoch_Time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60})
-
- if ((step + 1) % args.save_interval == 0 or step == iter_per_epoch - 1) and (not ddp or dist.get_rank() == 0):
+ if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
model.eval()
- moe_path = '_moe' if lm_config.use_moe else ''
- ckp = f'{args.save_dir}/reason_{lm_config.hidden_size}{moe_path}.pth'
-
+ moe_suffix = '_moe' if lm_config.use_moe else ''
+ ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
-
state_dict = {k: v.half() for k, v in state_dict.items()} # 半精度保存
torch.save(state_dict, ckp)
+ lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints')
model.train()
-def init_model(lm_config):
- tokenizer = AutoTokenizer.from_pretrained('../model')
- model = MiniMindForCausalLM(lm_config)
- moe_path = '_moe' if lm_config.use_moe else ''
- ckp = f'{args.save_dir}/rlhf_{lm_config.hidden_size}{moe_path}.pth'
- state_dict = torch.load(ckp, map_location=args.device)
- model.load_state_dict(state_dict, strict=False)
- Logger(f'LLM总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
- model = model.to(args.device)
- return model, tokenizer
-
-
-def init_distributed_mode():
- if not ddp: return
- global ddp_local_rank, DEVICE
-
- dist.init_process_group(backend="nccl")
- ddp_rank = int(os.environ["RANK"])
- ddp_local_rank = int(os.environ["LOCAL_RANK"])
- ddp_world_size = int(os.environ["WORLD_SIZE"])
- DEVICE = f"cuda:{ddp_local_rank}"
- torch.cuda.set_device(DEVICE)
-
-
if __name__ == "__main__":
- parser = argparse.ArgumentParser(description="MiniMind Distill Reasoning")
- parser.add_argument("--out_dir", type=str, default="../out")
- parser.add_argument("--epochs", type=int, default=1)
- parser.add_argument("--batch_size", type=int, default=8)
- parser.add_argument("--learning_rate", type=float, default=1e-6)
- parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu")
- parser.add_argument("--dtype", type=str, default="bfloat16")
- parser.add_argument("--use_wandb", action="store_true")
- parser.add_argument("--wandb_project", type=str, default="MiniMind-Full-SFT")
- parser.add_argument("--num_workers", type=int, default=1)
- parser.add_argument("--ddp", action="store_true")
- parser.add_argument("--accumulation_steps", type=int, default=1)
- parser.add_argument("--grad_clip", type=float, default=1.0)
- parser.add_argument("--warmup_iters", type=int, default=0)
- parser.add_argument("--log_interval", type=int, default=1)
- parser.add_argument("--save_interval", type=int, default=50)
- parser.add_argument('--local_rank', type=int, default=-1)
- parser.add_argument('--hidden_size', default=512, type=int)
- parser.add_argument('--num_hidden_layers', default=8, type=int)
- parser.add_argument('--max_seq_len', default=1024, type=int)
- parser.add_argument('--use_moe', default=False, type=bool)
- parser.add_argument("--data_path", type=str, default="../dataset/r1_mix_1024.jsonl")
-
+ parser = argparse.ArgumentParser(description="MiniMind Reasoning Distillation")
+ parser.add_argument("--save_dir", type=str, default="../out", help="模型保存目录")
+ parser.add_argument('--save_weight', default='reason', type=str, help="保存权重的前缀名")
+ parser.add_argument("--epochs", type=int, default=1, help="训练轮数")
+ parser.add_argument("--batch_size", type=int, default=8, help="batch size")
+ parser.add_argument("--learning_rate", type=float, default=1e-6, help="初始学习率")
+ parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备")
+ parser.add_argument("--dtype", type=str, default="bfloat16", help="混合精度类型")
+ parser.add_argument("--num_workers", type=int, default=1, help="数据加载线程数")
+ parser.add_argument("--accumulation_steps", type=int, default=1, help="梯度累积步数")
+ parser.add_argument("--grad_clip", type=float, default=1.0, help="梯度裁剪阈值")
+ parser.add_argument("--log_interval", type=int, default=100, help="日志打印间隔")
+ parser.add_argument("--save_interval", type=int, default=100, help="模型保存间隔")
+ parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度")
+ parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量")
+ parser.add_argument('--max_seq_len', default=1024, type=int, help="训练的最大截断长度")
+ parser.add_argument('--use_moe', default=False, type=bool, help="是否使用MoE")
+ parser.add_argument("--data_path", type=str, default="../dataset/r1_mix_1024.jsonl", help="推理蒸馏数据路径")
+ parser.add_argument('--from_weight', default='dpo', type=str, help="基于哪个权重训练,默认dpo")
+ parser.add_argument('--from_resume', default=0, type=int, help="是否自动检测&续训,0否1是")
+ parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
+ parser.add_argument("--wandb_project", type=str, default="MiniMind-Reasoning", help="wandb项目名")
args = parser.parse_args()
- lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers,
- use_moe=args.use_moe)
- args.save_dir = os.path.join(args.out_dir)
+ # ========== 1. 初始化环境和随机种子 ==========
+ local_rank = init_distributed_mode()
+ if dist.is_initialized(): args.device = f"cuda:{local_rank}"
+ setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))
+
+ # ========== 2. 配置目录、模型参数、检查ckp ==========
os.makedirs(args.save_dir, exist_ok=True)
- os.makedirs(args.out_dir, exist_ok=True)
- tokens_per_iter = args.batch_size * args.max_seq_len
+ lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=args.use_moe)
+ ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None
+
+ # ========== 3. 设置混合精度 ==========
device_type = "cuda" if "cuda" in args.device else "cpu"
-
- args.wandb_run_name = f"MiniMind-Distill-Reasoning-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"
-
- ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast()
- ddp = int(os.environ.get("RANK", -1)) != -1 # is this a ddp run?
- ddp_local_rank, DEVICE = 0, "cuda:0"
- base_seed = 1337
- torch.manual_seed(base_seed)
- torch.cuda.manual_seed(base_seed)
-
- if ddp:
- init_distributed_mode()
- args.device = torch.device(DEVICE)
- rank = dist.get_rank()
- torch.manual_seed(base_seed + rank)
- # 同时设置 CUDA 的随机种子
- torch.cuda.manual_seed(base_seed + rank)
-
- if args.use_wandb and (not ddp or ddp_local_rank == 0):
+ dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
+ autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
+
+ # ========== 4. 配wandb ==========
+ wandb = None
+ if args.use_wandb and is_main_process():
import swanlab as wandb
-
- wandb.init(project=args.wandb_project, name=args.wandb_run_name)
- else:
- wandb = None
-
- model, tokenizer = init_model(lm_config)
-
+ wandb_id = ckp_data.get('wandb_id') if ckp_data else None
+ resume = 'must' if wandb_id else None
+ wandb_run_name = f"MiniMind-Reasoning-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LR-{args.learning_rate}"
+ wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
+
+ # ========== 5. 定义模型、数据、优化器 ==========
+ model, tokenizer = init_model(lm_config, args.from_weight)
train_ds = SFTDataset(args.data_path, tokenizer, max_length=args.max_seq_len)
- train_sampler = DistributedSampler(train_ds) if ddp else None
- train_loader = DataLoader(
- train_ds,
- batch_size=args.batch_size,
- pin_memory=True,
- drop_last=False,
- shuffle=(train_sampler is None),
- num_workers=args.num_workers,
- sampler=train_sampler
- )
-
- scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype in ['float16', 'bfloat16']))
+ train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
+ scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
-
- if ddp:
+
+ # ========== 6. 从ckp恢复状态 ==========
+ start_epoch, start_step = 0, 0
+ if ckp_data:
+ model.load_state_dict(ckp_data['model'])
+ optimizer.load_state_dict(ckp_data['optimizer'])
+ scaler.load_state_dict(ckp_data['scaler'])
+ start_epoch = ckp_data['epoch']
+ start_step = ckp_data.get('step', 0)
+
+ # ========== 7. DDP包模型 ==========
+ if dist.is_initialized():
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
- model = DistributedDataParallel(model, device_ids=[ddp_local_rank])
-
- iter_per_epoch = len(train_loader)
- for epoch in range(args.epochs):
+ model = DistributedDataParallel(model, device_ids=[local_rank])
+
+ # ========== 8. 开始训练 ==========
+ for epoch in range(start_epoch, args.epochs):
train_sampler and train_sampler.set_epoch(epoch)
- train_epoch(epoch, wandb)
+ if epoch == start_epoch and start_step > 0: # 第一个epoch且存在检查点
+ batch_sampler = SkipBatchSampler(train_sampler or range(len(train_ds)), args.batch_size, start_step + 1)
+ loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
+ Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step,从step {start_step + 1}开始')
+ train_epoch(epoch, loader, len(loader) + start_step + 1, tokenizer, lm_config, start_step, wandb)
+ else: # 默认从头开始
+ loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=(train_sampler is None), sampler=train_sampler, num_workers=args.num_workers, pin_memory=True)
+ train_epoch(epoch, loader, len(loader), tokenizer, lm_config, 0, wandb)
diff --git a/trainer/train_distillation.py b/trainer/train_distillation.py
index c97860c..f1de5d2 100644
--- a/trainer/train_distillation.py
+++ b/trainer/train_distillation.py
@@ -3,11 +3,10 @@ import sys
__package__ = "trainer"
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
+
import argparse
import time
-import math
import warnings
-
import torch
import torch.nn.functional as F
import torch.distributed as dist
@@ -15,23 +14,14 @@ from contextlib import nullcontext
from torch import optim
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader, DistributedSampler
-from transformers import AutoTokenizer, AutoModelForCausalLM
-from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
+from model.model_minimind import MiniMindConfig
from dataset.lm_dataset import SFTDataset
+from trainer.trainer_utils import *
warnings.filterwarnings('ignore')
-def Logger(content):
- if not ddp or dist.get_rank() == 0:
- print(content)
-
-
-def get_lr(current_step, total_steps, lr):
- return lr / 10 + 0.5 * lr * (1 + math.cos(math.pi * current_step / total_steps))
-
-
-def distillation_loss_fn(student_logits, teacher_logits, temperature=1.0, reduction='batchmean'):
+def distillation_loss(student_logits, teacher_logits, temperature=1.0, reduction='batchmean'):
with torch.no_grad():
teacher_probs = F.softmax(teacher_logits / temperature, dim=-1).detach()
@@ -45,25 +35,23 @@ def distillation_loss_fn(student_logits, teacher_logits, temperature=1.0, reduct
return (temperature ** 2) * kl
-def train_epoch(epoch, wandb, alpha=0.0, temperature=1.0):
+def train_epoch(epoch, loader, iters, teacher_model, lm_config_student, start_step=0, wandb=None, alpha=0.0, temperature=1.0):
start_time = time.time()
-
+
if teacher_model is not None:
teacher_model.eval()
teacher_model.requires_grad_(False)
- for step, (X, Y, loss_mask) in enumerate(train_loader):
+ for step, (X, Y, loss_mask) in enumerate(loader, start=start_step + 1):
X = X.to(args.device)
Y = Y.to(args.device)
loss_mask = loss_mask.to(args.device)
- lr = get_lr(epoch * iter_per_epoch + step,
- args.epochs * iter_per_epoch,
- args.learning_rate)
+ lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
# 前向传播(学生模型)
- with ctx:
+ with autocast_ctx:
res = model(X)
student_logits = res.logits
@@ -71,11 +59,11 @@ def train_epoch(epoch, wandb, alpha=0.0, temperature=1.0):
if teacher_model is not None:
with torch.no_grad():
teacher_logits = teacher_model(X).logits
- vocab_size_student = student_logits.size(-1) # N
+ vocab_size_student = student_logits.size(-1)
teacher_logits = teacher_logits[..., :vocab_size_student]
# ========== 计算损失 ==========
- # 1) Ground-Truth CE Loss(可选)
+ # 1) Ground-Truth CE Loss
loss_mask_flat = loss_mask.view(-1)
ce_loss = F.cross_entropy(
student_logits.view(-1, student_logits.size(-1)),
@@ -87,10 +75,9 @@ def train_epoch(epoch, wandb, alpha=0.0, temperature=1.0):
if lm_config_student.use_moe:
ce_loss += res.aux_loss
- # 2) Distillation Loss(可选)
+ # 2) Distillation Loss
if teacher_model is not None:
- # 只在有效token位置做蒸馏
- distill_loss = distillation_loss_fn(
+ distill_loss = distillation_loss(
student_logits.view(-1, student_logits.size(-1))[loss_mask_flat == 1],
teacher_logits.view(-1, teacher_logits.size(-1))[loss_mask_flat == 1],
temperature=temperature
@@ -110,157 +97,126 @@ def train_epoch(epoch, wandb, alpha=0.0, temperature=1.0):
scaler.update()
optimizer.zero_grad(set_to_none=True)
- if step % args.log_interval == 0 or step == iter_per_epoch - 1:
+ if step % args.log_interval == 0 or step == iters - 1:
spend_time = time.time() - start_time
- Logger(
- 'Epoch:[{}/{}]({}/{}) loss:{:.6f} lr:{:.12f} epoch_Time:{}min:'.format(
- epoch,
- args.epochs - 1,
- step,
- iter_per_epoch,
- loss.item(),
- optimizer.param_groups[-1]['lr'],
- spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60
- )
- )
-
- if (wandb is not None) and (not ddp or dist.get_rank() == 0):
+ current_loss = loss.item() * args.accumulation_steps
+ current_lr = optimizer.param_groups[-1]['lr']
+ eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60
+
+ Logger(f'Epoch:[{epoch+1}/{args.epochs}]({step}/{iters}) loss:{current_loss:.6f} ce:{ce_loss.item():.4f} distill:{distill_loss.item():.4f} lr:{current_lr:.12f} epoch_Time:{eta_min}min:')
+
+ if wandb:
wandb.log({
- "loss": loss.item(),
+ "loss": current_loss,
"ce_loss": ce_loss.item(),
"distill_loss": distill_loss.item() if teacher_model is not None else 0.0,
- "lr": optimizer.param_groups[-1]['lr'],
- "last-time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60
+ "lr": current_lr,
+ "epoch_Time": eta_min
})
- if ((step + 1) % args.save_interval == 0 or step == iter_per_epoch - 1) and (not ddp or dist.get_rank() == 0):
+ if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
model.eval()
- moe_path = '_moe' if lm_config_student.use_moe else ''
- ckp = f'{args.save_dir}/full_dist_{lm_config_student.hidden_size}{moe_path}.pth'
+ moe_suffix = '_moe' if lm_config_student.use_moe else ''
+ ckp = f'{args.save_dir}/{args.save_weight}_{lm_config_student.hidden_size}{moe_suffix}.pth'
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
state_dict = {k: v.half() for k, v in state_dict.items()} # 半精度保存
torch.save(state_dict, ckp)
+ lm_checkpoint(lm_config_student, weight=args.save_weight, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints')
model.train()
-def init_student_model(lm_config):
- tokenizer = AutoTokenizer.from_pretrained('../model/')
- model = MiniMindForCausalLM(lm_config)
- moe_path = '_moe' if lm_config.use_moe else ''
- ckp = f'{args.save_dir}/full_sft_{lm_config.hidden_size}{moe_path}.pth'
- state_dict = torch.load(ckp, map_location=args.device)
- model.load_state_dict(state_dict, strict=False)
- Logger(f'学生模型(LLM)总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
- model = model.to(args.device)
-
- return model, tokenizer
-
-
-def init_teacher_model(lm_config):
- model = MiniMindForCausalLM(lm_config)
- moe_path = '_moe' if lm_config.use_moe else ''
- ckp = f'{args.save_dir}/full_sft_{lm_config.hidden_size}{moe_path}.pth'
- state_dict = torch.load(ckp, map_location=args.device)
- model.load_state_dict(state_dict, strict=False)
- Logger(f'教师模型(LLM)总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
- model = model.to(args.device)
- return model
-
-
-def init_distributed_mode():
- if not ddp: return
- global ddp_local_rank, DEVICE
-
- dist.init_process_group(backend="nccl")
- ddp_rank = int(os.environ["RANK"])
- ddp_local_rank = int(os.environ["LOCAL_RANK"])
- ddp_world_size = int(os.environ["WORLD_SIZE"])
- DEVICE = f"cuda:{ddp_local_rank}"
- torch.cuda.set_device(DEVICE)
-
-
if __name__ == "__main__":
- parser = argparse.ArgumentParser(description="MiniMind Full SFT")
- parser.add_argument("--out_dir", type=str, default="../out")
- parser.add_argument("--epochs", type=int, default=6)
- parser.add_argument("--batch_size", type=int, default=32)
- parser.add_argument("--learning_rate", type=float, default=5e-6)
- parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu")
- parser.add_argument("--dtype", type=str, default="bfloat16")
- parser.add_argument("--use_wandb", action="store_true")
- parser.add_argument("--wandb_project", type=str, default="MiniMind-Full-SFT")
- parser.add_argument("--num_workers", type=int, default=1)
- parser.add_argument("--ddp", action="store_true")
- parser.add_argument("--accumulation_steps", type=int, default=1)
- parser.add_argument("--grad_clip", type=float, default=1.0)
- parser.add_argument("--warmup_iters", type=int, default=0)
- parser.add_argument("--log_interval", type=int, default=100)
- parser.add_argument("--save_interval", type=int, default=100)
- parser.add_argument("--max_seq_len", type=int, default=512)
- parser.add_argument('--local_rank', type=int, default=-1)
- parser.add_argument("--data_path", type=str, default="../dataset/sft_xxx.jsonl")
-
+ parser = argparse.ArgumentParser(description="MiniMind Knowledge Distillation")
+ parser.add_argument("--save_dir", type=str, default="../out", help="模型保存目录")
+ parser.add_argument('--save_weight', default='full_dist', type=str, help="保存权重的前缀名")
+ parser.add_argument("--epochs", type=int, default=6, help="训练轮数")
+ parser.add_argument("--batch_size", type=int, default=32, help="batch size")
+ parser.add_argument("--learning_rate", type=float, default=5e-6, help="初始学习率")
+ parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备")
+ parser.add_argument("--dtype", type=str, default="bfloat16", help="混合精度类型")
+ parser.add_argument("--num_workers", type=int, default=1, help="数据加载线程数")
+ parser.add_argument("--accumulation_steps", type=int, default=1, help="梯度累积步数")
+ parser.add_argument("--grad_clip", type=float, default=1.0, help="梯度裁剪阈值")
+ parser.add_argument("--log_interval", type=int, default=100, help="日志打印间隔")
+ parser.add_argument("--save_interval", type=int, default=100, help="模型保存间隔")
+ parser.add_argument("--max_seq_len", type=int, default=512, help="训练的最大截断长度")
+ parser.add_argument("--data_path", type=str, default="../dataset/sft_mini_512.jsonl", help="训练数据路径")
+ parser.add_argument('--student_hidden_size', default=512, type=int, help="学生模型隐藏层维度")
+ parser.add_argument('--student_num_layers', default=8, type=int, help="学生模型隐藏层数量")
+ parser.add_argument('--teacher_hidden_size', default=768, type=int, help="教师模型隐藏层维度")
+ parser.add_argument('--teacher_num_layers', default=16, type=int, help="教师模型隐藏层数量")
+ parser.add_argument('--use_moe', default=False, type=bool, help="是否使用MoE")
+ parser.add_argument('--from_student_weight', default='full_sft', type=str, help="学生模型基于哪个权重")
+ parser.add_argument('--from_teacher_weight', default='full_sft', type=str, help="教师模型基于哪个权重")
+ parser.add_argument('--from_resume', default=0, type=int, help="是否自动检测&续训,0否1是")
+ parser.add_argument('--alpha', default=0.5, type=float, help="CE损失权重,总损失=alpha*CE+(1-alpha)*KL")
+ parser.add_argument('--temperature', default=2.0, type=float, help="蒸馏温度")
+ parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
+ parser.add_argument("--wandb_project", type=str, default="MiniMind-Distillation", help="wandb项目名")
args = parser.parse_args()
- # 定义学生模型和教师模型
- lm_config_student = MiniMindConfig(hidden_size=512, num_hidden_layers=8)
- lm_config_teacher = MiniMindConfig(hidden_size=768, num_hidden_layers=16)
- args.save_dir = os.path.join(args.out_dir)
+
+ # ========== 1. 初始化环境和随机种子 ==========
+ local_rank = init_distributed_mode()
+ if dist.is_initialized(): args.device = f"cuda:{local_rank}"
+ setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))
+
+ # ========== 2. 配置目录、模型参数、检查ckp ==========
os.makedirs(args.save_dir, exist_ok=True)
- os.makedirs(args.out_dir, exist_ok=True)
- tokens_per_iter = args.batch_size * args.max_seq_len
+ lm_config_student = MiniMindConfig(hidden_size=args.student_hidden_size, num_hidden_layers=args.student_num_layers, use_moe=args.use_moe)
+ lm_config_teacher = MiniMindConfig(hidden_size=args.teacher_hidden_size, num_hidden_layers=args.teacher_num_layers, use_moe=args.use_moe)
+ ckp_data = lm_checkpoint(lm_config_student, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None
+
+ # ========== 3. 设置混合精度 ==========
device_type = "cuda" if "cuda" in args.device else "cpu"
-
- args.wandb_run_name = f"MiniMind-Dist-SFT-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"
-
- ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast()
- ddp = int(os.environ.get("RANK", -1)) != -1 # is this a ddp run?
- ddp_local_rank, DEVICE = 0, "cuda:0"
- base_seed = 1337
- torch.manual_seed(base_seed)
- torch.cuda.manual_seed(base_seed)
-
- if ddp:
- init_distributed_mode()
- args.device = torch.device(DEVICE)
- rank = dist.get_rank()
- torch.manual_seed(base_seed + rank)
- # 同时设置 CUDA 的随机种子
- torch.cuda.manual_seed(base_seed + rank)
-
- if args.use_wandb and (not ddp or ddp_local_rank == 0):
+ dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
+ autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
+
+ # ========== 4. 配wandb ==========
+ wandb = None
+ if args.use_wandb and is_main_process():
import swanlab as wandb
-
- wandb.init(project=args.wandb_project, name=args.wandb_run_name)
- else:
- wandb = None
-
- # 初始化学生模型和教师模型
- model, tokenizer = init_student_model(lm_config_student)
- teacher_model = init_teacher_model(lm_config_teacher)
-
+ wandb_id = ckp_data.get('wandb_id') if ckp_data else None
+ resume = 'must' if wandb_id else None
+ wandb_run_name = f"MiniMind-Distill-S{args.student_hidden_size}T{args.teacher_hidden_size}-Epoch-{args.epochs}-BS-{args.batch_size}-LR-{args.learning_rate}"
+ wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
+
+ # ========== 5. 定义学生和教师模型 ==========
+ model, tokenizer = init_model(lm_config_student, args.from_student_weight)
+ Logger(f'学生模型总参数量:{sum(p.numel() for p in model.parameters()) / 1e6:.3f} M')
+ teacher_model, _ = init_model(lm_config_teacher, args.from_teacher_weight)
+ teacher_model.eval()
+ teacher_model.requires_grad_(False)
+ Logger(f'教师模型总参数量:{sum(p.numel() for p in teacher_model.parameters()) / 1e6:.3f} M')
train_ds = SFTDataset(args.data_path, tokenizer, max_length=args.max_seq_len)
- train_sampler = DistributedSampler(train_ds) if ddp else None
- train_loader = DataLoader(
- train_ds,
- batch_size=args.batch_size,
- pin_memory=True,
- drop_last=False,
- shuffle=(train_sampler is None),
- num_workers=args.num_workers,
- sampler=train_sampler
- )
-
- scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype in ['float16', 'bfloat16']))
+ train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
+ scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
-
- if ddp:
+
+ # ========== 6. 从ckp恢复状态 ==========
+ start_epoch, start_step = 0, 0
+ if ckp_data:
+ model.load_state_dict(ckp_data['model'])
+ optimizer.load_state_dict(ckp_data['optimizer'])
+ scaler.load_state_dict(ckp_data['scaler'])
+ start_epoch = ckp_data['epoch']
+ start_step = ckp_data.get('step', 0)
+
+ # ========== 7. DDP包模型 ==========
+ if dist.is_initialized():
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
- model = DistributedDataParallel(model, device_ids=[ddp_local_rank])
-
- iter_per_epoch = len(train_loader)
- for epoch in range(args.epochs):
+ model = DistributedDataParallel(model, device_ids=[local_rank])
+
+ # ========== 8. 开始训练 ==========
+ for epoch in range(start_epoch, args.epochs):
train_sampler and train_sampler.set_epoch(epoch)
- train_epoch(epoch, wandb)
+ if epoch == start_epoch and start_step > 0: # 第一个epoch且存在检查点
+ batch_sampler = SkipBatchSampler(train_sampler or range(len(train_ds)), args.batch_size, start_step + 1)
+ loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
+ Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step,从step {start_step + 1}开始')
+ train_epoch(epoch, loader, len(loader) + start_step + 1, teacher_model, lm_config_student, start_step, wandb, args.alpha, args.temperature)
+ else: # 默认从头开始
+ loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=(train_sampler is None), sampler=train_sampler, num_workers=args.num_workers, pin_memory=True)
+ train_epoch(epoch, loader, len(loader), teacher_model, lm_config_student, 0, wandb, args.alpha, args.temperature)
diff --git a/trainer/train_dpo.py b/trainer/train_dpo.py
index 8c7bb45..b4e7b37 100644
--- a/trainer/train_dpo.py
+++ b/trainer/train_dpo.py
@@ -6,7 +6,6 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import argparse
import time
-import math
import warnings
import torch
import torch.nn.functional as F
@@ -15,55 +14,47 @@ from contextlib import nullcontext
from torch import optim
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader, DistributedSampler
-from transformers import AutoTokenizer, AutoModelForCausalLM
-from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
+from model.model_minimind import MiniMindConfig
from dataset.lm_dataset import DPODataset
+from trainer.trainer_utils import *
warnings.filterwarnings('ignore')
-def Logger(content):
- if not ddp or dist.get_rank() == 0:
- print(content)
-
-
-def get_lr(current_step, total_steps, lr):
- return lr / 10 + 0.5 * lr * (1 + math.cos(math.pi * current_step / total_steps))
-
-
-def logits_to_probs(logits, labels):
+def logits_to_log_probs(logits, labels):
# logits shape: (batch_size, seq_len, vocab_size)
# labels shape: (batch_size, seq_len)
- # probs shape: (batch_size, seq_len)
+ # log_probs shape: (batch_size, seq_len)
log_probs = F.log_softmax(logits, dim=2)
- probs = torch.gather(log_probs, dim=2, index=labels.unsqueeze(2)).squeeze(-1)
- return probs
+ log_probs_per_token = torch.gather(log_probs, dim=2, index=labels.unsqueeze(2)).squeeze(-1)
+ return log_probs_per_token
-def dpo_loss(ref_probs, probs, mask, beta):
- # ref_probs 和 probs 都是 shape: (batch_size, seq_len)
+def dpo_loss(ref_log_probs, policy_log_probs, mask, beta):
+ # ref_log_probs 和 policy_log_probs 都是 shape: (batch_size, seq_len)
# https://github.com/jingyaogong/minimind/issues/298
- seq_lengths = mask.sum(dim=1, keepdim=True) # (batch_size, 1)
- ref_probs = (ref_probs * mask).sum(dim=1) / seq_lengths.squeeze()
- probs = (probs * mask).sum(dim=1) / seq_lengths.squeeze()
+ seq_lengths = mask.sum(dim=1, keepdim=True).clamp_min(1e-8) # 防止零长度mask导致除零NaN
+ ref_log_probs = (ref_log_probs * mask).sum(dim=1) / seq_lengths.squeeze()
+ policy_log_probs = (policy_log_probs * mask).sum(dim=1) / seq_lengths.squeeze()
# 将 chosen 和 rejected 数据分开
- batch_size = ref_probs.shape[0]
- chosen_ref_probs = ref_probs[:batch_size // 2]
- reject_ref_probs = ref_probs[batch_size // 2:]
- chosen_probs = probs[:batch_size // 2]
- reject_probs = probs[batch_size // 2:]
+ batch_size = ref_log_probs.shape[0]
+ chosen_ref_log_probs = ref_log_probs[:batch_size // 2]
+ reject_ref_log_probs = ref_log_probs[batch_size // 2:]
+ chosen_policy_log_probs = policy_log_probs[:batch_size // 2]
+ reject_policy_log_probs = policy_log_probs[batch_size // 2:]
- pi_logratios = chosen_probs - reject_probs
- ref_logratios = chosen_ref_probs - reject_ref_probs
+ pi_logratios = chosen_policy_log_probs - reject_policy_log_probs
+ ref_logratios = chosen_ref_log_probs - reject_ref_log_probs
logits = pi_logratios - ref_logratios
loss = -F.logsigmoid(beta * logits)
return loss.mean()
-def train_epoch(epoch, wandb):
+def train_epoch(epoch, loader, iters, ref_model, lm_config, start_step=0, wandb=None, beta=0.1):
start_time = time.time()
- for step, batch in enumerate(train_loader):
+
+ for step, batch in enumerate(loader, start=start_step + 1):
x_chosen = batch['x_chosen'].to(args.device)
x_rejected = batch['x_rejected'].to(args.device)
y_chosen = batch['y_chosen'].to(args.device)
@@ -74,21 +65,21 @@ def train_epoch(epoch, wandb):
y = torch.cat([y_chosen, y_rejected], dim=0)
mask = torch.cat([mask_chosen, mask_rejected], dim=0)
- lr = get_lr(epoch * iter_per_epoch + step, args.epochs * iter_per_epoch, args.learning_rate)
+ lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
- with ctx:
+ with autocast_ctx:
with torch.no_grad():
ref_outputs = ref_model(x)
ref_logits = ref_outputs.logits
- ref_probs = logits_to_probs(ref_logits, y)
- ref_probs = ref_probs * mask
+ ref_log_probs = logits_to_log_probs(ref_logits, y)
+
outputs = model(x)
logits = outputs.logits
- probs = logits_to_probs(logits, y)
- probs = probs * mask
- loss = dpo_loss(ref_probs, probs, mask, beta=0.1)
+ policy_log_probs = logits_to_log_probs(logits, y)
+
+ loss = dpo_loss(ref_log_probs, policy_log_probs, mask, beta=beta)
loss = loss / args.accumulation_steps
scaler.scale(loss).backward()
@@ -100,150 +91,116 @@ def train_epoch(epoch, wandb):
scaler.update()
optimizer.zero_grad(set_to_none=True)
- if step % args.log_interval == 0 or step == iter_per_epoch - 1:
+ if step % args.log_interval == 0 or step == iters - 1:
spend_time = time.time() - start_time
- Logger(
- 'Epoch:[{}/{}]({}/{}) loss:{:.6f} lr:{:.12f} epoch_Time:{}min:'.format(
- epoch + 1,
- args.epochs,
- step,
- iter_per_epoch,
- loss.item() * args.accumulation_steps,
- optimizer.param_groups[-1]['lr'],
- spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60))
+ current_loss = loss.item() * args.accumulation_steps
+ current_lr = optimizer.param_groups[-1]['lr']
+ eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60
+
+ Logger(f'Epoch:[{epoch+1}/{args.epochs}]({step}/{iters}) loss:{current_loss:.6f} lr:{current_lr:.12f} epoch_Time:{eta_min}min:')
+
+ if wandb: wandb.log({"loss": current_loss, "lr": current_lr, "epoch_Time": eta_min})
- if (wandb is not None) and (not ddp or dist.get_rank() == 0):
- wandb.log({"loss": loss.item() * args.accumulation_steps,
- "lr": optimizer.param_groups[-1]['lr'],
- "epoch_Time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60})
-
- if ((step + 1) % args.save_interval == 0 or step == iter_per_epoch - 1) and (not ddp or dist.get_rank() == 0):
+ if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
model.eval()
- moe_path = '_moe' if lm_config.use_moe else ''
- ckp = f'{args.save_dir}/rlhf_{lm_config.hidden_size}{moe_path}.pth'
-
+ moe_suffix = '_moe' if lm_config.use_moe else ''
+ ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
state_dict = {k: v.half() for k, v in state_dict.items()} # 半精度保存
torch.save(state_dict, ckp)
+ lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints')
model.train()
-def init_model(lm_config):
- tokenizer = AutoTokenizer.from_pretrained('../model/')
- model = MiniMindForCausalLM(lm_config)
- moe_path = '_moe' if lm_config.use_moe else ''
- ckp = f'{args.save_dir}/full_sft_{lm_config.hidden_size}{moe_path}.pth'
- state_dict = torch.load(ckp, map_location=args.device)
- model.load_state_dict(state_dict, strict=False)
- # 初始化参考模型
- ref_model = MiniMindForCausalLM(lm_config)
- ref_model.load_state_dict(state_dict, strict=False)
- ref_model.eval()
- ref_model.requires_grad_(False)
-
- Logger(f'LLM总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
- model = model.to(args.device)
- ref_model = ref_model.to(args.device)
-
- return model, ref_model, tokenizer
-
-
-def init_distributed_mode():
- if not ddp: return
- global ddp_local_rank, DEVICE
-
- dist.init_process_group(backend="nccl")
- ddp_rank = int(os.environ["RANK"])
- ddp_local_rank = int(os.environ["LOCAL_RANK"])
- ddp_world_size = int(os.environ["WORLD_SIZE"])
- DEVICE = f"cuda:{ddp_local_rank}"
- torch.cuda.set_device(DEVICE)
-
-
if __name__ == "__main__":
- parser = argparse.ArgumentParser(description="MiniMind RLHF")
- parser.add_argument("--out_dir", type=str, default="../out")
- parser.add_argument("--epochs", type=int, default=2)
- parser.add_argument("--batch_size", type=int, default=4)
- # sft阶段学习率为 「5e-6」->「5e-7」长度512,建议离线正负样本「概率」偏好对齐阶段lr <=「1e-8」长度3000,否则很容易遗忘训坏
- parser.add_argument("--learning_rate", type=float, default=1e-8)
- parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu")
- parser.add_argument("--dtype", type=str, default="bfloat16")
- parser.add_argument("--use_wandb", action="store_true")
- parser.add_argument("--wandb_project", type=str, default="MiniMind-RLHF-SFT")
- parser.add_argument("--num_workers", type=int, default=1)
- parser.add_argument("--ddp", action="store_true")
- parser.add_argument("--accumulation_steps", type=int, default=1)
- parser.add_argument("--grad_clip", type=float, default=1.0)
- parser.add_argument("--warmup_iters", type=int, default=0)
- parser.add_argument("--log_interval", type=int, default=100)
- parser.add_argument("--save_interval", type=int, default=100)
- parser.add_argument('--local_rank', type=int, default=-1)
- parser.add_argument('--hidden_size', default=512, type=int)
- parser.add_argument('--num_hidden_layers', default=8, type=int)
- parser.add_argument('--max_seq_len', default=1024, type=int)
- parser.add_argument('--use_moe', default=False, type=bool)
- parser.add_argument("--data_path", type=str, default="../dataset/dpo.jsonl")
-
+ parser = argparse.ArgumentParser(description="MiniMind DPO (Direct Preference Optimization)")
+ parser.add_argument("--save_dir", type=str, default="../out", help="模型保存目录")
+ parser.add_argument('--save_weight', default='dpo', type=str, help="保存权重的前缀名")
+ parser.add_argument("--epochs", type=int, default=2, help="训练轮数")
+ parser.add_argument("--batch_size", type=int, default=4, help="batch size")
+ parser.add_argument("--learning_rate", type=float, default=5e-8, help="初始学习率(建议<=5e-8避免遗忘)")
+ parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备")
+ parser.add_argument("--dtype", type=str, default="bfloat16", help="混合精度类型")
+ parser.add_argument("--num_workers", type=int, default=1, help="数据加载线程数")
+ parser.add_argument("--accumulation_steps", type=int, default=1, help="梯度累积步数")
+ parser.add_argument("--grad_clip", type=float, default=1.0, help="梯度裁剪阈值")
+ parser.add_argument("--log_interval", type=int, default=100, help="日志打印间隔")
+ parser.add_argument("--save_interval", type=int, default=100, help="模型保存间隔")
+ parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度")
+ parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量")
+ parser.add_argument('--max_seq_len', default=1024, type=int, help="训练的最大截断长度")
+ parser.add_argument('--use_moe', default=False, type=bool, help="是否使用MoE")
+ parser.add_argument("--data_path", type=str, default="../dataset/dpo.jsonl", help="DPO训练数据路径")
+ parser.add_argument('--from_weight', default='full_sft', type=str, help="基于哪个权重训练")
+ parser.add_argument('--from_resume', default=0, type=int, help="是否自动检测&续训,0否1是")
+ parser.add_argument('--beta', default=0.1, type=float, help="DPO中的beta参数")
+ parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
+ parser.add_argument("--wandb_project", type=str, default="MiniMind-DPO", help="wandb项目名")
args = parser.parse_args()
- lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers,
- use_moe=args.use_moe)
- args.save_dir = os.path.join(args.out_dir)
+ # ========== 1. 初始化环境和随机种子 ==========
+ local_rank = init_distributed_mode()
+ if dist.is_initialized(): args.device = f"cuda:{local_rank}"
+ setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))
+
+ # ========== 2. 配置目录、模型参数、检查ckp ==========
os.makedirs(args.save_dir, exist_ok=True)
- os.makedirs(args.out_dir, exist_ok=True)
- tokens_per_iter = args.batch_size * args.max_seq_len
+ lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=args.use_moe)
+ ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None
+
+ # ========== 3. 设置混合精度 ==========
device_type = "cuda" if "cuda" in args.device else "cpu"
-
- args.wandb_run_name = f"MiniMind-Full-DPO-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"
-
- ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast()
- ddp = int(os.environ.get("RANK", -1)) != -1 # is this a ddp run?
- ddp_local_rank, DEVICE = 0, "cuda:0"
- base_seed = 1337
- torch.manual_seed(base_seed)
- torch.cuda.manual_seed(base_seed)
-
- if ddp:
- init_distributed_mode()
- args.device = torch.device(DEVICE)
- rank = dist.get_rank()
- torch.manual_seed(base_seed + rank)
- # 同时设置 CUDA 的随机种子
- torch.cuda.manual_seed(base_seed + rank)
-
- if args.use_wandb and (not ddp or ddp_local_rank == 0):
+ dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
+ autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
+
+ # ========== 4. 配wandb ==========
+ wandb = None
+ if args.use_wandb and is_main_process():
import swanlab as wandb
-
- wandb.init(project=args.wandb_project, name=args.wandb_run_name)
- else:
- wandb = None
-
- model, ref_model, tokenizer = init_model(lm_config)
-
+ wandb_id = ckp_data.get('wandb_id') if ckp_data else None
+ resume = 'must' if wandb_id else None
+ wandb_run_name = f"MiniMind-DPO-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LR-{args.learning_rate}"
+ wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
+
+ # ========== 5. 定义模型和参考模型 ==========
+ model, tokenizer = init_model(lm_config, args.from_weight)
+ Logger(f'策略模型总参数量:{sum(p.numel() for p in model.parameters()) / 1e6:.3f} M')
+ # 初始化参考模型(ref_model冻结)
+ ref_model, _ = init_model(lm_config, args.from_weight)
+ ref_model.eval()
+ ref_model.requires_grad_(False)
+ Logger(f'参考模型总参数量:{sum(p.numel() for p in ref_model.parameters()) / 1e6:.3f} M')
+
train_ds = DPODataset(args.data_path, tokenizer, max_length=args.max_seq_len)
- train_sampler = DistributedSampler(train_ds) if ddp else None
- train_loader = DataLoader(
- train_ds,
- batch_size=args.batch_size,
- pin_memory=True,
- drop_last=False,
- shuffle=(train_sampler is None),
- num_workers=args.num_workers,
- sampler=train_sampler
- )
-
- scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype in ['float16', 'bfloat16']))
+ train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
+ scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
-
- if ddp:
+
+ # ========== 6. 从ckp恢复状态 ==========
+ start_epoch, start_step = 0, 0
+ if ckp_data:
+ model.load_state_dict(ckp_data['model'])
+ optimizer.load_state_dict(ckp_data['optimizer'])
+ scaler.load_state_dict(ckp_data['scaler'])
+ start_epoch = ckp_data['epoch']
+ start_step = ckp_data.get('step', 0)
+
+ # ========== 7. DDP包模型 ==========
+ if dist.is_initialized():
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
- model = DistributedDataParallel(model, device_ids=[ddp_local_rank])
-
- iter_per_epoch = len(train_loader)
- for epoch in range(args.epochs):
+ model = DistributedDataParallel(model, device_ids=[local_rank])
+
+ # ========== 8. 开始训练 ==========
+ for epoch in range(start_epoch, args.epochs):
train_sampler and train_sampler.set_epoch(epoch)
- train_epoch(epoch, wandb)
+ if epoch == start_epoch and start_step > 0: # 第一个epoch且存在检查点
+ batch_sampler = SkipBatchSampler(train_sampler or range(len(train_ds)), args.batch_size, start_step + 1)
+ loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
+ Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step,从step {start_step + 1}开始')
+ train_epoch(epoch, loader, len(loader) + start_step + 1, ref_model, lm_config, start_step, wandb, args.beta)
+ else: # 默认从头开始
+ loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=(train_sampler is None), sampler=train_sampler, num_workers=args.num_workers, pin_memory=True)
+ train_epoch(epoch, loader, len(loader), ref_model, lm_config, 0, wandb, args.beta)
diff --git a/trainer/train_full_sft.py b/trainer/train_full_sft.py
index 041801f..09fa941 100644
--- a/trainer/train_full_sft.py
+++ b/trainer/train_full_sft.py
@@ -6,7 +6,6 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import argparse
import time
-import math
import warnings
import torch
import torch.distributed as dist
@@ -14,34 +13,25 @@ from contextlib import nullcontext
from torch import optim, nn
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader, DistributedSampler
-from transformers import AutoTokenizer, AutoModelForCausalLM
-from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
+from model.model_minimind import MiniMindConfig
from dataset.lm_dataset import SFTDataset
+from trainer.trainer_utils import *
warnings.filterwarnings('ignore')
-def Logger(content):
- if not ddp or dist.get_rank() == 0:
- print(content)
-
-
-def get_lr(current_step, total_steps, lr):
- return lr / 10 + 0.5 * lr * (1 + math.cos(math.pi * current_step / total_steps))
-
-
-def train_epoch(epoch, wandb):
+def train_epoch(epoch, loader, iters, start_step=0, wandb=None):
loss_fct = nn.CrossEntropyLoss(reduction='none')
start_time = time.time()
- for step, (X, Y, loss_mask) in enumerate(train_loader):
+ for step, (X, Y, loss_mask) in enumerate(loader, start=start_step + 1):
X = X.to(args.device)
Y = Y.to(args.device)
loss_mask = loss_mask.to(args.device)
- lr = get_lr(epoch * iter_per_epoch + step, args.epochs * iter_per_epoch, args.learning_rate)
+ lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
- with ctx:
+ with autocast_ctx:
res = model(X)
loss = loss_fct(
res.logits.view(-1, res.logits.size(-1)),
@@ -63,141 +53,109 @@ def train_epoch(epoch, wandb):
optimizer.zero_grad(set_to_none=True)
- if step % args.log_interval == 0 or step == iter_per_epoch - 1:
+ if step % args.log_interval == 0 or step == iters - 1:
spend_time = time.time() - start_time
- Logger(
- 'Epoch:[{}/{}]({}/{}) loss:{:.6f} lr:{:.12f} epoch_Time:{}min:'.format(
- epoch + 1,
- args.epochs,
- step,
- iter_per_epoch,
- loss.item() * args.accumulation_steps,
- optimizer.param_groups[-1]['lr'],
- spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60))
+ current_loss = loss.item() * args.accumulation_steps
+ current_lr = optimizer.param_groups[-1]['lr']
+ eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60
+
+ Logger(f'Epoch:[{epoch+1}/{args.epochs}]({step}/{iters}) loss:{current_loss:.6f} lr:{current_lr:.12f} epoch_Time:{eta_min}min:')
+
+ if wandb: wandb.log({"loss": current_loss, "lr": current_lr, "epoch_Time": eta_min})
- if (wandb is not None) and (not ddp or dist.get_rank() == 0):
- wandb.log({"loss": loss.item() * args.accumulation_steps,
- "lr": optimizer.param_groups[-1]['lr'],
- "epoch_Time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60})
-
- if ((step + 1) % args.save_interval == 0 or step == iter_per_epoch - 1) and (not ddp or dist.get_rank() == 0):
+ if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
model.eval()
- moe_path = '_moe' if lm_config.use_moe else ''
- ckp = f'{args.save_dir}/full_sft_{lm_config.hidden_size}{moe_path}.pth'
+ moe_suffix = '_moe' if lm_config.use_moe else ''
+ ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
state_dict = {k: v.half() for k, v in state_dict.items()} # 半精度保存
torch.save(state_dict, ckp)
+ lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer,
+ epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints', scaler=scaler)
model.train()
-def init_model(lm_config):
- tokenizer = AutoTokenizer.from_pretrained('../model')
- model = MiniMindForCausalLM(lm_config)
- moe_path = '_moe' if lm_config.use_moe else ''
- ckp = f'{args.save_dir}/pretrain_{lm_config.hidden_size}{moe_path}.pth'
- state_dict = torch.load(ckp, map_location=args.device)
- model.load_state_dict(state_dict, strict=False)
-
- Logger(f'LLM可训练总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
- model = model.to(args.device)
- return model, tokenizer
-
-
-def init_distributed_mode():
- if not ddp: return
- global ddp_local_rank, DEVICE
-
- dist.init_process_group(backend="nccl")
- ddp_rank = int(os.environ["RANK"])
- ddp_local_rank = int(os.environ["LOCAL_RANK"])
- ddp_world_size = int(os.environ["WORLD_SIZE"])
- DEVICE = f"cuda:{ddp_local_rank}"
- torch.cuda.set_device(DEVICE)
-
-
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="MiniMind Full SFT")
- parser.add_argument("--out_dir", type=str, default="../out")
- parser.add_argument("--epochs", type=int, default=2)
- parser.add_argument("--batch_size", type=int, default=16)
- parser.add_argument("--learning_rate", type=float, default=5e-7)
- parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu")
- parser.add_argument("--dtype", type=str, default="bfloat16")
- parser.add_argument("--use_wandb", action="store_true")
- parser.add_argument("--wandb_project", type=str, default="MiniMind-Full-SFT")
- parser.add_argument("--num_workers", type=int, default=1)
- parser.add_argument("--ddp", action="store_true")
- parser.add_argument("--accumulation_steps", type=int, default=1)
- parser.add_argument("--grad_clip", type=float, default=1.0)
- parser.add_argument("--warmup_iters", type=int, default=0)
- parser.add_argument("--log_interval", type=int, default=100)
- parser.add_argument("--save_interval", type=int, default=100)
- parser.add_argument('--local_rank', type=int, default=-1)
- parser.add_argument('--hidden_size', default=512, type=int)
- parser.add_argument('--num_hidden_layers', default=8, type=int)
- parser.add_argument('--max_seq_len', default=512, type=int)
- parser.add_argument('--use_moe', default=False, type=bool)
- parser.add_argument("--data_path", type=str, default="../dataset/sft_mini_512.jsonl")
-
+ parser.add_argument("--save_dir", type=str, default="../out", help="模型保存目录")
+ parser.add_argument('--save_weight', default='full_sft', type=str, help="保存权重的前缀名")
+ parser.add_argument("--epochs", type=int, default=2, help="训练轮数")
+ parser.add_argument("--batch_size", type=int, default=16, help="batch size")
+ parser.add_argument("--learning_rate", type=float, default=5e-7, help="初始学习率")
+ parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备")
+ parser.add_argument("--dtype", type=str, default="bfloat16", help="混合精度类型")
+ parser.add_argument("--num_workers", type=int, default=1, help="数据加载线程数")
+ parser.add_argument("--accumulation_steps", type=int, default=1, help="梯度累积步数")
+ parser.add_argument("--grad_clip", type=float, default=1.0, help="梯度裁剪阈值")
+ parser.add_argument("--log_interval", type=int, default=100, help="日志打印间隔")
+ parser.add_argument("--save_interval", type=int, default=100, help="模型保存间隔")
+ parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度")
+ parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量")
+ parser.add_argument('--max_seq_len', default=512, type=int, help="训练的最大截断长度")
+ parser.add_argument('--use_moe', default=False, type=bool, help="是否使用MoE")
+ parser.add_argument("--data_path", type=str, default="../dataset/sft_mini_512.jsonl", help="训练数据路径")
+ parser.add_argument('--from_weight', default='pretrain', type=str, help="基于哪个权重训练,为none则不基于任何权重训练")
+ parser.add_argument('--from_resume', default=0, type=int, help="是否自动检测&续训,0否1是")
+ parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
+ parser.add_argument("--wandb_project", type=str, default="MiniMind-Full-SFT", help="wandb项目名")
args = parser.parse_args()
- lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers,
- use_moe=args.use_moe)
- args.save_dir = os.path.join(args.out_dir)
+ # ========== 1. 初始化环境和随机种子 ==========
+ local_rank = init_distributed_mode()
+ if dist.is_initialized(): args.device = f"cuda:{local_rank}"
+ setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))
+
+ # ========== 2. 配置目录、模型参数、检查ckp ==========
os.makedirs(args.save_dir, exist_ok=True)
- os.makedirs(args.out_dir, exist_ok=True)
- tokens_per_iter = args.batch_size * args.max_seq_len
+ lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=args.use_moe)
+ ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None
+
+ # ========== 3. 设置混合精度 ==========
device_type = "cuda" if "cuda" in args.device else "cpu"
-
- args.wandb_run_name = f"MiniMind-Full-SFT-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"
-
- ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast()
- ddp = int(os.environ.get("RANK", -1)) != -1 # is this a ddp run?
- ddp_local_rank, DEVICE = 0, "cuda:0"
- base_seed = 1337
- torch.manual_seed(base_seed)
- torch.cuda.manual_seed(base_seed)
-
- if ddp:
- init_distributed_mode()
- args.device = torch.device(DEVICE)
- rank = dist.get_rank()
- torch.manual_seed(base_seed + rank)
- # 同时设置 CUDA 的随机种子
- torch.cuda.manual_seed(base_seed + rank)
-
- if args.use_wandb and (not ddp or ddp_local_rank == 0):
+ dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
+ autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
+
+ # ========== 4. 配wandb ==========
+ wandb = None
+ if args.use_wandb and is_main_process():
import swanlab as wandb
-
- wandb.init(project=args.wandb_project, name=args.wandb_run_name)
- else:
- wandb = None
-
- model, tokenizer = init_model(lm_config)
-
+ wandb_id = ckp_data.get('wandb_id') if ckp_data else None
+ resume = 'must' if wandb_id else None
+ wandb_run_name = f"MiniMind-Full-SFT-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"
+ wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
+
+ # ========== 5. 定义模型、数据、优化器 ==========
+ model, tokenizer = init_model(lm_config, args.from_weight)
train_ds = SFTDataset(args.data_path, tokenizer, max_length=args.max_seq_len)
- train_sampler = DistributedSampler(train_ds) if ddp else None
- train_loader = DataLoader(
- train_ds,
- batch_size=args.batch_size,
- pin_memory=True,
- drop_last=False,
- shuffle=(train_sampler is None),
- num_workers=args.num_workers,
- sampler=train_sampler
- )
-
- scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype in ['float16', 'bfloat16']))
+ train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
+ scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
-
- if ddp:
+
+ # ========== 6. 从ckp恢复状态 ==========
+ start_epoch, start_step = 0, 0
+ if ckp_data:
+ model.load_state_dict(ckp_data['model'])
+ optimizer.load_state_dict(ckp_data['optimizer'])
+ scaler.load_state_dict(ckp_data['scaler'])
+ start_epoch = ckp_data['epoch']
+ start_step = ckp_data.get('step', 0)
+
+ # ========== 7. DDP包模型 ==========
+ if dist.is_initialized():
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
- model = DistributedDataParallel(model, device_ids=[ddp_local_rank])
-
- iter_per_epoch = len(train_loader)
- for epoch in range(args.epochs):
+ model = DistributedDataParallel(model, device_ids=[local_rank])
+
+ # ========== 8. 开始训练 ==========
+ for epoch in range(start_epoch, args.epochs):
train_sampler and train_sampler.set_epoch(epoch)
- train_epoch(epoch, wandb)
+ if epoch == start_epoch and start_step > 0: # 第一个epoch且存在检查点
+ batch_sampler = SkipBatchSampler(train_sampler or range(len(train_ds)), args.batch_size, start_step + 1)
+ loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
+ Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step,从step {start_step + 1}开始')
+ train_epoch(epoch, loader, len(loader) + start_step + 1, start_step, wandb)
+ else: # 默认从头开始
+ loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=(train_sampler is None), sampler=train_sampler, num_workers=args.num_workers, pin_memory=True)
+ train_epoch(epoch, loader, len(loader), 0, wandb)
diff --git a/trainer/train_grpo.py b/trainer/train_grpo.py
index 704800d..22a4e26 100755
--- a/trainer/train_grpo.py
+++ b/trainer/train_grpo.py
@@ -3,59 +3,49 @@ import sys
__package__ = "trainer"
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
+
import argparse
-import time
import re
import gc
+import warnings
import torch
-from contextlib import nullcontext
import torch.distributed as dist
+from transformers import AutoTokenizer
+from contextlib import nullcontext
from torch import optim
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader, DistributedSampler
-from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModel
+from torch.optim.lr_scheduler import CosineAnnealingLR
+from transformers import AutoModel
from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
from dataset.lm_dataset import RLAIFDataset
-from torch.optim.lr_scheduler import CosineAnnealingLR
+from trainer.trainer_utils import *
-
-def Logger(content):
- if not ddp or dist.get_rank() == 0:
- print(content)
+warnings.filterwarnings('ignore')
def calculate_rewards(prompts, responses, reward_model, reward_tokenizer):
"""整合所有奖励函数计算总奖励"""
-
def reasoning_model_reward(rewards):
- # 1. 格式奖励(仅针对训练推理模型时使用)
pattern = r"^\n.*?\n\n\n.*?\n$"
pattern2 = r"^\n.*?\n\n\n\n.*?\n$"
-
matches_pattern = [re.match(pattern, response, re.S) for response in responses]
matches_pattern2 = [re.match(pattern2, response, re.S) for response in responses]
format_rewards = []
for match_pattern, match_pattern2 in zip(matches_pattern, matches_pattern2):
- if match_pattern:
- format_rewards.append(0.5)
- elif match_pattern2:
+ if match_pattern or match_pattern2:
format_rewards.append(0.5)
else:
format_rewards.append(0.0)
rewards += torch.tensor(format_rewards, device=args.device)
- # 2. 标记奖励(防止严格奖励稀疏,仅针对训练推理模型时使用)
def mark_num(text):
reward = 0
- if text.count("") == 1:
- reward += 0.25
- if text.count("") == 1:
- reward += 0.25
- if text.count("") == 1:
- reward += 0.25
- if text.count("") == 1:
- reward += 0.25
+ if text.count("") == 1: reward += 0.25
+ if text.count("") == 1: reward += 0.25
+ if text.count("") == 1: reward += 0.25
+ if text.count("") == 1: reward += 0.25
return reward
mark_rewards = [mark_num(response) for response in responses]
@@ -63,12 +53,9 @@ def calculate_rewards(prompts, responses, reward_model, reward_tokenizer):
return rewards
rewards = torch.zeros(len(responses), device=args.device)
-
- # 3. 格式奖励
if args.reasoning == 1:
- rewards = reasoning_model_reward(rewards) # 训练推理模型时使用
+ rewards = reasoning_model_reward(rewards)
- # 4. 使用reward model计算奖励
with torch.no_grad():
reward_model_scores = []
batch_size = len(prompts)
@@ -105,8 +92,8 @@ def calculate_rewards(prompts, responses, reward_model, reward_tokenizer):
return rewards
-def grpo_train_epoch(epoch, wandb):
- for step, batch in enumerate(train_loader):
+def grpo_train_epoch(epoch, loader, iters, ref_model, reward_model, reward_tokenizer, start_step=0, wandb=None):
+ for step, batch in enumerate(loader, start=start_step + 1):
prompts = batch['prompt'] # list[str], length B
prompt_inputs = tokenizer(prompts, return_tensors="pt", padding=True, return_token_type_ids=False,
padding_side="left", add_special_tokens=False).to(args.device) # input_ids: [B, P], attention_mask: [B, P]
@@ -115,7 +102,9 @@ def grpo_train_epoch(epoch, wandb):
prompt_inputs["attention_mask"] = prompt_inputs["attention_mask"][:, -args.max_seq_len:]
with torch.no_grad():
- outputs = (model.module if ddp else model).generate(
+ # DDP 模型需要使用 .module 访问 generate 方法
+ model_for_gen = model.module if isinstance(model, DistributedDataParallel) else model
+ outputs = model_for_gen.generate(
**prompt_inputs, max_new_tokens=args.max_gen_len, do_sample=True, temperature=0.8,
num_return_sequences=args.num_generations, pad_token_id=tokenizer.pad_token_id) # [B*num_gen, P+R]
@@ -161,36 +150,33 @@ def grpo_train_epoch(epoch, wandb):
scheduler.step()
optimizer.zero_grad()
- if step % args.log_interval == 0 or step == iter_per_epoch - 1:
+ if step % args.log_interval == 0 or step == iters:
policy_loss_val = loss.item()
avg_reward_val = rewards.mean().item()
avg_len_val = completion_mask.sum(dim=1).float().mean().item()
current_lr = optimizer.param_groups[0]['lr']
- Logger(
- f'Epoch: {epoch}, Step: {step + 1}/{iter_per_epoch}, '
- f'Actor Loss: {policy_loss_val:.6f}, Reward: {avg_reward_val:.6f}, '
- f'Avg Response Len: {avg_len_val:.2f}, LR: {current_lr:.2e}')
+ Logger(f'Epoch: {epoch+1}, Step: {step}/{iters}, '
+ f'Actor Loss: {policy_loss_val:.6f}, Reward: {avg_reward_val:.6f}, '
+ f'Avg Response Len: {avg_len_val:.2f}, LR: {current_lr:.2e}')
- if wandb and (not ddp or dist.get_rank() == 0):
- log_dict = {
+ if wandb and is_main_process():
+ wandb.log({
"policy_loss": policy_loss_val,
"reward": avg_reward_val,
"avg_response_len": avg_len_val,
"advantages_mean": advantages.mean().item(),
"learning_rate": current_lr
- }
- wandb.log(log_dict)
+ })
- if ((step + 1) % args.save_interval == 0 or step == iter_per_epoch - 1) and (not ddp or dist.get_rank() == 0):
+ if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
model.eval()
- moe_path = '_moe' if lm_config.use_moe else ''
- suffix = 'grpo'
- ckp = f'{args.save_dir}/{suffix}_{lm_config.hidden_size}{moe_path}.pth'
-
- state_dict = model.module.state_dict() if isinstance(model,
- torch.nn.parallel.DistributedDataParallel) else model.state_dict()
+ moe_suffix = '_moe' if lm_config.use_moe else ''
+ ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
+ state_dict = model.module.state_dict() if isinstance(model, DistributedDataParallel) else model.state_dict()
torch.save({k: v.half() for k, v in state_dict.items()}, ckp)
+ lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer,
+ epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints', scheduler=scheduler)
model.train()
del prompt_inputs, outputs, completion_ids, per_token_logps, ref_per_token_logps
@@ -199,119 +185,114 @@ def grpo_train_epoch(epoch, wandb):
gc.collect()
-def init_model(lm_config):
- tokenizer = AutoTokenizer.from_pretrained('../model/')
- model = MiniMindForCausalLM(lm_config)
- moe_path = '_moe' if lm_config.use_moe else ''
- ckp = f'{args.save_dir}/full_sft_{lm_config.hidden_size}{moe_path}.pth'
- if args.reasoning == 1:
- ckp = f'{args.save_dir}/reason_{lm_config.hidden_size}{moe_path}.pth'
- state_dict = torch.load(ckp, map_location=args.device)
- model.load_state_dict(state_dict, strict=False)
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="MiniMind GRPO (Group Relative Policy Optimization)")
+ parser.add_argument("--save_dir", type=str, default="../out", help="模型保存目录")
+ parser.add_argument('--save_weight', default='grpo', type=str, help="保存权重的前缀名")
+ parser.add_argument("--epochs", type=int, default=1, help="训练轮数")
+ parser.add_argument("--batch_size", type=int, default=2, help="batch size")
+ parser.add_argument("--learning_rate", type=float, default=8e-8, help="初始学习率")
+ parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备")
+ parser.add_argument("--dtype", type=str, default="bfloat16", help="混合精度类型")
+ parser.add_argument("--num_workers", type=int, default=1, help="数据加载线程数")
+ parser.add_argument("--accumulation_steps", type=int, default=1, help="梯度累积步数")
+ parser.add_argument("--grad_clip", type=float, default=1.0, help="梯度裁剪阈值")
+ parser.add_argument("--log_interval", type=int, default=1, help="日志打印间隔")
+ parser.add_argument("--save_interval", type=int, default=10, help="模型保存间隔")
+ parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度")
+ parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量")
+ parser.add_argument('--use_moe', default=False, type=bool, help="是否使用MoE")
+ parser.add_argument('--max_seq_len', default=66, type=int, help="Prompt最大长度")
+ parser.add_argument("--max_gen_len", type=int, default=1536, help="生成的最大长度")
+ parser.add_argument("--data_path", type=str, default="../dataset/rlaif-mini.jsonl", help="RLAIF数据路径")
+ parser.add_argument("--num_generations", type=int, default=8, help="每个prompt生成的样本数")
+ parser.add_argument("--beta", type=float, default=0.02, help="KL惩罚系数")
+ parser.add_argument("--reasoning", type=int, default=1, help='0:普通模型,1:推理模型')
+ parser.add_argument("--reward_model_path", type=str, default="../../internlm2-1_8b-reward", help="Reward模型路径")
+ parser.add_argument('--from_resume', default=0, type=int, help="是否自动检测&续训,0否1是")
+ parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
+ parser.add_argument("--wandb_project", type=str, default="MiniMind-GRPO", help="wandb项目名")
+ args = parser.parse_args()
+ # ========== 1. 初始化环境和随机种子 ==========
+ local_rank = init_distributed_mode()
+ if dist.is_initialized(): args.device = f"cuda:{local_rank}"
+ setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))
+
+ # ========== 2. 配置目录、模型参数、检查ckp ==========
+ os.makedirs(args.save_dir, exist_ok=True)
+ lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers,
+ max_seq_len=args.max_seq_len + args.max_gen_len, use_moe=args.use_moe)
+ ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None
+
+ # ========== 3. 设置混合精度 ==========
+ device_type = "cuda" if "cuda" in args.device else "cpu"
+ dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
+ autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
+
+ # ========== 4. 配wandb ==========
+ wandb = None
+ if args.use_wandb and is_main_process():
+ import swanlab as wandb
+ wandb_id = ckp_data.get('wandb_id') if ckp_data else None
+ resume = 'must' if wandb_id else None
+ wandb_run_name = f"MiniMind-GRPO-Epoch-{args.epochs}-BS-{args.batch_size}-LR-{args.learning_rate}"
+ wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
+
+ # ========== 5. 初始化模型和数据 ==========
+ tokenizer = AutoTokenizer.from_pretrained('../model/')
+ moe_suffix = '_moe' if lm_config.use_moe else ''
+ base_weight = "reason" if args.reasoning == 1 else "full_sft"
+ ckp = f'{args.save_dir}/{base_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
+ state_dict = torch.load(ckp, map_location=args.device)
+ # Policy模型
+ model = MiniMindForCausalLM(lm_config)
+ model.load_state_dict(state_dict, strict=False)
+ model = model.to(args.device)
+ Logger(f'Policy模型总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} M')
+ # Reference模型
ref_model = MiniMindForCausalLM(lm_config)
ref_model.load_state_dict(state_dict, strict=False)
ref_model.eval().requires_grad_(False)
-
- Logger(f'LLM总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
- model = model.to(args.device)
ref_model = ref_model.to(args.device)
-
- reward_name = "../../internlm2-1_8b-reward"
+ # Reward模型
reward_model = AutoModel.from_pretrained(
- reward_name,
- device_map="cuda",
- torch_dtype=torch.float16,
- trust_remote_code=True,
+ args.reward_model_path, device_map="cuda", torch_dtype=torch.float16, trust_remote_code=True
).to(args.device).eval().requires_grad_(False)
- reward_tokenizer = AutoTokenizer.from_pretrained(reward_name, trust_remote_code=True)
-
- return model, ref_model, tokenizer, reward_model, reward_tokenizer
-
-
-def init_distributed_mode():
- if not ddp: return
- global ddp_local_rank, DEVICE
- dist.init_process_group(backend="nccl")
- ddp_local_rank = int(os.environ["LOCAL_RANK"])
- DEVICE = f"cuda:{ddp_local_rank}"
- torch.cuda.set_device(DEVICE)
-
-
-if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument("--out_dir", type=str, default="../out")
- parser.add_argument("--epochs", type=int, default=1)
- parser.add_argument("--batch_size", type=int, default=2)
- parser.add_argument("--learning_rate", type=float, default=8e-8)
- parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu")
- parser.add_argument("--dtype", type=str, default="bfloat16")
- parser.add_argument("--use_wandb", action="store_true")
- parser.add_argument("--wandb_project", type=str, default="MiniMind-GRPO")
- parser.add_argument("--num_workers", type=int, default=1)
- parser.add_argument("--ddp", action="store_true")
- parser.add_argument("--accumulation_steps", type=int, default=1)
- parser.add_argument("--grad_clip", type=float, default=1.0)
- parser.add_argument("--log_interval", type=int, default=1)
- parser.add_argument("--save_interval", type=int, default=10)
- parser.add_argument('--hidden_size', default=512, type=int)
- parser.add_argument('--num_hidden_layers', default=8, type=int)
- parser.add_argument('--use_moe', default=False, type=bool)
- parser.add_argument('--max_seq_len', default=66, type=int)
- parser.add_argument("--max_gen_len", type=int, default=1536)
- parser.add_argument("--data_path", type=str, default="../dataset/rlaif-mini.jsonl")
- parser.add_argument("--num_generations", type=int, default=8)
- parser.add_argument("--beta", type=float, default=0.02)
- parser.add_argument("--reasoning", type=int, default=1, help='0:普通模型,1:推理模型')
- args = parser.parse_args()
-
- lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers,
- max_seq_len=args.max_seq_len + args.max_gen_len,
- use_moe=args.use_moe)
- args.save_dir = os.path.join(args.out_dir)
- os.makedirs(args.save_dir, exist_ok=True)
- os.makedirs(args.out_dir, exist_ok=True)
-
- ctx = nullcontext() if "cuda" not in args.device else torch.amp.autocast('cuda')
- ddp = int(os.environ.get("RANK", -1)) != -1
- ddp_local_rank, DEVICE = 0, "cuda:0"
-
- base_seed = 1337
- torch.manual_seed(base_seed)
- torch.cuda.manual_seed(base_seed)
-
- if ddp:
- init_distributed_mode()
- args.device = torch.device(DEVICE)
- rank = dist.get_rank()
- torch.manual_seed(base_seed + rank)
- # 同时设置 CUDA 的随机种子
- torch.cuda.manual_seed(base_seed + rank)
-
- if args.use_wandb and (not ddp or ddp_local_rank == 0):
- import swanlab as wandb
-
- wandb.init(project=args.wandb_project)
- else:
- wandb = None
-
- model, ref_model, tokenizer, reward_model, reward_tokenizer = init_model(lm_config=lm_config)
+ reward_tokenizer = AutoTokenizer.from_pretrained(args.reward_model_path, trust_remote_code=True)
+ # 数据和优化器
train_ds = RLAIFDataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len)
- train_sampler = DistributedSampler(train_ds) if ddp else None
- train_loader = DataLoader(train_ds, batch_size=args.batch_size, pin_memory=True,
+ train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
+ optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
+ loader_for_count = DataLoader(train_ds, batch_size=args.batch_size, sampler=train_sampler)
+ iters = len(loader_for_count)
+ total_optimizer_steps = (iters // args.accumulation_steps) * args.epochs
+ scheduler = CosineAnnealingLR(optimizer, T_max=total_optimizer_steps, eta_min=args.learning_rate / 10)
+
+ # ========== 6. 从ckp恢复状态 ==========
+ start_epoch, start_step = 0, 0
+ if ckp_data:
+ model.load_state_dict(ckp_data['model'])
+ optimizer.load_state_dict(ckp_data['optimizer'])
+ scheduler.load_state_dict(ckp_data['scheduler'])
+ start_epoch = ckp_data['epoch']
+ start_step = ckp_data.get('step', 0)
+
+ # ========== 7. DDP包模型 ==========
+ if dist.is_initialized():
+ model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
+ model = DistributedDataParallel(model, device_ids=[local_rank])
+
+ # ========== 8. 开始训练 ==========
+ for epoch in range(start_epoch, args.epochs):
+ train_sampler and train_sampler.set_epoch(epoch)
+ if epoch == start_epoch and start_step > 0: # 第一个epoch且存在检查点
+ batch_sampler = SkipBatchSampler(train_sampler or range(len(train_ds)), args.batch_size, start_step + 1)
+ loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
+ Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step,从step {start_step + 1}开始')
+ grpo_train_epoch(epoch, loader, len(loader) + start_step + 1, ref_model, reward_model, reward_tokenizer, start_step, wandb)
+ else: # 默认从头开始
+ loader = DataLoader(train_ds, batch_size=args.batch_size, pin_memory=True,
drop_last=False, shuffle=(train_sampler is None),
num_workers=args.num_workers, sampler=train_sampler)
-
- optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
-
- iter_per_epoch = len(train_loader)
- total_optimizer_steps = (iter_per_epoch // args.accumulation_steps) * args.epochs
- scheduler = CosineAnnealingLR(optimizer, T_max=total_optimizer_steps, eta_min=args.learning_rate / 10)
-
- if ddp:
- model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
- model = DistributedDataParallel(model, device_ids=[ddp_local_rank])
-
- for epoch in range(args.epochs):
- train_sampler and train_sampler.set_epoch(epoch)
- grpo_train_epoch(epoch, wandb)
+ grpo_train_epoch(epoch, loader, len(loader), ref_model, reward_model, reward_tokenizer, 0, wandb)
diff --git a/trainer/train_lora.py b/trainer/train_lora.py
index df9f9ae..b6fc2b0 100644
--- a/trainer/train_lora.py
+++ b/trainer/train_lora.py
@@ -6,49 +6,39 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import argparse
import time
-import math
import warnings
import torch
-from torch import optim, nn
import torch.distributed as dist
from contextlib import nullcontext
+from torch import optim, nn
+from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader, DistributedSampler
-from transformers import AutoTokenizer, AutoModelForCausalLM
-from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
+from model.model_minimind import MiniMindConfig
from dataset.lm_dataset import SFTDataset
-from model.model_lora import load_lora, save_lora, apply_lora
+from model.model_lora import save_lora, apply_lora
+from trainer.trainer_utils import *
warnings.filterwarnings('ignore')
-# Logger function
-def Logger(content):
- if not ddp or dist.get_rank() == 0:
- print(content)
-
-
-def get_lr(current_step, total_steps, lr):
- return lr / 10 + 0.5 * lr * (1 + math.cos(math.pi * current_step / total_steps))
-
-
-# 代码和full_sft「几乎」一致
-def train_epoch(epoch, wandb):
+def train_epoch(epoch, loader, iters, lora_params, start_step=0, wandb=None):
loss_fct = nn.CrossEntropyLoss(reduction='none')
start_time = time.time()
- for step, (X, Y, loss_mask) in enumerate(train_loader):
+ for step, (X, Y, loss_mask) in enumerate(loader, start=start_step + 1):
X = X.to(args.device)
Y = Y.to(args.device)
loss_mask = loss_mask.to(args.device)
- lr = get_lr(epoch * iter_per_epoch + step, args.epochs * iter_per_epoch, args.learning_rate)
+ lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
- with ctx:
+ with autocast_ctx:
res = model(X)
loss = loss_fct(
res.logits.view(-1, res.logits.size(-1)),
Y.view(-1)
).view(Y.size())
+
loss = (loss * loss_mask).sum() / loss_mask.sum()
loss += res.aux_loss
loss = loss / args.accumulation_steps
@@ -64,146 +54,122 @@ def train_epoch(epoch, wandb):
optimizer.zero_grad(set_to_none=True)
- if step % args.log_interval == 0 or step == iter_per_epoch - 1:
+ if step % args.log_interval == 0 or step == iters - 1:
spend_time = time.time() - start_time
- Logger(
- 'Epoch:[{}/{}]({}/{}) loss:{:.6f} lr:{:.12f} epoch_Time:{}min:'.format(
- epoch + 1,
- args.epochs,
- step,
- iter_per_epoch,
- loss.item() * args.accumulation_steps,
- optimizer.param_groups[-1]['lr'],
- spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60))
+ current_loss = loss.item() * args.accumulation_steps
+ current_lr = optimizer.param_groups[-1]['lr']
+ eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60
+
+ Logger(f'Epoch:[{epoch+1}/{args.epochs}]({step}/{iters}) loss:{current_loss:.6f} lr:{current_lr:.12f} epoch_Time:{eta_min}min:')
+
+ if wandb: wandb.log({"loss": current_loss, "lr": current_lr, "epoch_Time": eta_min})
- if (wandb is not None) and (not ddp or dist.get_rank() == 0):
- wandb.log({"loss": loss.item() * args.accumulation_steps,
- "lr": optimizer.param_groups[-1]['lr'],
- "epoch_Time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60})
-
- if ((step + 1) % args.save_interval == 0 or step == iter_per_epoch - 1) and (not ddp or dist.get_rank() == 0):
+ if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
model.eval()
- lora_save_path = f'{args.save_dir}/lora/{args.lora_name}_{lm_config.hidden_size}.pth'
- os.makedirs(os.path.dirname(lora_save_path), exist_ok=True)
- # 【区别1】只保存lora权重即可
+ lora_save_path = f'{args.save_dir}/{args.lora_name}_{lm_config.hidden_size}.pth'
+ # LoRA只保存LoRA权重
save_lora(model, lora_save_path)
+ lm_checkpoint(lm_config, weight=args.lora_name, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints')
model.train()
-def init_model(lm_config):
- tokenizer = AutoTokenizer.from_pretrained('../model/')
- model = MiniMindForCausalLM(lm_config)
- moe_path = '_moe' if lm_config.use_moe else ''
- ckp = f'{args.save_dir}/full_sft_{lm_config.hidden_size}{moe_path}.pth'
- state_dict = torch.load(ckp, map_location=args.device)
- model.load_state_dict(state_dict, strict=False)
- return model.to(args.device), tokenizer
-
-
-def init_distributed_mode():
- if not ddp: return
- global ddp_local_rank, DEVICE
-
- dist.init_process_group(backend="nccl")
- ddp_rank = int(os.environ["RANK"])
- ddp_local_rank = int(os.environ["LOCAL_RANK"])
- ddp_world_size = int(os.environ["WORLD_SIZE"])
- DEVICE = f"cuda:{ddp_local_rank}"
- torch.cuda.set_device(DEVICE)
-
-
if __name__ == "__main__":
- parser = argparse.ArgumentParser(description="MiniMind SFT with LoRA")
- parser.add_argument("--out_dir", type=str, default="../out")
- parser.add_argument("--epochs", type=int, default=50)
- parser.add_argument("--batch_size", type=int, default=32)
- parser.add_argument("--learning_rate", type=float, default=1e-4)
- parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu")
- parser.add_argument("--dtype", type=str, default="bfloat16")
- parser.add_argument("--use_wandb", action="store_true")
- parser.add_argument("--wandb_project", type=str, default="MiniMind-LoRA-SFT")
- parser.add_argument("--num_workers", type=int, default=1)
- parser.add_argument("--ddp", action="store_true")
- parser.add_argument("--accumulation_steps", type=int, default=1)
- parser.add_argument("--grad_clip", type=float, default=1.0)
- parser.add_argument("--warmup_iters", type=int, default=0)
- parser.add_argument("--log_interval", type=int, default=10)
- parser.add_argument("--save_interval", type=int, default=1)
- parser.add_argument('--local_rank', type=int, default=-1)
- parser.add_argument('--hidden_size', default=512, type=int)
- parser.add_argument('--num_hidden_layers', default=8, type=int)
- parser.add_argument('--max_seq_len', default=512, type=int)
- parser.add_argument('--use_moe', default=False, type=bool)
- parser.add_argument("--data_path", type=str, default="../dataset/lora_identity.jsonl")
- parser.add_argument("--lora_name", type=str, default="lora_identity", help="根据任务保存成lora_(英文/医学/心理...)")
+ parser = argparse.ArgumentParser(description="MiniMind LoRA Fine-tuning")
+ parser.add_argument("--save_dir", type=str, default="../out/lora", help="模型保存目录")
+ parser.add_argument("--lora_name", type=str, default="lora_identity", help="LoRA权重名称(如lora_identity/lora_medical等)")
+ parser.add_argument("--epochs", type=int, default=50, help="训练轮数")
+ parser.add_argument("--batch_size", type=int, default=32, help="batch size")
+ parser.add_argument("--learning_rate", type=float, default=1e-4, help="初始学习率")
+ parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备")
+ parser.add_argument("--dtype", type=str, default="bfloat16", help="混合精度类型")
+ parser.add_argument("--num_workers", type=int, default=1, help="数据加载线程数")
+ parser.add_argument("--accumulation_steps", type=int, default=1, help="梯度累积步数")
+ parser.add_argument("--grad_clip", type=float, default=1.0, help="梯度裁剪阈值")
+ parser.add_argument("--log_interval", type=int, default=10, help="日志打印间隔")
+ parser.add_argument("--save_interval", type=int, default=1, help="模型保存间隔")
+ parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度")
+ parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量")
+ parser.add_argument('--max_seq_len', default=512, type=int, help="训练的最大截断长度")
+ parser.add_argument('--use_moe', default=False, type=bool, help="是否使用MoE")
+ parser.add_argument("--data_path", type=str, default="../dataset/lora_identity.jsonl", help="LoRA训练数据路径")
+ parser.add_argument('--from_weight', default='full_sft', type=str, help="基于哪个权重训练,默认full_sft")
+ parser.add_argument('--from_resume', default=0, type=int, help="是否自动检测&续训,0否1是")
+ parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
+ parser.add_argument("--wandb_project", type=str, default="MiniMind-LoRA", help="wandb项目名")
args = parser.parse_args()
- lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers,
- use_moe=args.use_moe)
- args.save_dir = os.path.join(args.out_dir)
+ # ========== 1. 初始化环境和随机种子 ==========
+ local_rank = init_distributed_mode()
+ if dist.is_initialized(): args.device = f"cuda:{local_rank}"
+ setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))
+
+ # ========== 2. 配置目录、模型参数、检查ckp ==========
os.makedirs(args.save_dir, exist_ok=True)
- os.makedirs(args.out_dir, exist_ok=True)
- tokens_per_iter = args.batch_size * args.max_seq_len
+ lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=args.use_moe)
+ ckp_data = lm_checkpoint(lm_config, weight=args.lora_name, save_dir='../checkpoints') if args.from_resume==1 else None
+
+ # ========== 3. 设置混合精度 ==========
device_type = "cuda" if "cuda" in args.device else "cpu"
-
- ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast()
- ddp = int(os.environ.get("RANK", -1)) != -1 # is this a ddp run?
- ddp_local_rank, DEVICE = 0, "cuda:0"
- base_seed = 1337
- torch.manual_seed(base_seed)
- torch.cuda.manual_seed(base_seed)
-
- if ddp:
- init_distributed_mode()
- args.device = torch.device(DEVICE)
- rank = dist.get_rank()
- torch.manual_seed(base_seed + rank)
- # 同时设置 CUDA 的随机种子
- torch.cuda.manual_seed(base_seed + rank)
-
- args.wandb_run_name = f"MiniMind-Lora-SFT-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"
- if args.use_wandb and (not ddp or ddp_local_rank == 0):
+ dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
+ autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
+
+ # ========== 4. 配wandb ==========
+ wandb = None
+ if args.use_wandb and is_main_process():
import swanlab as wandb
-
- wandb.init(project=args.wandb_project, name=args.wandb_run_name)
- else:
- wandb = None
-
- model, tokenizer = init_model(lm_config)
+ wandb_id = ckp_data.get('wandb_id') if ckp_data else None
+ resume = 'must' if wandb_id else None
+ wandb_run_name = f"MiniMind-LoRA-{args.lora_name}-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LR-{args.learning_rate}"
+ wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
+
+ # ========== 5. 定义模型、应用LoRA、冻结非LoRA参数 ==========
+ model, tokenizer = init_model(lm_config, args.from_weight)
apply_lora(model)
-
- total_params = sum(p.numel() for p in model.parameters()) # 总参数数量
- lora_params_count = sum(p.numel() for name, p in model.named_parameters() if 'lora' in name) # LoRA 参数数量
- if not ddp or dist.get_rank() == 0:
- print(f"LLM 总参数量: {total_params}")
- print(f"LoRA 参数量: {lora_params_count}")
- print(f"LoRA 参数占比: {lora_params_count / total_params * 100:.2f}%")
-
- for name, param in model.named_parameters():
- if 'lora' not in name:
- param.requires_grad = False
+
+ # 统计参数
+ total_params = sum(p.numel() for p in model.parameters())
+ lora_params_count = sum(p.numel() for name, p in model.named_parameters() if 'lora' in name)
+ Logger(f"LLM 总参数量: {total_params / 1e6:.3f} M")
+ Logger(f"LoRA 参数量: {lora_params_count / 1e6:.3f} M")
+ Logger(f"LoRA 参数占比: {lora_params_count / total_params * 100:.2f}%")
+
+ # 冻结非LoRA参数,收集LoRA参数
lora_params = []
for name, param in model.named_parameters():
if 'lora' in name:
+ param.requires_grad = True
lora_params.append(param)
-
- # 只对 LoRA 参数进行优化
- optimizer = optim.AdamW(lora_params, lr=args.learning_rate)
+ else:
+ param.requires_grad = False
+
+ # ========== 6. 定义数据和优化器 ==========
train_ds = SFTDataset(args.data_path, tokenizer, max_length=args.max_seq_len)
- train_sampler = DistributedSampler(train_ds) if ddp else None
- train_loader = DataLoader(
- train_ds,
- batch_size=args.batch_size,
- pin_memory=True,
- drop_last=False,
- shuffle=(train_sampler is None),
- num_workers=args.num_workers,
- sampler=train_sampler
- )
-
- scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype in ['float16', 'bfloat16']))
- iter_per_epoch = len(train_loader)
-
- for epoch in range(args.epochs):
+ train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
+ scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
+ optimizer = optim.AdamW(lora_params, lr=args.learning_rate)
+
+ # ========== 7. 从ckp恢复状态 ==========
+ start_epoch, start_step = 0, 0
+ if ckp_data:
+ model.load_state_dict(ckp_data['model'], strict=False)
+ optimizer.load_state_dict(ckp_data['optimizer'])
+ scaler.load_state_dict(ckp_data['scaler'])
+ start_epoch = ckp_data['epoch']
+ start_step = ckp_data.get('step', 0)
+
+ # ========== 8. DDP包模型 ==========
+ if dist.is_initialized():
+ model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
+ model = DistributedDataParallel(model, device_ids=[local_rank])
+
+ # ========== 9. 开始训练 ==========
+ for epoch in range(start_epoch, args.epochs):
train_sampler and train_sampler.set_epoch(epoch)
- train_epoch(epoch, wandb)
+ if epoch == start_epoch and start_step > 0: # 第一个epoch且存在检查点
+ batch_sampler = SkipBatchSampler(train_sampler or range(len(train_ds)), args.batch_size, start_step + 1)
+ loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
+ Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step,从step {start_step + 1}开始')
+ train_epoch(epoch, loader, len(loader) + start_step + 1, lora_params, start_step, wandb)
+ else: # 默认从头开始
+ loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=(train_sampler is None), sampler=train_sampler, num_workers=args.num_workers, pin_memory=True)
+ train_epoch(epoch, loader, len(loader), lora_params, 0, wandb)
diff --git a/trainer/train_ppo.py b/trainer/train_ppo.py
index 7652c22..27a82b7 100644
--- a/trainer/train_ppo.py
+++ b/trainer/train_ppo.py
@@ -6,27 +6,43 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import argparse
import re
+import warnings
import torch
import torch.distributed as dist
import torch.nn.functional as F
+from transformers import AutoTokenizer
+from contextlib import nullcontext
from torch import optim, nn
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader, DistributedSampler
-from transformers import AutoTokenizer, AutoModel
-from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
-from dataset.lm_dataset import RLAIFDataset
from torch.nn.utils import clip_grad_norm_
from torch.optim.lr_scheduler import CosineAnnealingLR
+from transformers import AutoModel
+from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
+from dataset.lm_dataset import RLAIFDataset
+from trainer.trainer_utils import *
+
+warnings.filterwarnings('ignore')
-def Logger(content):
- if not ddp or dist.get_rank() == 0:
- print(content)
+# 自定义的Critic模型,继承自MiniMindLM
+class CriticModel(MiniMindForCausalLM):
+ def __init__(self, params):
+ super().__init__(params)
+ # 替换lm_head为输出单一价值的线性层
+ self.value_head = nn.Linear(params.hidden_size, 1)
+
+ def forward(self, input_ids=None, attention_mask=None, **kwargs):
+ # 使用基础模型获取隐藏状态
+ outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, **kwargs)
+ hidden_states = self.model.norm(outputs[0])
+ # 使用value_head获取价值估计
+ values = self.value_head(hidden_states).squeeze(-1)
+ return values
def calculate_rewards(prompts, responses, reward_model, reward_tokenizer):
"""整合所有奖励函数计算总奖励"""
-
def reasoning_model_reward(rewards):
# 1. 格式奖励(仅针对训练推理模型时使用)
pattern = r"^\n.*?\n\n\n.*?\n$"
@@ -66,7 +82,7 @@ def calculate_rewards(prompts, responses, reward_model, reward_tokenizer):
# 格式奖励
if args.reasoning == 1:
- rewards = reasoning_model_reward(rewards) # 训练推理模型时使用
+ rewards = reasoning_model_reward(rewards)
# 使用reward model计算整个response的奖励
with torch.no_grad():
@@ -91,7 +107,6 @@ def calculate_rewards(prompts, responses, reward_model, reward_tokenizer):
tmp_chat = messages + [{"role": "assistant", "content": answer_content}]
answer_score = reward_model.get_score(reward_tokenizer, tmp_chat)
answer_score = max(min(answer_score, scale), -scale)
-
score = score * 0.4 + answer_score * 0.6
reward_model_scores.append(score)
@@ -101,19 +116,20 @@ def calculate_rewards(prompts, responses, reward_model, reward_tokenizer):
return rewards
-def ppo_train_epoch(epoch: int, wandb_run, old_actor_model, ref_model, actor_scheduler, critic_scheduler):
+def ppo_train_epoch(epoch, loader, iters, old_actor_model, ref_model, actor_scheduler, critic_scheduler, reward_model, reward_tokenizer, start_step=0, wandb=None):
actor_model.train()
critic_model.train()
- is_master = (not ddp) or dist.get_rank() == 0
- for step, batch in enumerate(train_loader):
+ for step, batch in enumerate(loader, start=start_step + 1):
prompts = batch["prompt"] # list[str], length B
enc = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True,
max_length=args.max_seq_len).to(args.device) # input_ids: [B, P], attention_mask: [B, P]
prompt_lengths = enc.attention_mask.sum(dim=1) # [B]
with torch.no_grad():
- gen_out = actor_model.generate(
+ # DDP 模型需要使用 .module 访问 generate 方法
+ model_for_gen = actor_model.module if isinstance(actor_model, DistributedDataParallel) else actor_model
+ gen_out = model_for_gen.generate(
input_ids=enc.input_ids, attention_mask=enc.attention_mask,
max_new_tokens=args.max_gen_len, do_sample=True, temperature=0.8,
pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id) # [B, P+R]
@@ -164,7 +180,7 @@ def ppo_train_epoch(epoch: int, wandb_run, old_actor_model, ref_model, actor_sch
actor_optimizer.zero_grad()
critic_optimizer.zero_grad()
- if is_master:
+ if is_main_process():
response_ids = gen_out[:, enc.input_ids.shape[1]:]
is_eos = (response_ids == tokenizer.eos_token_id)
eos_indices = torch.argmax(is_eos.int(), dim=1)
@@ -181,8 +197,8 @@ def ppo_train_epoch(epoch: int, wandb_run, old_actor_model, ref_model, actor_sch
actor_lr = actor_optimizer.param_groups[0]['lr']
critic_lr = critic_optimizer.param_groups[0]['lr']
- if wandb_run is not None:
- wandb_run.log({
+ if wandb is not None:
+ wandb.log({
"actor_loss": actor_loss_val,
"critic_loss": critic_loss_val,
"reward": reward_val,
@@ -192,183 +208,158 @@ def ppo_train_epoch(epoch: int, wandb_run, old_actor_model, ref_model, actor_sch
"actor_lr": actor_lr,
})
- Logger(f"Epoch: {epoch}, Step: {step + 1}/{len(train_loader)}, "
+ Logger(f"Epoch: {epoch+1}, Step: {step}/{iters}, "
f"Actor Loss: {actor_loss_val:.6f}, Critic Loss: {critic_loss_val:.6f}, "
f"Reward: {reward_val:.6f}, KL: {kl_val:.6f}, KL_ref: {kl_ref_val:.6f}, "
f"Avg Response Len: {avg_len_val:.2f}, Actor LR: {actor_lr:.2e}, Critic LR: {critic_lr:.2e}")
if (step + 1) % args.update_old_actor_freq == 0:
- state_dict = actor_model.module.state_dict() if isinstance(actor_model, torch.nn.parallel.DistributedDataParallel) else actor_model.state_dict()
+ state_dict = actor_model.module.state_dict() if isinstance(actor_model, DistributedDataParallel) else actor_model.state_dict()
old_actor_model.load_state_dict({k: v.detach().cpu() for k, v in state_dict.items()})
old_actor_model.to(args.device)
- if ((step + 1) % args.save_interval == 0 or step == iter_per_epoch - 1) and (not ddp or dist.get_rank() == 0):
+ if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
actor_model.eval()
- moe_path = '_moe' if lm_config.use_moe else ''
- ckp = f'{args.save_dir}/ppo_actor_{lm_config.hidden_size}{moe_path}.pth'
-
- if isinstance(actor_model, torch.nn.parallel.DistributedDataParallel):
- state_dict = actor_model.module.state_dict()
- else:
- state_dict = actor_model.state_dict()
-
- state_dict = {k: v.half() for k, v in state_dict.items()} # 半精度保存
- torch.save(state_dict, ckp)
+ moe_suffix = '_moe' if lm_config.use_moe else ''
+ ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
+ actor_state = actor_model.module.state_dict() if isinstance(actor_model, DistributedDataParallel) else actor_model.state_dict()
+ torch.save({k: v.half() for k, v in actor_state.items()}, ckp)
+
+ # 使用 lm_checkpoint 保存完整状态(包括 critic)
+ lm_checkpoint(lm_config, weight=args.save_weight, model=actor_model, optimizer=actor_optimizer,
+ epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints',
+ scheduler=actor_scheduler, critic_model=critic_model,
+ critic_optimizer=critic_optimizer, critic_scheduler=critic_scheduler)
actor_model.train()
-# 自定义的Critic模型,继承自MiniMindLM
-class CriticModel(MiniMindForCausalLM):
- def __init__(self, params):
- super().__init__(params)
- # 替换lm_head为输出单一价值的线性层
- self.value_head = nn.Linear(params.hidden_size, 1)
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="MiniMind PPO (Proximal Policy Optimization)")
+ parser.add_argument("--save_dir", type=str, default="../out", help="模型保存目录")
+ parser.add_argument('--save_weight', default='ppo_actor', type=str, help="保存权重的前缀名")
+ parser.add_argument("--epochs", type=int, default=1, help="训练轮数")
+ parser.add_argument("--batch_size", type=int, default=2, help="batch size")
+ parser.add_argument("--learning_rate", type=float, default=8e-8, help="Actor学习率")
+ parser.add_argument("--critic_learning_rate", type=float, default=8e-8, help="Critic学习率")
+ parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备")
+ parser.add_argument("--dtype", type=str, default="bfloat16", help="混合精度类型")
+ parser.add_argument("--num_workers", type=int, default=1, help="数据加载线程数")
+ parser.add_argument("--accumulation_steps", type=int, default=1, help="梯度累积步数")
+ parser.add_argument("--grad_clip", type=float, default=1.0, help="梯度裁剪阈值")
+ parser.add_argument("--log_interval", type=int, default=1, help="日志打印间隔")
+ parser.add_argument("--save_interval", type=int, default=10, help="模型保存间隔")
+ parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度")
+ parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量")
+ parser.add_argument('--use_moe', default=False, type=bool, help="是否使用MoE")
+ parser.add_argument('--max_seq_len', default=66, type=int, help="Prompt最大长度")
+ parser.add_argument("--max_gen_len", type=int, default=1536, help="生成的最大长度")
+ parser.add_argument("--data_path", type=str, default="../dataset/rlaif-mini.jsonl", help="RLAIF数据路径")
+ parser.add_argument("--clip_epsilon", type=float, default=0.1, help="PPO裁剪参数")
+ parser.add_argument("--vf_coef", type=float, default=0.5, help="Value function系数")
+ parser.add_argument("--kl_coef", type=float, default=0.02, help="KL散度惩罚系数")
+ parser.add_argument("--reasoning", type=int, default=1, help='0:普通模型,1:推理模型')
+ parser.add_argument("--update_old_actor_freq", type=int, default=4, help="更新old_actor_model的频率")
+ parser.add_argument("--reward_model_path", type=str, default="../../internlm2-1_8b-reward", help="Reward模型路径")
+ parser.add_argument('--from_resume', default=0, type=int, help="是否自动检测&续训,0否1是")
+ parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
+ parser.add_argument("--wandb_project", type=str, default="MiniMind-PPO", help="wandb项目名")
+ args = parser.parse_args()
- def forward(self, input_ids=None, attention_mask=None, **kwargs):
- # 使用基础模型获取隐藏状态
- outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, **kwargs)
- # self.model 返回的是一个元组,第一个元素是 last_hidden_state
- hidden_states = self.model.norm(outputs[0])
- # 使用value_head获取价值估计
- values = self.value_head(hidden_states).squeeze(-1)
- return values
-
-
-def init_model(lm_config):
+ # ========== 1. 初始化环境和随机种子 ==========
+ local_rank = init_distributed_mode()
+ if dist.is_initialized(): args.device = f"cuda:{local_rank}"
+ setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))
+
+ # ========== 2. 配置目录、模型参数、检查ckp ==========
+ os.makedirs(args.save_dir, exist_ok=True)
+ lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=args.use_moe)
+ ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None
+
+ # ========== 3. 设置混合精度 ==========
+ device_type = "cuda" if "cuda" in args.device else "cpu"
+ dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
+ autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
+
+ # ========== 4. 配wandb ==========
+ wandb = None
+ if args.use_wandb and is_main_process():
+ import swanlab as wandb
+ wandb_id = ckp_data.get('wandb_id') if ckp_data else None
+ resume = 'must' if wandb_id else None
+ wandb_run_name = f"MiniMind-PPO-Epoch-{args.epochs}-BS-{args.batch_size}-LR-{args.learning_rate}"
+ wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
+
+ # ========== 5. 初始化模型和数据 ==========
tokenizer = AutoTokenizer.from_pretrained('../model/', padding_side='left')
- if tokenizer.pad_token is None:
- tokenizer.pad_token = tokenizer.eos_token
-
- moe_path = '_moe' if lm_config.use_moe else ''
- ckp = f'{args.save_dir}/{"reason" if args.reasoning == 1 else "full_sft"}_{lm_config.hidden_size}{moe_path}.pth'
+ moe_suffix = '_moe' if lm_config.use_moe else ''
+ base_weight = "reason" if args.reasoning == 1 else "full_sft"
+ ckp = f'{args.save_dir}/{base_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
state_dict = torch.load(ckp, map_location=args.device)
-
+ # Actor模型
actor_model = MiniMindForCausalLM(lm_config)
actor_model.load_state_dict(state_dict, strict=False)
actor_model = actor_model.to(args.device)
-
+ Logger(f'Actor模型总参数量:{sum(p.numel() for p in actor_model.parameters() if p.requires_grad) / 1e6:.3f} M')
+ # Old Actor模型
old_actor_model = MiniMindForCausalLM(lm_config)
old_actor_model.load_state_dict(state_dict, strict=False)
old_actor_model = old_actor_model.eval().requires_grad_(False).to(args.device)
-
+ # Reference模型
ref_model = MiniMindForCausalLM(lm_config)
ref_model.load_state_dict(state_dict, strict=False)
ref_model = ref_model.eval().requires_grad_(False).to(args.device)
-
+ # Critic模型
critic_model = CriticModel(lm_config)
critic_model.load_state_dict(state_dict, strict=False)
critic_model = critic_model.to(args.device)
-
- reward_name = "../../internlm2-1_8b-reward"
+ Logger(f'Critic模型总参数量:{sum(p.numel() for p in critic_model.parameters() if p.requires_grad) / 1e6:.3f} M')
+ # Reward模型
reward_model = AutoModel.from_pretrained(
- reward_name, device_map="cuda", torch_dtype=torch.float32, trust_remote_code=True
+ args.reward_model_path, device_map="cuda", torch_dtype=torch.float32, trust_remote_code=True
).to(args.device).eval().requires_grad_(False)
- reward_tokenizer = AutoTokenizer.from_pretrained(reward_name, trust_remote_code=True)
-
- Logger(f'Actor模型总参数量:{sum(p.numel() for p in actor_model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
- Logger(f'Critic模型总参数量:{sum(p.numel() for p in critic_model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
-
- return actor_model, old_actor_model, ref_model, critic_model, reward_model, tokenizer, reward_tokenizer
-
-
-def init_distributed_mode():
- if not ddp: return
- global ddp_local_rank, DEVICE
- dist.init_process_group(backend="nccl")
- ddp_local_rank = int(os.environ["LOCAL_RANK"])
- DEVICE = f"cuda:{ddp_local_rank}"
- torch.cuda.set_device(DEVICE)
-
-
-if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument("--out_dir", type=str, default="../out")
- parser.add_argument("--epochs", type=int, default=1)
- parser.add_argument("--batch_size", type=int, default=2)
- parser.add_argument("--learning_rate", type=float, default=8e-8)
- parser.add_argument("--critic_learning_rate", type=float, default=8e-8)
- parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu")
- parser.add_argument("--dtype", type=str, default="bfloat16")
- parser.add_argument("--use_wandb", action="store_true")
- parser.add_argument("--wandb_project", type=str, default="MiniMind-PPO")
- parser.add_argument("--num_workers", type=int, default=1)
- parser.add_argument("--ddp", action="store_true")
- parser.add_argument("--accumulation_steps", type=int, default=1)
- parser.add_argument("--grad_clip", type=float, default=1.0)
- parser.add_argument("--log_interval", type=int, default=1)
- parser.add_argument("--save_interval", type=int, default=10)
- parser.add_argument('--hidden_size', default=512, type=int)
- parser.add_argument('--num_hidden_layers', default=8, type=int)
- parser.add_argument('--use_moe', default=False, type=bool)
- parser.add_argument('--max_seq_len', default=66, type=int)
- parser.add_argument("--max_gen_len", type=int, default=1536)
- parser.add_argument("--data_path", type=str, default="../dataset/rlaif-mini.jsonl")
- parser.add_argument("--clip_epsilon", type=float, default=0.1)
- parser.add_argument("--vf_coef", type=float, default=0.5)
- parser.add_argument("--kl_coef", type=float, default=0.02, help="KL散度惩罚系数")
- parser.add_argument("--reasoning", type=int, default=1, help='0:普通模型,1:推理模型')
- parser.add_argument("--update_old_actor_freq", type=int, default=4, help="频率:每处理n个batch后更新old_actor_model")
- args = parser.parse_args()
-
- lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers,
- use_moe=args.use_moe)
- args.save_dir = os.path.join(args.out_dir)
- os.makedirs(args.save_dir, exist_ok=True)
- os.makedirs(args.out_dir, exist_ok=True)
-
- ddp = int(os.environ.get("RANK", -1)) != -1
- ddp_local_rank, DEVICE = 0, "cuda:0"
- base_seed = 1337
- torch.manual_seed(base_seed)
- torch.cuda.manual_seed(base_seed)
-
- if ddp:
- init_distributed_mode()
- args.device = torch.device(DEVICE)
- rank = dist.get_rank()
- torch.manual_seed(base_seed + rank)
- # 同时设置 CUDA 的随机种子
- torch.cuda.manual_seed(base_seed + rank)
-
- if args.use_wandb and (not ddp or ddp_local_rank == 0):
- import swanlab as wandb
-
- wandb.init(project=args.wandb_project)
- else:
- wandb = None
-
- # 初始化所有模型
- actor_model, old_actor_model, ref_model, critic_model, reward_model, tokenizer, reward_tokenizer = init_model(lm_config=lm_config)
-
- # 准备数据集和数据加载器
+ reward_tokenizer = AutoTokenizer.from_pretrained(args.reward_model_path, trust_remote_code=True)
+ # 数据和优化器
train_ds = RLAIFDataset(args.data_path, tokenizer, max_length=(args.max_seq_len + args.max_gen_len))
- train_sampler = DistributedSampler(train_ds) if ddp else None
- train_loader = DataLoader(train_ds, batch_size=args.batch_size, pin_memory=True,
- drop_last=False, shuffle=(train_sampler is None),
- num_workers=args.num_workers, sampler=train_sampler)
-
- # 初始化优化器
+ train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
actor_optimizer = optim.AdamW(actor_model.parameters(), lr=args.learning_rate)
critic_optimizer = optim.AdamW(critic_model.parameters(), lr=args.critic_learning_rate)
-
- iter_per_epoch = len(train_loader)
- total_optimizer_steps = (iter_per_epoch // args.accumulation_steps) * args.epochs
+ loader_for_count = DataLoader(train_ds, batch_size=args.batch_size, sampler=train_sampler)
+ iters = len(loader_for_count)
+ total_optimizer_steps = (iters // args.accumulation_steps) * args.epochs
actor_scheduler = CosineAnnealingLR(actor_optimizer, T_max=total_optimizer_steps, eta_min=args.learning_rate / 10)
- critic_scheduler = CosineAnnealingLR(critic_optimizer, T_max=total_optimizer_steps,
- eta_min=args.critic_learning_rate / 10)
-
- # 如果使用分布式训练,包装模型
- if ddp:
+ critic_scheduler = CosineAnnealingLR(critic_optimizer, T_max=total_optimizer_steps, eta_min=args.critic_learning_rate / 10)
+
+ # ========== 6. 从ckp恢复状态 ==========
+ start_epoch, start_step = 0, 0
+ if ckp_data:
+ actor_model.load_state_dict(ckp_data['model'])
+ critic_model.load_state_dict(ckp_data['critic_model'])
+ actor_optimizer.load_state_dict(ckp_data['optimizer'])
+ critic_optimizer.load_state_dict(ckp_data['critic_optimizer'])
+ actor_scheduler.load_state_dict(ckp_data['scheduler'])
+ critic_scheduler.load_state_dict(ckp_data['critic_scheduler'])
+ start_epoch = ckp_data['epoch']
+ start_step = ckp_data.get('step', 0)
+
+ # ========== 7. DDP包模型 ==========
+ if dist.is_initialized():
actor_model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
critic_model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
- actor_model = DistributedDataParallel(actor_model, device_ids=[ddp_local_rank])
- critic_model = DistributedDataParallel(critic_model, device_ids=[ddp_local_rank])
- # old_actor_model 不需要DDP包装,因为它只在主进程上用于计算,并且不进行梯度更新
+ actor_model = DistributedDataParallel(actor_model, device_ids=[local_rank])
+ critic_model = DistributedDataParallel(critic_model, device_ids=[local_rank])
old_actor_model.to(args.device)
-
- for epoch in range(args.epochs):
+
+ # ========== 8. 开始训练 ==========
+ for epoch in range(start_epoch, args.epochs):
train_sampler and train_sampler.set_epoch(epoch)
- ppo_train_epoch(epoch, wandb, old_actor_model, ref_model, actor_scheduler, critic_scheduler)
-
- if ddp:
- dist.destroy_process_group()
+ if epoch == start_epoch and start_step > 0: # 第一个epoch且存在检查点
+ batch_sampler = SkipBatchSampler(train_sampler or range(len(train_ds)), args.batch_size, start_step + 1)
+ loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
+ Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step,从step {start_step + 1}开始')
+ ppo_train_epoch(epoch, loader, len(loader) + start_step + 1, old_actor_model, ref_model,
+ actor_scheduler, critic_scheduler, reward_model, reward_tokenizer, start_step, wandb)
+ else: # 默认从头开始
+ loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=(train_sampler is None),
+ sampler=train_sampler, num_workers=args.num_workers, pin_memory=True)
+ ppo_train_epoch(epoch, loader, len(loader), old_actor_model, ref_model,
+ actor_scheduler, critic_scheduler, reward_model, reward_tokenizer, 0, wandb)
diff --git a/trainer/train_pretrain.py b/trainer/train_pretrain.py
index 36a3cd8..18cc445 100644
--- a/trainer/train_pretrain.py
+++ b/trainer/train_pretrain.py
@@ -6,48 +6,38 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import argparse
import time
-import math
import warnings
import torch
import torch.distributed as dist
+from contextlib import nullcontext
from torch import optim, nn
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader, DistributedSampler
-from contextlib import nullcontext
-from transformers import AutoTokenizer
-from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
+from model.model_minimind import MiniMindConfig
from dataset.lm_dataset import PretrainDataset
+from trainer.trainer_utils import *
warnings.filterwarnings('ignore')
-def Logger(content):
- if not ddp or dist.get_rank() == 0:
- print(content)
-
-
-def get_lr(current_step, total_steps, lr):
- return lr / 10 + 0.5 * lr * (1 + math.cos(math.pi * current_step / total_steps))
-
-
-def train_epoch(epoch, wandb):
+def train_epoch(epoch, loader, iters, start_step=0, wandb=None):
loss_fct = nn.CrossEntropyLoss(reduction='none')
start_time = time.time()
- for step, (X, Y, loss_mask) in enumerate(train_loader):
+ for step, (X, Y, loss_mask) in enumerate(loader, start=start_step + 1):
X = X.to(args.device)
Y = Y.to(args.device)
loss_mask = loss_mask.to(args.device)
-
- lr = get_lr(epoch * iter_per_epoch + step, args.epochs * iter_per_epoch, args.learning_rate)
+ lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
- with ctx:
+ with autocast_ctx:
res = model(X)
loss = loss_fct(
res.logits.view(-1, res.logits.size(-1)),
Y.view(-1)
).view(Y.size())
+
loss = (loss * loss_mask).sum() / loss_mask.sum()
loss += res.aux_loss
loss = loss / args.accumulation_steps
@@ -63,139 +53,108 @@ def train_epoch(epoch, wandb):
optimizer.zero_grad(set_to_none=True)
- if step % args.log_interval == 0 or step == iter_per_epoch - 1:
+ if step % args.log_interval == 0 or step == iters - 1:
spend_time = time.time() - start_time
- Logger(
- 'Epoch:[{}/{}]({}/{}) loss:{:.6f} lr:{:.12f} epoch_Time:{}min:'.format(
- epoch + 1,
- args.epochs,
- step,
- iter_per_epoch,
- loss.item() * args.accumulation_steps,
- optimizer.param_groups[-1]['lr'],
- spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60))
+ current_loss = loss.item() * args.accumulation_steps
+ current_lr = optimizer.param_groups[-1]['lr']
+ eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60
+
+ Logger(f'Epoch:[{epoch+1}/{args.epochs}]({step}/{iters}) loss:{current_loss:.6f} lr:{current_lr:.12f} epoch_Time:{eta_min}min:')
+
+ if wandb: wandb.log({"loss": current_loss, "lr": current_lr, "epoch_Time": eta_min})
- if (wandb is not None) and (not ddp or dist.get_rank() == 0):
- wandb.log({"loss": loss.item() * args.accumulation_steps,
- "lr": optimizer.param_groups[-1]['lr'],
- "epoch_Time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60})
-
- if ((step + 1) % args.save_interval == 0 or step == iter_per_epoch - 1) and (not ddp or dist.get_rank() == 0):
+ if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
model.eval()
- moe_path = '_moe' if lm_config.use_moe else ''
- ckp = f'{args.save_dir}/pretrain_{lm_config.hidden_size}{moe_path}.pth'
-
+ moe_suffix = '_moe' if lm_config.use_moe else ''
+ ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
-
state_dict = {k: v.half() for k, v in state_dict.items()} # 半精度保存
torch.save(state_dict, ckp)
+ lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints')
model.train()
-def init_model(lm_config):
- tokenizer = AutoTokenizer.from_pretrained('../model/')
- model = MiniMindForCausalLM(lm_config).to(args.device)
- Logger(f'LLM可训练总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
- return model, tokenizer
-
-
-def init_distributed_mode():
- if not ddp: return
- global ddp_local_rank, DEVICE
-
- dist.init_process_group(backend="nccl")
- ddp_rank = int(os.environ["RANK"])
- ddp_local_rank = int(os.environ["LOCAL_RANK"])
- ddp_world_size = int(os.environ["WORLD_SIZE"])
- DEVICE = f"cuda:{ddp_local_rank}"
- torch.cuda.set_device(DEVICE)
-
-
-# torchrun --nproc_per_node 2 1-pretrain.py
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="MiniMind Pretraining")
- parser.add_argument("--out_dir", type=str, default="../out")
- # 若要以最快速度实现zero则epochs设置为1轮;否则应当利用有限的数据训练2~6个epochs。
- parser.add_argument("--epochs", type=int, default=1)
- parser.add_argument("--batch_size", type=int, default=32)
- parser.add_argument("--learning_rate", type=float, default=5e-4)
- parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu")
- parser.add_argument("--dtype", type=str, default="bfloat16")
- parser.add_argument("--use_wandb", action="store_true")
- parser.add_argument("--wandb_project", type=str, default="MiniMind-Pretrain")
- parser.add_argument("--num_workers", type=int, default=1)
- parser.add_argument("--ddp", action="store_true")
- parser.add_argument("--accumulation_steps", type=int, default=8)
- parser.add_argument("--grad_clip", type=float, default=1.0)
- parser.add_argument("--warmup_iters", type=int, default=0)
- parser.add_argument("--log_interval", type=int, default=100)
- parser.add_argument("--save_interval", type=int, default=100)
- parser.add_argument('--local_rank', type=int, default=-1)
- parser.add_argument('--hidden_size', default=512, type=int)
- parser.add_argument('--num_hidden_layers', default=8, type=int)
- parser.add_argument('--max_seq_len', default=512, type=int)
- parser.add_argument('--use_moe', default=False, type=bool)
- parser.add_argument("--data_path", type=str, default="../dataset/pretrain_hq.jsonl")
+ parser.add_argument("--save_dir", type=str, default="../out", help="模型保存目录")
+ parser.add_argument('--save_weight', default='pretrain', type=str, help="保存权重的前缀名")
+ parser.add_argument("--epochs", type=int, default=1, help="训练轮数(建议1轮zero或2-6轮充分训练)")
+ parser.add_argument("--batch_size", type=int, default=32, help="batch size")
+ parser.add_argument("--learning_rate", type=float, default=5e-4, help="初始学习率")
+ parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备")
+ parser.add_argument("--dtype", type=str, default="bfloat16", help="混合精度类型")
+ parser.add_argument("--num_workers", type=int, default=1, help="数据加载线程数")
+ parser.add_argument("--accumulation_steps", type=int, default=8, help="梯度累积步数")
+ parser.add_argument("--grad_clip", type=float, default=1.0, help="梯度裁剪阈值")
+ parser.add_argument("--log_interval", type=int, default=100, help="日志打印间隔")
+ parser.add_argument("--save_interval", type=int, default=100, help="模型保存间隔")
+ parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度")
+ parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量")
+ parser.add_argument('--max_seq_len', default=512, type=int, help="训练的最大截断长度")
+ parser.add_argument('--use_moe', default=False, type=bool, help="是否使用MoE")
+ parser.add_argument("--data_path", type=str, default="../dataset/pretrain_hq.jsonl", help="预训练数据路径")
+ parser.add_argument('--from_weight', default='none', type=str, help="基于哪个权重训练,为none则从头开始")
+ parser.add_argument('--from_resume', default=0, type=int, help="是否自动检测&续训,0否1是")
+ parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
+ parser.add_argument("--wandb_project", type=str, default="MiniMind-Pretrain", help="wandb项目名")
args = parser.parse_args()
- lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers,
- use_moe=args.use_moe)
- args.save_dir = os.path.join(args.out_dir)
+ # ========== 1. 初始化环境和随机种子 ==========
+ local_rank = init_distributed_mode()
+ if dist.is_initialized(): args.device = f"cuda:{local_rank}"
+ setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))
+
+ # ========== 2. 配置目录、模型参数、检查ckp ==========
os.makedirs(args.save_dir, exist_ok=True)
- os.makedirs(args.out_dir, exist_ok=True)
- tokens_per_iter = args.batch_size * args.max_seq_len
+ lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=args.use_moe)
+ ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None
+
+ # ========== 3. 设置混合精度 ==========
device_type = "cuda" if "cuda" in args.device else "cpu"
-
- args.wandb_run_name = f"MiniMind-Pretrain-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"
-
- ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast()
-
- ddp = int(os.environ.get("RANK", -1)) != -1 # is this a ddp run?
- ddp_local_rank, DEVICE = 0, "cuda:0"
-
- base_seed = 1337
- torch.manual_seed(base_seed)
- torch.cuda.manual_seed(base_seed)
-
- if ddp:
- init_distributed_mode()
- args.device = torch.device(DEVICE)
- rank = dist.get_rank()
- torch.manual_seed(base_seed + rank)
- # 同时设置 CUDA 的随机种子
- torch.cuda.manual_seed(base_seed + rank)
-
- if args.use_wandb and (not ddp or ddp_local_rank == 0):
+ dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
+ autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
+
+ # ========== 4. 配wandb ==========
+ wandb = None
+ if args.use_wandb and is_main_process():
import swanlab as wandb
-
- wandb.init(project=args.wandb_project, name=args.wandb_run_name)
- else:
- wandb = None
-
- model, tokenizer = init_model(lm_config)
+ wandb_id = ckp_data.get('wandb_id') if ckp_data else None
+ resume = 'must' if wandb_id else None
+ wandb_run_name = f"MiniMind-Pretrain-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"
+ wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
+
+ # ========== 5. 定义模型、数据、优化器 ==========
+ model, tokenizer = init_model(lm_config, args.from_weight)
train_ds = PretrainDataset(args.data_path, tokenizer, max_length=args.max_seq_len)
- train_sampler = DistributedSampler(train_ds) if ddp else None
- train_loader = DataLoader(
- train_ds,
- batch_size=args.batch_size,
- pin_memory=True,
- drop_last=False,
- shuffle=(train_sampler is None),
- num_workers=args.num_workers,
- sampler=train_sampler
- )
-
- scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype in ['float16', 'bfloat16']))
+ train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
+ scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
-
- if ddp:
+
+ # ========== 6. 从ckp恢复状态 ==========
+ start_epoch, start_step = 0, 0
+ if ckp_data:
+ model.load_state_dict(ckp_data['model'])
+ optimizer.load_state_dict(ckp_data['optimizer'])
+ scaler.load_state_dict(ckp_data['scaler'])
+ start_epoch = ckp_data['epoch']
+ start_step = ckp_data.get('step', 0)
+
+ # ========== 7. DDP包模型 ==========
+ if dist.is_initialized():
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
- model = DistributedDataParallel(model, device_ids=[ddp_local_rank])
-
- iter_per_epoch = len(train_loader)
- for epoch in range(args.epochs):
+ model = DistributedDataParallel(model, device_ids=[local_rank])
+
+ # ========== 8. 开始训练 ==========
+ for epoch in range(start_epoch, args.epochs):
train_sampler and train_sampler.set_epoch(epoch)
- train_epoch(epoch, wandb)
+ if epoch == start_epoch and start_step > 0: # 第一个epoch且存在检查点
+ batch_sampler = SkipBatchSampler(train_sampler or range(len(train_ds)), args.batch_size, start_step + 1)
+ loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
+ Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step,从step {start_step + 1}开始')
+ train_epoch(epoch, loader, len(loader) + start_step + 1, start_step, wandb)
+ else: # 默认从头开始
+ loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=(train_sampler is None), sampler=train_sampler, num_workers=args.num_workers, pin_memory=True)
+ train_epoch(epoch, loader, len(loader), 0, wandb)
diff --git a/trainer/train_spo.py b/trainer/train_spo.py
index e13e741..74dc72c 100755
--- a/trainer/train_spo.py
+++ b/trainer/train_spo.py
@@ -3,31 +3,35 @@ import sys
__package__ = "trainer"
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
+
import argparse
-import time
import re
import gc
+import warnings
import torch
-from contextlib import nullcontext
import torch.distributed as dist
+from transformers import AutoTokenizer
+from contextlib import nullcontext
from torch import optim
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader, DistributedSampler
-from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModel
+from torch.optim.lr_scheduler import CosineAnnealingLR
+from transformers import AutoModel
from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
from dataset.lm_dataset import RLAIFDataset
-from torch.optim.lr_scheduler import CosineAnnealingLR
-from collections import defaultdict
+from trainer.trainer_utils import *
+
+warnings.filterwarnings('ignore')
class AutoAdaptiveValueTracker:
+ """SPO自适应价值追踪器"""
def __init__(self, rho_mode='kl', rho_const=0.9, D_half=0.06, clip_lower=0.5, clip_upper=0.96):
self.rho_mode = rho_mode
self.rho_const = rho_const
self.D_half = D_half
self.clip_lower = clip_lower
self.clip_upper = clip_upper
- # Stable initialization following N_init = 1/(1-clip_lower)
N_init = 1.0 / (1.0 - self.clip_lower)
self.alpha = 0.5 * N_init
self.beta = 0.5 * N_init
@@ -62,43 +66,28 @@ class AutoAdaptiveValueTracker:
return rho
-def Logger(content):
- if not ddp or dist.get_rank() == 0:
- print(content)
-
-
def calculate_rewards(prompts, responses, reward_model, reward_tokenizer):
"""整合所有奖励函数计算总奖励"""
-
def reasoning_model_reward(rewards):
- # 1. 格式奖励(仅针对训练推理模型时使用)
pattern = r"^\n.*?\n\n\n.*?\n$"
pattern2 = r"^\n.*?\n\n\n\n.*?\n$"
-
matches_pattern = [re.match(pattern, response, re.S) for response in responses]
matches_pattern2 = [re.match(pattern2, response, re.S) for response in responses]
format_rewards = []
for match_pattern, match_pattern2 in zip(matches_pattern, matches_pattern2):
- if match_pattern:
- format_rewards.append(0.5)
- elif match_pattern2:
+ if match_pattern or match_pattern2:
format_rewards.append(0.5)
else:
format_rewards.append(0.0)
rewards += torch.tensor(format_rewards, device=args.device)
- # 2. 标记奖励(防止严格奖励稀疏,仅针对训练推理模型时使用)
def mark_num(text):
reward = 0
- if text.count("") == 1:
- reward += 0.25
- if text.count("") == 1:
- reward += 0.25
- if text.count("") == 1:
- reward += 0.25
- if text.count("") == 1:
- reward += 0.25
+ if text.count("") == 1: reward += 0.25
+ if text.count("") == 1: reward += 0.25
+ if text.count("") == 1: reward += 0.25
+ if text.count("") == 1: reward += 0.25
return reward
mark_rewards = [mark_num(response) for response in responses]
@@ -106,12 +95,9 @@ def calculate_rewards(prompts, responses, reward_model, reward_tokenizer):
return rewards
rewards = torch.zeros(len(responses), device=args.device)
-
- # 3. 格式奖励
if args.reasoning == 1:
- rewards = reasoning_model_reward(rewards) # 训练推理模型时使用
+ rewards = reasoning_model_reward(rewards)
- # 4. 使用reward model计算奖励
with torch.no_grad():
reward_model_scores = []
scale = 3.0
@@ -142,8 +128,8 @@ def calculate_rewards(prompts, responses, reward_model, reward_tokenizer):
return rewards
-def spo_train_epoch(epoch, wandb, value_tracker):
- for step, batch in enumerate(train_loader):
+def spo_train_epoch(epoch, loader, iters, ref_model, reward_model, reward_tokenizer, value_tracker, start_step=0, wandb=None):
+ for step, batch in enumerate(loader, start=start_step + 1):
prompts = batch['prompt'] # list[str], length B
prompt_inputs = tokenizer(prompts, return_tensors="pt", padding=True, return_token_type_ids=False,
padding_side="left", add_special_tokens=False).to(args.device) # input_ids: [B, P], attention_mask: [B, P]
@@ -152,7 +138,9 @@ def spo_train_epoch(epoch, wandb, value_tracker):
prompt_inputs["attention_mask"] = prompt_inputs["attention_mask"][:, -args.max_seq_len:]
with torch.no_grad():
- outputs = (model.module if ddp else model).generate(
+ # DDP 模型需要使用 .module 访问 generate 方法
+ model_for_gen = model.module if isinstance(model, DistributedDataParallel) else model
+ outputs = model_for_gen.generate(
**prompt_inputs, max_new_tokens=args.max_gen_len, do_sample=True, temperature=0.8,
num_return_sequences=1, pad_token_id=tokenizer.pad_token_id) # [B, P+R]
@@ -205,42 +193,38 @@ def spo_train_epoch(epoch, wandb, value_tracker):
scheduler.step()
optimizer.zero_grad()
- if step % args.log_interval == 0 or step == iter_per_epoch - 1:
+ if step % args.log_interval == 0 or step == iters:
policy_loss_val = loss.item()
avg_reward_val = rewards.mean().item()
avg_len_val = completion_mask.sum(dim=1).float().mean().item()
- # average kl over valid tokens for logging
kl_val = ((per_token_kl * completion_mask).sum() / (completion_mask.sum() + 1e-8)).item()
avg_baseline_val = baselines.mean().item()
current_lr = optimizer.param_groups[0]['lr']
- Logger(
- f'Epoch: {epoch}, Step: {step + 1}/{iter_per_epoch}, '
- f'Actor Loss: {policy_loss_val:.6f}, Reward: {avg_reward_val:.6f}, '
- f'Baseline: {avg_baseline_val:.4f}, KL: {kl_val:.4f}, Rho: {rho:.4f}, Avg Response Len: {avg_len_val:.2f}, LR: {current_lr:.2e}')
+ Logger(f'Epoch: {epoch+1}, Step: {step}/{iters}, '
+ f'Actor Loss: {policy_loss_val:.6f}, Reward: {avg_reward_val:.6f}, '
+ f'Baseline: {avg_baseline_val:.4f}, KL: {kl_val:.4f}, Rho: {rho:.4f}, '
+ f'Avg Response Len: {avg_len_val:.2f}, LR: {current_lr:.2e}')
- if wandb and (not ddp or dist.get_rank() == 0):
- log_dict = {
+ if wandb and is_main_process():
+ wandb.log({
"policy_loss": policy_loss_val,
"reward": avg_reward_val,
"kl": kl_val,
"rho": float(rho),
"baseline": avg_baseline_val,
- # "avg_response_len": avg_len_val,
"advantages_mean": advantages.mean().item(),
"learning_rate": current_lr
- }
- wandb.log(log_dict)
+ })
- if ((step + 1) % args.save_interval == 0 or step == iter_per_epoch - 1) and (not ddp or dist.get_rank() == 0):
+ if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
model.eval()
- moe_path = '_moe' if lm_config.use_moe else ''
- suffix = 'spo'
- ckp = f'{args.save_dir}/{suffix}_{lm_config.hidden_size}{moe_path}.pth'
-
- state_dict = model.module.state_dict() if isinstance(model,
- torch.nn.parallel.DistributedDataParallel) else model.state_dict()
+ moe_suffix = '_moe' if lm_config.use_moe else ''
+ ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
+ state_dict = model.module.state_dict() if isinstance(model, DistributedDataParallel) else model.state_dict()
torch.save({k: v.half() for k, v in state_dict.items()}, ckp)
+ lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer,
+ epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints', scheduler=scheduler)
model.train()
del prompt_inputs, outputs, completion_ids, per_token_logps, ref_per_token_logps
@@ -249,120 +233,116 @@ def spo_train_epoch(epoch, wandb, value_tracker):
gc.collect()
-def init_model(lm_config):
- tokenizer = AutoTokenizer.from_pretrained('../model/')
- model = MiniMindForCausalLM(lm_config)
- moe_path = '_moe' if lm_config.use_moe else ''
- ckp = f'{args.save_dir}/full_sft_{lm_config.hidden_size}{moe_path}.pth'
- if args.reasoning == 1:
- ckp = f'{args.save_dir}/reason_{lm_config.hidden_size}{moe_path}.pth'
- state_dict = torch.load(ckp, map_location=args.device)
- model.load_state_dict(state_dict, strict=False)
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="MiniMind SPO (Self-Play Optimization)")
+ parser.add_argument("--save_dir", type=str, default="../out", help="模型保存目录")
+ parser.add_argument('--save_weight', default='spo', type=str, help="保存权重的前缀名")
+ parser.add_argument("--epochs", type=int, default=1, help="训练轮数")
+ parser.add_argument("--batch_size", type=int, default=2, help="batch size")
+ parser.add_argument("--learning_rate", type=float, default=1e-7, help="初始学习率")
+ parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备")
+ parser.add_argument("--dtype", type=str, default="bfloat16", help="混合精度类型")
+ parser.add_argument("--num_workers", type=int, default=1, help="数据加载线程数")
+ parser.add_argument("--accumulation_steps", type=int, default=4, help="梯度累积步数")
+ parser.add_argument("--grad_clip", type=float, default=1.0, help="梯度裁剪阈值")
+ parser.add_argument("--log_interval", type=int, default=1, help="日志打印间隔")
+ parser.add_argument("--save_interval", type=int, default=10, help="模型保存间隔")
+ parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度")
+ parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量")
+ parser.add_argument('--use_moe', default=False, type=bool, help="是否使用MoE")
+ parser.add_argument('--max_seq_len', default=66, type=int, help="Prompt最大长度")
+ parser.add_argument("--max_gen_len", type=int, default=1536, help="生成的最大长度")
+ parser.add_argument("--data_path", type=str, default="../dataset/rlaif-mini.jsonl", help="RLAIF数据路径")
+ parser.add_argument("--beta", type=float, default=0.02, help="KL惩罚系数")
+ parser.add_argument("--reasoning", type=int, default=1, help='0:普通模型,1:推理模型')
+ parser.add_argument("--reward_model_path", type=str, default="../../internlm2-1_8b-reward", help="Reward模型路径")
+ parser.add_argument('--from_resume', default=0, type=int, help="是否自动检测&续训,0否1是")
+ parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
+ parser.add_argument("--wandb_project", type=str, default="MiniMind-SPO", help="wandb项目名")
+ args = parser.parse_args()
+
+ # ========== 1. 初始化环境和随机种子 ==========
+ local_rank = init_distributed_mode()
+ if dist.is_initialized(): args.device = f"cuda:{local_rank}"
+ setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))
+ # ========== 2. 配置目录、模型参数、检查ckp ==========
+ os.makedirs(args.save_dir, exist_ok=True)
+ lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers,
+ max_seq_len=args.max_seq_len + args.max_gen_len, use_moe=args.use_moe)
+ ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None
+
+ # ========== 3. 设置混合精度 ==========
+ device_type = "cuda" if "cuda" in args.device else "cpu"
+ dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
+ autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
+
+ # ========== 4. 配wandb ==========
+ wandb = None
+ if args.use_wandb and is_main_process():
+ import swanlab as wandb
+ wandb_id = ckp_data.get('wandb_id') if ckp_data else None
+ resume = 'must' if wandb_id else None
+ wandb_run_name = f"MiniMind-SPO-Epoch-{args.epochs}-BS-{args.batch_size}-LR-{args.learning_rate}"
+ wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
+
+ # ========== 5. 初始化模型(Policy, Ref, Reward)和Value Tracker、数据 ==========
+ tokenizer = AutoTokenizer.from_pretrained('../model/')
+ moe_suffix = '_moe' if lm_config.use_moe else ''
+ base_weight = "reason" if args.reasoning == 1 else "full_sft"
+ ckp = f'{args.save_dir}/{base_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
+ state_dict = torch.load(ckp, map_location=args.device)
+ # Policy模型
+ model = MiniMindForCausalLM(lm_config)
+ model.load_state_dict(state_dict, strict=False)
+ model = model.to(args.device)
+ Logger(f'Policy模型总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} M')
+ # Reference模型
ref_model = MiniMindForCausalLM(lm_config)
ref_model.load_state_dict(state_dict, strict=False)
ref_model.eval().requires_grad_(False)
-
- Logger(f'LLM总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
- model = model.to(args.device)
ref_model = ref_model.to(args.device)
-
- reward_name = "../../internlm2-1_8b-reward"
+ # Reward模型
reward_model = AutoModel.from_pretrained(
- reward_name,
- device_map="cuda",
- torch_dtype=torch.float16,
- trust_remote_code=True,
+ args.reward_model_path, device_map="cuda", torch_dtype=torch.float16, trust_remote_code=True
).to(args.device).eval().requires_grad_(False)
- reward_tokenizer = AutoTokenizer.from_pretrained(reward_name, trust_remote_code=True)
-
- return model, ref_model, tokenizer, reward_model, reward_tokenizer
-
-
-def init_distributed_mode():
- if not ddp: return
- global ddp_local_rank, DEVICE
- dist.init_process_group(backend="nccl")
- ddp_local_rank = int(os.environ["LOCAL_RANK"])
- DEVICE = f"cuda:{ddp_local_rank}"
- torch.cuda.set_device(DEVICE)
-
-
-if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument("--out_dir", type=str, default="../out")
- parser.add_argument("--epochs", type=int, default=1)
- parser.add_argument("--batch_size", type=int, default=2)
- parser.add_argument("--learning_rate", type=float, default=1e-7)
- parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu")
- parser.add_argument("--dtype", type=str, default="bfloat16")
- parser.add_argument("--use_wandb", action="store_true")
- parser.add_argument("--wandb_project", type=str, default="MiniMind-SPO")
- parser.add_argument("--num_workers", type=int, default=1)
- parser.add_argument("--ddp", action="store_true")
- parser.add_argument("--accumulation_steps", type=int, default=4)
- parser.add_argument("--grad_clip", type=float, default=1.0)
- parser.add_argument("--log_interval", type=int, default=1)
- parser.add_argument("--save_interval", type=int, default=10)
- parser.add_argument('--hidden_size', default=512, type=int)
- parser.add_argument('--num_hidden_layers', default=8, type=int)
- parser.add_argument('--use_moe', default=False, type=bool)
- parser.add_argument('--max_seq_len', default=66, type=int)
- parser.add_argument("--max_gen_len", type=int, default=1536)
- parser.add_argument("--data_path", type=str, default="../dataset/rlaif-mini.jsonl")
- parser.add_argument("--beta", type=float, default=0.02)
- parser.add_argument("--reasoning", type=int, default=1, help='0:普通模型,1:推理模型')
- args = parser.parse_args()
-
- lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers,
- max_seq_len=args.max_seq_len + args.max_gen_len,
- use_moe=args.use_moe)
- args.save_dir = os.path.join(args.out_dir)
- os.makedirs(args.save_dir, exist_ok=True)
- os.makedirs(args.out_dir, exist_ok=True)
-
- ctx = nullcontext() if "cuda" not in args.device else torch.amp.autocast('cuda')
- ddp = int(os.environ.get("RANK", -1)) != -1
- ddp_local_rank, DEVICE = 0, "cuda:0"
-
- base_seed = 1337
- torch.manual_seed(base_seed)
- torch.cuda.manual_seed(base_seed)
-
- if ddp:
- init_distributed_mode()
- args.device = torch.device(DEVICE)
- rank = dist.get_rank()
- torch.manual_seed(base_seed + rank)
- # 同时设置 CUDA 的随机种子
- torch.cuda.manual_seed(base_seed + rank)
-
- if args.use_wandb and (not ddp or ddp_local_rank == 0):
- import swanlab as wandb
-
- wandb.init(project=args.wandb_project)
- else:
- wandb = None
-
- model, ref_model, tokenizer, reward_model, reward_tokenizer = init_model(lm_config=lm_config)
+ reward_tokenizer = AutoTokenizer.from_pretrained(args.reward_model_path, trust_remote_code=True)
+ # Value Tracker
+ value_tracker = AutoAdaptiveValueTracker(rho_mode='kl', rho_const=0.9, D_half=0.06, clip_lower=0.5, clip_upper=0.96)
+
train_ds = RLAIFDataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len)
- train_sampler = DistributedSampler(train_ds) if ddp else None
- train_loader = DataLoader(train_ds, batch_size=args.batch_size, pin_memory=True,
+ train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
+ optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
+
+ loader_for_count = DataLoader(train_ds, batch_size=args.batch_size, sampler=train_sampler)
+ iters = len(loader_for_count)
+ total_optimizer_steps = (iters // args.accumulation_steps) * args.epochs
+ scheduler = CosineAnnealingLR(optimizer, T_max=total_optimizer_steps, eta_min=args.learning_rate / 10)
+
+ # ========== 6. 从ckp恢复状态 ==========
+ start_epoch, start_step = 0, 0
+ if ckp_data:
+ model.load_state_dict(ckp_data['model'])
+ optimizer.load_state_dict(ckp_data['optimizer'])
+ scheduler.load_state_dict(ckp_data['scheduler'])
+ start_epoch = ckp_data['epoch']
+ start_step = ckp_data.get('step', 0)
+
+ # ========== 7. DDP包模型 ==========
+ if dist.is_initialized():
+ model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
+ model = DistributedDataParallel(model, device_ids=[local_rank])
+
+ # ========== 8. 开始训练 ==========
+ for epoch in range(start_epoch, args.epochs):
+ train_sampler and train_sampler.set_epoch(epoch)
+ if epoch == start_epoch and start_step > 0: # 第一个epoch且存在检查点
+ batch_sampler = SkipBatchSampler(train_sampler or range(len(train_ds)), args.batch_size, start_step + 1)
+ loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
+ Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step,从step {start_step + 1}开始')
+ spo_train_epoch(epoch, loader, len(loader) + start_step + 1, ref_model, reward_model, reward_tokenizer, value_tracker, start_step, wandb)
+ else: # 默认从头开始
+ loader = DataLoader(train_ds, batch_size=args.batch_size, pin_memory=True,
drop_last=False, shuffle=(train_sampler is None),
num_workers=args.num_workers, sampler=train_sampler)
-
- optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
-
- iter_per_epoch = len(train_loader)
- total_optimizer_steps = (iter_per_epoch // args.accumulation_steps) * args.epochs
- scheduler = CosineAnnealingLR(optimizer, T_max=total_optimizer_steps, eta_min=args.learning_rate / 10)
-
- if ddp:
- model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
- model = DistributedDataParallel(model, device_ids=[ddp_local_rank])
-
- value_tracker = AutoAdaptiveValueTracker(rho_mode='kl', rho_const=0.9, D_half=0.06, clip_lower=0.5, clip_upper=0.96)
-
- for epoch in range(args.epochs):
- train_sampler and train_sampler.set_epoch(epoch)
- spo_train_epoch(epoch, wandb, value_tracker)
+ spo_train_epoch(epoch, loader, len(loader), ref_model, reward_model, reward_tokenizer, value_tracker, 0, wandb)
diff --git a/trainer/trainer_utils.py b/trainer/trainer_utils.py
new file mode 100644
index 0000000..9675c14
--- /dev/null
+++ b/trainer/trainer_utils.py
@@ -0,0 +1,139 @@
+"""
+训练工具函数集合
+"""
+import os
+import random
+import math
+import numpy as np
+import torch
+import torch.distributed as dist
+from torch.utils.data import Sampler
+
+
+def is_main_process():
+ return not dist.is_initialized() or dist.get_rank() == 0
+
+
+def Logger(content):
+ if is_main_process():
+ print(content)
+
+
+def get_lr(current_step, total_steps, lr):
+ return lr / 10 + 0.5 * lr * (1 + math.cos(math.pi * current_step / total_steps))
+
+
+def init_distributed_mode():
+ if int(os.environ.get("RANK", -1)) == -1:
+ return 0 # 非DDP模式
+
+ dist.init_process_group(backend="nccl")
+ local_rank = int(os.environ["LOCAL_RANK"])
+ torch.cuda.set_device(local_rank)
+ return local_rank
+
+
+def setup_seed(seed: int):
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+
+def lm_checkpoint(lm_config, weight='full_sft', model=None, optimizer=None, epoch=0, step=0, wandb=None, save_dir='../checkpoints', **kwargs):
+ os.makedirs(save_dir, exist_ok=True)
+ moe_path = '_moe' if lm_config.use_moe else ''
+ ckp_path = f'{save_dir}/{weight}_{lm_config.hidden_size}{moe_path}.pth'
+ resume_path = f'{save_dir}/{weight}_{lm_config.hidden_size}{moe_path}_resume.pth'
+
+ if model is not None:
+ from torch.nn.parallel import DistributedDataParallel
+ state_dict = model.module.state_dict() if isinstance(model, DistributedDataParallel) else model.state_dict()
+ ckp_tmp = ckp_path + '.tmp'
+ torch.save({k: v.half() for k, v in state_dict.items()}, ckp_tmp)
+ os.replace(ckp_tmp, ckp_path)
+ wandb_id = None
+ if wandb:
+ if hasattr(wandb, 'get_run'):
+ run = wandb.get_run()
+ wandb_id = getattr(run, 'id', None) if run else None
+ else:
+ wandb_id = getattr(wandb, 'id', None)
+
+ resume_data = {
+ 'model': state_dict,
+ 'optimizer': optimizer.state_dict(),
+ 'epoch': epoch,
+ 'step': step,
+ 'world_size': dist.get_world_size() if dist.is_initialized() else 1,
+ 'wandb_id': wandb_id
+ }
+ for key, value in kwargs.items():
+ if value is not None:
+ if hasattr(value, 'state_dict'):
+ if isinstance(value, DistributedDataParallel):
+ resume_data[key] = value.module.state_dict()
+ else:
+ resume_data[key] = value.state_dict()
+ else:
+ resume_data[key] = value
+
+ resume_tmp = resume_path + '.tmp'
+ torch.save(resume_data, resume_tmp)
+ os.replace(resume_tmp, resume_path)
+ else: # 加载模式
+ if os.path.exists(resume_path):
+ ckp_data = torch.load(resume_path, map_location='cpu')
+ saved_ws = ckp_data.get('world_size', 1)
+ current_ws = dist.get_world_size() if dist.is_initialized() else 1
+ if saved_ws != current_ws:
+ ckp_data['step'] = ckp_data['step'] * saved_ws // current_ws
+ Logger(f'GPU数量变化({saved_ws}→{current_ws}),step已自动转换为{ckp_data["step"]}')
+ return ckp_data
+ return None
+
+
+def init_model(lm_config, from_weight='pretrain', tokenizer_path='../model', save_dir='../out', device='cuda'):
+ from transformers import AutoTokenizer
+ from model.model_minimind import MiniMindForCausalLM
+
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
+ model = MiniMindForCausalLM(lm_config)
+
+ if from_weight!= 'none':
+ moe_suffix = '_moe' if lm_config.use_moe else ''
+ weight_path = f'{save_dir}/{from_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
+ weights = torch.load(weight_path, map_location=device)
+ model.load_state_dict(weights, strict=False)
+
+ Logger(f'所加载Model可训练参数:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
+ return model.to(device), tokenizer
+
+
+class SkipBatchSampler(Sampler):
+ def __init__(self, sampler, batch_size, skip_batches=0):
+ self.sampler = sampler
+ self.batch_size = batch_size
+ self.skip_batches = skip_batches
+
+ def __iter__(self):
+ batch = []
+ skipped = 0
+ for idx in self.sampler:
+ batch.append(idx)
+ if len(batch) == self.batch_size:
+ if skipped < self.skip_batches:
+ skipped += 1
+ batch = []
+ continue
+ yield batch
+ batch = []
+ if len(batch) > 0 and skipped >= self.skip_batches:
+ yield batch
+
+ def __len__(self):
+ total_batches = (len(self.sampler) + self.batch_size - 1) // self.batch_size
+ return max(0, total_batches - self.skip_batches)
+