[feat] pause-training

This commit is contained in:
jingyaogong 2025-10-26 18:49:52 +08:00
parent 6efba3249a
commit e8484874f5
10 changed files with 1171 additions and 1328 deletions

View File

@ -6,7 +6,6 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import argparse
import time
import math
import warnings
import torch
import torch.distributed as dist
@ -14,23 +13,14 @@ from contextlib import nullcontext
from torch import optim, nn
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader, DistributedSampler
from transformers import AutoTokenizer, AutoModelForCausalLM
from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
from model.model_minimind import MiniMindConfig
from dataset.lm_dataset import SFTDataset
from trainer.trainer_utils import *
warnings.filterwarnings('ignore')
def Logger(content):
if not ddp or dist.get_rank() == 0:
print(content)
def get_lr(current_step, total_steps, lr):
return lr / 10 + 0.5 * lr * (1 + math.cos(math.pi * current_step / total_steps))
def train_epoch(epoch, wandb):
def train_epoch(epoch, loader, iters, tokenizer, lm_config, start_step=0, wandb=None):
# 思考标签占位符
start_of_think_ids = tokenizer('<think>').input_ids
end_of_think_ids = tokenizer('</think>').input_ids
@ -38,28 +28,30 @@ def train_epoch(epoch, wandb):
end_of_answer_ids = tokenizer('</answer>').input_ids
loss_fct = nn.CrossEntropyLoss(reduction='none')
start_time = time.time()
for step, (X, Y, loss_mask) in enumerate(train_loader):
for step, (X, Y, loss_mask) in enumerate(loader, start=start_step + 1):
X = X.to(args.device)
Y = Y.to(args.device)
loss_mask = loss_mask.to(args.device)
lr = get_lr(epoch * iter_per_epoch + step, args.epochs * iter_per_epoch, args.learning_rate)
lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
with ctx:
with autocast_ctx:
res = model(X)
loss = loss_fct(
res.logits.view(-1, res.logits.size(-1)),
Y.view(-1)
).view(Y.size())
# 特殊标签位置增加权重(推理蒸馏特有)
sp_ids = torch.isin(Y.view(-1),
torch.tensor(start_of_think_ids + end_of_think_ids
+ start_of_answer_ids + end_of_answer_ids
).to(args.device))
# 在 sp_ids 对应的位置增加额外的惩罚
loss_mask = loss_mask.view(-1)
loss_mask_sum = loss_mask.sum()
loss_mask[sp_ids] = 10
loss_mask[sp_ids] = 10 # 对思考标签增加10倍权重
loss_mask = loss_mask.view(Y.size())
loss = (loss * loss_mask).sum() / loss_mask_sum
loss += res.aux_loss
@ -70,148 +62,112 @@ def train_epoch(epoch, wandb):
if (step + 1) % args.accumulation_steps == 0:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
if step % args.log_interval == 0 or step == iter_per_epoch - 1:
if step % args.log_interval == 0 or step == iters - 1:
spend_time = time.time() - start_time
Logger(
'Epoch:[{}/{}]({}/{}) loss:{:.6f} lr:{:.12f} epoch_Time:{}min:'.format(
epoch + 1,
args.epochs,
step,
iter_per_epoch,
loss.item() * args.accumulation_steps,
optimizer.param_groups[-1]['lr'],
spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60))
current_loss = loss.item() * args.accumulation_steps
current_lr = optimizer.param_groups[-1]['lr']
eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60
Logger(f'Epoch:[{epoch+1}/{args.epochs}]({step}/{iters}) loss:{current_loss:.6f} lr:{current_lr:.12f} epoch_Time:{eta_min}min:')
if wandb: wandb.log({"loss": current_loss, "lr": current_lr, "epoch_Time": eta_min})
if (wandb is not None) and (not ddp or dist.get_rank() == 0):
wandb.log({"loss": loss.item() * args.accumulation_steps,
"lr": optimizer.param_groups[-1]['lr'],
"epoch_Time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60})
if ((step + 1) % args.save_interval == 0 or step == iter_per_epoch - 1) and (not ddp or dist.get_rank() == 0):
if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
model.eval()
moe_path = '_moe' if lm_config.use_moe else ''
ckp = f'{args.save_dir}/reason_{lm_config.hidden_size}{moe_path}.pth'
moe_suffix = '_moe' if lm_config.use_moe else ''
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
state_dict = {k: v.half() for k, v in state_dict.items()} # 半精度保存
torch.save(state_dict, ckp)
lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints')
model.train()
def init_model(lm_config):
tokenizer = AutoTokenizer.from_pretrained('../model')
model = MiniMindForCausalLM(lm_config)
moe_path = '_moe' if lm_config.use_moe else ''
ckp = f'{args.save_dir}/rlhf_{lm_config.hidden_size}{moe_path}.pth'
state_dict = torch.load(ckp, map_location=args.device)
model.load_state_dict(state_dict, strict=False)
Logger(f'LLM总参数量{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
model = model.to(args.device)
return model, tokenizer
def init_distributed_mode():
if not ddp: return
global ddp_local_rank, DEVICE
dist.init_process_group(backend="nccl")
ddp_rank = int(os.environ["RANK"])
ddp_local_rank = int(os.environ["LOCAL_RANK"])
ddp_world_size = int(os.environ["WORLD_SIZE"])
DEVICE = f"cuda:{ddp_local_rank}"
torch.cuda.set_device(DEVICE)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="MiniMind Distill Reasoning")
parser.add_argument("--out_dir", type=str, default="../out")
parser.add_argument("--epochs", type=int, default=1)
parser.add_argument("--batch_size", type=int, default=8)
parser.add_argument("--learning_rate", type=float, default=1e-6)
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu")
parser.add_argument("--dtype", type=str, default="bfloat16")
parser.add_argument("--use_wandb", action="store_true")
parser.add_argument("--wandb_project", type=str, default="MiniMind-Full-SFT")
parser.add_argument("--num_workers", type=int, default=1)
parser.add_argument("--ddp", action="store_true")
parser.add_argument("--accumulation_steps", type=int, default=1)
parser.add_argument("--grad_clip", type=float, default=1.0)
parser.add_argument("--warmup_iters", type=int, default=0)
parser.add_argument("--log_interval", type=int, default=1)
parser.add_argument("--save_interval", type=int, default=50)
parser.add_argument('--local_rank', type=int, default=-1)
parser.add_argument('--hidden_size', default=512, type=int)
parser.add_argument('--num_hidden_layers', default=8, type=int)
parser.add_argument('--max_seq_len', default=1024, type=int)
parser.add_argument('--use_moe', default=False, type=bool)
parser.add_argument("--data_path", type=str, default="../dataset/r1_mix_1024.jsonl")
parser = argparse.ArgumentParser(description="MiniMind Reasoning Distillation")
parser.add_argument("--save_dir", type=str, default="../out", help="模型保存目录")
parser.add_argument('--save_weight', default='reason', type=str, help="保存权重的前缀名")
parser.add_argument("--epochs", type=int, default=1, help="训练轮数")
parser.add_argument("--batch_size", type=int, default=8, help="batch size")
parser.add_argument("--learning_rate", type=float, default=1e-6, help="初始学习率")
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备")
parser.add_argument("--dtype", type=str, default="bfloat16", help="混合精度类型")
parser.add_argument("--num_workers", type=int, default=1, help="数据加载线程数")
parser.add_argument("--accumulation_steps", type=int, default=1, help="梯度累积步数")
parser.add_argument("--grad_clip", type=float, default=1.0, help="梯度裁剪阈值")
parser.add_argument("--log_interval", type=int, default=100, help="日志打印间隔")
parser.add_argument("--save_interval", type=int, default=100, help="模型保存间隔")
parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度")
parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量")
parser.add_argument('--max_seq_len', default=1024, type=int, help="训练的最大截断长度")
parser.add_argument('--use_moe', default=False, type=bool, help="是否使用MoE")
parser.add_argument("--data_path", type=str, default="../dataset/r1_mix_1024.jsonl", help="推理蒸馏数据路径")
parser.add_argument('--from_weight', default='dpo', type=str, help="基于哪个权重训练默认dpo")
parser.add_argument('--from_resume', default=0, type=int, help="是否自动检测&续训0否1是")
parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
parser.add_argument("--wandb_project", type=str, default="MiniMind-Reasoning", help="wandb项目名")
args = parser.parse_args()
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers,
use_moe=args.use_moe)
args.save_dir = os.path.join(args.out_dir)
# ========== 1. 初始化环境和随机种子 ==========
local_rank = init_distributed_mode()
if dist.is_initialized(): args.device = f"cuda:{local_rank}"
setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))
# ========== 2. 配置目录、模型参数、检查ckp ==========
os.makedirs(args.save_dir, exist_ok=True)
os.makedirs(args.out_dir, exist_ok=True)
tokens_per_iter = args.batch_size * args.max_seq_len
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=args.use_moe)
ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None
# ========== 3. 设置混合精度 ==========
device_type = "cuda" if "cuda" in args.device else "cpu"
args.wandb_run_name = f"MiniMind-Distill-Reasoning-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"
ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast()
ddp = int(os.environ.get("RANK", -1)) != -1 # is this a ddp run?
ddp_local_rank, DEVICE = 0, "cuda:0"
base_seed = 1337
torch.manual_seed(base_seed)
torch.cuda.manual_seed(base_seed)
if ddp:
init_distributed_mode()
args.device = torch.device(DEVICE)
rank = dist.get_rank()
torch.manual_seed(base_seed + rank)
# 同时设置 CUDA 的随机种子
torch.cuda.manual_seed(base_seed + rank)
if args.use_wandb and (not ddp or ddp_local_rank == 0):
dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
# ========== 4. 配wandb ==========
wandb = None
if args.use_wandb and is_main_process():
import swanlab as wandb
wandb.init(project=args.wandb_project, name=args.wandb_run_name)
else:
wandb = None
model, tokenizer = init_model(lm_config)
wandb_id = ckp_data.get('wandb_id') if ckp_data else None
resume = 'must' if wandb_id else None
wandb_run_name = f"MiniMind-Reasoning-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LR-{args.learning_rate}"
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
# ========== 5. 定义模型、数据、优化器 ==========
model, tokenizer = init_model(lm_config, args.from_weight)
train_ds = SFTDataset(args.data_path, tokenizer, max_length=args.max_seq_len)
train_sampler = DistributedSampler(train_ds) if ddp else None
train_loader = DataLoader(
train_ds,
batch_size=args.batch_size,
pin_memory=True,
drop_last=False,
shuffle=(train_sampler is None),
num_workers=args.num_workers,
sampler=train_sampler
)
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype in ['float16', 'bfloat16']))
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
if ddp:
# ========== 6. 从ckp恢复状态 ==========
start_epoch, start_step = 0, 0
if ckp_data:
model.load_state_dict(ckp_data['model'])
optimizer.load_state_dict(ckp_data['optimizer'])
scaler.load_state_dict(ckp_data['scaler'])
start_epoch = ckp_data['epoch']
start_step = ckp_data.get('step', 0)
# ========== 7. DDP包模型 ==========
if dist.is_initialized():
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
model = DistributedDataParallel(model, device_ids=[ddp_local_rank])
iter_per_epoch = len(train_loader)
for epoch in range(args.epochs):
model = DistributedDataParallel(model, device_ids=[local_rank])
# ========== 8. 开始训练 ==========
for epoch in range(start_epoch, args.epochs):
train_sampler and train_sampler.set_epoch(epoch)
train_epoch(epoch, wandb)
if epoch == start_epoch and start_step > 0: # 第一个epoch且存在检查点
batch_sampler = SkipBatchSampler(train_sampler or range(len(train_ds)), args.batch_size, start_step + 1)
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step从step {start_step + 1}开始')
train_epoch(epoch, loader, len(loader) + start_step + 1, tokenizer, lm_config, start_step, wandb)
else: # 默认从头开始
loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=(train_sampler is None), sampler=train_sampler, num_workers=args.num_workers, pin_memory=True)
train_epoch(epoch, loader, len(loader), tokenizer, lm_config, 0, wandb)

View File

@ -3,11 +3,10 @@ import sys
__package__ = "trainer"
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import argparse
import time
import math
import warnings
import torch
import torch.nn.functional as F
import torch.distributed as dist
@ -15,23 +14,14 @@ from contextlib import nullcontext
from torch import optim
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader, DistributedSampler
from transformers import AutoTokenizer, AutoModelForCausalLM
from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
from model.model_minimind import MiniMindConfig
from dataset.lm_dataset import SFTDataset
from trainer.trainer_utils import *
warnings.filterwarnings('ignore')
def Logger(content):
if not ddp or dist.get_rank() == 0:
print(content)
def get_lr(current_step, total_steps, lr):
return lr / 10 + 0.5 * lr * (1 + math.cos(math.pi * current_step / total_steps))
def distillation_loss_fn(student_logits, teacher_logits, temperature=1.0, reduction='batchmean'):
def distillation_loss(student_logits, teacher_logits, temperature=1.0, reduction='batchmean'):
with torch.no_grad():
teacher_probs = F.softmax(teacher_logits / temperature, dim=-1).detach()
@ -45,25 +35,23 @@ def distillation_loss_fn(student_logits, teacher_logits, temperature=1.0, reduct
return (temperature ** 2) * kl
def train_epoch(epoch, wandb, alpha=0.0, temperature=1.0):
def train_epoch(epoch, loader, iters, teacher_model, lm_config_student, start_step=0, wandb=None, alpha=0.0, temperature=1.0):
start_time = time.time()
if teacher_model is not None:
teacher_model.eval()
teacher_model.requires_grad_(False)
for step, (X, Y, loss_mask) in enumerate(train_loader):
for step, (X, Y, loss_mask) in enumerate(loader, start=start_step + 1):
X = X.to(args.device)
Y = Y.to(args.device)
loss_mask = loss_mask.to(args.device)
lr = get_lr(epoch * iter_per_epoch + step,
args.epochs * iter_per_epoch,
args.learning_rate)
lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
# 前向传播(学生模型)
with ctx:
with autocast_ctx:
res = model(X)
student_logits = res.logits
@ -71,11 +59,11 @@ def train_epoch(epoch, wandb, alpha=0.0, temperature=1.0):
if teacher_model is not None:
with torch.no_grad():
teacher_logits = teacher_model(X).logits
vocab_size_student = student_logits.size(-1) # N
vocab_size_student = student_logits.size(-1)
teacher_logits = teacher_logits[..., :vocab_size_student]
# ========== 计算损失 ==========
# 1) Ground-Truth CE Loss(可选)
# 1) Ground-Truth CE Loss
loss_mask_flat = loss_mask.view(-1)
ce_loss = F.cross_entropy(
student_logits.view(-1, student_logits.size(-1)),
@ -87,10 +75,9 @@ def train_epoch(epoch, wandb, alpha=0.0, temperature=1.0):
if lm_config_student.use_moe:
ce_loss += res.aux_loss
# 2) Distillation Loss(可选)
# 2) Distillation Loss
if teacher_model is not None:
# 只在有效token位置做蒸馏
distill_loss = distillation_loss_fn(
distill_loss = distillation_loss(
student_logits.view(-1, student_logits.size(-1))[loss_mask_flat == 1],
teacher_logits.view(-1, teacher_logits.size(-1))[loss_mask_flat == 1],
temperature=temperature
@ -110,157 +97,126 @@ def train_epoch(epoch, wandb, alpha=0.0, temperature=1.0):
scaler.update()
optimizer.zero_grad(set_to_none=True)
if step % args.log_interval == 0 or step == iter_per_epoch - 1:
if step % args.log_interval == 0 or step == iters - 1:
spend_time = time.time() - start_time
Logger(
'Epoch:[{}/{}]({}/{}) loss:{:.6f} lr:{:.12f} epoch_Time:{}min:'.format(
epoch,
args.epochs - 1,
step,
iter_per_epoch,
loss.item(),
optimizer.param_groups[-1]['lr'],
spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60
)
)
if (wandb is not None) and (not ddp or dist.get_rank() == 0):
current_loss = loss.item() * args.accumulation_steps
current_lr = optimizer.param_groups[-1]['lr']
eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60
Logger(f'Epoch:[{epoch+1}/{args.epochs}]({step}/{iters}) loss:{current_loss:.6f} ce:{ce_loss.item():.4f} distill:{distill_loss.item():.4f} lr:{current_lr:.12f} epoch_Time:{eta_min}min:')
if wandb:
wandb.log({
"loss": loss.item(),
"loss": current_loss,
"ce_loss": ce_loss.item(),
"distill_loss": distill_loss.item() if teacher_model is not None else 0.0,
"lr": optimizer.param_groups[-1]['lr'],
"last-time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60
"lr": current_lr,
"epoch_Time": eta_min
})
if ((step + 1) % args.save_interval == 0 or step == iter_per_epoch - 1) and (not ddp or dist.get_rank() == 0):
if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
model.eval()
moe_path = '_moe' if lm_config_student.use_moe else ''
ckp = f'{args.save_dir}/full_dist_{lm_config_student.hidden_size}{moe_path}.pth'
moe_suffix = '_moe' if lm_config_student.use_moe else ''
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config_student.hidden_size}{moe_suffix}.pth'
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
state_dict = {k: v.half() for k, v in state_dict.items()} # 半精度保存
torch.save(state_dict, ckp)
lm_checkpoint(lm_config_student, weight=args.save_weight, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints')
model.train()
def init_student_model(lm_config):
tokenizer = AutoTokenizer.from_pretrained('../model/')
model = MiniMindForCausalLM(lm_config)
moe_path = '_moe' if lm_config.use_moe else ''
ckp = f'{args.save_dir}/full_sft_{lm_config.hidden_size}{moe_path}.pth'
state_dict = torch.load(ckp, map_location=args.device)
model.load_state_dict(state_dict, strict=False)
Logger(f'学生模型(LLM)总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
model = model.to(args.device)
return model, tokenizer
def init_teacher_model(lm_config):
model = MiniMindForCausalLM(lm_config)
moe_path = '_moe' if lm_config.use_moe else ''
ckp = f'{args.save_dir}/full_sft_{lm_config.hidden_size}{moe_path}.pth'
state_dict = torch.load(ckp, map_location=args.device)
model.load_state_dict(state_dict, strict=False)
Logger(f'教师模型(LLM)总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
model = model.to(args.device)
return model
def init_distributed_mode():
if not ddp: return
global ddp_local_rank, DEVICE
dist.init_process_group(backend="nccl")
ddp_rank = int(os.environ["RANK"])
ddp_local_rank = int(os.environ["LOCAL_RANK"])
ddp_world_size = int(os.environ["WORLD_SIZE"])
DEVICE = f"cuda:{ddp_local_rank}"
torch.cuda.set_device(DEVICE)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="MiniMind Full SFT")
parser.add_argument("--out_dir", type=str, default="../out")
parser.add_argument("--epochs", type=int, default=6)
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--learning_rate", type=float, default=5e-6)
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu")
parser.add_argument("--dtype", type=str, default="bfloat16")
parser.add_argument("--use_wandb", action="store_true")
parser.add_argument("--wandb_project", type=str, default="MiniMind-Full-SFT")
parser.add_argument("--num_workers", type=int, default=1)
parser.add_argument("--ddp", action="store_true")
parser.add_argument("--accumulation_steps", type=int, default=1)
parser.add_argument("--grad_clip", type=float, default=1.0)
parser.add_argument("--warmup_iters", type=int, default=0)
parser.add_argument("--log_interval", type=int, default=100)
parser.add_argument("--save_interval", type=int, default=100)
parser.add_argument("--max_seq_len", type=int, default=512)
parser.add_argument('--local_rank', type=int, default=-1)
parser.add_argument("--data_path", type=str, default="../dataset/sft_xxx.jsonl")
parser = argparse.ArgumentParser(description="MiniMind Knowledge Distillation")
parser.add_argument("--save_dir", type=str, default="../out", help="模型保存目录")
parser.add_argument('--save_weight', default='full_dist', type=str, help="保存权重的前缀名")
parser.add_argument("--epochs", type=int, default=6, help="训练轮数")
parser.add_argument("--batch_size", type=int, default=32, help="batch size")
parser.add_argument("--learning_rate", type=float, default=5e-6, help="初始学习率")
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备")
parser.add_argument("--dtype", type=str, default="bfloat16", help="混合精度类型")
parser.add_argument("--num_workers", type=int, default=1, help="数据加载线程数")
parser.add_argument("--accumulation_steps", type=int, default=1, help="梯度累积步数")
parser.add_argument("--grad_clip", type=float, default=1.0, help="梯度裁剪阈值")
parser.add_argument("--log_interval", type=int, default=100, help="日志打印间隔")
parser.add_argument("--save_interval", type=int, default=100, help="模型保存间隔")
parser.add_argument("--max_seq_len", type=int, default=512, help="训练的最大截断长度")
parser.add_argument("--data_path", type=str, default="../dataset/sft_mini_512.jsonl", help="训练数据路径")
parser.add_argument('--student_hidden_size', default=512, type=int, help="学生模型隐藏层维度")
parser.add_argument('--student_num_layers', default=8, type=int, help="学生模型隐藏层数量")
parser.add_argument('--teacher_hidden_size', default=768, type=int, help="教师模型隐藏层维度")
parser.add_argument('--teacher_num_layers', default=16, type=int, help="教师模型隐藏层数量")
parser.add_argument('--use_moe', default=False, type=bool, help="是否使用MoE")
parser.add_argument('--from_student_weight', default='full_sft', type=str, help="学生模型基于哪个权重")
parser.add_argument('--from_teacher_weight', default='full_sft', type=str, help="教师模型基于哪个权重")
parser.add_argument('--from_resume', default=0, type=int, help="是否自动检测&续训0否1是")
parser.add_argument('--alpha', default=0.5, type=float, help="CE损失权重总损失=alpha*CE+(1-alpha)*KL")
parser.add_argument('--temperature', default=2.0, type=float, help="蒸馏温度")
parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
parser.add_argument("--wandb_project", type=str, default="MiniMind-Distillation", help="wandb项目名")
args = parser.parse_args()
# 定义学生模型和教师模型
lm_config_student = MiniMindConfig(hidden_size=512, num_hidden_layers=8)
lm_config_teacher = MiniMindConfig(hidden_size=768, num_hidden_layers=16)
args.save_dir = os.path.join(args.out_dir)
# ========== 1. 初始化环境和随机种子 ==========
local_rank = init_distributed_mode()
if dist.is_initialized(): args.device = f"cuda:{local_rank}"
setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))
# ========== 2. 配置目录、模型参数、检查ckp ==========
os.makedirs(args.save_dir, exist_ok=True)
os.makedirs(args.out_dir, exist_ok=True)
tokens_per_iter = args.batch_size * args.max_seq_len
lm_config_student = MiniMindConfig(hidden_size=args.student_hidden_size, num_hidden_layers=args.student_num_layers, use_moe=args.use_moe)
lm_config_teacher = MiniMindConfig(hidden_size=args.teacher_hidden_size, num_hidden_layers=args.teacher_num_layers, use_moe=args.use_moe)
ckp_data = lm_checkpoint(lm_config_student, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None
# ========== 3. 设置混合精度 ==========
device_type = "cuda" if "cuda" in args.device else "cpu"
args.wandb_run_name = f"MiniMind-Dist-SFT-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"
ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast()
ddp = int(os.environ.get("RANK", -1)) != -1 # is this a ddp run?
ddp_local_rank, DEVICE = 0, "cuda:0"
base_seed = 1337
torch.manual_seed(base_seed)
torch.cuda.manual_seed(base_seed)
if ddp:
init_distributed_mode()
args.device = torch.device(DEVICE)
rank = dist.get_rank()
torch.manual_seed(base_seed + rank)
# 同时设置 CUDA 的随机种子
torch.cuda.manual_seed(base_seed + rank)
if args.use_wandb and (not ddp or ddp_local_rank == 0):
dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
# ========== 4. 配wandb ==========
wandb = None
if args.use_wandb and is_main_process():
import swanlab as wandb
wandb.init(project=args.wandb_project, name=args.wandb_run_name)
else:
wandb = None
# 初始化学生模型和教师模型
model, tokenizer = init_student_model(lm_config_student)
teacher_model = init_teacher_model(lm_config_teacher)
wandb_id = ckp_data.get('wandb_id') if ckp_data else None
resume = 'must' if wandb_id else None
wandb_run_name = f"MiniMind-Distill-S{args.student_hidden_size}T{args.teacher_hidden_size}-Epoch-{args.epochs}-BS-{args.batch_size}-LR-{args.learning_rate}"
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
# ========== 5. 定义学生和教师模型 ==========
model, tokenizer = init_model(lm_config_student, args.from_student_weight)
Logger(f'学生模型总参数量:{sum(p.numel() for p in model.parameters()) / 1e6:.3f} M')
teacher_model, _ = init_model(lm_config_teacher, args.from_teacher_weight)
teacher_model.eval()
teacher_model.requires_grad_(False)
Logger(f'教师模型总参数量:{sum(p.numel() for p in teacher_model.parameters()) / 1e6:.3f} M')
train_ds = SFTDataset(args.data_path, tokenizer, max_length=args.max_seq_len)
train_sampler = DistributedSampler(train_ds) if ddp else None
train_loader = DataLoader(
train_ds,
batch_size=args.batch_size,
pin_memory=True,
drop_last=False,
shuffle=(train_sampler is None),
num_workers=args.num_workers,
sampler=train_sampler
)
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype in ['float16', 'bfloat16']))
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
if ddp:
# ========== 6. 从ckp恢复状态 ==========
start_epoch, start_step = 0, 0
if ckp_data:
model.load_state_dict(ckp_data['model'])
optimizer.load_state_dict(ckp_data['optimizer'])
scaler.load_state_dict(ckp_data['scaler'])
start_epoch = ckp_data['epoch']
start_step = ckp_data.get('step', 0)
# ========== 7. DDP包模型 ==========
if dist.is_initialized():
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
model = DistributedDataParallel(model, device_ids=[ddp_local_rank])
iter_per_epoch = len(train_loader)
for epoch in range(args.epochs):
model = DistributedDataParallel(model, device_ids=[local_rank])
# ========== 8. 开始训练 ==========
for epoch in range(start_epoch, args.epochs):
train_sampler and train_sampler.set_epoch(epoch)
train_epoch(epoch, wandb)
if epoch == start_epoch and start_step > 0: # 第一个epoch且存在检查点
batch_sampler = SkipBatchSampler(train_sampler or range(len(train_ds)), args.batch_size, start_step + 1)
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step从step {start_step + 1}开始')
train_epoch(epoch, loader, len(loader) + start_step + 1, teacher_model, lm_config_student, start_step, wandb, args.alpha, args.temperature)
else: # 默认从头开始
loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=(train_sampler is None), sampler=train_sampler, num_workers=args.num_workers, pin_memory=True)
train_epoch(epoch, loader, len(loader), teacher_model, lm_config_student, 0, wandb, args.alpha, args.temperature)

View File

@ -6,7 +6,6 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import argparse
import time
import math
import warnings
import torch
import torch.nn.functional as F
@ -15,55 +14,47 @@ from contextlib import nullcontext
from torch import optim
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader, DistributedSampler
from transformers import AutoTokenizer, AutoModelForCausalLM
from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
from model.model_minimind import MiniMindConfig
from dataset.lm_dataset import DPODataset
from trainer.trainer_utils import *
warnings.filterwarnings('ignore')
def Logger(content):
if not ddp or dist.get_rank() == 0:
print(content)
def get_lr(current_step, total_steps, lr):
return lr / 10 + 0.5 * lr * (1 + math.cos(math.pi * current_step / total_steps))
def logits_to_probs(logits, labels):
def logits_to_log_probs(logits, labels):
# logits shape: (batch_size, seq_len, vocab_size)
# labels shape: (batch_size, seq_len)
# probs shape: (batch_size, seq_len)
# log_probs shape: (batch_size, seq_len)
log_probs = F.log_softmax(logits, dim=2)
probs = torch.gather(log_probs, dim=2, index=labels.unsqueeze(2)).squeeze(-1)
return probs
log_probs_per_token = torch.gather(log_probs, dim=2, index=labels.unsqueeze(2)).squeeze(-1)
return log_probs_per_token
def dpo_loss(ref_probs, probs, mask, beta):
# ref_probs 和 probs 都是 shape: (batch_size, seq_len)
def dpo_loss(ref_log_probs, policy_log_probs, mask, beta):
# ref_log_probs 和 policy_log_probs 都是 shape: (batch_size, seq_len)
# https://github.com/jingyaogong/minimind/issues/298
seq_lengths = mask.sum(dim=1, keepdim=True) # (batch_size, 1)
ref_probs = (ref_probs * mask).sum(dim=1) / seq_lengths.squeeze()
probs = (probs * mask).sum(dim=1) / seq_lengths.squeeze()
seq_lengths = mask.sum(dim=1, keepdim=True).clamp_min(1e-8) # 防止零长度mask导致除零NaN
ref_log_probs = (ref_log_probs * mask).sum(dim=1) / seq_lengths.squeeze()
policy_log_probs = (policy_log_probs * mask).sum(dim=1) / seq_lengths.squeeze()
# 将 chosen 和 rejected 数据分开
batch_size = ref_probs.shape[0]
chosen_ref_probs = ref_probs[:batch_size // 2]
reject_ref_probs = ref_probs[batch_size // 2:]
chosen_probs = probs[:batch_size // 2]
reject_probs = probs[batch_size // 2:]
batch_size = ref_log_probs.shape[0]
chosen_ref_log_probs = ref_log_probs[:batch_size // 2]
reject_ref_log_probs = ref_log_probs[batch_size // 2:]
chosen_policy_log_probs = policy_log_probs[:batch_size // 2]
reject_policy_log_probs = policy_log_probs[batch_size // 2:]
pi_logratios = chosen_probs - reject_probs
ref_logratios = chosen_ref_probs - reject_ref_probs
pi_logratios = chosen_policy_log_probs - reject_policy_log_probs
ref_logratios = chosen_ref_log_probs - reject_ref_log_probs
logits = pi_logratios - ref_logratios
loss = -F.logsigmoid(beta * logits)
return loss.mean()
def train_epoch(epoch, wandb):
def train_epoch(epoch, loader, iters, ref_model, lm_config, start_step=0, wandb=None, beta=0.1):
start_time = time.time()
for step, batch in enumerate(train_loader):
for step, batch in enumerate(loader, start=start_step + 1):
x_chosen = batch['x_chosen'].to(args.device)
x_rejected = batch['x_rejected'].to(args.device)
y_chosen = batch['y_chosen'].to(args.device)
@ -74,21 +65,21 @@ def train_epoch(epoch, wandb):
y = torch.cat([y_chosen, y_rejected], dim=0)
mask = torch.cat([mask_chosen, mask_rejected], dim=0)
lr = get_lr(epoch * iter_per_epoch + step, args.epochs * iter_per_epoch, args.learning_rate)
lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
with ctx:
with autocast_ctx:
with torch.no_grad():
ref_outputs = ref_model(x)
ref_logits = ref_outputs.logits
ref_probs = logits_to_probs(ref_logits, y)
ref_probs = ref_probs * mask
ref_log_probs = logits_to_log_probs(ref_logits, y)
outputs = model(x)
logits = outputs.logits
probs = logits_to_probs(logits, y)
probs = probs * mask
loss = dpo_loss(ref_probs, probs, mask, beta=0.1)
policy_log_probs = logits_to_log_probs(logits, y)
loss = dpo_loss(ref_log_probs, policy_log_probs, mask, beta=beta)
loss = loss / args.accumulation_steps
scaler.scale(loss).backward()
@ -100,150 +91,116 @@ def train_epoch(epoch, wandb):
scaler.update()
optimizer.zero_grad(set_to_none=True)
if step % args.log_interval == 0 or step == iter_per_epoch - 1:
if step % args.log_interval == 0 or step == iters - 1:
spend_time = time.time() - start_time
Logger(
'Epoch:[{}/{}]({}/{}) loss:{:.6f} lr:{:.12f} epoch_Time:{}min:'.format(
epoch + 1,
args.epochs,
step,
iter_per_epoch,
loss.item() * args.accumulation_steps,
optimizer.param_groups[-1]['lr'],
spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60))
current_loss = loss.item() * args.accumulation_steps
current_lr = optimizer.param_groups[-1]['lr']
eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60
Logger(f'Epoch:[{epoch+1}/{args.epochs}]({step}/{iters}) loss:{current_loss:.6f} lr:{current_lr:.12f} epoch_Time:{eta_min}min:')
if wandb: wandb.log({"loss": current_loss, "lr": current_lr, "epoch_Time": eta_min})
if (wandb is not None) and (not ddp or dist.get_rank() == 0):
wandb.log({"loss": loss.item() * args.accumulation_steps,
"lr": optimizer.param_groups[-1]['lr'],
"epoch_Time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60})
if ((step + 1) % args.save_interval == 0 or step == iter_per_epoch - 1) and (not ddp or dist.get_rank() == 0):
if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
model.eval()
moe_path = '_moe' if lm_config.use_moe else ''
ckp = f'{args.save_dir}/rlhf_{lm_config.hidden_size}{moe_path}.pth'
moe_suffix = '_moe' if lm_config.use_moe else ''
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
state_dict = {k: v.half() for k, v in state_dict.items()} # 半精度保存
torch.save(state_dict, ckp)
lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints')
model.train()
def init_model(lm_config):
tokenizer = AutoTokenizer.from_pretrained('../model/')
model = MiniMindForCausalLM(lm_config)
moe_path = '_moe' if lm_config.use_moe else ''
ckp = f'{args.save_dir}/full_sft_{lm_config.hidden_size}{moe_path}.pth'
state_dict = torch.load(ckp, map_location=args.device)
model.load_state_dict(state_dict, strict=False)
# 初始化参考模型
ref_model = MiniMindForCausalLM(lm_config)
ref_model.load_state_dict(state_dict, strict=False)
ref_model.eval()
ref_model.requires_grad_(False)
Logger(f'LLM总参数量{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
model = model.to(args.device)
ref_model = ref_model.to(args.device)
return model, ref_model, tokenizer
def init_distributed_mode():
if not ddp: return
global ddp_local_rank, DEVICE
dist.init_process_group(backend="nccl")
ddp_rank = int(os.environ["RANK"])
ddp_local_rank = int(os.environ["LOCAL_RANK"])
ddp_world_size = int(os.environ["WORLD_SIZE"])
DEVICE = f"cuda:{ddp_local_rank}"
torch.cuda.set_device(DEVICE)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="MiniMind RLHF")
parser.add_argument("--out_dir", type=str, default="../out")
parser.add_argument("--epochs", type=int, default=2)
parser.add_argument("--batch_size", type=int, default=4)
# sft阶段学习率为 「5e-6」->「5e-7」长度512建议离线正负样本「概率」偏好对齐阶段lr <=「1e-8」长度3000否则很容易遗忘训坏
parser.add_argument("--learning_rate", type=float, default=1e-8)
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu")
parser.add_argument("--dtype", type=str, default="bfloat16")
parser.add_argument("--use_wandb", action="store_true")
parser.add_argument("--wandb_project", type=str, default="MiniMind-RLHF-SFT")
parser.add_argument("--num_workers", type=int, default=1)
parser.add_argument("--ddp", action="store_true")
parser.add_argument("--accumulation_steps", type=int, default=1)
parser.add_argument("--grad_clip", type=float, default=1.0)
parser.add_argument("--warmup_iters", type=int, default=0)
parser.add_argument("--log_interval", type=int, default=100)
parser.add_argument("--save_interval", type=int, default=100)
parser.add_argument('--local_rank', type=int, default=-1)
parser.add_argument('--hidden_size', default=512, type=int)
parser.add_argument('--num_hidden_layers', default=8, type=int)
parser.add_argument('--max_seq_len', default=1024, type=int)
parser.add_argument('--use_moe', default=False, type=bool)
parser.add_argument("--data_path", type=str, default="../dataset/dpo.jsonl")
parser = argparse.ArgumentParser(description="MiniMind DPO (Direct Preference Optimization)")
parser.add_argument("--save_dir", type=str, default="../out", help="模型保存目录")
parser.add_argument('--save_weight', default='dpo', type=str, help="保存权重的前缀名")
parser.add_argument("--epochs", type=int, default=2, help="训练轮数")
parser.add_argument("--batch_size", type=int, default=4, help="batch size")
parser.add_argument("--learning_rate", type=float, default=5e-8, help="初始学习率(建议<=5e-8避免遗忘")
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备")
parser.add_argument("--dtype", type=str, default="bfloat16", help="混合精度类型")
parser.add_argument("--num_workers", type=int, default=1, help="数据加载线程数")
parser.add_argument("--accumulation_steps", type=int, default=1, help="梯度累积步数")
parser.add_argument("--grad_clip", type=float, default=1.0, help="梯度裁剪阈值")
parser.add_argument("--log_interval", type=int, default=100, help="日志打印间隔")
parser.add_argument("--save_interval", type=int, default=100, help="模型保存间隔")
parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度")
parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量")
parser.add_argument('--max_seq_len', default=1024, type=int, help="训练的最大截断长度")
parser.add_argument('--use_moe', default=False, type=bool, help="是否使用MoE")
parser.add_argument("--data_path", type=str, default="../dataset/dpo.jsonl", help="DPO训练数据路径")
parser.add_argument('--from_weight', default='full_sft', type=str, help="基于哪个权重训练")
parser.add_argument('--from_resume', default=0, type=int, help="是否自动检测&续训0否1是")
parser.add_argument('--beta', default=0.1, type=float, help="DPO中的beta参数")
parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
parser.add_argument("--wandb_project", type=str, default="MiniMind-DPO", help="wandb项目名")
args = parser.parse_args()
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers,
use_moe=args.use_moe)
args.save_dir = os.path.join(args.out_dir)
# ========== 1. 初始化环境和随机种子 ==========
local_rank = init_distributed_mode()
if dist.is_initialized(): args.device = f"cuda:{local_rank}"
setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))
# ========== 2. 配置目录、模型参数、检查ckp ==========
os.makedirs(args.save_dir, exist_ok=True)
os.makedirs(args.out_dir, exist_ok=True)
tokens_per_iter = args.batch_size * args.max_seq_len
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=args.use_moe)
ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None
# ========== 3. 设置混合精度 ==========
device_type = "cuda" if "cuda" in args.device else "cpu"
args.wandb_run_name = f"MiniMind-Full-DPO-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"
ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast()
ddp = int(os.environ.get("RANK", -1)) != -1 # is this a ddp run?
ddp_local_rank, DEVICE = 0, "cuda:0"
base_seed = 1337
torch.manual_seed(base_seed)
torch.cuda.manual_seed(base_seed)
if ddp:
init_distributed_mode()
args.device = torch.device(DEVICE)
rank = dist.get_rank()
torch.manual_seed(base_seed + rank)
# 同时设置 CUDA 的随机种子
torch.cuda.manual_seed(base_seed + rank)
if args.use_wandb and (not ddp or ddp_local_rank == 0):
dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
# ========== 4. 配wandb ==========
wandb = None
if args.use_wandb and is_main_process():
import swanlab as wandb
wandb.init(project=args.wandb_project, name=args.wandb_run_name)
else:
wandb = None
model, ref_model, tokenizer = init_model(lm_config)
wandb_id = ckp_data.get('wandb_id') if ckp_data else None
resume = 'must' if wandb_id else None
wandb_run_name = f"MiniMind-DPO-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LR-{args.learning_rate}"
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
# ========== 5. 定义模型和参考模型 ==========
model, tokenizer = init_model(lm_config, args.from_weight)
Logger(f'策略模型总参数量:{sum(p.numel() for p in model.parameters()) / 1e6:.3f} M')
# 初始化参考模型ref_model冻结
ref_model, _ = init_model(lm_config, args.from_weight)
ref_model.eval()
ref_model.requires_grad_(False)
Logger(f'参考模型总参数量:{sum(p.numel() for p in ref_model.parameters()) / 1e6:.3f} M')
train_ds = DPODataset(args.data_path, tokenizer, max_length=args.max_seq_len)
train_sampler = DistributedSampler(train_ds) if ddp else None
train_loader = DataLoader(
train_ds,
batch_size=args.batch_size,
pin_memory=True,
drop_last=False,
shuffle=(train_sampler is None),
num_workers=args.num_workers,
sampler=train_sampler
)
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype in ['float16', 'bfloat16']))
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
if ddp:
# ========== 6. 从ckp恢复状态 ==========
start_epoch, start_step = 0, 0
if ckp_data:
model.load_state_dict(ckp_data['model'])
optimizer.load_state_dict(ckp_data['optimizer'])
scaler.load_state_dict(ckp_data['scaler'])
start_epoch = ckp_data['epoch']
start_step = ckp_data.get('step', 0)
# ========== 7. DDP包模型 ==========
if dist.is_initialized():
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
model = DistributedDataParallel(model, device_ids=[ddp_local_rank])
iter_per_epoch = len(train_loader)
for epoch in range(args.epochs):
model = DistributedDataParallel(model, device_ids=[local_rank])
# ========== 8. 开始训练 ==========
for epoch in range(start_epoch, args.epochs):
train_sampler and train_sampler.set_epoch(epoch)
train_epoch(epoch, wandb)
if epoch == start_epoch and start_step > 0: # 第一个epoch且存在检查点
batch_sampler = SkipBatchSampler(train_sampler or range(len(train_ds)), args.batch_size, start_step + 1)
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step从step {start_step + 1}开始')
train_epoch(epoch, loader, len(loader) + start_step + 1, ref_model, lm_config, start_step, wandb, args.beta)
else: # 默认从头开始
loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=(train_sampler is None), sampler=train_sampler, num_workers=args.num_workers, pin_memory=True)
train_epoch(epoch, loader, len(loader), ref_model, lm_config, 0, wandb, args.beta)

View File

@ -6,7 +6,6 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import argparse
import time
import math
import warnings
import torch
import torch.distributed as dist
@ -14,34 +13,25 @@ from contextlib import nullcontext
from torch import optim, nn
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader, DistributedSampler
from transformers import AutoTokenizer, AutoModelForCausalLM
from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
from model.model_minimind import MiniMindConfig
from dataset.lm_dataset import SFTDataset
from trainer.trainer_utils import *
warnings.filterwarnings('ignore')
def Logger(content):
if not ddp or dist.get_rank() == 0:
print(content)
def get_lr(current_step, total_steps, lr):
return lr / 10 + 0.5 * lr * (1 + math.cos(math.pi * current_step / total_steps))
def train_epoch(epoch, wandb):
def train_epoch(epoch, loader, iters, start_step=0, wandb=None):
loss_fct = nn.CrossEntropyLoss(reduction='none')
start_time = time.time()
for step, (X, Y, loss_mask) in enumerate(train_loader):
for step, (X, Y, loss_mask) in enumerate(loader, start=start_step + 1):
X = X.to(args.device)
Y = Y.to(args.device)
loss_mask = loss_mask.to(args.device)
lr = get_lr(epoch * iter_per_epoch + step, args.epochs * iter_per_epoch, args.learning_rate)
lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
with ctx:
with autocast_ctx:
res = model(X)
loss = loss_fct(
res.logits.view(-1, res.logits.size(-1)),
@ -63,141 +53,109 @@ def train_epoch(epoch, wandb):
optimizer.zero_grad(set_to_none=True)
if step % args.log_interval == 0 or step == iter_per_epoch - 1:
if step % args.log_interval == 0 or step == iters - 1:
spend_time = time.time() - start_time
Logger(
'Epoch:[{}/{}]({}/{}) loss:{:.6f} lr:{:.12f} epoch_Time:{}min:'.format(
epoch + 1,
args.epochs,
step,
iter_per_epoch,
loss.item() * args.accumulation_steps,
optimizer.param_groups[-1]['lr'],
spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60))
current_loss = loss.item() * args.accumulation_steps
current_lr = optimizer.param_groups[-1]['lr']
eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60
Logger(f'Epoch:[{epoch+1}/{args.epochs}]({step}/{iters}) loss:{current_loss:.6f} lr:{current_lr:.12f} epoch_Time:{eta_min}min:')
if wandb: wandb.log({"loss": current_loss, "lr": current_lr, "epoch_Time": eta_min})
if (wandb is not None) and (not ddp or dist.get_rank() == 0):
wandb.log({"loss": loss.item() * args.accumulation_steps,
"lr": optimizer.param_groups[-1]['lr'],
"epoch_Time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60})
if ((step + 1) % args.save_interval == 0 or step == iter_per_epoch - 1) and (not ddp or dist.get_rank() == 0):
if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
model.eval()
moe_path = '_moe' if lm_config.use_moe else ''
ckp = f'{args.save_dir}/full_sft_{lm_config.hidden_size}{moe_path}.pth'
moe_suffix = '_moe' if lm_config.use_moe else ''
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
state_dict = {k: v.half() for k, v in state_dict.items()} # 半精度保存
torch.save(state_dict, ckp)
lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer,
epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints', scaler=scaler)
model.train()
def init_model(lm_config):
tokenizer = AutoTokenizer.from_pretrained('../model')
model = MiniMindForCausalLM(lm_config)
moe_path = '_moe' if lm_config.use_moe else ''
ckp = f'{args.save_dir}/pretrain_{lm_config.hidden_size}{moe_path}.pth'
state_dict = torch.load(ckp, map_location=args.device)
model.load_state_dict(state_dict, strict=False)
Logger(f'LLM可训练总参数量{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
model = model.to(args.device)
return model, tokenizer
def init_distributed_mode():
if not ddp: return
global ddp_local_rank, DEVICE
dist.init_process_group(backend="nccl")
ddp_rank = int(os.environ["RANK"])
ddp_local_rank = int(os.environ["LOCAL_RANK"])
ddp_world_size = int(os.environ["WORLD_SIZE"])
DEVICE = f"cuda:{ddp_local_rank}"
torch.cuda.set_device(DEVICE)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="MiniMind Full SFT")
parser.add_argument("--out_dir", type=str, default="../out")
parser.add_argument("--epochs", type=int, default=2)
parser.add_argument("--batch_size", type=int, default=16)
parser.add_argument("--learning_rate", type=float, default=5e-7)
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu")
parser.add_argument("--dtype", type=str, default="bfloat16")
parser.add_argument("--use_wandb", action="store_true")
parser.add_argument("--wandb_project", type=str, default="MiniMind-Full-SFT")
parser.add_argument("--num_workers", type=int, default=1)
parser.add_argument("--ddp", action="store_true")
parser.add_argument("--accumulation_steps", type=int, default=1)
parser.add_argument("--grad_clip", type=float, default=1.0)
parser.add_argument("--warmup_iters", type=int, default=0)
parser.add_argument("--log_interval", type=int, default=100)
parser.add_argument("--save_interval", type=int, default=100)
parser.add_argument('--local_rank', type=int, default=-1)
parser.add_argument('--hidden_size', default=512, type=int)
parser.add_argument('--num_hidden_layers', default=8, type=int)
parser.add_argument('--max_seq_len', default=512, type=int)
parser.add_argument('--use_moe', default=False, type=bool)
parser.add_argument("--data_path", type=str, default="../dataset/sft_mini_512.jsonl")
parser.add_argument("--save_dir", type=str, default="../out", help="模型保存目录")
parser.add_argument('--save_weight', default='full_sft', type=str, help="保存权重的前缀名")
parser.add_argument("--epochs", type=int, default=2, help="训练轮数")
parser.add_argument("--batch_size", type=int, default=16, help="batch size")
parser.add_argument("--learning_rate", type=float, default=5e-7, help="初始学习率")
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备")
parser.add_argument("--dtype", type=str, default="bfloat16", help="混合精度类型")
parser.add_argument("--num_workers", type=int, default=1, help="数据加载线程数")
parser.add_argument("--accumulation_steps", type=int, default=1, help="梯度累积步数")
parser.add_argument("--grad_clip", type=float, default=1.0, help="梯度裁剪阈值")
parser.add_argument("--log_interval", type=int, default=100, help="日志打印间隔")
parser.add_argument("--save_interval", type=int, default=100, help="模型保存间隔")
parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度")
parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量")
parser.add_argument('--max_seq_len', default=512, type=int, help="训练的最大截断长度")
parser.add_argument('--use_moe', default=False, type=bool, help="是否使用MoE")
parser.add_argument("--data_path", type=str, default="../dataset/sft_mini_512.jsonl", help="训练数据路径")
parser.add_argument('--from_weight', default='pretrain', type=str, help="基于哪个权重训练为none则不基于任何权重训练")
parser.add_argument('--from_resume', default=0, type=int, help="是否自动检测&续训0否1是")
parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
parser.add_argument("--wandb_project", type=str, default="MiniMind-Full-SFT", help="wandb项目名")
args = parser.parse_args()
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers,
use_moe=args.use_moe)
args.save_dir = os.path.join(args.out_dir)
# ========== 1. 初始化环境和随机种子 ==========
local_rank = init_distributed_mode()
if dist.is_initialized(): args.device = f"cuda:{local_rank}"
setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))
# ========== 2. 配置目录、模型参数、检查ckp ==========
os.makedirs(args.save_dir, exist_ok=True)
os.makedirs(args.out_dir, exist_ok=True)
tokens_per_iter = args.batch_size * args.max_seq_len
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=args.use_moe)
ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None
# ========== 3. 设置混合精度 ==========
device_type = "cuda" if "cuda" in args.device else "cpu"
args.wandb_run_name = f"MiniMind-Full-SFT-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"
ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast()
ddp = int(os.environ.get("RANK", -1)) != -1 # is this a ddp run?
ddp_local_rank, DEVICE = 0, "cuda:0"
base_seed = 1337
torch.manual_seed(base_seed)
torch.cuda.manual_seed(base_seed)
if ddp:
init_distributed_mode()
args.device = torch.device(DEVICE)
rank = dist.get_rank()
torch.manual_seed(base_seed + rank)
# 同时设置 CUDA 的随机种子
torch.cuda.manual_seed(base_seed + rank)
if args.use_wandb and (not ddp or ddp_local_rank == 0):
dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
# ========== 4. 配wandb ==========
wandb = None
if args.use_wandb and is_main_process():
import swanlab as wandb
wandb.init(project=args.wandb_project, name=args.wandb_run_name)
else:
wandb = None
model, tokenizer = init_model(lm_config)
wandb_id = ckp_data.get('wandb_id') if ckp_data else None
resume = 'must' if wandb_id else None
wandb_run_name = f"MiniMind-Full-SFT-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
# ========== 5. 定义模型、数据、优化器 ==========
model, tokenizer = init_model(lm_config, args.from_weight)
train_ds = SFTDataset(args.data_path, tokenizer, max_length=args.max_seq_len)
train_sampler = DistributedSampler(train_ds) if ddp else None
train_loader = DataLoader(
train_ds,
batch_size=args.batch_size,
pin_memory=True,
drop_last=False,
shuffle=(train_sampler is None),
num_workers=args.num_workers,
sampler=train_sampler
)
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype in ['float16', 'bfloat16']))
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
if ddp:
# ========== 6. 从ckp恢复状态 ==========
start_epoch, start_step = 0, 0
if ckp_data:
model.load_state_dict(ckp_data['model'])
optimizer.load_state_dict(ckp_data['optimizer'])
scaler.load_state_dict(ckp_data['scaler'])
start_epoch = ckp_data['epoch']
start_step = ckp_data.get('step', 0)
# ========== 7. DDP包模型 ==========
if dist.is_initialized():
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
model = DistributedDataParallel(model, device_ids=[ddp_local_rank])
iter_per_epoch = len(train_loader)
for epoch in range(args.epochs):
model = DistributedDataParallel(model, device_ids=[local_rank])
# ========== 8. 开始训练 ==========
for epoch in range(start_epoch, args.epochs):
train_sampler and train_sampler.set_epoch(epoch)
train_epoch(epoch, wandb)
if epoch == start_epoch and start_step > 0: # 第一个epoch且存在检查点
batch_sampler = SkipBatchSampler(train_sampler or range(len(train_ds)), args.batch_size, start_step + 1)
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step从step {start_step + 1}开始')
train_epoch(epoch, loader, len(loader) + start_step + 1, start_step, wandb)
else: # 默认从头开始
loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=(train_sampler is None), sampler=train_sampler, num_workers=args.num_workers, pin_memory=True)
train_epoch(epoch, loader, len(loader), 0, wandb)

View File

@ -3,59 +3,49 @@ import sys
__package__ = "trainer"
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import argparse
import time
import re
import gc
import warnings
import torch
from contextlib import nullcontext
import torch.distributed as dist
from transformers import AutoTokenizer
from contextlib import nullcontext
from torch import optim
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader, DistributedSampler
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModel
from torch.optim.lr_scheduler import CosineAnnealingLR
from transformers import AutoModel
from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
from dataset.lm_dataset import RLAIFDataset
from torch.optim.lr_scheduler import CosineAnnealingLR
from trainer.trainer_utils import *
def Logger(content):
if not ddp or dist.get_rank() == 0:
print(content)
warnings.filterwarnings('ignore')
def calculate_rewards(prompts, responses, reward_model, reward_tokenizer):
"""整合所有奖励函数计算总奖励"""
def reasoning_model_reward(rewards):
# 1. 格式奖励(仅针对训练推理模型时使用)
pattern = r"^<think>\n.*?\n</think>\n<answer>\n.*?\n</answer>$"
pattern2 = r"^<think>\n.*?\n</think>\n\n<answer>\n.*?\n</answer>$"
matches_pattern = [re.match(pattern, response, re.S) for response in responses]
matches_pattern2 = [re.match(pattern2, response, re.S) for response in responses]
format_rewards = []
for match_pattern, match_pattern2 in zip(matches_pattern, matches_pattern2):
if match_pattern:
format_rewards.append(0.5)
elif match_pattern2:
if match_pattern or match_pattern2:
format_rewards.append(0.5)
else:
format_rewards.append(0.0)
rewards += torch.tensor(format_rewards, device=args.device)
# 2. 标记奖励(防止严格奖励稀疏,仅针对训练推理模型时使用)
def mark_num(text):
reward = 0
if text.count("<think>") == 1:
reward += 0.25
if text.count("</think>") == 1:
reward += 0.25
if text.count("<answer>") == 1:
reward += 0.25
if text.count("</answer>") == 1:
reward += 0.25
if text.count("<think>") == 1: reward += 0.25
if text.count("</think>") == 1: reward += 0.25
if text.count("<answer>") == 1: reward += 0.25
if text.count("</answer>") == 1: reward += 0.25
return reward
mark_rewards = [mark_num(response) for response in responses]
@ -63,12 +53,9 @@ def calculate_rewards(prompts, responses, reward_model, reward_tokenizer):
return rewards
rewards = torch.zeros(len(responses), device=args.device)
# 3. 格式奖励
if args.reasoning == 1:
rewards = reasoning_model_reward(rewards) # 训练推理模型时使用
rewards = reasoning_model_reward(rewards)
# 4. 使用reward model计算奖励
with torch.no_grad():
reward_model_scores = []
batch_size = len(prompts)
@ -105,8 +92,8 @@ def calculate_rewards(prompts, responses, reward_model, reward_tokenizer):
return rewards
def grpo_train_epoch(epoch, wandb):
for step, batch in enumerate(train_loader):
def grpo_train_epoch(epoch, loader, iters, ref_model, reward_model, reward_tokenizer, start_step=0, wandb=None):
for step, batch in enumerate(loader, start=start_step + 1):
prompts = batch['prompt'] # list[str], length B
prompt_inputs = tokenizer(prompts, return_tensors="pt", padding=True, return_token_type_ids=False,
padding_side="left", add_special_tokens=False).to(args.device) # input_ids: [B, P], attention_mask: [B, P]
@ -115,7 +102,9 @@ def grpo_train_epoch(epoch, wandb):
prompt_inputs["attention_mask"] = prompt_inputs["attention_mask"][:, -args.max_seq_len:]
with torch.no_grad():
outputs = (model.module if ddp else model).generate(
# DDP 模型需要使用 .module 访问 generate 方法
model_for_gen = model.module if isinstance(model, DistributedDataParallel) else model
outputs = model_for_gen.generate(
**prompt_inputs, max_new_tokens=args.max_gen_len, do_sample=True, temperature=0.8,
num_return_sequences=args.num_generations, pad_token_id=tokenizer.pad_token_id) # [B*num_gen, P+R]
@ -161,36 +150,33 @@ def grpo_train_epoch(epoch, wandb):
scheduler.step()
optimizer.zero_grad()
if step % args.log_interval == 0 or step == iter_per_epoch - 1:
if step % args.log_interval == 0 or step == iters:
policy_loss_val = loss.item()
avg_reward_val = rewards.mean().item()
avg_len_val = completion_mask.sum(dim=1).float().mean().item()
current_lr = optimizer.param_groups[0]['lr']
Logger(
f'Epoch: {epoch}, Step: {step + 1}/{iter_per_epoch}, '
f'Actor Loss: {policy_loss_val:.6f}, Reward: {avg_reward_val:.6f}, '
f'Avg Response Len: {avg_len_val:.2f}, LR: {current_lr:.2e}')
Logger(f'Epoch: {epoch+1}, Step: {step}/{iters}, '
f'Actor Loss: {policy_loss_val:.6f}, Reward: {avg_reward_val:.6f}, '
f'Avg Response Len: {avg_len_val:.2f}, LR: {current_lr:.2e}')
if wandb and (not ddp or dist.get_rank() == 0):
log_dict = {
if wandb and is_main_process():
wandb.log({
"policy_loss": policy_loss_val,
"reward": avg_reward_val,
"avg_response_len": avg_len_val,
"advantages_mean": advantages.mean().item(),
"learning_rate": current_lr
}
wandb.log(log_dict)
})
if ((step + 1) % args.save_interval == 0 or step == iter_per_epoch - 1) and (not ddp or dist.get_rank() == 0):
if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
model.eval()
moe_path = '_moe' if lm_config.use_moe else ''
suffix = 'grpo'
ckp = f'{args.save_dir}/{suffix}_{lm_config.hidden_size}{moe_path}.pth'
state_dict = model.module.state_dict() if isinstance(model,
torch.nn.parallel.DistributedDataParallel) else model.state_dict()
moe_suffix = '_moe' if lm_config.use_moe else ''
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
state_dict = model.module.state_dict() if isinstance(model, DistributedDataParallel) else model.state_dict()
torch.save({k: v.half() for k, v in state_dict.items()}, ckp)
lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer,
epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints', scheduler=scheduler)
model.train()
del prompt_inputs, outputs, completion_ids, per_token_logps, ref_per_token_logps
@ -199,119 +185,114 @@ def grpo_train_epoch(epoch, wandb):
gc.collect()
def init_model(lm_config):
tokenizer = AutoTokenizer.from_pretrained('../model/')
model = MiniMindForCausalLM(lm_config)
moe_path = '_moe' if lm_config.use_moe else ''
ckp = f'{args.save_dir}/full_sft_{lm_config.hidden_size}{moe_path}.pth'
if args.reasoning == 1:
ckp = f'{args.save_dir}/reason_{lm_config.hidden_size}{moe_path}.pth'
state_dict = torch.load(ckp, map_location=args.device)
model.load_state_dict(state_dict, strict=False)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="MiniMind GRPO (Group Relative Policy Optimization)")
parser.add_argument("--save_dir", type=str, default="../out", help="模型保存目录")
parser.add_argument('--save_weight', default='grpo', type=str, help="保存权重的前缀名")
parser.add_argument("--epochs", type=int, default=1, help="训练轮数")
parser.add_argument("--batch_size", type=int, default=2, help="batch size")
parser.add_argument("--learning_rate", type=float, default=8e-8, help="初始学习率")
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备")
parser.add_argument("--dtype", type=str, default="bfloat16", help="混合精度类型")
parser.add_argument("--num_workers", type=int, default=1, help="数据加载线程数")
parser.add_argument("--accumulation_steps", type=int, default=1, help="梯度累积步数")
parser.add_argument("--grad_clip", type=float, default=1.0, help="梯度裁剪阈值")
parser.add_argument("--log_interval", type=int, default=1, help="日志打印间隔")
parser.add_argument("--save_interval", type=int, default=10, help="模型保存间隔")
parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度")
parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量")
parser.add_argument('--use_moe', default=False, type=bool, help="是否使用MoE")
parser.add_argument('--max_seq_len', default=66, type=int, help="Prompt最大长度")
parser.add_argument("--max_gen_len", type=int, default=1536, help="生成的最大长度")
parser.add_argument("--data_path", type=str, default="../dataset/rlaif-mini.jsonl", help="RLAIF数据路径")
parser.add_argument("--num_generations", type=int, default=8, help="每个prompt生成的样本数")
parser.add_argument("--beta", type=float, default=0.02, help="KL惩罚系数")
parser.add_argument("--reasoning", type=int, default=1, help='0:普通模型1:推理模型')
parser.add_argument("--reward_model_path", type=str, default="../../internlm2-1_8b-reward", help="Reward模型路径")
parser.add_argument('--from_resume', default=0, type=int, help="是否自动检测&续训0否1是")
parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
parser.add_argument("--wandb_project", type=str, default="MiniMind-GRPO", help="wandb项目名")
args = parser.parse_args()
# ========== 1. 初始化环境和随机种子 ==========
local_rank = init_distributed_mode()
if dist.is_initialized(): args.device = f"cuda:{local_rank}"
setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))
# ========== 2. 配置目录、模型参数、检查ckp ==========
os.makedirs(args.save_dir, exist_ok=True)
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers,
max_seq_len=args.max_seq_len + args.max_gen_len, use_moe=args.use_moe)
ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None
# ========== 3. 设置混合精度 ==========
device_type = "cuda" if "cuda" in args.device else "cpu"
dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
# ========== 4. 配wandb ==========
wandb = None
if args.use_wandb and is_main_process():
import swanlab as wandb
wandb_id = ckp_data.get('wandb_id') if ckp_data else None
resume = 'must' if wandb_id else None
wandb_run_name = f"MiniMind-GRPO-Epoch-{args.epochs}-BS-{args.batch_size}-LR-{args.learning_rate}"
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
# ========== 5. 初始化模型和数据 ==========
tokenizer = AutoTokenizer.from_pretrained('../model/')
moe_suffix = '_moe' if lm_config.use_moe else ''
base_weight = "reason" if args.reasoning == 1 else "full_sft"
ckp = f'{args.save_dir}/{base_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
state_dict = torch.load(ckp, map_location=args.device)
# Policy模型
model = MiniMindForCausalLM(lm_config)
model.load_state_dict(state_dict, strict=False)
model = model.to(args.device)
Logger(f'Policy模型总参数量{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} M')
# Reference模型
ref_model = MiniMindForCausalLM(lm_config)
ref_model.load_state_dict(state_dict, strict=False)
ref_model.eval().requires_grad_(False)
Logger(f'LLM总参数量{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
model = model.to(args.device)
ref_model = ref_model.to(args.device)
reward_name = "../../internlm2-1_8b-reward"
# Reward模型
reward_model = AutoModel.from_pretrained(
reward_name,
device_map="cuda",
torch_dtype=torch.float16,
trust_remote_code=True,
args.reward_model_path, device_map="cuda", torch_dtype=torch.float16, trust_remote_code=True
).to(args.device).eval().requires_grad_(False)
reward_tokenizer = AutoTokenizer.from_pretrained(reward_name, trust_remote_code=True)
return model, ref_model, tokenizer, reward_model, reward_tokenizer
def init_distributed_mode():
if not ddp: return
global ddp_local_rank, DEVICE
dist.init_process_group(backend="nccl")
ddp_local_rank = int(os.environ["LOCAL_RANK"])
DEVICE = f"cuda:{ddp_local_rank}"
torch.cuda.set_device(DEVICE)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--out_dir", type=str, default="../out")
parser.add_argument("--epochs", type=int, default=1)
parser.add_argument("--batch_size", type=int, default=2)
parser.add_argument("--learning_rate", type=float, default=8e-8)
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu")
parser.add_argument("--dtype", type=str, default="bfloat16")
parser.add_argument("--use_wandb", action="store_true")
parser.add_argument("--wandb_project", type=str, default="MiniMind-GRPO")
parser.add_argument("--num_workers", type=int, default=1)
parser.add_argument("--ddp", action="store_true")
parser.add_argument("--accumulation_steps", type=int, default=1)
parser.add_argument("--grad_clip", type=float, default=1.0)
parser.add_argument("--log_interval", type=int, default=1)
parser.add_argument("--save_interval", type=int, default=10)
parser.add_argument('--hidden_size', default=512, type=int)
parser.add_argument('--num_hidden_layers', default=8, type=int)
parser.add_argument('--use_moe', default=False, type=bool)
parser.add_argument('--max_seq_len', default=66, type=int)
parser.add_argument("--max_gen_len", type=int, default=1536)
parser.add_argument("--data_path", type=str, default="../dataset/rlaif-mini.jsonl")
parser.add_argument("--num_generations", type=int, default=8)
parser.add_argument("--beta", type=float, default=0.02)
parser.add_argument("--reasoning", type=int, default=1, help='0:普通模型1:推理模型')
args = parser.parse_args()
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers,
max_seq_len=args.max_seq_len + args.max_gen_len,
use_moe=args.use_moe)
args.save_dir = os.path.join(args.out_dir)
os.makedirs(args.save_dir, exist_ok=True)
os.makedirs(args.out_dir, exist_ok=True)
ctx = nullcontext() if "cuda" not in args.device else torch.amp.autocast('cuda')
ddp = int(os.environ.get("RANK", -1)) != -1
ddp_local_rank, DEVICE = 0, "cuda:0"
base_seed = 1337
torch.manual_seed(base_seed)
torch.cuda.manual_seed(base_seed)
if ddp:
init_distributed_mode()
args.device = torch.device(DEVICE)
rank = dist.get_rank()
torch.manual_seed(base_seed + rank)
# 同时设置 CUDA 的随机种子
torch.cuda.manual_seed(base_seed + rank)
if args.use_wandb and (not ddp or ddp_local_rank == 0):
import swanlab as wandb
wandb.init(project=args.wandb_project)
else:
wandb = None
model, ref_model, tokenizer, reward_model, reward_tokenizer = init_model(lm_config=lm_config)
reward_tokenizer = AutoTokenizer.from_pretrained(args.reward_model_path, trust_remote_code=True)
# 数据和优化器
train_ds = RLAIFDataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len)
train_sampler = DistributedSampler(train_ds) if ddp else None
train_loader = DataLoader(train_ds, batch_size=args.batch_size, pin_memory=True,
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
loader_for_count = DataLoader(train_ds, batch_size=args.batch_size, sampler=train_sampler)
iters = len(loader_for_count)
total_optimizer_steps = (iters // args.accumulation_steps) * args.epochs
scheduler = CosineAnnealingLR(optimizer, T_max=total_optimizer_steps, eta_min=args.learning_rate / 10)
# ========== 6. 从ckp恢复状态 ==========
start_epoch, start_step = 0, 0
if ckp_data:
model.load_state_dict(ckp_data['model'])
optimizer.load_state_dict(ckp_data['optimizer'])
scheduler.load_state_dict(ckp_data['scheduler'])
start_epoch = ckp_data['epoch']
start_step = ckp_data.get('step', 0)
# ========== 7. DDP包模型 ==========
if dist.is_initialized():
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
model = DistributedDataParallel(model, device_ids=[local_rank])
# ========== 8. 开始训练 ==========
for epoch in range(start_epoch, args.epochs):
train_sampler and train_sampler.set_epoch(epoch)
if epoch == start_epoch and start_step > 0: # 第一个epoch且存在检查点
batch_sampler = SkipBatchSampler(train_sampler or range(len(train_ds)), args.batch_size, start_step + 1)
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step从step {start_step + 1}开始')
grpo_train_epoch(epoch, loader, len(loader) + start_step + 1, ref_model, reward_model, reward_tokenizer, start_step, wandb)
else: # 默认从头开始
loader = DataLoader(train_ds, batch_size=args.batch_size, pin_memory=True,
drop_last=False, shuffle=(train_sampler is None),
num_workers=args.num_workers, sampler=train_sampler)
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
iter_per_epoch = len(train_loader)
total_optimizer_steps = (iter_per_epoch // args.accumulation_steps) * args.epochs
scheduler = CosineAnnealingLR(optimizer, T_max=total_optimizer_steps, eta_min=args.learning_rate / 10)
if ddp:
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
model = DistributedDataParallel(model, device_ids=[ddp_local_rank])
for epoch in range(args.epochs):
train_sampler and train_sampler.set_epoch(epoch)
grpo_train_epoch(epoch, wandb)
grpo_train_epoch(epoch, loader, len(loader), ref_model, reward_model, reward_tokenizer, 0, wandb)

View File

@ -6,49 +6,39 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import argparse
import time
import math
import warnings
import torch
from torch import optim, nn
import torch.distributed as dist
from contextlib import nullcontext
from torch import optim, nn
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader, DistributedSampler
from transformers import AutoTokenizer, AutoModelForCausalLM
from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
from model.model_minimind import MiniMindConfig
from dataset.lm_dataset import SFTDataset
from model.model_lora import load_lora, save_lora, apply_lora
from model.model_lora import save_lora, apply_lora
from trainer.trainer_utils import *
warnings.filterwarnings('ignore')
# Logger function
def Logger(content):
if not ddp or dist.get_rank() == 0:
print(content)
def get_lr(current_step, total_steps, lr):
return lr / 10 + 0.5 * lr * (1 + math.cos(math.pi * current_step / total_steps))
# 代码和full_sft「几乎」一致
def train_epoch(epoch, wandb):
def train_epoch(epoch, loader, iters, lora_params, start_step=0, wandb=None):
loss_fct = nn.CrossEntropyLoss(reduction='none')
start_time = time.time()
for step, (X, Y, loss_mask) in enumerate(train_loader):
for step, (X, Y, loss_mask) in enumerate(loader, start=start_step + 1):
X = X.to(args.device)
Y = Y.to(args.device)
loss_mask = loss_mask.to(args.device)
lr = get_lr(epoch * iter_per_epoch + step, args.epochs * iter_per_epoch, args.learning_rate)
lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
with ctx:
with autocast_ctx:
res = model(X)
loss = loss_fct(
res.logits.view(-1, res.logits.size(-1)),
Y.view(-1)
).view(Y.size())
loss = (loss * loss_mask).sum() / loss_mask.sum()
loss += res.aux_loss
loss = loss / args.accumulation_steps
@ -64,146 +54,122 @@ def train_epoch(epoch, wandb):
optimizer.zero_grad(set_to_none=True)
if step % args.log_interval == 0 or step == iter_per_epoch - 1:
if step % args.log_interval == 0 or step == iters - 1:
spend_time = time.time() - start_time
Logger(
'Epoch:[{}/{}]({}/{}) loss:{:.6f} lr:{:.12f} epoch_Time:{}min:'.format(
epoch + 1,
args.epochs,
step,
iter_per_epoch,
loss.item() * args.accumulation_steps,
optimizer.param_groups[-1]['lr'],
spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60))
current_loss = loss.item() * args.accumulation_steps
current_lr = optimizer.param_groups[-1]['lr']
eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60
Logger(f'Epoch:[{epoch+1}/{args.epochs}]({step}/{iters}) loss:{current_loss:.6f} lr:{current_lr:.12f} epoch_Time:{eta_min}min:')
if wandb: wandb.log({"loss": current_loss, "lr": current_lr, "epoch_Time": eta_min})
if (wandb is not None) and (not ddp or dist.get_rank() == 0):
wandb.log({"loss": loss.item() * args.accumulation_steps,
"lr": optimizer.param_groups[-1]['lr'],
"epoch_Time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60})
if ((step + 1) % args.save_interval == 0 or step == iter_per_epoch - 1) and (not ddp or dist.get_rank() == 0):
if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
model.eval()
lora_save_path = f'{args.save_dir}/lora/{args.lora_name}_{lm_config.hidden_size}.pth'
os.makedirs(os.path.dirname(lora_save_path), exist_ok=True)
# 【区别1】只保存lora权重即可
lora_save_path = f'{args.save_dir}/{args.lora_name}_{lm_config.hidden_size}.pth'
# LoRA只保存LoRA权重
save_lora(model, lora_save_path)
lm_checkpoint(lm_config, weight=args.lora_name, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints')
model.train()
def init_model(lm_config):
tokenizer = AutoTokenizer.from_pretrained('../model/')
model = MiniMindForCausalLM(lm_config)
moe_path = '_moe' if lm_config.use_moe else ''
ckp = f'{args.save_dir}/full_sft_{lm_config.hidden_size}{moe_path}.pth'
state_dict = torch.load(ckp, map_location=args.device)
model.load_state_dict(state_dict, strict=False)
return model.to(args.device), tokenizer
def init_distributed_mode():
if not ddp: return
global ddp_local_rank, DEVICE
dist.init_process_group(backend="nccl")
ddp_rank = int(os.environ["RANK"])
ddp_local_rank = int(os.environ["LOCAL_RANK"])
ddp_world_size = int(os.environ["WORLD_SIZE"])
DEVICE = f"cuda:{ddp_local_rank}"
torch.cuda.set_device(DEVICE)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="MiniMind SFT with LoRA")
parser.add_argument("--out_dir", type=str, default="../out")
parser.add_argument("--epochs", type=int, default=50)
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--learning_rate", type=float, default=1e-4)
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu")
parser.add_argument("--dtype", type=str, default="bfloat16")
parser.add_argument("--use_wandb", action="store_true")
parser.add_argument("--wandb_project", type=str, default="MiniMind-LoRA-SFT")
parser.add_argument("--num_workers", type=int, default=1)
parser.add_argument("--ddp", action="store_true")
parser.add_argument("--accumulation_steps", type=int, default=1)
parser.add_argument("--grad_clip", type=float, default=1.0)
parser.add_argument("--warmup_iters", type=int, default=0)
parser.add_argument("--log_interval", type=int, default=10)
parser.add_argument("--save_interval", type=int, default=1)
parser.add_argument('--local_rank', type=int, default=-1)
parser.add_argument('--hidden_size', default=512, type=int)
parser.add_argument('--num_hidden_layers', default=8, type=int)
parser.add_argument('--max_seq_len', default=512, type=int)
parser.add_argument('--use_moe', default=False, type=bool)
parser.add_argument("--data_path", type=str, default="../dataset/lora_identity.jsonl")
parser.add_argument("--lora_name", type=str, default="lora_identity", help="根据任务保存成lora_(英文/医学/心理...)")
parser = argparse.ArgumentParser(description="MiniMind LoRA Fine-tuning")
parser.add_argument("--save_dir", type=str, default="../out/lora", help="模型保存目录")
parser.add_argument("--lora_name", type=str, default="lora_identity", help="LoRA权重名称(如lora_identity/lora_medical等)")
parser.add_argument("--epochs", type=int, default=50, help="训练轮数")
parser.add_argument("--batch_size", type=int, default=32, help="batch size")
parser.add_argument("--learning_rate", type=float, default=1e-4, help="初始学习率")
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备")
parser.add_argument("--dtype", type=str, default="bfloat16", help="混合精度类型")
parser.add_argument("--num_workers", type=int, default=1, help="数据加载线程数")
parser.add_argument("--accumulation_steps", type=int, default=1, help="梯度累积步数")
parser.add_argument("--grad_clip", type=float, default=1.0, help="梯度裁剪阈值")
parser.add_argument("--log_interval", type=int, default=10, help="日志打印间隔")
parser.add_argument("--save_interval", type=int, default=1, help="模型保存间隔")
parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度")
parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量")
parser.add_argument('--max_seq_len', default=512, type=int, help="训练的最大截断长度")
parser.add_argument('--use_moe', default=False, type=bool, help="是否使用MoE")
parser.add_argument("--data_path", type=str, default="../dataset/lora_identity.jsonl", help="LoRA训练数据路径")
parser.add_argument('--from_weight', default='full_sft', type=str, help="基于哪个权重训练默认full_sft")
parser.add_argument('--from_resume', default=0, type=int, help="是否自动检测&续训0否1是")
parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
parser.add_argument("--wandb_project", type=str, default="MiniMind-LoRA", help="wandb项目名")
args = parser.parse_args()
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers,
use_moe=args.use_moe)
args.save_dir = os.path.join(args.out_dir)
# ========== 1. 初始化环境和随机种子 ==========
local_rank = init_distributed_mode()
if dist.is_initialized(): args.device = f"cuda:{local_rank}"
setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))
# ========== 2. 配置目录、模型参数、检查ckp ==========
os.makedirs(args.save_dir, exist_ok=True)
os.makedirs(args.out_dir, exist_ok=True)
tokens_per_iter = args.batch_size * args.max_seq_len
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=args.use_moe)
ckp_data = lm_checkpoint(lm_config, weight=args.lora_name, save_dir='../checkpoints') if args.from_resume==1 else None
# ========== 3. 设置混合精度 ==========
device_type = "cuda" if "cuda" in args.device else "cpu"
ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast()
ddp = int(os.environ.get("RANK", -1)) != -1 # is this a ddp run?
ddp_local_rank, DEVICE = 0, "cuda:0"
base_seed = 1337
torch.manual_seed(base_seed)
torch.cuda.manual_seed(base_seed)
if ddp:
init_distributed_mode()
args.device = torch.device(DEVICE)
rank = dist.get_rank()
torch.manual_seed(base_seed + rank)
# 同时设置 CUDA 的随机种子
torch.cuda.manual_seed(base_seed + rank)
args.wandb_run_name = f"MiniMind-Lora-SFT-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"
if args.use_wandb and (not ddp or ddp_local_rank == 0):
dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
# ========== 4. 配wandb ==========
wandb = None
if args.use_wandb and is_main_process():
import swanlab as wandb
wandb.init(project=args.wandb_project, name=args.wandb_run_name)
else:
wandb = None
model, tokenizer = init_model(lm_config)
wandb_id = ckp_data.get('wandb_id') if ckp_data else None
resume = 'must' if wandb_id else None
wandb_run_name = f"MiniMind-LoRA-{args.lora_name}-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LR-{args.learning_rate}"
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
# ========== 5. 定义模型、应用LoRA、冻结非LoRA参数 ==========
model, tokenizer = init_model(lm_config, args.from_weight)
apply_lora(model)
total_params = sum(p.numel() for p in model.parameters()) # 总参数数量
lora_params_count = sum(p.numel() for name, p in model.named_parameters() if 'lora' in name) # LoRA 参数数量
if not ddp or dist.get_rank() == 0:
print(f"LLM 总参数量: {total_params}")
print(f"LoRA 参数量: {lora_params_count}")
print(f"LoRA 参数占比: {lora_params_count / total_params * 100:.2f}%")
for name, param in model.named_parameters():
if 'lora' not in name:
param.requires_grad = False
# 统计参数
total_params = sum(p.numel() for p in model.parameters())
lora_params_count = sum(p.numel() for name, p in model.named_parameters() if 'lora' in name)
Logger(f"LLM 总参数量: {total_params / 1e6:.3f} M")
Logger(f"LoRA 参数量: {lora_params_count / 1e6:.3f} M")
Logger(f"LoRA 参数占比: {lora_params_count / total_params * 100:.2f}%")
# 冻结非LoRA参数收集LoRA参数
lora_params = []
for name, param in model.named_parameters():
if 'lora' in name:
param.requires_grad = True
lora_params.append(param)
# 只对 LoRA 参数进行优化
optimizer = optim.AdamW(lora_params, lr=args.learning_rate)
else:
param.requires_grad = False
# ========== 6. 定义数据和优化器 ==========
train_ds = SFTDataset(args.data_path, tokenizer, max_length=args.max_seq_len)
train_sampler = DistributedSampler(train_ds) if ddp else None
train_loader = DataLoader(
train_ds,
batch_size=args.batch_size,
pin_memory=True,
drop_last=False,
shuffle=(train_sampler is None),
num_workers=args.num_workers,
sampler=train_sampler
)
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype in ['float16', 'bfloat16']))
iter_per_epoch = len(train_loader)
for epoch in range(args.epochs):
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
optimizer = optim.AdamW(lora_params, lr=args.learning_rate)
# ========== 7. 从ckp恢复状态 ==========
start_epoch, start_step = 0, 0
if ckp_data:
model.load_state_dict(ckp_data['model'], strict=False)
optimizer.load_state_dict(ckp_data['optimizer'])
scaler.load_state_dict(ckp_data['scaler'])
start_epoch = ckp_data['epoch']
start_step = ckp_data.get('step', 0)
# ========== 8. DDP包模型 ==========
if dist.is_initialized():
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
model = DistributedDataParallel(model, device_ids=[local_rank])
# ========== 9. 开始训练 ==========
for epoch in range(start_epoch, args.epochs):
train_sampler and train_sampler.set_epoch(epoch)
train_epoch(epoch, wandb)
if epoch == start_epoch and start_step > 0: # 第一个epoch且存在检查点
batch_sampler = SkipBatchSampler(train_sampler or range(len(train_ds)), args.batch_size, start_step + 1)
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step从step {start_step + 1}开始')
train_epoch(epoch, loader, len(loader) + start_step + 1, lora_params, start_step, wandb)
else: # 默认从头开始
loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=(train_sampler is None), sampler=train_sampler, num_workers=args.num_workers, pin_memory=True)
train_epoch(epoch, loader, len(loader), lora_params, 0, wandb)

View File

@ -6,27 +6,43 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import argparse
import re
import warnings
import torch
import torch.distributed as dist
import torch.nn.functional as F
from transformers import AutoTokenizer
from contextlib import nullcontext
from torch import optim, nn
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader, DistributedSampler
from transformers import AutoTokenizer, AutoModel
from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
from dataset.lm_dataset import RLAIFDataset
from torch.nn.utils import clip_grad_norm_
from torch.optim.lr_scheduler import CosineAnnealingLR
from transformers import AutoModel
from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
from dataset.lm_dataset import RLAIFDataset
from trainer.trainer_utils import *
warnings.filterwarnings('ignore')
def Logger(content):
if not ddp or dist.get_rank() == 0:
print(content)
# 自定义的Critic模型继承自MiniMindLM
class CriticModel(MiniMindForCausalLM):
def __init__(self, params):
super().__init__(params)
# 替换lm_head为输出单一价值的线性层
self.value_head = nn.Linear(params.hidden_size, 1)
def forward(self, input_ids=None, attention_mask=None, **kwargs):
# 使用基础模型获取隐藏状态
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, **kwargs)
hidden_states = self.model.norm(outputs[0])
# 使用value_head获取价值估计
values = self.value_head(hidden_states).squeeze(-1)
return values
def calculate_rewards(prompts, responses, reward_model, reward_tokenizer):
"""整合所有奖励函数计算总奖励"""
def reasoning_model_reward(rewards):
# 1. 格式奖励(仅针对训练推理模型时使用)
pattern = r"^<think>\n.*?\n</think>\n<answer>\n.*?\n</answer>$"
@ -66,7 +82,7 @@ def calculate_rewards(prompts, responses, reward_model, reward_tokenizer):
# 格式奖励
if args.reasoning == 1:
rewards = reasoning_model_reward(rewards) # 训练推理模型时使用
rewards = reasoning_model_reward(rewards)
# 使用reward model计算整个response的奖励
with torch.no_grad():
@ -91,7 +107,6 @@ def calculate_rewards(prompts, responses, reward_model, reward_tokenizer):
tmp_chat = messages + [{"role": "assistant", "content": answer_content}]
answer_score = reward_model.get_score(reward_tokenizer, tmp_chat)
answer_score = max(min(answer_score, scale), -scale)
score = score * 0.4 + answer_score * 0.6
reward_model_scores.append(score)
@ -101,19 +116,20 @@ def calculate_rewards(prompts, responses, reward_model, reward_tokenizer):
return rewards
def ppo_train_epoch(epoch: int, wandb_run, old_actor_model, ref_model, actor_scheduler, critic_scheduler):
def ppo_train_epoch(epoch, loader, iters, old_actor_model, ref_model, actor_scheduler, critic_scheduler, reward_model, reward_tokenizer, start_step=0, wandb=None):
actor_model.train()
critic_model.train()
is_master = (not ddp) or dist.get_rank() == 0
for step, batch in enumerate(train_loader):
for step, batch in enumerate(loader, start=start_step + 1):
prompts = batch["prompt"] # list[str], length B
enc = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True,
max_length=args.max_seq_len).to(args.device) # input_ids: [B, P], attention_mask: [B, P]
prompt_lengths = enc.attention_mask.sum(dim=1) # [B]
with torch.no_grad():
gen_out = actor_model.generate(
# DDP 模型需要使用 .module 访问 generate 方法
model_for_gen = actor_model.module if isinstance(actor_model, DistributedDataParallel) else actor_model
gen_out = model_for_gen.generate(
input_ids=enc.input_ids, attention_mask=enc.attention_mask,
max_new_tokens=args.max_gen_len, do_sample=True, temperature=0.8,
pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id) # [B, P+R]
@ -164,7 +180,7 @@ def ppo_train_epoch(epoch: int, wandb_run, old_actor_model, ref_model, actor_sch
actor_optimizer.zero_grad()
critic_optimizer.zero_grad()
if is_master:
if is_main_process():
response_ids = gen_out[:, enc.input_ids.shape[1]:]
is_eos = (response_ids == tokenizer.eos_token_id)
eos_indices = torch.argmax(is_eos.int(), dim=1)
@ -181,8 +197,8 @@ def ppo_train_epoch(epoch: int, wandb_run, old_actor_model, ref_model, actor_sch
actor_lr = actor_optimizer.param_groups[0]['lr']
critic_lr = critic_optimizer.param_groups[0]['lr']
if wandb_run is not None:
wandb_run.log({
if wandb is not None:
wandb.log({
"actor_loss": actor_loss_val,
"critic_loss": critic_loss_val,
"reward": reward_val,
@ -192,183 +208,158 @@ def ppo_train_epoch(epoch: int, wandb_run, old_actor_model, ref_model, actor_sch
"actor_lr": actor_lr,
})
Logger(f"Epoch: {epoch}, Step: {step + 1}/{len(train_loader)}, "
Logger(f"Epoch: {epoch+1}, Step: {step}/{iters}, "
f"Actor Loss: {actor_loss_val:.6f}, Critic Loss: {critic_loss_val:.6f}, "
f"Reward: {reward_val:.6f}, KL: {kl_val:.6f}, KL_ref: {kl_ref_val:.6f}, "
f"Avg Response Len: {avg_len_val:.2f}, Actor LR: {actor_lr:.2e}, Critic LR: {critic_lr:.2e}")
if (step + 1) % args.update_old_actor_freq == 0:
state_dict = actor_model.module.state_dict() if isinstance(actor_model, torch.nn.parallel.DistributedDataParallel) else actor_model.state_dict()
state_dict = actor_model.module.state_dict() if isinstance(actor_model, DistributedDataParallel) else actor_model.state_dict()
old_actor_model.load_state_dict({k: v.detach().cpu() for k, v in state_dict.items()})
old_actor_model.to(args.device)
if ((step + 1) % args.save_interval == 0 or step == iter_per_epoch - 1) and (not ddp or dist.get_rank() == 0):
if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
actor_model.eval()
moe_path = '_moe' if lm_config.use_moe else ''
ckp = f'{args.save_dir}/ppo_actor_{lm_config.hidden_size}{moe_path}.pth'
if isinstance(actor_model, torch.nn.parallel.DistributedDataParallel):
state_dict = actor_model.module.state_dict()
else:
state_dict = actor_model.state_dict()
state_dict = {k: v.half() for k, v in state_dict.items()} # 半精度保存
torch.save(state_dict, ckp)
moe_suffix = '_moe' if lm_config.use_moe else ''
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
actor_state = actor_model.module.state_dict() if isinstance(actor_model, DistributedDataParallel) else actor_model.state_dict()
torch.save({k: v.half() for k, v in actor_state.items()}, ckp)
# 使用 lm_checkpoint 保存完整状态(包括 critic
lm_checkpoint(lm_config, weight=args.save_weight, model=actor_model, optimizer=actor_optimizer,
epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints',
scheduler=actor_scheduler, critic_model=critic_model,
critic_optimizer=critic_optimizer, critic_scheduler=critic_scheduler)
actor_model.train()
# 自定义的Critic模型继承自MiniMindLM
class CriticModel(MiniMindForCausalLM):
def __init__(self, params):
super().__init__(params)
# 替换lm_head为输出单一价值的线性层
self.value_head = nn.Linear(params.hidden_size, 1)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="MiniMind PPO (Proximal Policy Optimization)")
parser.add_argument("--save_dir", type=str, default="../out", help="模型保存目录")
parser.add_argument('--save_weight', default='ppo_actor', type=str, help="保存权重的前缀名")
parser.add_argument("--epochs", type=int, default=1, help="训练轮数")
parser.add_argument("--batch_size", type=int, default=2, help="batch size")
parser.add_argument("--learning_rate", type=float, default=8e-8, help="Actor学习率")
parser.add_argument("--critic_learning_rate", type=float, default=8e-8, help="Critic学习率")
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备")
parser.add_argument("--dtype", type=str, default="bfloat16", help="混合精度类型")
parser.add_argument("--num_workers", type=int, default=1, help="数据加载线程数")
parser.add_argument("--accumulation_steps", type=int, default=1, help="梯度累积步数")
parser.add_argument("--grad_clip", type=float, default=1.0, help="梯度裁剪阈值")
parser.add_argument("--log_interval", type=int, default=1, help="日志打印间隔")
parser.add_argument("--save_interval", type=int, default=10, help="模型保存间隔")
parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度")
parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量")
parser.add_argument('--use_moe', default=False, type=bool, help="是否使用MoE")
parser.add_argument('--max_seq_len', default=66, type=int, help="Prompt最大长度")
parser.add_argument("--max_gen_len", type=int, default=1536, help="生成的最大长度")
parser.add_argument("--data_path", type=str, default="../dataset/rlaif-mini.jsonl", help="RLAIF数据路径")
parser.add_argument("--clip_epsilon", type=float, default=0.1, help="PPO裁剪参数")
parser.add_argument("--vf_coef", type=float, default=0.5, help="Value function系数")
parser.add_argument("--kl_coef", type=float, default=0.02, help="KL散度惩罚系数")
parser.add_argument("--reasoning", type=int, default=1, help='0:普通模型1:推理模型')
parser.add_argument("--update_old_actor_freq", type=int, default=4, help="更新old_actor_model的频率")
parser.add_argument("--reward_model_path", type=str, default="../../internlm2-1_8b-reward", help="Reward模型路径")
parser.add_argument('--from_resume', default=0, type=int, help="是否自动检测&续训0否1是")
parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
parser.add_argument("--wandb_project", type=str, default="MiniMind-PPO", help="wandb项目名")
args = parser.parse_args()
def forward(self, input_ids=None, attention_mask=None, **kwargs):
# 使用基础模型获取隐藏状态
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, **kwargs)
# self.model 返回的是一个元组,第一个元素是 last_hidden_state
hidden_states = self.model.norm(outputs[0])
# 使用value_head获取价值估计
values = self.value_head(hidden_states).squeeze(-1)
return values
def init_model(lm_config):
# ========== 1. 初始化环境和随机种子 ==========
local_rank = init_distributed_mode()
if dist.is_initialized(): args.device = f"cuda:{local_rank}"
setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))
# ========== 2. 配置目录、模型参数、检查ckp ==========
os.makedirs(args.save_dir, exist_ok=True)
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=args.use_moe)
ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None
# ========== 3. 设置混合精度 ==========
device_type = "cuda" if "cuda" in args.device else "cpu"
dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
# ========== 4. 配wandb ==========
wandb = None
if args.use_wandb and is_main_process():
import swanlab as wandb
wandb_id = ckp_data.get('wandb_id') if ckp_data else None
resume = 'must' if wandb_id else None
wandb_run_name = f"MiniMind-PPO-Epoch-{args.epochs}-BS-{args.batch_size}-LR-{args.learning_rate}"
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
# ========== 5. 初始化模型和数据 ==========
tokenizer = AutoTokenizer.from_pretrained('../model/', padding_side='left')
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
moe_path = '_moe' if lm_config.use_moe else ''
ckp = f'{args.save_dir}/{"reason" if args.reasoning == 1 else "full_sft"}_{lm_config.hidden_size}{moe_path}.pth'
moe_suffix = '_moe' if lm_config.use_moe else ''
base_weight = "reason" if args.reasoning == 1 else "full_sft"
ckp = f'{args.save_dir}/{base_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
state_dict = torch.load(ckp, map_location=args.device)
# Actor模型
actor_model = MiniMindForCausalLM(lm_config)
actor_model.load_state_dict(state_dict, strict=False)
actor_model = actor_model.to(args.device)
Logger(f'Actor模型总参数量{sum(p.numel() for p in actor_model.parameters() if p.requires_grad) / 1e6:.3f} M')
# Old Actor模型
old_actor_model = MiniMindForCausalLM(lm_config)
old_actor_model.load_state_dict(state_dict, strict=False)
old_actor_model = old_actor_model.eval().requires_grad_(False).to(args.device)
# Reference模型
ref_model = MiniMindForCausalLM(lm_config)
ref_model.load_state_dict(state_dict, strict=False)
ref_model = ref_model.eval().requires_grad_(False).to(args.device)
# Critic模型
critic_model = CriticModel(lm_config)
critic_model.load_state_dict(state_dict, strict=False)
critic_model = critic_model.to(args.device)
reward_name = "../../internlm2-1_8b-reward"
Logger(f'Critic模型总参数量{sum(p.numel() for p in critic_model.parameters() if p.requires_grad) / 1e6:.3f} M')
# Reward模型
reward_model = AutoModel.from_pretrained(
reward_name, device_map="cuda", torch_dtype=torch.float32, trust_remote_code=True
args.reward_model_path, device_map="cuda", torch_dtype=torch.float32, trust_remote_code=True
).to(args.device).eval().requires_grad_(False)
reward_tokenizer = AutoTokenizer.from_pretrained(reward_name, trust_remote_code=True)
Logger(f'Actor模型总参数量{sum(p.numel() for p in actor_model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
Logger(f'Critic模型总参数量{sum(p.numel() for p in critic_model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
return actor_model, old_actor_model, ref_model, critic_model, reward_model, tokenizer, reward_tokenizer
def init_distributed_mode():
if not ddp: return
global ddp_local_rank, DEVICE
dist.init_process_group(backend="nccl")
ddp_local_rank = int(os.environ["LOCAL_RANK"])
DEVICE = f"cuda:{ddp_local_rank}"
torch.cuda.set_device(DEVICE)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--out_dir", type=str, default="../out")
parser.add_argument("--epochs", type=int, default=1)
parser.add_argument("--batch_size", type=int, default=2)
parser.add_argument("--learning_rate", type=float, default=8e-8)
parser.add_argument("--critic_learning_rate", type=float, default=8e-8)
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu")
parser.add_argument("--dtype", type=str, default="bfloat16")
parser.add_argument("--use_wandb", action="store_true")
parser.add_argument("--wandb_project", type=str, default="MiniMind-PPO")
parser.add_argument("--num_workers", type=int, default=1)
parser.add_argument("--ddp", action="store_true")
parser.add_argument("--accumulation_steps", type=int, default=1)
parser.add_argument("--grad_clip", type=float, default=1.0)
parser.add_argument("--log_interval", type=int, default=1)
parser.add_argument("--save_interval", type=int, default=10)
parser.add_argument('--hidden_size', default=512, type=int)
parser.add_argument('--num_hidden_layers', default=8, type=int)
parser.add_argument('--use_moe', default=False, type=bool)
parser.add_argument('--max_seq_len', default=66, type=int)
parser.add_argument("--max_gen_len", type=int, default=1536)
parser.add_argument("--data_path", type=str, default="../dataset/rlaif-mini.jsonl")
parser.add_argument("--clip_epsilon", type=float, default=0.1)
parser.add_argument("--vf_coef", type=float, default=0.5)
parser.add_argument("--kl_coef", type=float, default=0.02, help="KL散度惩罚系数")
parser.add_argument("--reasoning", type=int, default=1, help='0:普通模型1:推理模型')
parser.add_argument("--update_old_actor_freq", type=int, default=4, help="频率每处理n个batch后更新old_actor_model")
args = parser.parse_args()
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers,
use_moe=args.use_moe)
args.save_dir = os.path.join(args.out_dir)
os.makedirs(args.save_dir, exist_ok=True)
os.makedirs(args.out_dir, exist_ok=True)
ddp = int(os.environ.get("RANK", -1)) != -1
ddp_local_rank, DEVICE = 0, "cuda:0"
base_seed = 1337
torch.manual_seed(base_seed)
torch.cuda.manual_seed(base_seed)
if ddp:
init_distributed_mode()
args.device = torch.device(DEVICE)
rank = dist.get_rank()
torch.manual_seed(base_seed + rank)
# 同时设置 CUDA 的随机种子
torch.cuda.manual_seed(base_seed + rank)
if args.use_wandb and (not ddp or ddp_local_rank == 0):
import swanlab as wandb
wandb.init(project=args.wandb_project)
else:
wandb = None
# 初始化所有模型
actor_model, old_actor_model, ref_model, critic_model, reward_model, tokenizer, reward_tokenizer = init_model(lm_config=lm_config)
# 准备数据集和数据加载器
reward_tokenizer = AutoTokenizer.from_pretrained(args.reward_model_path, trust_remote_code=True)
# 数据和优化器
train_ds = RLAIFDataset(args.data_path, tokenizer, max_length=(args.max_seq_len + args.max_gen_len))
train_sampler = DistributedSampler(train_ds) if ddp else None
train_loader = DataLoader(train_ds, batch_size=args.batch_size, pin_memory=True,
drop_last=False, shuffle=(train_sampler is None),
num_workers=args.num_workers, sampler=train_sampler)
# 初始化优化器
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
actor_optimizer = optim.AdamW(actor_model.parameters(), lr=args.learning_rate)
critic_optimizer = optim.AdamW(critic_model.parameters(), lr=args.critic_learning_rate)
iter_per_epoch = len(train_loader)
total_optimizer_steps = (iter_per_epoch // args.accumulation_steps) * args.epochs
loader_for_count = DataLoader(train_ds, batch_size=args.batch_size, sampler=train_sampler)
iters = len(loader_for_count)
total_optimizer_steps = (iters // args.accumulation_steps) * args.epochs
actor_scheduler = CosineAnnealingLR(actor_optimizer, T_max=total_optimizer_steps, eta_min=args.learning_rate / 10)
critic_scheduler = CosineAnnealingLR(critic_optimizer, T_max=total_optimizer_steps,
eta_min=args.critic_learning_rate / 10)
# 如果使用分布式训练,包装模型
if ddp:
critic_scheduler = CosineAnnealingLR(critic_optimizer, T_max=total_optimizer_steps, eta_min=args.critic_learning_rate / 10)
# ========== 6. 从ckp恢复状态 ==========
start_epoch, start_step = 0, 0
if ckp_data:
actor_model.load_state_dict(ckp_data['model'])
critic_model.load_state_dict(ckp_data['critic_model'])
actor_optimizer.load_state_dict(ckp_data['optimizer'])
critic_optimizer.load_state_dict(ckp_data['critic_optimizer'])
actor_scheduler.load_state_dict(ckp_data['scheduler'])
critic_scheduler.load_state_dict(ckp_data['critic_scheduler'])
start_epoch = ckp_data['epoch']
start_step = ckp_data.get('step', 0)
# ========== 7. DDP包模型 ==========
if dist.is_initialized():
actor_model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
critic_model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
actor_model = DistributedDataParallel(actor_model, device_ids=[ddp_local_rank])
critic_model = DistributedDataParallel(critic_model, device_ids=[ddp_local_rank])
# old_actor_model 不需要DDP包装因为它只在主进程上用于计算并且不进行梯度更新
actor_model = DistributedDataParallel(actor_model, device_ids=[local_rank])
critic_model = DistributedDataParallel(critic_model, device_ids=[local_rank])
old_actor_model.to(args.device)
for epoch in range(args.epochs):
# ========== 8. 开始训练 ==========
for epoch in range(start_epoch, args.epochs):
train_sampler and train_sampler.set_epoch(epoch)
ppo_train_epoch(epoch, wandb, old_actor_model, ref_model, actor_scheduler, critic_scheduler)
if ddp:
dist.destroy_process_group()
if epoch == start_epoch and start_step > 0: # 第一个epoch且存在检查点
batch_sampler = SkipBatchSampler(train_sampler or range(len(train_ds)), args.batch_size, start_step + 1)
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step从step {start_step + 1}开始')
ppo_train_epoch(epoch, loader, len(loader) + start_step + 1, old_actor_model, ref_model,
actor_scheduler, critic_scheduler, reward_model, reward_tokenizer, start_step, wandb)
else: # 默认从头开始
loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=(train_sampler is None),
sampler=train_sampler, num_workers=args.num_workers, pin_memory=True)
ppo_train_epoch(epoch, loader, len(loader), old_actor_model, ref_model,
actor_scheduler, critic_scheduler, reward_model, reward_tokenizer, 0, wandb)

View File

@ -6,48 +6,38 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import argparse
import time
import math
import warnings
import torch
import torch.distributed as dist
from contextlib import nullcontext
from torch import optim, nn
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader, DistributedSampler
from contextlib import nullcontext
from transformers import AutoTokenizer
from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
from model.model_minimind import MiniMindConfig
from dataset.lm_dataset import PretrainDataset
from trainer.trainer_utils import *
warnings.filterwarnings('ignore')
def Logger(content):
if not ddp or dist.get_rank() == 0:
print(content)
def get_lr(current_step, total_steps, lr):
return lr / 10 + 0.5 * lr * (1 + math.cos(math.pi * current_step / total_steps))
def train_epoch(epoch, wandb):
def train_epoch(epoch, loader, iters, start_step=0, wandb=None):
loss_fct = nn.CrossEntropyLoss(reduction='none')
start_time = time.time()
for step, (X, Y, loss_mask) in enumerate(train_loader):
for step, (X, Y, loss_mask) in enumerate(loader, start=start_step + 1):
X = X.to(args.device)
Y = Y.to(args.device)
loss_mask = loss_mask.to(args.device)
lr = get_lr(epoch * iter_per_epoch + step, args.epochs * iter_per_epoch, args.learning_rate)
lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
with ctx:
with autocast_ctx:
res = model(X)
loss = loss_fct(
res.logits.view(-1, res.logits.size(-1)),
Y.view(-1)
).view(Y.size())
loss = (loss * loss_mask).sum() / loss_mask.sum()
loss += res.aux_loss
loss = loss / args.accumulation_steps
@ -63,139 +53,108 @@ def train_epoch(epoch, wandb):
optimizer.zero_grad(set_to_none=True)
if step % args.log_interval == 0 or step == iter_per_epoch - 1:
if step % args.log_interval == 0 or step == iters - 1:
spend_time = time.time() - start_time
Logger(
'Epoch:[{}/{}]({}/{}) loss:{:.6f} lr:{:.12f} epoch_Time:{}min:'.format(
epoch + 1,
args.epochs,
step,
iter_per_epoch,
loss.item() * args.accumulation_steps,
optimizer.param_groups[-1]['lr'],
spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60))
current_loss = loss.item() * args.accumulation_steps
current_lr = optimizer.param_groups[-1]['lr']
eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60
Logger(f'Epoch:[{epoch+1}/{args.epochs}]({step}/{iters}) loss:{current_loss:.6f} lr:{current_lr:.12f} epoch_Time:{eta_min}min:')
if wandb: wandb.log({"loss": current_loss, "lr": current_lr, "epoch_Time": eta_min})
if (wandb is not None) and (not ddp or dist.get_rank() == 0):
wandb.log({"loss": loss.item() * args.accumulation_steps,
"lr": optimizer.param_groups[-1]['lr'],
"epoch_Time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60})
if ((step + 1) % args.save_interval == 0 or step == iter_per_epoch - 1) and (not ddp or dist.get_rank() == 0):
if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
model.eval()
moe_path = '_moe' if lm_config.use_moe else ''
ckp = f'{args.save_dir}/pretrain_{lm_config.hidden_size}{moe_path}.pth'
moe_suffix = '_moe' if lm_config.use_moe else ''
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
state_dict = {k: v.half() for k, v in state_dict.items()} # 半精度保存
torch.save(state_dict, ckp)
lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints')
model.train()
def init_model(lm_config):
tokenizer = AutoTokenizer.from_pretrained('../model/')
model = MiniMindForCausalLM(lm_config).to(args.device)
Logger(f'LLM可训练总参数量{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
return model, tokenizer
def init_distributed_mode():
if not ddp: return
global ddp_local_rank, DEVICE
dist.init_process_group(backend="nccl")
ddp_rank = int(os.environ["RANK"])
ddp_local_rank = int(os.environ["LOCAL_RANK"])
ddp_world_size = int(os.environ["WORLD_SIZE"])
DEVICE = f"cuda:{ddp_local_rank}"
torch.cuda.set_device(DEVICE)
# torchrun --nproc_per_node 2 1-pretrain.py
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="MiniMind Pretraining")
parser.add_argument("--out_dir", type=str, default="../out")
# 若要以最快速度实现zero则epochs设置为1轮否则应当利用有限的数据训练2~6个epochs。
parser.add_argument("--epochs", type=int, default=1)
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--learning_rate", type=float, default=5e-4)
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu")
parser.add_argument("--dtype", type=str, default="bfloat16")
parser.add_argument("--use_wandb", action="store_true")
parser.add_argument("--wandb_project", type=str, default="MiniMind-Pretrain")
parser.add_argument("--num_workers", type=int, default=1)
parser.add_argument("--ddp", action="store_true")
parser.add_argument("--accumulation_steps", type=int, default=8)
parser.add_argument("--grad_clip", type=float, default=1.0)
parser.add_argument("--warmup_iters", type=int, default=0)
parser.add_argument("--log_interval", type=int, default=100)
parser.add_argument("--save_interval", type=int, default=100)
parser.add_argument('--local_rank', type=int, default=-1)
parser.add_argument('--hidden_size', default=512, type=int)
parser.add_argument('--num_hidden_layers', default=8, type=int)
parser.add_argument('--max_seq_len', default=512, type=int)
parser.add_argument('--use_moe', default=False, type=bool)
parser.add_argument("--data_path", type=str, default="../dataset/pretrain_hq.jsonl")
parser.add_argument("--save_dir", type=str, default="../out", help="模型保存目录")
parser.add_argument('--save_weight', default='pretrain', type=str, help="保存权重的前缀名")
parser.add_argument("--epochs", type=int, default=1, help="训练轮数建议1轮zero或2-6轮充分训练")
parser.add_argument("--batch_size", type=int, default=32, help="batch size")
parser.add_argument("--learning_rate", type=float, default=5e-4, help="初始学习率")
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备")
parser.add_argument("--dtype", type=str, default="bfloat16", help="混合精度类型")
parser.add_argument("--num_workers", type=int, default=1, help="数据加载线程数")
parser.add_argument("--accumulation_steps", type=int, default=8, help="梯度累积步数")
parser.add_argument("--grad_clip", type=float, default=1.0, help="梯度裁剪阈值")
parser.add_argument("--log_interval", type=int, default=100, help="日志打印间隔")
parser.add_argument("--save_interval", type=int, default=100, help="模型保存间隔")
parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度")
parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量")
parser.add_argument('--max_seq_len', default=512, type=int, help="训练的最大截断长度")
parser.add_argument('--use_moe', default=False, type=bool, help="是否使用MoE")
parser.add_argument("--data_path", type=str, default="../dataset/pretrain_hq.jsonl", help="预训练数据路径")
parser.add_argument('--from_weight', default='none', type=str, help="基于哪个权重训练为none则从头开始")
parser.add_argument('--from_resume', default=0, type=int, help="是否自动检测&续训0否1是")
parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
parser.add_argument("--wandb_project", type=str, default="MiniMind-Pretrain", help="wandb项目名")
args = parser.parse_args()
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers,
use_moe=args.use_moe)
args.save_dir = os.path.join(args.out_dir)
# ========== 1. 初始化环境和随机种子 ==========
local_rank = init_distributed_mode()
if dist.is_initialized(): args.device = f"cuda:{local_rank}"
setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))
# ========== 2. 配置目录、模型参数、检查ckp ==========
os.makedirs(args.save_dir, exist_ok=True)
os.makedirs(args.out_dir, exist_ok=True)
tokens_per_iter = args.batch_size * args.max_seq_len
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=args.use_moe)
ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None
# ========== 3. 设置混合精度 ==========
device_type = "cuda" if "cuda" in args.device else "cpu"
args.wandb_run_name = f"MiniMind-Pretrain-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"
ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast()
ddp = int(os.environ.get("RANK", -1)) != -1 # is this a ddp run?
ddp_local_rank, DEVICE = 0, "cuda:0"
base_seed = 1337
torch.manual_seed(base_seed)
torch.cuda.manual_seed(base_seed)
if ddp:
init_distributed_mode()
args.device = torch.device(DEVICE)
rank = dist.get_rank()
torch.manual_seed(base_seed + rank)
# 同时设置 CUDA 的随机种子
torch.cuda.manual_seed(base_seed + rank)
if args.use_wandb and (not ddp or ddp_local_rank == 0):
dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
# ========== 4. 配wandb ==========
wandb = None
if args.use_wandb and is_main_process():
import swanlab as wandb
wandb.init(project=args.wandb_project, name=args.wandb_run_name)
else:
wandb = None
model, tokenizer = init_model(lm_config)
wandb_id = ckp_data.get('wandb_id') if ckp_data else None
resume = 'must' if wandb_id else None
wandb_run_name = f"MiniMind-Pretrain-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
# ========== 5. 定义模型、数据、优化器 ==========
model, tokenizer = init_model(lm_config, args.from_weight)
train_ds = PretrainDataset(args.data_path, tokenizer, max_length=args.max_seq_len)
train_sampler = DistributedSampler(train_ds) if ddp else None
train_loader = DataLoader(
train_ds,
batch_size=args.batch_size,
pin_memory=True,
drop_last=False,
shuffle=(train_sampler is None),
num_workers=args.num_workers,
sampler=train_sampler
)
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype in ['float16', 'bfloat16']))
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
if ddp:
# ========== 6. 从ckp恢复状态 ==========
start_epoch, start_step = 0, 0
if ckp_data:
model.load_state_dict(ckp_data['model'])
optimizer.load_state_dict(ckp_data['optimizer'])
scaler.load_state_dict(ckp_data['scaler'])
start_epoch = ckp_data['epoch']
start_step = ckp_data.get('step', 0)
# ========== 7. DDP包模型 ==========
if dist.is_initialized():
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
model = DistributedDataParallel(model, device_ids=[ddp_local_rank])
iter_per_epoch = len(train_loader)
for epoch in range(args.epochs):
model = DistributedDataParallel(model, device_ids=[local_rank])
# ========== 8. 开始训练 ==========
for epoch in range(start_epoch, args.epochs):
train_sampler and train_sampler.set_epoch(epoch)
train_epoch(epoch, wandb)
if epoch == start_epoch and start_step > 0: # 第一个epoch且存在检查点
batch_sampler = SkipBatchSampler(train_sampler or range(len(train_ds)), args.batch_size, start_step + 1)
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step从step {start_step + 1}开始')
train_epoch(epoch, loader, len(loader) + start_step + 1, start_step, wandb)
else: # 默认从头开始
loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=(train_sampler is None), sampler=train_sampler, num_workers=args.num_workers, pin_memory=True)
train_epoch(epoch, loader, len(loader), 0, wandb)

View File

@ -3,31 +3,35 @@ import sys
__package__ = "trainer"
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import argparse
import time
import re
import gc
import warnings
import torch
from contextlib import nullcontext
import torch.distributed as dist
from transformers import AutoTokenizer
from contextlib import nullcontext
from torch import optim
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader, DistributedSampler
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModel
from torch.optim.lr_scheduler import CosineAnnealingLR
from transformers import AutoModel
from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
from dataset.lm_dataset import RLAIFDataset
from torch.optim.lr_scheduler import CosineAnnealingLR
from collections import defaultdict
from trainer.trainer_utils import *
warnings.filterwarnings('ignore')
class AutoAdaptiveValueTracker:
"""SPO自适应价值追踪器"""
def __init__(self, rho_mode='kl', rho_const=0.9, D_half=0.06, clip_lower=0.5, clip_upper=0.96):
self.rho_mode = rho_mode
self.rho_const = rho_const
self.D_half = D_half
self.clip_lower = clip_lower
self.clip_upper = clip_upper
# Stable initialization following N_init = 1/(1-clip_lower)
N_init = 1.0 / (1.0 - self.clip_lower)
self.alpha = 0.5 * N_init
self.beta = 0.5 * N_init
@ -62,43 +66,28 @@ class AutoAdaptiveValueTracker:
return rho
def Logger(content):
if not ddp or dist.get_rank() == 0:
print(content)
def calculate_rewards(prompts, responses, reward_model, reward_tokenizer):
"""整合所有奖励函数计算总奖励"""
def reasoning_model_reward(rewards):
# 1. 格式奖励(仅针对训练推理模型时使用)
pattern = r"^<think>\n.*?\n</think>\n<answer>\n.*?\n</answer>$"
pattern2 = r"^<think>\n.*?\n</think>\n\n<answer>\n.*?\n</answer>$"
matches_pattern = [re.match(pattern, response, re.S) for response in responses]
matches_pattern2 = [re.match(pattern2, response, re.S) for response in responses]
format_rewards = []
for match_pattern, match_pattern2 in zip(matches_pattern, matches_pattern2):
if match_pattern:
format_rewards.append(0.5)
elif match_pattern2:
if match_pattern or match_pattern2:
format_rewards.append(0.5)
else:
format_rewards.append(0.0)
rewards += torch.tensor(format_rewards, device=args.device)
# 2. 标记奖励(防止严格奖励稀疏,仅针对训练推理模型时使用)
def mark_num(text):
reward = 0
if text.count("<think>") == 1:
reward += 0.25
if text.count("</think>") == 1:
reward += 0.25
if text.count("<answer>") == 1:
reward += 0.25
if text.count("</answer>") == 1:
reward += 0.25
if text.count("<think>") == 1: reward += 0.25
if text.count("</think>") == 1: reward += 0.25
if text.count("<answer>") == 1: reward += 0.25
if text.count("</answer>") == 1: reward += 0.25
return reward
mark_rewards = [mark_num(response) for response in responses]
@ -106,12 +95,9 @@ def calculate_rewards(prompts, responses, reward_model, reward_tokenizer):
return rewards
rewards = torch.zeros(len(responses), device=args.device)
# 3. 格式奖励
if args.reasoning == 1:
rewards = reasoning_model_reward(rewards) # 训练推理模型时使用
rewards = reasoning_model_reward(rewards)
# 4. 使用reward model计算奖励
with torch.no_grad():
reward_model_scores = []
scale = 3.0
@ -142,8 +128,8 @@ def calculate_rewards(prompts, responses, reward_model, reward_tokenizer):
return rewards
def spo_train_epoch(epoch, wandb, value_tracker):
for step, batch in enumerate(train_loader):
def spo_train_epoch(epoch, loader, iters, ref_model, reward_model, reward_tokenizer, value_tracker, start_step=0, wandb=None):
for step, batch in enumerate(loader, start=start_step + 1):
prompts = batch['prompt'] # list[str], length B
prompt_inputs = tokenizer(prompts, return_tensors="pt", padding=True, return_token_type_ids=False,
padding_side="left", add_special_tokens=False).to(args.device) # input_ids: [B, P], attention_mask: [B, P]
@ -152,7 +138,9 @@ def spo_train_epoch(epoch, wandb, value_tracker):
prompt_inputs["attention_mask"] = prompt_inputs["attention_mask"][:, -args.max_seq_len:]
with torch.no_grad():
outputs = (model.module if ddp else model).generate(
# DDP 模型需要使用 .module 访问 generate 方法
model_for_gen = model.module if isinstance(model, DistributedDataParallel) else model
outputs = model_for_gen.generate(
**prompt_inputs, max_new_tokens=args.max_gen_len, do_sample=True, temperature=0.8,
num_return_sequences=1, pad_token_id=tokenizer.pad_token_id) # [B, P+R]
@ -205,42 +193,38 @@ def spo_train_epoch(epoch, wandb, value_tracker):
scheduler.step()
optimizer.zero_grad()
if step % args.log_interval == 0 or step == iter_per_epoch - 1:
if step % args.log_interval == 0 or step == iters:
policy_loss_val = loss.item()
avg_reward_val = rewards.mean().item()
avg_len_val = completion_mask.sum(dim=1).float().mean().item()
# average kl over valid tokens for logging
kl_val = ((per_token_kl * completion_mask).sum() / (completion_mask.sum() + 1e-8)).item()
avg_baseline_val = baselines.mean().item()
current_lr = optimizer.param_groups[0]['lr']
Logger(
f'Epoch: {epoch}, Step: {step + 1}/{iter_per_epoch}, '
f'Actor Loss: {policy_loss_val:.6f}, Reward: {avg_reward_val:.6f}, '
f'Baseline: {avg_baseline_val:.4f}, KL: {kl_val:.4f}, Rho: {rho:.4f}, Avg Response Len: {avg_len_val:.2f}, LR: {current_lr:.2e}')
Logger(f'Epoch: {epoch+1}, Step: {step}/{iters}, '
f'Actor Loss: {policy_loss_val:.6f}, Reward: {avg_reward_val:.6f}, '
f'Baseline: {avg_baseline_val:.4f}, KL: {kl_val:.4f}, Rho: {rho:.4f}, '
f'Avg Response Len: {avg_len_val:.2f}, LR: {current_lr:.2e}')
if wandb and (not ddp or dist.get_rank() == 0):
log_dict = {
if wandb and is_main_process():
wandb.log({
"policy_loss": policy_loss_val,
"reward": avg_reward_val,
"kl": kl_val,
"rho": float(rho),
"baseline": avg_baseline_val,
# "avg_response_len": avg_len_val,
"advantages_mean": advantages.mean().item(),
"learning_rate": current_lr
}
wandb.log(log_dict)
})
if ((step + 1) % args.save_interval == 0 or step == iter_per_epoch - 1) and (not ddp or dist.get_rank() == 0):
if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
model.eval()
moe_path = '_moe' if lm_config.use_moe else ''
suffix = 'spo'
ckp = f'{args.save_dir}/{suffix}_{lm_config.hidden_size}{moe_path}.pth'
state_dict = model.module.state_dict() if isinstance(model,
torch.nn.parallel.DistributedDataParallel) else model.state_dict()
moe_suffix = '_moe' if lm_config.use_moe else ''
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
state_dict = model.module.state_dict() if isinstance(model, DistributedDataParallel) else model.state_dict()
torch.save({k: v.half() for k, v in state_dict.items()}, ckp)
lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer,
epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints', scheduler=scheduler)
model.train()
del prompt_inputs, outputs, completion_ids, per_token_logps, ref_per_token_logps
@ -249,120 +233,116 @@ def spo_train_epoch(epoch, wandb, value_tracker):
gc.collect()
def init_model(lm_config):
tokenizer = AutoTokenizer.from_pretrained('../model/')
model = MiniMindForCausalLM(lm_config)
moe_path = '_moe' if lm_config.use_moe else ''
ckp = f'{args.save_dir}/full_sft_{lm_config.hidden_size}{moe_path}.pth'
if args.reasoning == 1:
ckp = f'{args.save_dir}/reason_{lm_config.hidden_size}{moe_path}.pth'
state_dict = torch.load(ckp, map_location=args.device)
model.load_state_dict(state_dict, strict=False)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="MiniMind SPO (Self-Play Optimization)")
parser.add_argument("--save_dir", type=str, default="../out", help="模型保存目录")
parser.add_argument('--save_weight', default='spo', type=str, help="保存权重的前缀名")
parser.add_argument("--epochs", type=int, default=1, help="训练轮数")
parser.add_argument("--batch_size", type=int, default=2, help="batch size")
parser.add_argument("--learning_rate", type=float, default=1e-7, help="初始学习率")
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备")
parser.add_argument("--dtype", type=str, default="bfloat16", help="混合精度类型")
parser.add_argument("--num_workers", type=int, default=1, help="数据加载线程数")
parser.add_argument("--accumulation_steps", type=int, default=4, help="梯度累积步数")
parser.add_argument("--grad_clip", type=float, default=1.0, help="梯度裁剪阈值")
parser.add_argument("--log_interval", type=int, default=1, help="日志打印间隔")
parser.add_argument("--save_interval", type=int, default=10, help="模型保存间隔")
parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度")
parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量")
parser.add_argument('--use_moe', default=False, type=bool, help="是否使用MoE")
parser.add_argument('--max_seq_len', default=66, type=int, help="Prompt最大长度")
parser.add_argument("--max_gen_len", type=int, default=1536, help="生成的最大长度")
parser.add_argument("--data_path", type=str, default="../dataset/rlaif-mini.jsonl", help="RLAIF数据路径")
parser.add_argument("--beta", type=float, default=0.02, help="KL惩罚系数")
parser.add_argument("--reasoning", type=int, default=1, help='0:普通模型1:推理模型')
parser.add_argument("--reward_model_path", type=str, default="../../internlm2-1_8b-reward", help="Reward模型路径")
parser.add_argument('--from_resume', default=0, type=int, help="是否自动检测&续训0否1是")
parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
parser.add_argument("--wandb_project", type=str, default="MiniMind-SPO", help="wandb项目名")
args = parser.parse_args()
# ========== 1. 初始化环境和随机种子 ==========
local_rank = init_distributed_mode()
if dist.is_initialized(): args.device = f"cuda:{local_rank}"
setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))
# ========== 2. 配置目录、模型参数、检查ckp ==========
os.makedirs(args.save_dir, exist_ok=True)
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers,
max_seq_len=args.max_seq_len + args.max_gen_len, use_moe=args.use_moe)
ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None
# ========== 3. 设置混合精度 ==========
device_type = "cuda" if "cuda" in args.device else "cpu"
dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
# ========== 4. 配wandb ==========
wandb = None
if args.use_wandb and is_main_process():
import swanlab as wandb
wandb_id = ckp_data.get('wandb_id') if ckp_data else None
resume = 'must' if wandb_id else None
wandb_run_name = f"MiniMind-SPO-Epoch-{args.epochs}-BS-{args.batch_size}-LR-{args.learning_rate}"
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
# ========== 5. 初始化模型Policy, Ref, Reward和Value Tracker、数据 ==========
tokenizer = AutoTokenizer.from_pretrained('../model/')
moe_suffix = '_moe' if lm_config.use_moe else ''
base_weight = "reason" if args.reasoning == 1 else "full_sft"
ckp = f'{args.save_dir}/{base_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
state_dict = torch.load(ckp, map_location=args.device)
# Policy模型
model = MiniMindForCausalLM(lm_config)
model.load_state_dict(state_dict, strict=False)
model = model.to(args.device)
Logger(f'Policy模型总参数量{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} M')
# Reference模型
ref_model = MiniMindForCausalLM(lm_config)
ref_model.load_state_dict(state_dict, strict=False)
ref_model.eval().requires_grad_(False)
Logger(f'LLM总参数量{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
model = model.to(args.device)
ref_model = ref_model.to(args.device)
reward_name = "../../internlm2-1_8b-reward"
# Reward模型
reward_model = AutoModel.from_pretrained(
reward_name,
device_map="cuda",
torch_dtype=torch.float16,
trust_remote_code=True,
args.reward_model_path, device_map="cuda", torch_dtype=torch.float16, trust_remote_code=True
).to(args.device).eval().requires_grad_(False)
reward_tokenizer = AutoTokenizer.from_pretrained(reward_name, trust_remote_code=True)
return model, ref_model, tokenizer, reward_model, reward_tokenizer
def init_distributed_mode():
if not ddp: return
global ddp_local_rank, DEVICE
dist.init_process_group(backend="nccl")
ddp_local_rank = int(os.environ["LOCAL_RANK"])
DEVICE = f"cuda:{ddp_local_rank}"
torch.cuda.set_device(DEVICE)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--out_dir", type=str, default="../out")
parser.add_argument("--epochs", type=int, default=1)
parser.add_argument("--batch_size", type=int, default=2)
parser.add_argument("--learning_rate", type=float, default=1e-7)
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu")
parser.add_argument("--dtype", type=str, default="bfloat16")
parser.add_argument("--use_wandb", action="store_true")
parser.add_argument("--wandb_project", type=str, default="MiniMind-SPO")
parser.add_argument("--num_workers", type=int, default=1)
parser.add_argument("--ddp", action="store_true")
parser.add_argument("--accumulation_steps", type=int, default=4)
parser.add_argument("--grad_clip", type=float, default=1.0)
parser.add_argument("--log_interval", type=int, default=1)
parser.add_argument("--save_interval", type=int, default=10)
parser.add_argument('--hidden_size', default=512, type=int)
parser.add_argument('--num_hidden_layers', default=8, type=int)
parser.add_argument('--use_moe', default=False, type=bool)
parser.add_argument('--max_seq_len', default=66, type=int)
parser.add_argument("--max_gen_len", type=int, default=1536)
parser.add_argument("--data_path", type=str, default="../dataset/rlaif-mini.jsonl")
parser.add_argument("--beta", type=float, default=0.02)
parser.add_argument("--reasoning", type=int, default=1, help='0:普通模型1:推理模型')
args = parser.parse_args()
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers,
max_seq_len=args.max_seq_len + args.max_gen_len,
use_moe=args.use_moe)
args.save_dir = os.path.join(args.out_dir)
os.makedirs(args.save_dir, exist_ok=True)
os.makedirs(args.out_dir, exist_ok=True)
ctx = nullcontext() if "cuda" not in args.device else torch.amp.autocast('cuda')
ddp = int(os.environ.get("RANK", -1)) != -1
ddp_local_rank, DEVICE = 0, "cuda:0"
base_seed = 1337
torch.manual_seed(base_seed)
torch.cuda.manual_seed(base_seed)
if ddp:
init_distributed_mode()
args.device = torch.device(DEVICE)
rank = dist.get_rank()
torch.manual_seed(base_seed + rank)
# 同时设置 CUDA 的随机种子
torch.cuda.manual_seed(base_seed + rank)
if args.use_wandb and (not ddp or ddp_local_rank == 0):
import swanlab as wandb
wandb.init(project=args.wandb_project)
else:
wandb = None
model, ref_model, tokenizer, reward_model, reward_tokenizer = init_model(lm_config=lm_config)
reward_tokenizer = AutoTokenizer.from_pretrained(args.reward_model_path, trust_remote_code=True)
# Value Tracker
value_tracker = AutoAdaptiveValueTracker(rho_mode='kl', rho_const=0.9, D_half=0.06, clip_lower=0.5, clip_upper=0.96)
train_ds = RLAIFDataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len)
train_sampler = DistributedSampler(train_ds) if ddp else None
train_loader = DataLoader(train_ds, batch_size=args.batch_size, pin_memory=True,
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
loader_for_count = DataLoader(train_ds, batch_size=args.batch_size, sampler=train_sampler)
iters = len(loader_for_count)
total_optimizer_steps = (iters // args.accumulation_steps) * args.epochs
scheduler = CosineAnnealingLR(optimizer, T_max=total_optimizer_steps, eta_min=args.learning_rate / 10)
# ========== 6. 从ckp恢复状态 ==========
start_epoch, start_step = 0, 0
if ckp_data:
model.load_state_dict(ckp_data['model'])
optimizer.load_state_dict(ckp_data['optimizer'])
scheduler.load_state_dict(ckp_data['scheduler'])
start_epoch = ckp_data['epoch']
start_step = ckp_data.get('step', 0)
# ========== 7. DDP包模型 ==========
if dist.is_initialized():
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
model = DistributedDataParallel(model, device_ids=[local_rank])
# ========== 8. 开始训练 ==========
for epoch in range(start_epoch, args.epochs):
train_sampler and train_sampler.set_epoch(epoch)
if epoch == start_epoch and start_step > 0: # 第一个epoch且存在检查点
batch_sampler = SkipBatchSampler(train_sampler or range(len(train_ds)), args.batch_size, start_step + 1)
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step从step {start_step + 1}开始')
spo_train_epoch(epoch, loader, len(loader) + start_step + 1, ref_model, reward_model, reward_tokenizer, value_tracker, start_step, wandb)
else: # 默认从头开始
loader = DataLoader(train_ds, batch_size=args.batch_size, pin_memory=True,
drop_last=False, shuffle=(train_sampler is None),
num_workers=args.num_workers, sampler=train_sampler)
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
iter_per_epoch = len(train_loader)
total_optimizer_steps = (iter_per_epoch // args.accumulation_steps) * args.epochs
scheduler = CosineAnnealingLR(optimizer, T_max=total_optimizer_steps, eta_min=args.learning_rate / 10)
if ddp:
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
model = DistributedDataParallel(model, device_ids=[ddp_local_rank])
value_tracker = AutoAdaptiveValueTracker(rho_mode='kl', rho_const=0.9, D_half=0.06, clip_lower=0.5, clip_upper=0.96)
for epoch in range(args.epochs):
train_sampler and train_sampler.set_epoch(epoch)
spo_train_epoch(epoch, wandb, value_tracker)
spo_train_epoch(epoch, loader, len(loader), ref_model, reward_model, reward_tokenizer, value_tracker, 0, wandb)

139
trainer/trainer_utils.py Normal file
View File

@ -0,0 +1,139 @@
"""
训练工具函数集合
"""
import os
import random
import math
import numpy as np
import torch
import torch.distributed as dist
from torch.utils.data import Sampler
def is_main_process():
return not dist.is_initialized() or dist.get_rank() == 0
def Logger(content):
if is_main_process():
print(content)
def get_lr(current_step, total_steps, lr):
return lr / 10 + 0.5 * lr * (1 + math.cos(math.pi * current_step / total_steps))
def init_distributed_mode():
if int(os.environ.get("RANK", -1)) == -1:
return 0 # 非DDP模式
dist.init_process_group(backend="nccl")
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)
return local_rank
def setup_seed(seed: int):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def lm_checkpoint(lm_config, weight='full_sft', model=None, optimizer=None, epoch=0, step=0, wandb=None, save_dir='../checkpoints', **kwargs):
os.makedirs(save_dir, exist_ok=True)
moe_path = '_moe' if lm_config.use_moe else ''
ckp_path = f'{save_dir}/{weight}_{lm_config.hidden_size}{moe_path}.pth'
resume_path = f'{save_dir}/{weight}_{lm_config.hidden_size}{moe_path}_resume.pth'
if model is not None:
from torch.nn.parallel import DistributedDataParallel
state_dict = model.module.state_dict() if isinstance(model, DistributedDataParallel) else model.state_dict()
ckp_tmp = ckp_path + '.tmp'
torch.save({k: v.half() for k, v in state_dict.items()}, ckp_tmp)
os.replace(ckp_tmp, ckp_path)
wandb_id = None
if wandb:
if hasattr(wandb, 'get_run'):
run = wandb.get_run()
wandb_id = getattr(run, 'id', None) if run else None
else:
wandb_id = getattr(wandb, 'id', None)
resume_data = {
'model': state_dict,
'optimizer': optimizer.state_dict(),
'epoch': epoch,
'step': step,
'world_size': dist.get_world_size() if dist.is_initialized() else 1,
'wandb_id': wandb_id
}
for key, value in kwargs.items():
if value is not None:
if hasattr(value, 'state_dict'):
if isinstance(value, DistributedDataParallel):
resume_data[key] = value.module.state_dict()
else:
resume_data[key] = value.state_dict()
else:
resume_data[key] = value
resume_tmp = resume_path + '.tmp'
torch.save(resume_data, resume_tmp)
os.replace(resume_tmp, resume_path)
else: # 加载模式
if os.path.exists(resume_path):
ckp_data = torch.load(resume_path, map_location='cpu')
saved_ws = ckp_data.get('world_size', 1)
current_ws = dist.get_world_size() if dist.is_initialized() else 1
if saved_ws != current_ws:
ckp_data['step'] = ckp_data['step'] * saved_ws // current_ws
Logger(f'GPU数量变化({saved_ws}{current_ws})step已自动转换为{ckp_data["step"]}')
return ckp_data
return None
def init_model(lm_config, from_weight='pretrain', tokenizer_path='../model', save_dir='../out', device='cuda'):
from transformers import AutoTokenizer
from model.model_minimind import MiniMindForCausalLM
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
model = MiniMindForCausalLM(lm_config)
if from_weight!= 'none':
moe_suffix = '_moe' if lm_config.use_moe else ''
weight_path = f'{save_dir}/{from_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
weights = torch.load(weight_path, map_location=device)
model.load_state_dict(weights, strict=False)
Logger(f'所加载Model可训练参数{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
return model.to(device), tokenizer
class SkipBatchSampler(Sampler):
def __init__(self, sampler, batch_size, skip_batches=0):
self.sampler = sampler
self.batch_size = batch_size
self.skip_batches = skip_batches
def __iter__(self):
batch = []
skipped = 0
for idx in self.sampler:
batch.append(idx)
if len(batch) == self.batch_size:
if skipped < self.skip_batches:
skipped += 1
batch = []
continue
yield batch
batch = []
if len(batch) > 0 and skipped >= self.skip_batches:
yield batch
def __len__(self):
total_batches = (len(self.sampler) + self.batch_size - 1) // self.batch_size
return max(0, total_batches - self.skip_batches)