mirror of
https://github.com/jingyaogong/minimind.git
synced 2026-05-01 11:48:14 +08:00
[feat] pause-training
This commit is contained in:
parent
6efba3249a
commit
e8484874f5
@ -6,7 +6,6 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||
|
||||
import argparse
|
||||
import time
|
||||
import math
|
||||
import warnings
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -14,23 +13,14 @@ from contextlib import nullcontext
|
||||
from torch import optim, nn
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
|
||||
from model.model_minimind import MiniMindConfig
|
||||
from dataset.lm_dataset import SFTDataset
|
||||
from trainer.trainer_utils import *
|
||||
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
|
||||
def Logger(content):
|
||||
if not ddp or dist.get_rank() == 0:
|
||||
print(content)
|
||||
|
||||
|
||||
def get_lr(current_step, total_steps, lr):
|
||||
return lr / 10 + 0.5 * lr * (1 + math.cos(math.pi * current_step / total_steps))
|
||||
|
||||
|
||||
def train_epoch(epoch, wandb):
|
||||
def train_epoch(epoch, loader, iters, tokenizer, lm_config, start_step=0, wandb=None):
|
||||
# 思考标签占位符
|
||||
start_of_think_ids = tokenizer('<think>').input_ids
|
||||
end_of_think_ids = tokenizer('</think>').input_ids
|
||||
@ -38,28 +28,30 @@ def train_epoch(epoch, wandb):
|
||||
end_of_answer_ids = tokenizer('</answer>').input_ids
|
||||
loss_fct = nn.CrossEntropyLoss(reduction='none')
|
||||
start_time = time.time()
|
||||
for step, (X, Y, loss_mask) in enumerate(train_loader):
|
||||
|
||||
for step, (X, Y, loss_mask) in enumerate(loader, start=start_step + 1):
|
||||
X = X.to(args.device)
|
||||
Y = Y.to(args.device)
|
||||
loss_mask = loss_mask.to(args.device)
|
||||
lr = get_lr(epoch * iter_per_epoch + step, args.epochs * iter_per_epoch, args.learning_rate)
|
||||
lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate)
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = lr
|
||||
|
||||
with ctx:
|
||||
with autocast_ctx:
|
||||
res = model(X)
|
||||
loss = loss_fct(
|
||||
res.logits.view(-1, res.logits.size(-1)),
|
||||
Y.view(-1)
|
||||
).view(Y.size())
|
||||
|
||||
# 特殊标签位置增加权重(推理蒸馏特有)
|
||||
sp_ids = torch.isin(Y.view(-1),
|
||||
torch.tensor(start_of_think_ids + end_of_think_ids
|
||||
+ start_of_answer_ids + end_of_answer_ids
|
||||
).to(args.device))
|
||||
# 在 sp_ids 对应的位置增加额外的惩罚
|
||||
loss_mask = loss_mask.view(-1)
|
||||
loss_mask_sum = loss_mask.sum()
|
||||
loss_mask[sp_ids] = 10
|
||||
loss_mask[sp_ids] = 10 # 对思考标签增加10倍权重
|
||||
loss_mask = loss_mask.view(Y.size())
|
||||
loss = (loss * loss_mask).sum() / loss_mask_sum
|
||||
loss += res.aux_loss
|
||||
@ -70,148 +62,112 @@ def train_epoch(epoch, wandb):
|
||||
if (step + 1) % args.accumulation_steps == 0:
|
||||
scaler.unscale_(optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
|
||||
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
if step % args.log_interval == 0 or step == iter_per_epoch - 1:
|
||||
if step % args.log_interval == 0 or step == iters - 1:
|
||||
spend_time = time.time() - start_time
|
||||
Logger(
|
||||
'Epoch:[{}/{}]({}/{}) loss:{:.6f} lr:{:.12f} epoch_Time:{}min:'.format(
|
||||
epoch + 1,
|
||||
args.epochs,
|
||||
step,
|
||||
iter_per_epoch,
|
||||
loss.item() * args.accumulation_steps,
|
||||
optimizer.param_groups[-1]['lr'],
|
||||
spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60))
|
||||
current_loss = loss.item() * args.accumulation_steps
|
||||
current_lr = optimizer.param_groups[-1]['lr']
|
||||
eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60
|
||||
|
||||
Logger(f'Epoch:[{epoch+1}/{args.epochs}]({step}/{iters}) loss:{current_loss:.6f} lr:{current_lr:.12f} epoch_Time:{eta_min}min:')
|
||||
|
||||
if wandb: wandb.log({"loss": current_loss, "lr": current_lr, "epoch_Time": eta_min})
|
||||
|
||||
if (wandb is not None) and (not ddp or dist.get_rank() == 0):
|
||||
wandb.log({"loss": loss.item() * args.accumulation_steps,
|
||||
"lr": optimizer.param_groups[-1]['lr'],
|
||||
"epoch_Time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60})
|
||||
|
||||
if ((step + 1) % args.save_interval == 0 or step == iter_per_epoch - 1) and (not ddp or dist.get_rank() == 0):
|
||||
if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
|
||||
model.eval()
|
||||
moe_path = '_moe' if lm_config.use_moe else ''
|
||||
ckp = f'{args.save_dir}/reason_{lm_config.hidden_size}{moe_path}.pth'
|
||||
|
||||
moe_suffix = '_moe' if lm_config.use_moe else ''
|
||||
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
|
||||
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
|
||||
state_dict = model.module.state_dict()
|
||||
else:
|
||||
state_dict = model.state_dict()
|
||||
|
||||
state_dict = {k: v.half() for k, v in state_dict.items()} # 半精度保存
|
||||
torch.save(state_dict, ckp)
|
||||
lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints')
|
||||
model.train()
|
||||
|
||||
|
||||
def init_model(lm_config):
|
||||
tokenizer = AutoTokenizer.from_pretrained('../model')
|
||||
model = MiniMindForCausalLM(lm_config)
|
||||
moe_path = '_moe' if lm_config.use_moe else ''
|
||||
ckp = f'{args.save_dir}/rlhf_{lm_config.hidden_size}{moe_path}.pth'
|
||||
state_dict = torch.load(ckp, map_location=args.device)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
Logger(f'LLM总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
|
||||
model = model.to(args.device)
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def init_distributed_mode():
|
||||
if not ddp: return
|
||||
global ddp_local_rank, DEVICE
|
||||
|
||||
dist.init_process_group(backend="nccl")
|
||||
ddp_rank = int(os.environ["RANK"])
|
||||
ddp_local_rank = int(os.environ["LOCAL_RANK"])
|
||||
ddp_world_size = int(os.environ["WORLD_SIZE"])
|
||||
DEVICE = f"cuda:{ddp_local_rank}"
|
||||
torch.cuda.set_device(DEVICE)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="MiniMind Distill Reasoning")
|
||||
parser.add_argument("--out_dir", type=str, default="../out")
|
||||
parser.add_argument("--epochs", type=int, default=1)
|
||||
parser.add_argument("--batch_size", type=int, default=8)
|
||||
parser.add_argument("--learning_rate", type=float, default=1e-6)
|
||||
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
parser.add_argument("--dtype", type=str, default="bfloat16")
|
||||
parser.add_argument("--use_wandb", action="store_true")
|
||||
parser.add_argument("--wandb_project", type=str, default="MiniMind-Full-SFT")
|
||||
parser.add_argument("--num_workers", type=int, default=1)
|
||||
parser.add_argument("--ddp", action="store_true")
|
||||
parser.add_argument("--accumulation_steps", type=int, default=1)
|
||||
parser.add_argument("--grad_clip", type=float, default=1.0)
|
||||
parser.add_argument("--warmup_iters", type=int, default=0)
|
||||
parser.add_argument("--log_interval", type=int, default=1)
|
||||
parser.add_argument("--save_interval", type=int, default=50)
|
||||
parser.add_argument('--local_rank', type=int, default=-1)
|
||||
parser.add_argument('--hidden_size', default=512, type=int)
|
||||
parser.add_argument('--num_hidden_layers', default=8, type=int)
|
||||
parser.add_argument('--max_seq_len', default=1024, type=int)
|
||||
parser.add_argument('--use_moe', default=False, type=bool)
|
||||
parser.add_argument("--data_path", type=str, default="../dataset/r1_mix_1024.jsonl")
|
||||
|
||||
parser = argparse.ArgumentParser(description="MiniMind Reasoning Distillation")
|
||||
parser.add_argument("--save_dir", type=str, default="../out", help="模型保存目录")
|
||||
parser.add_argument('--save_weight', default='reason', type=str, help="保存权重的前缀名")
|
||||
parser.add_argument("--epochs", type=int, default=1, help="训练轮数")
|
||||
parser.add_argument("--batch_size", type=int, default=8, help="batch size")
|
||||
parser.add_argument("--learning_rate", type=float, default=1e-6, help="初始学习率")
|
||||
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备")
|
||||
parser.add_argument("--dtype", type=str, default="bfloat16", help="混合精度类型")
|
||||
parser.add_argument("--num_workers", type=int, default=1, help="数据加载线程数")
|
||||
parser.add_argument("--accumulation_steps", type=int, default=1, help="梯度累积步数")
|
||||
parser.add_argument("--grad_clip", type=float, default=1.0, help="梯度裁剪阈值")
|
||||
parser.add_argument("--log_interval", type=int, default=100, help="日志打印间隔")
|
||||
parser.add_argument("--save_interval", type=int, default=100, help="模型保存间隔")
|
||||
parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度")
|
||||
parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量")
|
||||
parser.add_argument('--max_seq_len', default=1024, type=int, help="训练的最大截断长度")
|
||||
parser.add_argument('--use_moe', default=False, type=bool, help="是否使用MoE")
|
||||
parser.add_argument("--data_path", type=str, default="../dataset/r1_mix_1024.jsonl", help="推理蒸馏数据路径")
|
||||
parser.add_argument('--from_weight', default='dpo', type=str, help="基于哪个权重训练,默认dpo")
|
||||
parser.add_argument('--from_resume', default=0, type=int, help="是否自动检测&续训,0否1是")
|
||||
parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
|
||||
parser.add_argument("--wandb_project", type=str, default="MiniMind-Reasoning", help="wandb项目名")
|
||||
args = parser.parse_args()
|
||||
|
||||
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers,
|
||||
use_moe=args.use_moe)
|
||||
args.save_dir = os.path.join(args.out_dir)
|
||||
# ========== 1. 初始化环境和随机种子 ==========
|
||||
local_rank = init_distributed_mode()
|
||||
if dist.is_initialized(): args.device = f"cuda:{local_rank}"
|
||||
setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))
|
||||
|
||||
# ========== 2. 配置目录、模型参数、检查ckp ==========
|
||||
os.makedirs(args.save_dir, exist_ok=True)
|
||||
os.makedirs(args.out_dir, exist_ok=True)
|
||||
tokens_per_iter = args.batch_size * args.max_seq_len
|
||||
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=args.use_moe)
|
||||
ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None
|
||||
|
||||
# ========== 3. 设置混合精度 ==========
|
||||
device_type = "cuda" if "cuda" in args.device else "cpu"
|
||||
|
||||
args.wandb_run_name = f"MiniMind-Distill-Reasoning-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"
|
||||
|
||||
ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast()
|
||||
ddp = int(os.environ.get("RANK", -1)) != -1 # is this a ddp run?
|
||||
ddp_local_rank, DEVICE = 0, "cuda:0"
|
||||
base_seed = 1337
|
||||
torch.manual_seed(base_seed)
|
||||
torch.cuda.manual_seed(base_seed)
|
||||
|
||||
if ddp:
|
||||
init_distributed_mode()
|
||||
args.device = torch.device(DEVICE)
|
||||
rank = dist.get_rank()
|
||||
torch.manual_seed(base_seed + rank)
|
||||
# 同时设置 CUDA 的随机种子
|
||||
torch.cuda.manual_seed(base_seed + rank)
|
||||
|
||||
if args.use_wandb and (not ddp or ddp_local_rank == 0):
|
||||
dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
|
||||
autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
|
||||
|
||||
# ========== 4. 配wandb ==========
|
||||
wandb = None
|
||||
if args.use_wandb and is_main_process():
|
||||
import swanlab as wandb
|
||||
|
||||
wandb.init(project=args.wandb_project, name=args.wandb_run_name)
|
||||
else:
|
||||
wandb = None
|
||||
|
||||
model, tokenizer = init_model(lm_config)
|
||||
|
||||
wandb_id = ckp_data.get('wandb_id') if ckp_data else None
|
||||
resume = 'must' if wandb_id else None
|
||||
wandb_run_name = f"MiniMind-Reasoning-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LR-{args.learning_rate}"
|
||||
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
|
||||
|
||||
# ========== 5. 定义模型、数据、优化器 ==========
|
||||
model, tokenizer = init_model(lm_config, args.from_weight)
|
||||
train_ds = SFTDataset(args.data_path, tokenizer, max_length=args.max_seq_len)
|
||||
train_sampler = DistributedSampler(train_ds) if ddp else None
|
||||
train_loader = DataLoader(
|
||||
train_ds,
|
||||
batch_size=args.batch_size,
|
||||
pin_memory=True,
|
||||
drop_last=False,
|
||||
shuffle=(train_sampler is None),
|
||||
num_workers=args.num_workers,
|
||||
sampler=train_sampler
|
||||
)
|
||||
|
||||
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype in ['float16', 'bfloat16']))
|
||||
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
|
||||
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
|
||||
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
|
||||
|
||||
if ddp:
|
||||
|
||||
# ========== 6. 从ckp恢复状态 ==========
|
||||
start_epoch, start_step = 0, 0
|
||||
if ckp_data:
|
||||
model.load_state_dict(ckp_data['model'])
|
||||
optimizer.load_state_dict(ckp_data['optimizer'])
|
||||
scaler.load_state_dict(ckp_data['scaler'])
|
||||
start_epoch = ckp_data['epoch']
|
||||
start_step = ckp_data.get('step', 0)
|
||||
|
||||
# ========== 7. DDP包模型 ==========
|
||||
if dist.is_initialized():
|
||||
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
|
||||
model = DistributedDataParallel(model, device_ids=[ddp_local_rank])
|
||||
|
||||
iter_per_epoch = len(train_loader)
|
||||
for epoch in range(args.epochs):
|
||||
model = DistributedDataParallel(model, device_ids=[local_rank])
|
||||
|
||||
# ========== 8. 开始训练 ==========
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
train_sampler and train_sampler.set_epoch(epoch)
|
||||
train_epoch(epoch, wandb)
|
||||
if epoch == start_epoch and start_step > 0: # 第一个epoch且存在检查点
|
||||
batch_sampler = SkipBatchSampler(train_sampler or range(len(train_ds)), args.batch_size, start_step + 1)
|
||||
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
|
||||
Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step,从step {start_step + 1}开始')
|
||||
train_epoch(epoch, loader, len(loader) + start_step + 1, tokenizer, lm_config, start_step, wandb)
|
||||
else: # 默认从头开始
|
||||
loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=(train_sampler is None), sampler=train_sampler, num_workers=args.num_workers, pin_memory=True)
|
||||
train_epoch(epoch, loader, len(loader), tokenizer, lm_config, 0, wandb)
|
||||
|
||||
@ -3,11 +3,10 @@ import sys
|
||||
|
||||
__package__ = "trainer"
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||
|
||||
import argparse
|
||||
import time
|
||||
import math
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.distributed as dist
|
||||
@ -15,23 +14,14 @@ from contextlib import nullcontext
|
||||
from torch import optim
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
|
||||
from model.model_minimind import MiniMindConfig
|
||||
from dataset.lm_dataset import SFTDataset
|
||||
from trainer.trainer_utils import *
|
||||
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
|
||||
def Logger(content):
|
||||
if not ddp or dist.get_rank() == 0:
|
||||
print(content)
|
||||
|
||||
|
||||
def get_lr(current_step, total_steps, lr):
|
||||
return lr / 10 + 0.5 * lr * (1 + math.cos(math.pi * current_step / total_steps))
|
||||
|
||||
|
||||
def distillation_loss_fn(student_logits, teacher_logits, temperature=1.0, reduction='batchmean'):
|
||||
def distillation_loss(student_logits, teacher_logits, temperature=1.0, reduction='batchmean'):
|
||||
with torch.no_grad():
|
||||
teacher_probs = F.softmax(teacher_logits / temperature, dim=-1).detach()
|
||||
|
||||
@ -45,25 +35,23 @@ def distillation_loss_fn(student_logits, teacher_logits, temperature=1.0, reduct
|
||||
return (temperature ** 2) * kl
|
||||
|
||||
|
||||
def train_epoch(epoch, wandb, alpha=0.0, temperature=1.0):
|
||||
def train_epoch(epoch, loader, iters, teacher_model, lm_config_student, start_step=0, wandb=None, alpha=0.0, temperature=1.0):
|
||||
start_time = time.time()
|
||||
|
||||
|
||||
if teacher_model is not None:
|
||||
teacher_model.eval()
|
||||
teacher_model.requires_grad_(False)
|
||||
|
||||
for step, (X, Y, loss_mask) in enumerate(train_loader):
|
||||
for step, (X, Y, loss_mask) in enumerate(loader, start=start_step + 1):
|
||||
X = X.to(args.device)
|
||||
Y = Y.to(args.device)
|
||||
loss_mask = loss_mask.to(args.device)
|
||||
lr = get_lr(epoch * iter_per_epoch + step,
|
||||
args.epochs * iter_per_epoch,
|
||||
args.learning_rate)
|
||||
lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate)
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = lr
|
||||
|
||||
# 前向传播(学生模型)
|
||||
with ctx:
|
||||
with autocast_ctx:
|
||||
res = model(X)
|
||||
student_logits = res.logits
|
||||
|
||||
@ -71,11 +59,11 @@ def train_epoch(epoch, wandb, alpha=0.0, temperature=1.0):
|
||||
if teacher_model is not None:
|
||||
with torch.no_grad():
|
||||
teacher_logits = teacher_model(X).logits
|
||||
vocab_size_student = student_logits.size(-1) # N
|
||||
vocab_size_student = student_logits.size(-1)
|
||||
teacher_logits = teacher_logits[..., :vocab_size_student]
|
||||
|
||||
# ========== 计算损失 ==========
|
||||
# 1) Ground-Truth CE Loss(可选)
|
||||
# 1) Ground-Truth CE Loss
|
||||
loss_mask_flat = loss_mask.view(-1)
|
||||
ce_loss = F.cross_entropy(
|
||||
student_logits.view(-1, student_logits.size(-1)),
|
||||
@ -87,10 +75,9 @@ def train_epoch(epoch, wandb, alpha=0.0, temperature=1.0):
|
||||
if lm_config_student.use_moe:
|
||||
ce_loss += res.aux_loss
|
||||
|
||||
# 2) Distillation Loss(可选)
|
||||
# 2) Distillation Loss
|
||||
if teacher_model is not None:
|
||||
# 只在有效token位置做蒸馏
|
||||
distill_loss = distillation_loss_fn(
|
||||
distill_loss = distillation_loss(
|
||||
student_logits.view(-1, student_logits.size(-1))[loss_mask_flat == 1],
|
||||
teacher_logits.view(-1, teacher_logits.size(-1))[loss_mask_flat == 1],
|
||||
temperature=temperature
|
||||
@ -110,157 +97,126 @@ def train_epoch(epoch, wandb, alpha=0.0, temperature=1.0):
|
||||
scaler.update()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
if step % args.log_interval == 0 or step == iter_per_epoch - 1:
|
||||
if step % args.log_interval == 0 or step == iters - 1:
|
||||
spend_time = time.time() - start_time
|
||||
Logger(
|
||||
'Epoch:[{}/{}]({}/{}) loss:{:.6f} lr:{:.12f} epoch_Time:{}min:'.format(
|
||||
epoch,
|
||||
args.epochs - 1,
|
||||
step,
|
||||
iter_per_epoch,
|
||||
loss.item(),
|
||||
optimizer.param_groups[-1]['lr'],
|
||||
spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60
|
||||
)
|
||||
)
|
||||
|
||||
if (wandb is not None) and (not ddp or dist.get_rank() == 0):
|
||||
current_loss = loss.item() * args.accumulation_steps
|
||||
current_lr = optimizer.param_groups[-1]['lr']
|
||||
eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60
|
||||
|
||||
Logger(f'Epoch:[{epoch+1}/{args.epochs}]({step}/{iters}) loss:{current_loss:.6f} ce:{ce_loss.item():.4f} distill:{distill_loss.item():.4f} lr:{current_lr:.12f} epoch_Time:{eta_min}min:')
|
||||
|
||||
if wandb:
|
||||
wandb.log({
|
||||
"loss": loss.item(),
|
||||
"loss": current_loss,
|
||||
"ce_loss": ce_loss.item(),
|
||||
"distill_loss": distill_loss.item() if teacher_model is not None else 0.0,
|
||||
"lr": optimizer.param_groups[-1]['lr'],
|
||||
"last-time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60
|
||||
"lr": current_lr,
|
||||
"epoch_Time": eta_min
|
||||
})
|
||||
|
||||
if ((step + 1) % args.save_interval == 0 or step == iter_per_epoch - 1) and (not ddp or dist.get_rank() == 0):
|
||||
if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
|
||||
model.eval()
|
||||
moe_path = '_moe' if lm_config_student.use_moe else ''
|
||||
ckp = f'{args.save_dir}/full_dist_{lm_config_student.hidden_size}{moe_path}.pth'
|
||||
moe_suffix = '_moe' if lm_config_student.use_moe else ''
|
||||
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config_student.hidden_size}{moe_suffix}.pth'
|
||||
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
|
||||
state_dict = model.module.state_dict()
|
||||
else:
|
||||
state_dict = model.state_dict()
|
||||
state_dict = {k: v.half() for k, v in state_dict.items()} # 半精度保存
|
||||
torch.save(state_dict, ckp)
|
||||
lm_checkpoint(lm_config_student, weight=args.save_weight, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints')
|
||||
model.train()
|
||||
|
||||
|
||||
def init_student_model(lm_config):
|
||||
tokenizer = AutoTokenizer.from_pretrained('../model/')
|
||||
model = MiniMindForCausalLM(lm_config)
|
||||
moe_path = '_moe' if lm_config.use_moe else ''
|
||||
ckp = f'{args.save_dir}/full_sft_{lm_config.hidden_size}{moe_path}.pth'
|
||||
state_dict = torch.load(ckp, map_location=args.device)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
Logger(f'学生模型(LLM)总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
|
||||
model = model.to(args.device)
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def init_teacher_model(lm_config):
|
||||
model = MiniMindForCausalLM(lm_config)
|
||||
moe_path = '_moe' if lm_config.use_moe else ''
|
||||
ckp = f'{args.save_dir}/full_sft_{lm_config.hidden_size}{moe_path}.pth'
|
||||
state_dict = torch.load(ckp, map_location=args.device)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
Logger(f'教师模型(LLM)总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
|
||||
model = model.to(args.device)
|
||||
return model
|
||||
|
||||
|
||||
def init_distributed_mode():
|
||||
if not ddp: return
|
||||
global ddp_local_rank, DEVICE
|
||||
|
||||
dist.init_process_group(backend="nccl")
|
||||
ddp_rank = int(os.environ["RANK"])
|
||||
ddp_local_rank = int(os.environ["LOCAL_RANK"])
|
||||
ddp_world_size = int(os.environ["WORLD_SIZE"])
|
||||
DEVICE = f"cuda:{ddp_local_rank}"
|
||||
torch.cuda.set_device(DEVICE)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="MiniMind Full SFT")
|
||||
parser.add_argument("--out_dir", type=str, default="../out")
|
||||
parser.add_argument("--epochs", type=int, default=6)
|
||||
parser.add_argument("--batch_size", type=int, default=32)
|
||||
parser.add_argument("--learning_rate", type=float, default=5e-6)
|
||||
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
parser.add_argument("--dtype", type=str, default="bfloat16")
|
||||
parser.add_argument("--use_wandb", action="store_true")
|
||||
parser.add_argument("--wandb_project", type=str, default="MiniMind-Full-SFT")
|
||||
parser.add_argument("--num_workers", type=int, default=1)
|
||||
parser.add_argument("--ddp", action="store_true")
|
||||
parser.add_argument("--accumulation_steps", type=int, default=1)
|
||||
parser.add_argument("--grad_clip", type=float, default=1.0)
|
||||
parser.add_argument("--warmup_iters", type=int, default=0)
|
||||
parser.add_argument("--log_interval", type=int, default=100)
|
||||
parser.add_argument("--save_interval", type=int, default=100)
|
||||
parser.add_argument("--max_seq_len", type=int, default=512)
|
||||
parser.add_argument('--local_rank', type=int, default=-1)
|
||||
parser.add_argument("--data_path", type=str, default="../dataset/sft_xxx.jsonl")
|
||||
|
||||
parser = argparse.ArgumentParser(description="MiniMind Knowledge Distillation")
|
||||
parser.add_argument("--save_dir", type=str, default="../out", help="模型保存目录")
|
||||
parser.add_argument('--save_weight', default='full_dist', type=str, help="保存权重的前缀名")
|
||||
parser.add_argument("--epochs", type=int, default=6, help="训练轮数")
|
||||
parser.add_argument("--batch_size", type=int, default=32, help="batch size")
|
||||
parser.add_argument("--learning_rate", type=float, default=5e-6, help="初始学习率")
|
||||
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备")
|
||||
parser.add_argument("--dtype", type=str, default="bfloat16", help="混合精度类型")
|
||||
parser.add_argument("--num_workers", type=int, default=1, help="数据加载线程数")
|
||||
parser.add_argument("--accumulation_steps", type=int, default=1, help="梯度累积步数")
|
||||
parser.add_argument("--grad_clip", type=float, default=1.0, help="梯度裁剪阈值")
|
||||
parser.add_argument("--log_interval", type=int, default=100, help="日志打印间隔")
|
||||
parser.add_argument("--save_interval", type=int, default=100, help="模型保存间隔")
|
||||
parser.add_argument("--max_seq_len", type=int, default=512, help="训练的最大截断长度")
|
||||
parser.add_argument("--data_path", type=str, default="../dataset/sft_mini_512.jsonl", help="训练数据路径")
|
||||
parser.add_argument('--student_hidden_size', default=512, type=int, help="学生模型隐藏层维度")
|
||||
parser.add_argument('--student_num_layers', default=8, type=int, help="学生模型隐藏层数量")
|
||||
parser.add_argument('--teacher_hidden_size', default=768, type=int, help="教师模型隐藏层维度")
|
||||
parser.add_argument('--teacher_num_layers', default=16, type=int, help="教师模型隐藏层数量")
|
||||
parser.add_argument('--use_moe', default=False, type=bool, help="是否使用MoE")
|
||||
parser.add_argument('--from_student_weight', default='full_sft', type=str, help="学生模型基于哪个权重")
|
||||
parser.add_argument('--from_teacher_weight', default='full_sft', type=str, help="教师模型基于哪个权重")
|
||||
parser.add_argument('--from_resume', default=0, type=int, help="是否自动检测&续训,0否1是")
|
||||
parser.add_argument('--alpha', default=0.5, type=float, help="CE损失权重,总损失=alpha*CE+(1-alpha)*KL")
|
||||
parser.add_argument('--temperature', default=2.0, type=float, help="蒸馏温度")
|
||||
parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
|
||||
parser.add_argument("--wandb_project", type=str, default="MiniMind-Distillation", help="wandb项目名")
|
||||
args = parser.parse_args()
|
||||
# 定义学生模型和教师模型
|
||||
lm_config_student = MiniMindConfig(hidden_size=512, num_hidden_layers=8)
|
||||
lm_config_teacher = MiniMindConfig(hidden_size=768, num_hidden_layers=16)
|
||||
args.save_dir = os.path.join(args.out_dir)
|
||||
|
||||
# ========== 1. 初始化环境和随机种子 ==========
|
||||
local_rank = init_distributed_mode()
|
||||
if dist.is_initialized(): args.device = f"cuda:{local_rank}"
|
||||
setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))
|
||||
|
||||
# ========== 2. 配置目录、模型参数、检查ckp ==========
|
||||
os.makedirs(args.save_dir, exist_ok=True)
|
||||
os.makedirs(args.out_dir, exist_ok=True)
|
||||
tokens_per_iter = args.batch_size * args.max_seq_len
|
||||
lm_config_student = MiniMindConfig(hidden_size=args.student_hidden_size, num_hidden_layers=args.student_num_layers, use_moe=args.use_moe)
|
||||
lm_config_teacher = MiniMindConfig(hidden_size=args.teacher_hidden_size, num_hidden_layers=args.teacher_num_layers, use_moe=args.use_moe)
|
||||
ckp_data = lm_checkpoint(lm_config_student, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None
|
||||
|
||||
# ========== 3. 设置混合精度 ==========
|
||||
device_type = "cuda" if "cuda" in args.device else "cpu"
|
||||
|
||||
args.wandb_run_name = f"MiniMind-Dist-SFT-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"
|
||||
|
||||
ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast()
|
||||
ddp = int(os.environ.get("RANK", -1)) != -1 # is this a ddp run?
|
||||
ddp_local_rank, DEVICE = 0, "cuda:0"
|
||||
base_seed = 1337
|
||||
torch.manual_seed(base_seed)
|
||||
torch.cuda.manual_seed(base_seed)
|
||||
|
||||
if ddp:
|
||||
init_distributed_mode()
|
||||
args.device = torch.device(DEVICE)
|
||||
rank = dist.get_rank()
|
||||
torch.manual_seed(base_seed + rank)
|
||||
# 同时设置 CUDA 的随机种子
|
||||
torch.cuda.manual_seed(base_seed + rank)
|
||||
|
||||
if args.use_wandb and (not ddp or ddp_local_rank == 0):
|
||||
dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
|
||||
autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
|
||||
|
||||
# ========== 4. 配wandb ==========
|
||||
wandb = None
|
||||
if args.use_wandb and is_main_process():
|
||||
import swanlab as wandb
|
||||
|
||||
wandb.init(project=args.wandb_project, name=args.wandb_run_name)
|
||||
else:
|
||||
wandb = None
|
||||
|
||||
# 初始化学生模型和教师模型
|
||||
model, tokenizer = init_student_model(lm_config_student)
|
||||
teacher_model = init_teacher_model(lm_config_teacher)
|
||||
|
||||
wandb_id = ckp_data.get('wandb_id') if ckp_data else None
|
||||
resume = 'must' if wandb_id else None
|
||||
wandb_run_name = f"MiniMind-Distill-S{args.student_hidden_size}T{args.teacher_hidden_size}-Epoch-{args.epochs}-BS-{args.batch_size}-LR-{args.learning_rate}"
|
||||
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
|
||||
|
||||
# ========== 5. 定义学生和教师模型 ==========
|
||||
model, tokenizer = init_model(lm_config_student, args.from_student_weight)
|
||||
Logger(f'学生模型总参数量:{sum(p.numel() for p in model.parameters()) / 1e6:.3f} M')
|
||||
teacher_model, _ = init_model(lm_config_teacher, args.from_teacher_weight)
|
||||
teacher_model.eval()
|
||||
teacher_model.requires_grad_(False)
|
||||
Logger(f'教师模型总参数量:{sum(p.numel() for p in teacher_model.parameters()) / 1e6:.3f} M')
|
||||
train_ds = SFTDataset(args.data_path, tokenizer, max_length=args.max_seq_len)
|
||||
train_sampler = DistributedSampler(train_ds) if ddp else None
|
||||
train_loader = DataLoader(
|
||||
train_ds,
|
||||
batch_size=args.batch_size,
|
||||
pin_memory=True,
|
||||
drop_last=False,
|
||||
shuffle=(train_sampler is None),
|
||||
num_workers=args.num_workers,
|
||||
sampler=train_sampler
|
||||
)
|
||||
|
||||
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype in ['float16', 'bfloat16']))
|
||||
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
|
||||
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
|
||||
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
|
||||
|
||||
if ddp:
|
||||
|
||||
# ========== 6. 从ckp恢复状态 ==========
|
||||
start_epoch, start_step = 0, 0
|
||||
if ckp_data:
|
||||
model.load_state_dict(ckp_data['model'])
|
||||
optimizer.load_state_dict(ckp_data['optimizer'])
|
||||
scaler.load_state_dict(ckp_data['scaler'])
|
||||
start_epoch = ckp_data['epoch']
|
||||
start_step = ckp_data.get('step', 0)
|
||||
|
||||
# ========== 7. DDP包模型 ==========
|
||||
if dist.is_initialized():
|
||||
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
|
||||
model = DistributedDataParallel(model, device_ids=[ddp_local_rank])
|
||||
|
||||
iter_per_epoch = len(train_loader)
|
||||
for epoch in range(args.epochs):
|
||||
model = DistributedDataParallel(model, device_ids=[local_rank])
|
||||
|
||||
# ========== 8. 开始训练 ==========
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
train_sampler and train_sampler.set_epoch(epoch)
|
||||
train_epoch(epoch, wandb)
|
||||
if epoch == start_epoch and start_step > 0: # 第一个epoch且存在检查点
|
||||
batch_sampler = SkipBatchSampler(train_sampler or range(len(train_ds)), args.batch_size, start_step + 1)
|
||||
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
|
||||
Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step,从step {start_step + 1}开始')
|
||||
train_epoch(epoch, loader, len(loader) + start_step + 1, teacher_model, lm_config_student, start_step, wandb, args.alpha, args.temperature)
|
||||
else: # 默认从头开始
|
||||
loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=(train_sampler is None), sampler=train_sampler, num_workers=args.num_workers, pin_memory=True)
|
||||
train_epoch(epoch, loader, len(loader), teacher_model, lm_config_student, 0, wandb, args.alpha, args.temperature)
|
||||
|
||||
@ -6,7 +6,6 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||
|
||||
import argparse
|
||||
import time
|
||||
import math
|
||||
import warnings
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@ -15,55 +14,47 @@ from contextlib import nullcontext
|
||||
from torch import optim
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
|
||||
from model.model_minimind import MiniMindConfig
|
||||
from dataset.lm_dataset import DPODataset
|
||||
from trainer.trainer_utils import *
|
||||
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
|
||||
def Logger(content):
|
||||
if not ddp or dist.get_rank() == 0:
|
||||
print(content)
|
||||
|
||||
|
||||
def get_lr(current_step, total_steps, lr):
|
||||
return lr / 10 + 0.5 * lr * (1 + math.cos(math.pi * current_step / total_steps))
|
||||
|
||||
|
||||
def logits_to_probs(logits, labels):
|
||||
def logits_to_log_probs(logits, labels):
|
||||
# logits shape: (batch_size, seq_len, vocab_size)
|
||||
# labels shape: (batch_size, seq_len)
|
||||
# probs shape: (batch_size, seq_len)
|
||||
# log_probs shape: (batch_size, seq_len)
|
||||
log_probs = F.log_softmax(logits, dim=2)
|
||||
probs = torch.gather(log_probs, dim=2, index=labels.unsqueeze(2)).squeeze(-1)
|
||||
return probs
|
||||
log_probs_per_token = torch.gather(log_probs, dim=2, index=labels.unsqueeze(2)).squeeze(-1)
|
||||
return log_probs_per_token
|
||||
|
||||
|
||||
def dpo_loss(ref_probs, probs, mask, beta):
|
||||
# ref_probs 和 probs 都是 shape: (batch_size, seq_len)
|
||||
def dpo_loss(ref_log_probs, policy_log_probs, mask, beta):
|
||||
# ref_log_probs 和 policy_log_probs 都是 shape: (batch_size, seq_len)
|
||||
# https://github.com/jingyaogong/minimind/issues/298
|
||||
seq_lengths = mask.sum(dim=1, keepdim=True) # (batch_size, 1)
|
||||
ref_probs = (ref_probs * mask).sum(dim=1) / seq_lengths.squeeze()
|
||||
probs = (probs * mask).sum(dim=1) / seq_lengths.squeeze()
|
||||
seq_lengths = mask.sum(dim=1, keepdim=True).clamp_min(1e-8) # 防止零长度mask导致除零NaN
|
||||
ref_log_probs = (ref_log_probs * mask).sum(dim=1) / seq_lengths.squeeze()
|
||||
policy_log_probs = (policy_log_probs * mask).sum(dim=1) / seq_lengths.squeeze()
|
||||
|
||||
# 将 chosen 和 rejected 数据分开
|
||||
batch_size = ref_probs.shape[0]
|
||||
chosen_ref_probs = ref_probs[:batch_size // 2]
|
||||
reject_ref_probs = ref_probs[batch_size // 2:]
|
||||
chosen_probs = probs[:batch_size // 2]
|
||||
reject_probs = probs[batch_size // 2:]
|
||||
batch_size = ref_log_probs.shape[0]
|
||||
chosen_ref_log_probs = ref_log_probs[:batch_size // 2]
|
||||
reject_ref_log_probs = ref_log_probs[batch_size // 2:]
|
||||
chosen_policy_log_probs = policy_log_probs[:batch_size // 2]
|
||||
reject_policy_log_probs = policy_log_probs[batch_size // 2:]
|
||||
|
||||
pi_logratios = chosen_probs - reject_probs
|
||||
ref_logratios = chosen_ref_probs - reject_ref_probs
|
||||
pi_logratios = chosen_policy_log_probs - reject_policy_log_probs
|
||||
ref_logratios = chosen_ref_log_probs - reject_ref_log_probs
|
||||
logits = pi_logratios - ref_logratios
|
||||
loss = -F.logsigmoid(beta * logits)
|
||||
return loss.mean()
|
||||
|
||||
|
||||
def train_epoch(epoch, wandb):
|
||||
def train_epoch(epoch, loader, iters, ref_model, lm_config, start_step=0, wandb=None, beta=0.1):
|
||||
start_time = time.time()
|
||||
for step, batch in enumerate(train_loader):
|
||||
|
||||
for step, batch in enumerate(loader, start=start_step + 1):
|
||||
x_chosen = batch['x_chosen'].to(args.device)
|
||||
x_rejected = batch['x_rejected'].to(args.device)
|
||||
y_chosen = batch['y_chosen'].to(args.device)
|
||||
@ -74,21 +65,21 @@ def train_epoch(epoch, wandb):
|
||||
y = torch.cat([y_chosen, y_rejected], dim=0)
|
||||
mask = torch.cat([mask_chosen, mask_rejected], dim=0)
|
||||
|
||||
lr = get_lr(epoch * iter_per_epoch + step, args.epochs * iter_per_epoch, args.learning_rate)
|
||||
lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate)
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = lr
|
||||
|
||||
with ctx:
|
||||
with autocast_ctx:
|
||||
with torch.no_grad():
|
||||
ref_outputs = ref_model(x)
|
||||
ref_logits = ref_outputs.logits
|
||||
ref_probs = logits_to_probs(ref_logits, y)
|
||||
ref_probs = ref_probs * mask
|
||||
ref_log_probs = logits_to_log_probs(ref_logits, y)
|
||||
|
||||
outputs = model(x)
|
||||
logits = outputs.logits
|
||||
probs = logits_to_probs(logits, y)
|
||||
probs = probs * mask
|
||||
loss = dpo_loss(ref_probs, probs, mask, beta=0.1)
|
||||
policy_log_probs = logits_to_log_probs(logits, y)
|
||||
|
||||
loss = dpo_loss(ref_log_probs, policy_log_probs, mask, beta=beta)
|
||||
loss = loss / args.accumulation_steps
|
||||
|
||||
scaler.scale(loss).backward()
|
||||
@ -100,150 +91,116 @@ def train_epoch(epoch, wandb):
|
||||
scaler.update()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
if step % args.log_interval == 0 or step == iter_per_epoch - 1:
|
||||
if step % args.log_interval == 0 or step == iters - 1:
|
||||
spend_time = time.time() - start_time
|
||||
Logger(
|
||||
'Epoch:[{}/{}]({}/{}) loss:{:.6f} lr:{:.12f} epoch_Time:{}min:'.format(
|
||||
epoch + 1,
|
||||
args.epochs,
|
||||
step,
|
||||
iter_per_epoch,
|
||||
loss.item() * args.accumulation_steps,
|
||||
optimizer.param_groups[-1]['lr'],
|
||||
spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60))
|
||||
current_loss = loss.item() * args.accumulation_steps
|
||||
current_lr = optimizer.param_groups[-1]['lr']
|
||||
eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60
|
||||
|
||||
Logger(f'Epoch:[{epoch+1}/{args.epochs}]({step}/{iters}) loss:{current_loss:.6f} lr:{current_lr:.12f} epoch_Time:{eta_min}min:')
|
||||
|
||||
if wandb: wandb.log({"loss": current_loss, "lr": current_lr, "epoch_Time": eta_min})
|
||||
|
||||
if (wandb is not None) and (not ddp or dist.get_rank() == 0):
|
||||
wandb.log({"loss": loss.item() * args.accumulation_steps,
|
||||
"lr": optimizer.param_groups[-1]['lr'],
|
||||
"epoch_Time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60})
|
||||
|
||||
if ((step + 1) % args.save_interval == 0 or step == iter_per_epoch - 1) and (not ddp or dist.get_rank() == 0):
|
||||
if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
|
||||
model.eval()
|
||||
moe_path = '_moe' if lm_config.use_moe else ''
|
||||
ckp = f'{args.save_dir}/rlhf_{lm_config.hidden_size}{moe_path}.pth'
|
||||
|
||||
moe_suffix = '_moe' if lm_config.use_moe else ''
|
||||
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
|
||||
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
|
||||
state_dict = model.module.state_dict()
|
||||
else:
|
||||
state_dict = model.state_dict()
|
||||
state_dict = {k: v.half() for k, v in state_dict.items()} # 半精度保存
|
||||
torch.save(state_dict, ckp)
|
||||
lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints')
|
||||
model.train()
|
||||
|
||||
|
||||
def init_model(lm_config):
|
||||
tokenizer = AutoTokenizer.from_pretrained('../model/')
|
||||
model = MiniMindForCausalLM(lm_config)
|
||||
moe_path = '_moe' if lm_config.use_moe else ''
|
||||
ckp = f'{args.save_dir}/full_sft_{lm_config.hidden_size}{moe_path}.pth'
|
||||
state_dict = torch.load(ckp, map_location=args.device)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
# 初始化参考模型
|
||||
ref_model = MiniMindForCausalLM(lm_config)
|
||||
ref_model.load_state_dict(state_dict, strict=False)
|
||||
ref_model.eval()
|
||||
ref_model.requires_grad_(False)
|
||||
|
||||
Logger(f'LLM总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
|
||||
model = model.to(args.device)
|
||||
ref_model = ref_model.to(args.device)
|
||||
|
||||
return model, ref_model, tokenizer
|
||||
|
||||
|
||||
def init_distributed_mode():
|
||||
if not ddp: return
|
||||
global ddp_local_rank, DEVICE
|
||||
|
||||
dist.init_process_group(backend="nccl")
|
||||
ddp_rank = int(os.environ["RANK"])
|
||||
ddp_local_rank = int(os.environ["LOCAL_RANK"])
|
||||
ddp_world_size = int(os.environ["WORLD_SIZE"])
|
||||
DEVICE = f"cuda:{ddp_local_rank}"
|
||||
torch.cuda.set_device(DEVICE)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="MiniMind RLHF")
|
||||
parser.add_argument("--out_dir", type=str, default="../out")
|
||||
parser.add_argument("--epochs", type=int, default=2)
|
||||
parser.add_argument("--batch_size", type=int, default=4)
|
||||
# sft阶段学习率为 「5e-6」->「5e-7」长度512,建议离线正负样本「概率」偏好对齐阶段lr <=「1e-8」长度3000,否则很容易遗忘训坏
|
||||
parser.add_argument("--learning_rate", type=float, default=1e-8)
|
||||
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
parser.add_argument("--dtype", type=str, default="bfloat16")
|
||||
parser.add_argument("--use_wandb", action="store_true")
|
||||
parser.add_argument("--wandb_project", type=str, default="MiniMind-RLHF-SFT")
|
||||
parser.add_argument("--num_workers", type=int, default=1)
|
||||
parser.add_argument("--ddp", action="store_true")
|
||||
parser.add_argument("--accumulation_steps", type=int, default=1)
|
||||
parser.add_argument("--grad_clip", type=float, default=1.0)
|
||||
parser.add_argument("--warmup_iters", type=int, default=0)
|
||||
parser.add_argument("--log_interval", type=int, default=100)
|
||||
parser.add_argument("--save_interval", type=int, default=100)
|
||||
parser.add_argument('--local_rank', type=int, default=-1)
|
||||
parser.add_argument('--hidden_size', default=512, type=int)
|
||||
parser.add_argument('--num_hidden_layers', default=8, type=int)
|
||||
parser.add_argument('--max_seq_len', default=1024, type=int)
|
||||
parser.add_argument('--use_moe', default=False, type=bool)
|
||||
parser.add_argument("--data_path", type=str, default="../dataset/dpo.jsonl")
|
||||
|
||||
parser = argparse.ArgumentParser(description="MiniMind DPO (Direct Preference Optimization)")
|
||||
parser.add_argument("--save_dir", type=str, default="../out", help="模型保存目录")
|
||||
parser.add_argument('--save_weight', default='dpo', type=str, help="保存权重的前缀名")
|
||||
parser.add_argument("--epochs", type=int, default=2, help="训练轮数")
|
||||
parser.add_argument("--batch_size", type=int, default=4, help="batch size")
|
||||
parser.add_argument("--learning_rate", type=float, default=5e-8, help="初始学习率(建议<=5e-8避免遗忘)")
|
||||
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备")
|
||||
parser.add_argument("--dtype", type=str, default="bfloat16", help="混合精度类型")
|
||||
parser.add_argument("--num_workers", type=int, default=1, help="数据加载线程数")
|
||||
parser.add_argument("--accumulation_steps", type=int, default=1, help="梯度累积步数")
|
||||
parser.add_argument("--grad_clip", type=float, default=1.0, help="梯度裁剪阈值")
|
||||
parser.add_argument("--log_interval", type=int, default=100, help="日志打印间隔")
|
||||
parser.add_argument("--save_interval", type=int, default=100, help="模型保存间隔")
|
||||
parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度")
|
||||
parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量")
|
||||
parser.add_argument('--max_seq_len', default=1024, type=int, help="训练的最大截断长度")
|
||||
parser.add_argument('--use_moe', default=False, type=bool, help="是否使用MoE")
|
||||
parser.add_argument("--data_path", type=str, default="../dataset/dpo.jsonl", help="DPO训练数据路径")
|
||||
parser.add_argument('--from_weight', default='full_sft', type=str, help="基于哪个权重训练")
|
||||
parser.add_argument('--from_resume', default=0, type=int, help="是否自动检测&续训,0否1是")
|
||||
parser.add_argument('--beta', default=0.1, type=float, help="DPO中的beta参数")
|
||||
parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
|
||||
parser.add_argument("--wandb_project", type=str, default="MiniMind-DPO", help="wandb项目名")
|
||||
args = parser.parse_args()
|
||||
|
||||
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers,
|
||||
use_moe=args.use_moe)
|
||||
args.save_dir = os.path.join(args.out_dir)
|
||||
# ========== 1. 初始化环境和随机种子 ==========
|
||||
local_rank = init_distributed_mode()
|
||||
if dist.is_initialized(): args.device = f"cuda:{local_rank}"
|
||||
setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))
|
||||
|
||||
# ========== 2. 配置目录、模型参数、检查ckp ==========
|
||||
os.makedirs(args.save_dir, exist_ok=True)
|
||||
os.makedirs(args.out_dir, exist_ok=True)
|
||||
tokens_per_iter = args.batch_size * args.max_seq_len
|
||||
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=args.use_moe)
|
||||
ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None
|
||||
|
||||
# ========== 3. 设置混合精度 ==========
|
||||
device_type = "cuda" if "cuda" in args.device else "cpu"
|
||||
|
||||
args.wandb_run_name = f"MiniMind-Full-DPO-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"
|
||||
|
||||
ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast()
|
||||
ddp = int(os.environ.get("RANK", -1)) != -1 # is this a ddp run?
|
||||
ddp_local_rank, DEVICE = 0, "cuda:0"
|
||||
base_seed = 1337
|
||||
torch.manual_seed(base_seed)
|
||||
torch.cuda.manual_seed(base_seed)
|
||||
|
||||
if ddp:
|
||||
init_distributed_mode()
|
||||
args.device = torch.device(DEVICE)
|
||||
rank = dist.get_rank()
|
||||
torch.manual_seed(base_seed + rank)
|
||||
# 同时设置 CUDA 的随机种子
|
||||
torch.cuda.manual_seed(base_seed + rank)
|
||||
|
||||
if args.use_wandb and (not ddp or ddp_local_rank == 0):
|
||||
dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
|
||||
autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
|
||||
|
||||
# ========== 4. 配wandb ==========
|
||||
wandb = None
|
||||
if args.use_wandb and is_main_process():
|
||||
import swanlab as wandb
|
||||
|
||||
wandb.init(project=args.wandb_project, name=args.wandb_run_name)
|
||||
else:
|
||||
wandb = None
|
||||
|
||||
model, ref_model, tokenizer = init_model(lm_config)
|
||||
|
||||
wandb_id = ckp_data.get('wandb_id') if ckp_data else None
|
||||
resume = 'must' if wandb_id else None
|
||||
wandb_run_name = f"MiniMind-DPO-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LR-{args.learning_rate}"
|
||||
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
|
||||
|
||||
# ========== 5. 定义模型和参考模型 ==========
|
||||
model, tokenizer = init_model(lm_config, args.from_weight)
|
||||
Logger(f'策略模型总参数量:{sum(p.numel() for p in model.parameters()) / 1e6:.3f} M')
|
||||
# 初始化参考模型(ref_model冻结)
|
||||
ref_model, _ = init_model(lm_config, args.from_weight)
|
||||
ref_model.eval()
|
||||
ref_model.requires_grad_(False)
|
||||
Logger(f'参考模型总参数量:{sum(p.numel() for p in ref_model.parameters()) / 1e6:.3f} M')
|
||||
|
||||
train_ds = DPODataset(args.data_path, tokenizer, max_length=args.max_seq_len)
|
||||
train_sampler = DistributedSampler(train_ds) if ddp else None
|
||||
train_loader = DataLoader(
|
||||
train_ds,
|
||||
batch_size=args.batch_size,
|
||||
pin_memory=True,
|
||||
drop_last=False,
|
||||
shuffle=(train_sampler is None),
|
||||
num_workers=args.num_workers,
|
||||
sampler=train_sampler
|
||||
)
|
||||
|
||||
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype in ['float16', 'bfloat16']))
|
||||
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
|
||||
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
|
||||
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
|
||||
|
||||
if ddp:
|
||||
|
||||
# ========== 6. 从ckp恢复状态 ==========
|
||||
start_epoch, start_step = 0, 0
|
||||
if ckp_data:
|
||||
model.load_state_dict(ckp_data['model'])
|
||||
optimizer.load_state_dict(ckp_data['optimizer'])
|
||||
scaler.load_state_dict(ckp_data['scaler'])
|
||||
start_epoch = ckp_data['epoch']
|
||||
start_step = ckp_data.get('step', 0)
|
||||
|
||||
# ========== 7. DDP包模型 ==========
|
||||
if dist.is_initialized():
|
||||
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
|
||||
model = DistributedDataParallel(model, device_ids=[ddp_local_rank])
|
||||
|
||||
iter_per_epoch = len(train_loader)
|
||||
for epoch in range(args.epochs):
|
||||
model = DistributedDataParallel(model, device_ids=[local_rank])
|
||||
|
||||
# ========== 8. 开始训练 ==========
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
train_sampler and train_sampler.set_epoch(epoch)
|
||||
train_epoch(epoch, wandb)
|
||||
if epoch == start_epoch and start_step > 0: # 第一个epoch且存在检查点
|
||||
batch_sampler = SkipBatchSampler(train_sampler or range(len(train_ds)), args.batch_size, start_step + 1)
|
||||
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
|
||||
Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step,从step {start_step + 1}开始')
|
||||
train_epoch(epoch, loader, len(loader) + start_step + 1, ref_model, lm_config, start_step, wandb, args.beta)
|
||||
else: # 默认从头开始
|
||||
loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=(train_sampler is None), sampler=train_sampler, num_workers=args.num_workers, pin_memory=True)
|
||||
train_epoch(epoch, loader, len(loader), ref_model, lm_config, 0, wandb, args.beta)
|
||||
|
||||
@ -6,7 +6,6 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||
|
||||
import argparse
|
||||
import time
|
||||
import math
|
||||
import warnings
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -14,34 +13,25 @@ from contextlib import nullcontext
|
||||
from torch import optim, nn
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
|
||||
from model.model_minimind import MiniMindConfig
|
||||
from dataset.lm_dataset import SFTDataset
|
||||
from trainer.trainer_utils import *
|
||||
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
|
||||
def Logger(content):
|
||||
if not ddp or dist.get_rank() == 0:
|
||||
print(content)
|
||||
|
||||
|
||||
def get_lr(current_step, total_steps, lr):
|
||||
return lr / 10 + 0.5 * lr * (1 + math.cos(math.pi * current_step / total_steps))
|
||||
|
||||
|
||||
def train_epoch(epoch, wandb):
|
||||
def train_epoch(epoch, loader, iters, start_step=0, wandb=None):
|
||||
loss_fct = nn.CrossEntropyLoss(reduction='none')
|
||||
start_time = time.time()
|
||||
for step, (X, Y, loss_mask) in enumerate(train_loader):
|
||||
for step, (X, Y, loss_mask) in enumerate(loader, start=start_step + 1):
|
||||
X = X.to(args.device)
|
||||
Y = Y.to(args.device)
|
||||
loss_mask = loss_mask.to(args.device)
|
||||
lr = get_lr(epoch * iter_per_epoch + step, args.epochs * iter_per_epoch, args.learning_rate)
|
||||
lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate)
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = lr
|
||||
|
||||
with ctx:
|
||||
with autocast_ctx:
|
||||
res = model(X)
|
||||
loss = loss_fct(
|
||||
res.logits.view(-1, res.logits.size(-1)),
|
||||
@ -63,141 +53,109 @@ def train_epoch(epoch, wandb):
|
||||
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
if step % args.log_interval == 0 or step == iter_per_epoch - 1:
|
||||
if step % args.log_interval == 0 or step == iters - 1:
|
||||
spend_time = time.time() - start_time
|
||||
Logger(
|
||||
'Epoch:[{}/{}]({}/{}) loss:{:.6f} lr:{:.12f} epoch_Time:{}min:'.format(
|
||||
epoch + 1,
|
||||
args.epochs,
|
||||
step,
|
||||
iter_per_epoch,
|
||||
loss.item() * args.accumulation_steps,
|
||||
optimizer.param_groups[-1]['lr'],
|
||||
spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60))
|
||||
current_loss = loss.item() * args.accumulation_steps
|
||||
current_lr = optimizer.param_groups[-1]['lr']
|
||||
eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60
|
||||
|
||||
Logger(f'Epoch:[{epoch+1}/{args.epochs}]({step}/{iters}) loss:{current_loss:.6f} lr:{current_lr:.12f} epoch_Time:{eta_min}min:')
|
||||
|
||||
if wandb: wandb.log({"loss": current_loss, "lr": current_lr, "epoch_Time": eta_min})
|
||||
|
||||
if (wandb is not None) and (not ddp or dist.get_rank() == 0):
|
||||
wandb.log({"loss": loss.item() * args.accumulation_steps,
|
||||
"lr": optimizer.param_groups[-1]['lr'],
|
||||
"epoch_Time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60})
|
||||
|
||||
if ((step + 1) % args.save_interval == 0 or step == iter_per_epoch - 1) and (not ddp or dist.get_rank() == 0):
|
||||
if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
|
||||
model.eval()
|
||||
moe_path = '_moe' if lm_config.use_moe else ''
|
||||
ckp = f'{args.save_dir}/full_sft_{lm_config.hidden_size}{moe_path}.pth'
|
||||
moe_suffix = '_moe' if lm_config.use_moe else ''
|
||||
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
|
||||
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
|
||||
state_dict = model.module.state_dict()
|
||||
else:
|
||||
state_dict = model.state_dict()
|
||||
state_dict = {k: v.half() for k, v in state_dict.items()} # 半精度保存
|
||||
torch.save(state_dict, ckp)
|
||||
lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer,
|
||||
epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints', scaler=scaler)
|
||||
model.train()
|
||||
|
||||
|
||||
def init_model(lm_config):
|
||||
tokenizer = AutoTokenizer.from_pretrained('../model')
|
||||
model = MiniMindForCausalLM(lm_config)
|
||||
moe_path = '_moe' if lm_config.use_moe else ''
|
||||
ckp = f'{args.save_dir}/pretrain_{lm_config.hidden_size}{moe_path}.pth'
|
||||
state_dict = torch.load(ckp, map_location=args.device)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
Logger(f'LLM可训练总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
|
||||
model = model.to(args.device)
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def init_distributed_mode():
|
||||
if not ddp: return
|
||||
global ddp_local_rank, DEVICE
|
||||
|
||||
dist.init_process_group(backend="nccl")
|
||||
ddp_rank = int(os.environ["RANK"])
|
||||
ddp_local_rank = int(os.environ["LOCAL_RANK"])
|
||||
ddp_world_size = int(os.environ["WORLD_SIZE"])
|
||||
DEVICE = f"cuda:{ddp_local_rank}"
|
||||
torch.cuda.set_device(DEVICE)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="MiniMind Full SFT")
|
||||
parser.add_argument("--out_dir", type=str, default="../out")
|
||||
parser.add_argument("--epochs", type=int, default=2)
|
||||
parser.add_argument("--batch_size", type=int, default=16)
|
||||
parser.add_argument("--learning_rate", type=float, default=5e-7)
|
||||
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
parser.add_argument("--dtype", type=str, default="bfloat16")
|
||||
parser.add_argument("--use_wandb", action="store_true")
|
||||
parser.add_argument("--wandb_project", type=str, default="MiniMind-Full-SFT")
|
||||
parser.add_argument("--num_workers", type=int, default=1)
|
||||
parser.add_argument("--ddp", action="store_true")
|
||||
parser.add_argument("--accumulation_steps", type=int, default=1)
|
||||
parser.add_argument("--grad_clip", type=float, default=1.0)
|
||||
parser.add_argument("--warmup_iters", type=int, default=0)
|
||||
parser.add_argument("--log_interval", type=int, default=100)
|
||||
parser.add_argument("--save_interval", type=int, default=100)
|
||||
parser.add_argument('--local_rank', type=int, default=-1)
|
||||
parser.add_argument('--hidden_size', default=512, type=int)
|
||||
parser.add_argument('--num_hidden_layers', default=8, type=int)
|
||||
parser.add_argument('--max_seq_len', default=512, type=int)
|
||||
parser.add_argument('--use_moe', default=False, type=bool)
|
||||
parser.add_argument("--data_path", type=str, default="../dataset/sft_mini_512.jsonl")
|
||||
|
||||
parser.add_argument("--save_dir", type=str, default="../out", help="模型保存目录")
|
||||
parser.add_argument('--save_weight', default='full_sft', type=str, help="保存权重的前缀名")
|
||||
parser.add_argument("--epochs", type=int, default=2, help="训练轮数")
|
||||
parser.add_argument("--batch_size", type=int, default=16, help="batch size")
|
||||
parser.add_argument("--learning_rate", type=float, default=5e-7, help="初始学习率")
|
||||
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备")
|
||||
parser.add_argument("--dtype", type=str, default="bfloat16", help="混合精度类型")
|
||||
parser.add_argument("--num_workers", type=int, default=1, help="数据加载线程数")
|
||||
parser.add_argument("--accumulation_steps", type=int, default=1, help="梯度累积步数")
|
||||
parser.add_argument("--grad_clip", type=float, default=1.0, help="梯度裁剪阈值")
|
||||
parser.add_argument("--log_interval", type=int, default=100, help="日志打印间隔")
|
||||
parser.add_argument("--save_interval", type=int, default=100, help="模型保存间隔")
|
||||
parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度")
|
||||
parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量")
|
||||
parser.add_argument('--max_seq_len', default=512, type=int, help="训练的最大截断长度")
|
||||
parser.add_argument('--use_moe', default=False, type=bool, help="是否使用MoE")
|
||||
parser.add_argument("--data_path", type=str, default="../dataset/sft_mini_512.jsonl", help="训练数据路径")
|
||||
parser.add_argument('--from_weight', default='pretrain', type=str, help="基于哪个权重训练,为none则不基于任何权重训练")
|
||||
parser.add_argument('--from_resume', default=0, type=int, help="是否自动检测&续训,0否1是")
|
||||
parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
|
||||
parser.add_argument("--wandb_project", type=str, default="MiniMind-Full-SFT", help="wandb项目名")
|
||||
args = parser.parse_args()
|
||||
|
||||
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers,
|
||||
use_moe=args.use_moe)
|
||||
args.save_dir = os.path.join(args.out_dir)
|
||||
# ========== 1. 初始化环境和随机种子 ==========
|
||||
local_rank = init_distributed_mode()
|
||||
if dist.is_initialized(): args.device = f"cuda:{local_rank}"
|
||||
setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))
|
||||
|
||||
# ========== 2. 配置目录、模型参数、检查ckp ==========
|
||||
os.makedirs(args.save_dir, exist_ok=True)
|
||||
os.makedirs(args.out_dir, exist_ok=True)
|
||||
tokens_per_iter = args.batch_size * args.max_seq_len
|
||||
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=args.use_moe)
|
||||
ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None
|
||||
|
||||
# ========== 3. 设置混合精度 ==========
|
||||
device_type = "cuda" if "cuda" in args.device else "cpu"
|
||||
|
||||
args.wandb_run_name = f"MiniMind-Full-SFT-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"
|
||||
|
||||
ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast()
|
||||
ddp = int(os.environ.get("RANK", -1)) != -1 # is this a ddp run?
|
||||
ddp_local_rank, DEVICE = 0, "cuda:0"
|
||||
base_seed = 1337
|
||||
torch.manual_seed(base_seed)
|
||||
torch.cuda.manual_seed(base_seed)
|
||||
|
||||
if ddp:
|
||||
init_distributed_mode()
|
||||
args.device = torch.device(DEVICE)
|
||||
rank = dist.get_rank()
|
||||
torch.manual_seed(base_seed + rank)
|
||||
# 同时设置 CUDA 的随机种子
|
||||
torch.cuda.manual_seed(base_seed + rank)
|
||||
|
||||
if args.use_wandb and (not ddp or ddp_local_rank == 0):
|
||||
dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
|
||||
autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
|
||||
|
||||
# ========== 4. 配wandb ==========
|
||||
wandb = None
|
||||
if args.use_wandb and is_main_process():
|
||||
import swanlab as wandb
|
||||
|
||||
wandb.init(project=args.wandb_project, name=args.wandb_run_name)
|
||||
else:
|
||||
wandb = None
|
||||
|
||||
model, tokenizer = init_model(lm_config)
|
||||
|
||||
wandb_id = ckp_data.get('wandb_id') if ckp_data else None
|
||||
resume = 'must' if wandb_id else None
|
||||
wandb_run_name = f"MiniMind-Full-SFT-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"
|
||||
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
|
||||
|
||||
# ========== 5. 定义模型、数据、优化器 ==========
|
||||
model, tokenizer = init_model(lm_config, args.from_weight)
|
||||
train_ds = SFTDataset(args.data_path, tokenizer, max_length=args.max_seq_len)
|
||||
train_sampler = DistributedSampler(train_ds) if ddp else None
|
||||
train_loader = DataLoader(
|
||||
train_ds,
|
||||
batch_size=args.batch_size,
|
||||
pin_memory=True,
|
||||
drop_last=False,
|
||||
shuffle=(train_sampler is None),
|
||||
num_workers=args.num_workers,
|
||||
sampler=train_sampler
|
||||
)
|
||||
|
||||
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype in ['float16', 'bfloat16']))
|
||||
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
|
||||
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
|
||||
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
|
||||
|
||||
if ddp:
|
||||
|
||||
# ========== 6. 从ckp恢复状态 ==========
|
||||
start_epoch, start_step = 0, 0
|
||||
if ckp_data:
|
||||
model.load_state_dict(ckp_data['model'])
|
||||
optimizer.load_state_dict(ckp_data['optimizer'])
|
||||
scaler.load_state_dict(ckp_data['scaler'])
|
||||
start_epoch = ckp_data['epoch']
|
||||
start_step = ckp_data.get('step', 0)
|
||||
|
||||
# ========== 7. DDP包模型 ==========
|
||||
if dist.is_initialized():
|
||||
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
|
||||
model = DistributedDataParallel(model, device_ids=[ddp_local_rank])
|
||||
|
||||
iter_per_epoch = len(train_loader)
|
||||
for epoch in range(args.epochs):
|
||||
model = DistributedDataParallel(model, device_ids=[local_rank])
|
||||
|
||||
# ========== 8. 开始训练 ==========
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
train_sampler and train_sampler.set_epoch(epoch)
|
||||
train_epoch(epoch, wandb)
|
||||
if epoch == start_epoch and start_step > 0: # 第一个epoch且存在检查点
|
||||
batch_sampler = SkipBatchSampler(train_sampler or range(len(train_ds)), args.batch_size, start_step + 1)
|
||||
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
|
||||
Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step,从step {start_step + 1}开始')
|
||||
train_epoch(epoch, loader, len(loader) + start_step + 1, start_step, wandb)
|
||||
else: # 默认从头开始
|
||||
loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=(train_sampler is None), sampler=train_sampler, num_workers=args.num_workers, pin_memory=True)
|
||||
train_epoch(epoch, loader, len(loader), 0, wandb)
|
||||
|
||||
@ -3,59 +3,49 @@ import sys
|
||||
|
||||
__package__ = "trainer"
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||
|
||||
import argparse
|
||||
import time
|
||||
import re
|
||||
import gc
|
||||
import warnings
|
||||
import torch
|
||||
from contextlib import nullcontext
|
||||
import torch.distributed as dist
|
||||
from transformers import AutoTokenizer
|
||||
from contextlib import nullcontext
|
||||
from torch import optim
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModel
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
from transformers import AutoModel
|
||||
from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
|
||||
from dataset.lm_dataset import RLAIFDataset
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
from trainer.trainer_utils import *
|
||||
|
||||
|
||||
def Logger(content):
|
||||
if not ddp or dist.get_rank() == 0:
|
||||
print(content)
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
|
||||
def calculate_rewards(prompts, responses, reward_model, reward_tokenizer):
|
||||
"""整合所有奖励函数计算总奖励"""
|
||||
|
||||
def reasoning_model_reward(rewards):
|
||||
# 1. 格式奖励(仅针对训练推理模型时使用)
|
||||
pattern = r"^<think>\n.*?\n</think>\n<answer>\n.*?\n</answer>$"
|
||||
pattern2 = r"^<think>\n.*?\n</think>\n\n<answer>\n.*?\n</answer>$"
|
||||
|
||||
matches_pattern = [re.match(pattern, response, re.S) for response in responses]
|
||||
matches_pattern2 = [re.match(pattern2, response, re.S) for response in responses]
|
||||
|
||||
format_rewards = []
|
||||
for match_pattern, match_pattern2 in zip(matches_pattern, matches_pattern2):
|
||||
if match_pattern:
|
||||
format_rewards.append(0.5)
|
||||
elif match_pattern2:
|
||||
if match_pattern or match_pattern2:
|
||||
format_rewards.append(0.5)
|
||||
else:
|
||||
format_rewards.append(0.0)
|
||||
rewards += torch.tensor(format_rewards, device=args.device)
|
||||
|
||||
# 2. 标记奖励(防止严格奖励稀疏,仅针对训练推理模型时使用)
|
||||
def mark_num(text):
|
||||
reward = 0
|
||||
if text.count("<think>") == 1:
|
||||
reward += 0.25
|
||||
if text.count("</think>") == 1:
|
||||
reward += 0.25
|
||||
if text.count("<answer>") == 1:
|
||||
reward += 0.25
|
||||
if text.count("</answer>") == 1:
|
||||
reward += 0.25
|
||||
if text.count("<think>") == 1: reward += 0.25
|
||||
if text.count("</think>") == 1: reward += 0.25
|
||||
if text.count("<answer>") == 1: reward += 0.25
|
||||
if text.count("</answer>") == 1: reward += 0.25
|
||||
return reward
|
||||
|
||||
mark_rewards = [mark_num(response) for response in responses]
|
||||
@ -63,12 +53,9 @@ def calculate_rewards(prompts, responses, reward_model, reward_tokenizer):
|
||||
return rewards
|
||||
|
||||
rewards = torch.zeros(len(responses), device=args.device)
|
||||
|
||||
# 3. 格式奖励
|
||||
if args.reasoning == 1:
|
||||
rewards = reasoning_model_reward(rewards) # 训练推理模型时使用
|
||||
rewards = reasoning_model_reward(rewards)
|
||||
|
||||
# 4. 使用reward model计算奖励
|
||||
with torch.no_grad():
|
||||
reward_model_scores = []
|
||||
batch_size = len(prompts)
|
||||
@ -105,8 +92,8 @@ def calculate_rewards(prompts, responses, reward_model, reward_tokenizer):
|
||||
return rewards
|
||||
|
||||
|
||||
def grpo_train_epoch(epoch, wandb):
|
||||
for step, batch in enumerate(train_loader):
|
||||
def grpo_train_epoch(epoch, loader, iters, ref_model, reward_model, reward_tokenizer, start_step=0, wandb=None):
|
||||
for step, batch in enumerate(loader, start=start_step + 1):
|
||||
prompts = batch['prompt'] # list[str], length B
|
||||
prompt_inputs = tokenizer(prompts, return_tensors="pt", padding=True, return_token_type_ids=False,
|
||||
padding_side="left", add_special_tokens=False).to(args.device) # input_ids: [B, P], attention_mask: [B, P]
|
||||
@ -115,7 +102,9 @@ def grpo_train_epoch(epoch, wandb):
|
||||
prompt_inputs["attention_mask"] = prompt_inputs["attention_mask"][:, -args.max_seq_len:]
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = (model.module if ddp else model).generate(
|
||||
# DDP 模型需要使用 .module 访问 generate 方法
|
||||
model_for_gen = model.module if isinstance(model, DistributedDataParallel) else model
|
||||
outputs = model_for_gen.generate(
|
||||
**prompt_inputs, max_new_tokens=args.max_gen_len, do_sample=True, temperature=0.8,
|
||||
num_return_sequences=args.num_generations, pad_token_id=tokenizer.pad_token_id) # [B*num_gen, P+R]
|
||||
|
||||
@ -161,36 +150,33 @@ def grpo_train_epoch(epoch, wandb):
|
||||
scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
if step % args.log_interval == 0 or step == iter_per_epoch - 1:
|
||||
if step % args.log_interval == 0 or step == iters:
|
||||
policy_loss_val = loss.item()
|
||||
avg_reward_val = rewards.mean().item()
|
||||
avg_len_val = completion_mask.sum(dim=1).float().mean().item()
|
||||
current_lr = optimizer.param_groups[0]['lr']
|
||||
|
||||
Logger(
|
||||
f'Epoch: {epoch}, Step: {step + 1}/{iter_per_epoch}, '
|
||||
f'Actor Loss: {policy_loss_val:.6f}, Reward: {avg_reward_val:.6f}, '
|
||||
f'Avg Response Len: {avg_len_val:.2f}, LR: {current_lr:.2e}')
|
||||
Logger(f'Epoch: {epoch+1}, Step: {step}/{iters}, '
|
||||
f'Actor Loss: {policy_loss_val:.6f}, Reward: {avg_reward_val:.6f}, '
|
||||
f'Avg Response Len: {avg_len_val:.2f}, LR: {current_lr:.2e}')
|
||||
|
||||
if wandb and (not ddp or dist.get_rank() == 0):
|
||||
log_dict = {
|
||||
if wandb and is_main_process():
|
||||
wandb.log({
|
||||
"policy_loss": policy_loss_val,
|
||||
"reward": avg_reward_val,
|
||||
"avg_response_len": avg_len_val,
|
||||
"advantages_mean": advantages.mean().item(),
|
||||
"learning_rate": current_lr
|
||||
}
|
||||
wandb.log(log_dict)
|
||||
})
|
||||
|
||||
if ((step + 1) % args.save_interval == 0 or step == iter_per_epoch - 1) and (not ddp or dist.get_rank() == 0):
|
||||
if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
|
||||
model.eval()
|
||||
moe_path = '_moe' if lm_config.use_moe else ''
|
||||
suffix = 'grpo'
|
||||
ckp = f'{args.save_dir}/{suffix}_{lm_config.hidden_size}{moe_path}.pth'
|
||||
|
||||
state_dict = model.module.state_dict() if isinstance(model,
|
||||
torch.nn.parallel.DistributedDataParallel) else model.state_dict()
|
||||
moe_suffix = '_moe' if lm_config.use_moe else ''
|
||||
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
|
||||
state_dict = model.module.state_dict() if isinstance(model, DistributedDataParallel) else model.state_dict()
|
||||
torch.save({k: v.half() for k, v in state_dict.items()}, ckp)
|
||||
lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer,
|
||||
epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints', scheduler=scheduler)
|
||||
model.train()
|
||||
|
||||
del prompt_inputs, outputs, completion_ids, per_token_logps, ref_per_token_logps
|
||||
@ -199,119 +185,114 @@ def grpo_train_epoch(epoch, wandb):
|
||||
gc.collect()
|
||||
|
||||
|
||||
def init_model(lm_config):
|
||||
tokenizer = AutoTokenizer.from_pretrained('../model/')
|
||||
model = MiniMindForCausalLM(lm_config)
|
||||
moe_path = '_moe' if lm_config.use_moe else ''
|
||||
ckp = f'{args.save_dir}/full_sft_{lm_config.hidden_size}{moe_path}.pth'
|
||||
if args.reasoning == 1:
|
||||
ckp = f'{args.save_dir}/reason_{lm_config.hidden_size}{moe_path}.pth'
|
||||
state_dict = torch.load(ckp, map_location=args.device)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="MiniMind GRPO (Group Relative Policy Optimization)")
|
||||
parser.add_argument("--save_dir", type=str, default="../out", help="模型保存目录")
|
||||
parser.add_argument('--save_weight', default='grpo', type=str, help="保存权重的前缀名")
|
||||
parser.add_argument("--epochs", type=int, default=1, help="训练轮数")
|
||||
parser.add_argument("--batch_size", type=int, default=2, help="batch size")
|
||||
parser.add_argument("--learning_rate", type=float, default=8e-8, help="初始学习率")
|
||||
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备")
|
||||
parser.add_argument("--dtype", type=str, default="bfloat16", help="混合精度类型")
|
||||
parser.add_argument("--num_workers", type=int, default=1, help="数据加载线程数")
|
||||
parser.add_argument("--accumulation_steps", type=int, default=1, help="梯度累积步数")
|
||||
parser.add_argument("--grad_clip", type=float, default=1.0, help="梯度裁剪阈值")
|
||||
parser.add_argument("--log_interval", type=int, default=1, help="日志打印间隔")
|
||||
parser.add_argument("--save_interval", type=int, default=10, help="模型保存间隔")
|
||||
parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度")
|
||||
parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量")
|
||||
parser.add_argument('--use_moe', default=False, type=bool, help="是否使用MoE")
|
||||
parser.add_argument('--max_seq_len', default=66, type=int, help="Prompt最大长度")
|
||||
parser.add_argument("--max_gen_len", type=int, default=1536, help="生成的最大长度")
|
||||
parser.add_argument("--data_path", type=str, default="../dataset/rlaif-mini.jsonl", help="RLAIF数据路径")
|
||||
parser.add_argument("--num_generations", type=int, default=8, help="每个prompt生成的样本数")
|
||||
parser.add_argument("--beta", type=float, default=0.02, help="KL惩罚系数")
|
||||
parser.add_argument("--reasoning", type=int, default=1, help='0:普通模型,1:推理模型')
|
||||
parser.add_argument("--reward_model_path", type=str, default="../../internlm2-1_8b-reward", help="Reward模型路径")
|
||||
parser.add_argument('--from_resume', default=0, type=int, help="是否自动检测&续训,0否1是")
|
||||
parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
|
||||
parser.add_argument("--wandb_project", type=str, default="MiniMind-GRPO", help="wandb项目名")
|
||||
args = parser.parse_args()
|
||||
|
||||
# ========== 1. 初始化环境和随机种子 ==========
|
||||
local_rank = init_distributed_mode()
|
||||
if dist.is_initialized(): args.device = f"cuda:{local_rank}"
|
||||
setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))
|
||||
|
||||
# ========== 2. 配置目录、模型参数、检查ckp ==========
|
||||
os.makedirs(args.save_dir, exist_ok=True)
|
||||
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers,
|
||||
max_seq_len=args.max_seq_len + args.max_gen_len, use_moe=args.use_moe)
|
||||
ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None
|
||||
|
||||
# ========== 3. 设置混合精度 ==========
|
||||
device_type = "cuda" if "cuda" in args.device else "cpu"
|
||||
dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
|
||||
autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
|
||||
|
||||
# ========== 4. 配wandb ==========
|
||||
wandb = None
|
||||
if args.use_wandb and is_main_process():
|
||||
import swanlab as wandb
|
||||
wandb_id = ckp_data.get('wandb_id') if ckp_data else None
|
||||
resume = 'must' if wandb_id else None
|
||||
wandb_run_name = f"MiniMind-GRPO-Epoch-{args.epochs}-BS-{args.batch_size}-LR-{args.learning_rate}"
|
||||
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
|
||||
|
||||
# ========== 5. 初始化模型和数据 ==========
|
||||
tokenizer = AutoTokenizer.from_pretrained('../model/')
|
||||
moe_suffix = '_moe' if lm_config.use_moe else ''
|
||||
base_weight = "reason" if args.reasoning == 1 else "full_sft"
|
||||
ckp = f'{args.save_dir}/{base_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
|
||||
state_dict = torch.load(ckp, map_location=args.device)
|
||||
# Policy模型
|
||||
model = MiniMindForCausalLM(lm_config)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
model = model.to(args.device)
|
||||
Logger(f'Policy模型总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} M')
|
||||
# Reference模型
|
||||
ref_model = MiniMindForCausalLM(lm_config)
|
||||
ref_model.load_state_dict(state_dict, strict=False)
|
||||
ref_model.eval().requires_grad_(False)
|
||||
|
||||
Logger(f'LLM总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
|
||||
model = model.to(args.device)
|
||||
ref_model = ref_model.to(args.device)
|
||||
|
||||
reward_name = "../../internlm2-1_8b-reward"
|
||||
# Reward模型
|
||||
reward_model = AutoModel.from_pretrained(
|
||||
reward_name,
|
||||
device_map="cuda",
|
||||
torch_dtype=torch.float16,
|
||||
trust_remote_code=True,
|
||||
args.reward_model_path, device_map="cuda", torch_dtype=torch.float16, trust_remote_code=True
|
||||
).to(args.device).eval().requires_grad_(False)
|
||||
reward_tokenizer = AutoTokenizer.from_pretrained(reward_name, trust_remote_code=True)
|
||||
|
||||
return model, ref_model, tokenizer, reward_model, reward_tokenizer
|
||||
|
||||
|
||||
def init_distributed_mode():
|
||||
if not ddp: return
|
||||
global ddp_local_rank, DEVICE
|
||||
dist.init_process_group(backend="nccl")
|
||||
ddp_local_rank = int(os.environ["LOCAL_RANK"])
|
||||
DEVICE = f"cuda:{ddp_local_rank}"
|
||||
torch.cuda.set_device(DEVICE)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--out_dir", type=str, default="../out")
|
||||
parser.add_argument("--epochs", type=int, default=1)
|
||||
parser.add_argument("--batch_size", type=int, default=2)
|
||||
parser.add_argument("--learning_rate", type=float, default=8e-8)
|
||||
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
parser.add_argument("--dtype", type=str, default="bfloat16")
|
||||
parser.add_argument("--use_wandb", action="store_true")
|
||||
parser.add_argument("--wandb_project", type=str, default="MiniMind-GRPO")
|
||||
parser.add_argument("--num_workers", type=int, default=1)
|
||||
parser.add_argument("--ddp", action="store_true")
|
||||
parser.add_argument("--accumulation_steps", type=int, default=1)
|
||||
parser.add_argument("--grad_clip", type=float, default=1.0)
|
||||
parser.add_argument("--log_interval", type=int, default=1)
|
||||
parser.add_argument("--save_interval", type=int, default=10)
|
||||
parser.add_argument('--hidden_size', default=512, type=int)
|
||||
parser.add_argument('--num_hidden_layers', default=8, type=int)
|
||||
parser.add_argument('--use_moe', default=False, type=bool)
|
||||
parser.add_argument('--max_seq_len', default=66, type=int)
|
||||
parser.add_argument("--max_gen_len", type=int, default=1536)
|
||||
parser.add_argument("--data_path", type=str, default="../dataset/rlaif-mini.jsonl")
|
||||
parser.add_argument("--num_generations", type=int, default=8)
|
||||
parser.add_argument("--beta", type=float, default=0.02)
|
||||
parser.add_argument("--reasoning", type=int, default=1, help='0:普通模型,1:推理模型')
|
||||
args = parser.parse_args()
|
||||
|
||||
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers,
|
||||
max_seq_len=args.max_seq_len + args.max_gen_len,
|
||||
use_moe=args.use_moe)
|
||||
args.save_dir = os.path.join(args.out_dir)
|
||||
os.makedirs(args.save_dir, exist_ok=True)
|
||||
os.makedirs(args.out_dir, exist_ok=True)
|
||||
|
||||
ctx = nullcontext() if "cuda" not in args.device else torch.amp.autocast('cuda')
|
||||
ddp = int(os.environ.get("RANK", -1)) != -1
|
||||
ddp_local_rank, DEVICE = 0, "cuda:0"
|
||||
|
||||
base_seed = 1337
|
||||
torch.manual_seed(base_seed)
|
||||
torch.cuda.manual_seed(base_seed)
|
||||
|
||||
if ddp:
|
||||
init_distributed_mode()
|
||||
args.device = torch.device(DEVICE)
|
||||
rank = dist.get_rank()
|
||||
torch.manual_seed(base_seed + rank)
|
||||
# 同时设置 CUDA 的随机种子
|
||||
torch.cuda.manual_seed(base_seed + rank)
|
||||
|
||||
if args.use_wandb and (not ddp or ddp_local_rank == 0):
|
||||
import swanlab as wandb
|
||||
|
||||
wandb.init(project=args.wandb_project)
|
||||
else:
|
||||
wandb = None
|
||||
|
||||
model, ref_model, tokenizer, reward_model, reward_tokenizer = init_model(lm_config=lm_config)
|
||||
reward_tokenizer = AutoTokenizer.from_pretrained(args.reward_model_path, trust_remote_code=True)
|
||||
# 数据和优化器
|
||||
train_ds = RLAIFDataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len)
|
||||
train_sampler = DistributedSampler(train_ds) if ddp else None
|
||||
train_loader = DataLoader(train_ds, batch_size=args.batch_size, pin_memory=True,
|
||||
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
|
||||
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
|
||||
loader_for_count = DataLoader(train_ds, batch_size=args.batch_size, sampler=train_sampler)
|
||||
iters = len(loader_for_count)
|
||||
total_optimizer_steps = (iters // args.accumulation_steps) * args.epochs
|
||||
scheduler = CosineAnnealingLR(optimizer, T_max=total_optimizer_steps, eta_min=args.learning_rate / 10)
|
||||
|
||||
# ========== 6. 从ckp恢复状态 ==========
|
||||
start_epoch, start_step = 0, 0
|
||||
if ckp_data:
|
||||
model.load_state_dict(ckp_data['model'])
|
||||
optimizer.load_state_dict(ckp_data['optimizer'])
|
||||
scheduler.load_state_dict(ckp_data['scheduler'])
|
||||
start_epoch = ckp_data['epoch']
|
||||
start_step = ckp_data.get('step', 0)
|
||||
|
||||
# ========== 7. DDP包模型 ==========
|
||||
if dist.is_initialized():
|
||||
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
|
||||
model = DistributedDataParallel(model, device_ids=[local_rank])
|
||||
|
||||
# ========== 8. 开始训练 ==========
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
train_sampler and train_sampler.set_epoch(epoch)
|
||||
if epoch == start_epoch and start_step > 0: # 第一个epoch且存在检查点
|
||||
batch_sampler = SkipBatchSampler(train_sampler or range(len(train_ds)), args.batch_size, start_step + 1)
|
||||
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
|
||||
Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step,从step {start_step + 1}开始')
|
||||
grpo_train_epoch(epoch, loader, len(loader) + start_step + 1, ref_model, reward_model, reward_tokenizer, start_step, wandb)
|
||||
else: # 默认从头开始
|
||||
loader = DataLoader(train_ds, batch_size=args.batch_size, pin_memory=True,
|
||||
drop_last=False, shuffle=(train_sampler is None),
|
||||
num_workers=args.num_workers, sampler=train_sampler)
|
||||
|
||||
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
|
||||
|
||||
iter_per_epoch = len(train_loader)
|
||||
total_optimizer_steps = (iter_per_epoch // args.accumulation_steps) * args.epochs
|
||||
scheduler = CosineAnnealingLR(optimizer, T_max=total_optimizer_steps, eta_min=args.learning_rate / 10)
|
||||
|
||||
if ddp:
|
||||
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
|
||||
model = DistributedDataParallel(model, device_ids=[ddp_local_rank])
|
||||
|
||||
for epoch in range(args.epochs):
|
||||
train_sampler and train_sampler.set_epoch(epoch)
|
||||
grpo_train_epoch(epoch, wandb)
|
||||
grpo_train_epoch(epoch, loader, len(loader), ref_model, reward_model, reward_tokenizer, 0, wandb)
|
||||
|
||||
@ -6,49 +6,39 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||
|
||||
import argparse
|
||||
import time
|
||||
import math
|
||||
import warnings
|
||||
import torch
|
||||
from torch import optim, nn
|
||||
import torch.distributed as dist
|
||||
from contextlib import nullcontext
|
||||
from torch import optim, nn
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
|
||||
from model.model_minimind import MiniMindConfig
|
||||
from dataset.lm_dataset import SFTDataset
|
||||
from model.model_lora import load_lora, save_lora, apply_lora
|
||||
from model.model_lora import save_lora, apply_lora
|
||||
from trainer.trainer_utils import *
|
||||
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
|
||||
# Logger function
|
||||
def Logger(content):
|
||||
if not ddp or dist.get_rank() == 0:
|
||||
print(content)
|
||||
|
||||
|
||||
def get_lr(current_step, total_steps, lr):
|
||||
return lr / 10 + 0.5 * lr * (1 + math.cos(math.pi * current_step / total_steps))
|
||||
|
||||
|
||||
# 代码和full_sft「几乎」一致
|
||||
def train_epoch(epoch, wandb):
|
||||
def train_epoch(epoch, loader, iters, lora_params, start_step=0, wandb=None):
|
||||
loss_fct = nn.CrossEntropyLoss(reduction='none')
|
||||
start_time = time.time()
|
||||
for step, (X, Y, loss_mask) in enumerate(train_loader):
|
||||
for step, (X, Y, loss_mask) in enumerate(loader, start=start_step + 1):
|
||||
X = X.to(args.device)
|
||||
Y = Y.to(args.device)
|
||||
loss_mask = loss_mask.to(args.device)
|
||||
lr = get_lr(epoch * iter_per_epoch + step, args.epochs * iter_per_epoch, args.learning_rate)
|
||||
lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate)
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = lr
|
||||
|
||||
with ctx:
|
||||
with autocast_ctx:
|
||||
res = model(X)
|
||||
loss = loss_fct(
|
||||
res.logits.view(-1, res.logits.size(-1)),
|
||||
Y.view(-1)
|
||||
).view(Y.size())
|
||||
|
||||
loss = (loss * loss_mask).sum() / loss_mask.sum()
|
||||
loss += res.aux_loss
|
||||
loss = loss / args.accumulation_steps
|
||||
@ -64,146 +54,122 @@ def train_epoch(epoch, wandb):
|
||||
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
if step % args.log_interval == 0 or step == iter_per_epoch - 1:
|
||||
if step % args.log_interval == 0 or step == iters - 1:
|
||||
spend_time = time.time() - start_time
|
||||
Logger(
|
||||
'Epoch:[{}/{}]({}/{}) loss:{:.6f} lr:{:.12f} epoch_Time:{}min:'.format(
|
||||
epoch + 1,
|
||||
args.epochs,
|
||||
step,
|
||||
iter_per_epoch,
|
||||
loss.item() * args.accumulation_steps,
|
||||
optimizer.param_groups[-1]['lr'],
|
||||
spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60))
|
||||
current_loss = loss.item() * args.accumulation_steps
|
||||
current_lr = optimizer.param_groups[-1]['lr']
|
||||
eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60
|
||||
|
||||
Logger(f'Epoch:[{epoch+1}/{args.epochs}]({step}/{iters}) loss:{current_loss:.6f} lr:{current_lr:.12f} epoch_Time:{eta_min}min:')
|
||||
|
||||
if wandb: wandb.log({"loss": current_loss, "lr": current_lr, "epoch_Time": eta_min})
|
||||
|
||||
if (wandb is not None) and (not ddp or dist.get_rank() == 0):
|
||||
wandb.log({"loss": loss.item() * args.accumulation_steps,
|
||||
"lr": optimizer.param_groups[-1]['lr'],
|
||||
"epoch_Time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60})
|
||||
|
||||
if ((step + 1) % args.save_interval == 0 or step == iter_per_epoch - 1) and (not ddp or dist.get_rank() == 0):
|
||||
if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
|
||||
model.eval()
|
||||
lora_save_path = f'{args.save_dir}/lora/{args.lora_name}_{lm_config.hidden_size}.pth'
|
||||
os.makedirs(os.path.dirname(lora_save_path), exist_ok=True)
|
||||
# 【区别1】只保存lora权重即可
|
||||
lora_save_path = f'{args.save_dir}/{args.lora_name}_{lm_config.hidden_size}.pth'
|
||||
# LoRA只保存LoRA权重
|
||||
save_lora(model, lora_save_path)
|
||||
lm_checkpoint(lm_config, weight=args.lora_name, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints')
|
||||
model.train()
|
||||
|
||||
|
||||
def init_model(lm_config):
|
||||
tokenizer = AutoTokenizer.from_pretrained('../model/')
|
||||
model = MiniMindForCausalLM(lm_config)
|
||||
moe_path = '_moe' if lm_config.use_moe else ''
|
||||
ckp = f'{args.save_dir}/full_sft_{lm_config.hidden_size}{moe_path}.pth'
|
||||
state_dict = torch.load(ckp, map_location=args.device)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
return model.to(args.device), tokenizer
|
||||
|
||||
|
||||
def init_distributed_mode():
|
||||
if not ddp: return
|
||||
global ddp_local_rank, DEVICE
|
||||
|
||||
dist.init_process_group(backend="nccl")
|
||||
ddp_rank = int(os.environ["RANK"])
|
||||
ddp_local_rank = int(os.environ["LOCAL_RANK"])
|
||||
ddp_world_size = int(os.environ["WORLD_SIZE"])
|
||||
DEVICE = f"cuda:{ddp_local_rank}"
|
||||
torch.cuda.set_device(DEVICE)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="MiniMind SFT with LoRA")
|
||||
parser.add_argument("--out_dir", type=str, default="../out")
|
||||
parser.add_argument("--epochs", type=int, default=50)
|
||||
parser.add_argument("--batch_size", type=int, default=32)
|
||||
parser.add_argument("--learning_rate", type=float, default=1e-4)
|
||||
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
parser.add_argument("--dtype", type=str, default="bfloat16")
|
||||
parser.add_argument("--use_wandb", action="store_true")
|
||||
parser.add_argument("--wandb_project", type=str, default="MiniMind-LoRA-SFT")
|
||||
parser.add_argument("--num_workers", type=int, default=1)
|
||||
parser.add_argument("--ddp", action="store_true")
|
||||
parser.add_argument("--accumulation_steps", type=int, default=1)
|
||||
parser.add_argument("--grad_clip", type=float, default=1.0)
|
||||
parser.add_argument("--warmup_iters", type=int, default=0)
|
||||
parser.add_argument("--log_interval", type=int, default=10)
|
||||
parser.add_argument("--save_interval", type=int, default=1)
|
||||
parser.add_argument('--local_rank', type=int, default=-1)
|
||||
parser.add_argument('--hidden_size', default=512, type=int)
|
||||
parser.add_argument('--num_hidden_layers', default=8, type=int)
|
||||
parser.add_argument('--max_seq_len', default=512, type=int)
|
||||
parser.add_argument('--use_moe', default=False, type=bool)
|
||||
parser.add_argument("--data_path", type=str, default="../dataset/lora_identity.jsonl")
|
||||
parser.add_argument("--lora_name", type=str, default="lora_identity", help="根据任务保存成lora_(英文/医学/心理...)")
|
||||
parser = argparse.ArgumentParser(description="MiniMind LoRA Fine-tuning")
|
||||
parser.add_argument("--save_dir", type=str, default="../out/lora", help="模型保存目录")
|
||||
parser.add_argument("--lora_name", type=str, default="lora_identity", help="LoRA权重名称(如lora_identity/lora_medical等)")
|
||||
parser.add_argument("--epochs", type=int, default=50, help="训练轮数")
|
||||
parser.add_argument("--batch_size", type=int, default=32, help="batch size")
|
||||
parser.add_argument("--learning_rate", type=float, default=1e-4, help="初始学习率")
|
||||
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备")
|
||||
parser.add_argument("--dtype", type=str, default="bfloat16", help="混合精度类型")
|
||||
parser.add_argument("--num_workers", type=int, default=1, help="数据加载线程数")
|
||||
parser.add_argument("--accumulation_steps", type=int, default=1, help="梯度累积步数")
|
||||
parser.add_argument("--grad_clip", type=float, default=1.0, help="梯度裁剪阈值")
|
||||
parser.add_argument("--log_interval", type=int, default=10, help="日志打印间隔")
|
||||
parser.add_argument("--save_interval", type=int, default=1, help="模型保存间隔")
|
||||
parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度")
|
||||
parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量")
|
||||
parser.add_argument('--max_seq_len', default=512, type=int, help="训练的最大截断长度")
|
||||
parser.add_argument('--use_moe', default=False, type=bool, help="是否使用MoE")
|
||||
parser.add_argument("--data_path", type=str, default="../dataset/lora_identity.jsonl", help="LoRA训练数据路径")
|
||||
parser.add_argument('--from_weight', default='full_sft', type=str, help="基于哪个权重训练,默认full_sft")
|
||||
parser.add_argument('--from_resume', default=0, type=int, help="是否自动检测&续训,0否1是")
|
||||
parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
|
||||
parser.add_argument("--wandb_project", type=str, default="MiniMind-LoRA", help="wandb项目名")
|
||||
args = parser.parse_args()
|
||||
|
||||
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers,
|
||||
use_moe=args.use_moe)
|
||||
args.save_dir = os.path.join(args.out_dir)
|
||||
# ========== 1. 初始化环境和随机种子 ==========
|
||||
local_rank = init_distributed_mode()
|
||||
if dist.is_initialized(): args.device = f"cuda:{local_rank}"
|
||||
setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))
|
||||
|
||||
# ========== 2. 配置目录、模型参数、检查ckp ==========
|
||||
os.makedirs(args.save_dir, exist_ok=True)
|
||||
os.makedirs(args.out_dir, exist_ok=True)
|
||||
tokens_per_iter = args.batch_size * args.max_seq_len
|
||||
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=args.use_moe)
|
||||
ckp_data = lm_checkpoint(lm_config, weight=args.lora_name, save_dir='../checkpoints') if args.from_resume==1 else None
|
||||
|
||||
# ========== 3. 设置混合精度 ==========
|
||||
device_type = "cuda" if "cuda" in args.device else "cpu"
|
||||
|
||||
ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast()
|
||||
ddp = int(os.environ.get("RANK", -1)) != -1 # is this a ddp run?
|
||||
ddp_local_rank, DEVICE = 0, "cuda:0"
|
||||
base_seed = 1337
|
||||
torch.manual_seed(base_seed)
|
||||
torch.cuda.manual_seed(base_seed)
|
||||
|
||||
if ddp:
|
||||
init_distributed_mode()
|
||||
args.device = torch.device(DEVICE)
|
||||
rank = dist.get_rank()
|
||||
torch.manual_seed(base_seed + rank)
|
||||
# 同时设置 CUDA 的随机种子
|
||||
torch.cuda.manual_seed(base_seed + rank)
|
||||
|
||||
args.wandb_run_name = f"MiniMind-Lora-SFT-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"
|
||||
if args.use_wandb and (not ddp or ddp_local_rank == 0):
|
||||
dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
|
||||
autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
|
||||
|
||||
# ========== 4. 配wandb ==========
|
||||
wandb = None
|
||||
if args.use_wandb and is_main_process():
|
||||
import swanlab as wandb
|
||||
|
||||
wandb.init(project=args.wandb_project, name=args.wandb_run_name)
|
||||
else:
|
||||
wandb = None
|
||||
|
||||
model, tokenizer = init_model(lm_config)
|
||||
wandb_id = ckp_data.get('wandb_id') if ckp_data else None
|
||||
resume = 'must' if wandb_id else None
|
||||
wandb_run_name = f"MiniMind-LoRA-{args.lora_name}-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LR-{args.learning_rate}"
|
||||
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
|
||||
|
||||
# ========== 5. 定义模型、应用LoRA、冻结非LoRA参数 ==========
|
||||
model, tokenizer = init_model(lm_config, args.from_weight)
|
||||
apply_lora(model)
|
||||
|
||||
total_params = sum(p.numel() for p in model.parameters()) # 总参数数量
|
||||
lora_params_count = sum(p.numel() for name, p in model.named_parameters() if 'lora' in name) # LoRA 参数数量
|
||||
if not ddp or dist.get_rank() == 0:
|
||||
print(f"LLM 总参数量: {total_params}")
|
||||
print(f"LoRA 参数量: {lora_params_count}")
|
||||
print(f"LoRA 参数占比: {lora_params_count / total_params * 100:.2f}%")
|
||||
|
||||
for name, param in model.named_parameters():
|
||||
if 'lora' not in name:
|
||||
param.requires_grad = False
|
||||
|
||||
# 统计参数
|
||||
total_params = sum(p.numel() for p in model.parameters())
|
||||
lora_params_count = sum(p.numel() for name, p in model.named_parameters() if 'lora' in name)
|
||||
Logger(f"LLM 总参数量: {total_params / 1e6:.3f} M")
|
||||
Logger(f"LoRA 参数量: {lora_params_count / 1e6:.3f} M")
|
||||
Logger(f"LoRA 参数占比: {lora_params_count / total_params * 100:.2f}%")
|
||||
|
||||
# 冻结非LoRA参数,收集LoRA参数
|
||||
lora_params = []
|
||||
for name, param in model.named_parameters():
|
||||
if 'lora' in name:
|
||||
param.requires_grad = True
|
||||
lora_params.append(param)
|
||||
|
||||
# 只对 LoRA 参数进行优化
|
||||
optimizer = optim.AdamW(lora_params, lr=args.learning_rate)
|
||||
else:
|
||||
param.requires_grad = False
|
||||
|
||||
# ========== 6. 定义数据和优化器 ==========
|
||||
train_ds = SFTDataset(args.data_path, tokenizer, max_length=args.max_seq_len)
|
||||
train_sampler = DistributedSampler(train_ds) if ddp else None
|
||||
train_loader = DataLoader(
|
||||
train_ds,
|
||||
batch_size=args.batch_size,
|
||||
pin_memory=True,
|
||||
drop_last=False,
|
||||
shuffle=(train_sampler is None),
|
||||
num_workers=args.num_workers,
|
||||
sampler=train_sampler
|
||||
)
|
||||
|
||||
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype in ['float16', 'bfloat16']))
|
||||
iter_per_epoch = len(train_loader)
|
||||
|
||||
for epoch in range(args.epochs):
|
||||
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
|
||||
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
|
||||
optimizer = optim.AdamW(lora_params, lr=args.learning_rate)
|
||||
|
||||
# ========== 7. 从ckp恢复状态 ==========
|
||||
start_epoch, start_step = 0, 0
|
||||
if ckp_data:
|
||||
model.load_state_dict(ckp_data['model'], strict=False)
|
||||
optimizer.load_state_dict(ckp_data['optimizer'])
|
||||
scaler.load_state_dict(ckp_data['scaler'])
|
||||
start_epoch = ckp_data['epoch']
|
||||
start_step = ckp_data.get('step', 0)
|
||||
|
||||
# ========== 8. DDP包模型 ==========
|
||||
if dist.is_initialized():
|
||||
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
|
||||
model = DistributedDataParallel(model, device_ids=[local_rank])
|
||||
|
||||
# ========== 9. 开始训练 ==========
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
train_sampler and train_sampler.set_epoch(epoch)
|
||||
train_epoch(epoch, wandb)
|
||||
if epoch == start_epoch and start_step > 0: # 第一个epoch且存在检查点
|
||||
batch_sampler = SkipBatchSampler(train_sampler or range(len(train_ds)), args.batch_size, start_step + 1)
|
||||
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
|
||||
Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step,从step {start_step + 1}开始')
|
||||
train_epoch(epoch, loader, len(loader) + start_step + 1, lora_params, start_step, wandb)
|
||||
else: # 默认从头开始
|
||||
loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=(train_sampler is None), sampler=train_sampler, num_workers=args.num_workers, pin_memory=True)
|
||||
train_epoch(epoch, loader, len(loader), lora_params, 0, wandb)
|
||||
|
||||
@ -6,27 +6,43 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||
|
||||
import argparse
|
||||
import re
|
||||
import warnings
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
from transformers import AutoTokenizer
|
||||
from contextlib import nullcontext
|
||||
from torch import optim, nn
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
from transformers import AutoTokenizer, AutoModel
|
||||
from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
|
||||
from dataset.lm_dataset import RLAIFDataset
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
from transformers import AutoModel
|
||||
from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
|
||||
from dataset.lm_dataset import RLAIFDataset
|
||||
from trainer.trainer_utils import *
|
||||
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
|
||||
def Logger(content):
|
||||
if not ddp or dist.get_rank() == 0:
|
||||
print(content)
|
||||
# 自定义的Critic模型,继承自MiniMindLM
|
||||
class CriticModel(MiniMindForCausalLM):
|
||||
def __init__(self, params):
|
||||
super().__init__(params)
|
||||
# 替换lm_head为输出单一价值的线性层
|
||||
self.value_head = nn.Linear(params.hidden_size, 1)
|
||||
|
||||
def forward(self, input_ids=None, attention_mask=None, **kwargs):
|
||||
# 使用基础模型获取隐藏状态
|
||||
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, **kwargs)
|
||||
hidden_states = self.model.norm(outputs[0])
|
||||
# 使用value_head获取价值估计
|
||||
values = self.value_head(hidden_states).squeeze(-1)
|
||||
return values
|
||||
|
||||
|
||||
def calculate_rewards(prompts, responses, reward_model, reward_tokenizer):
|
||||
"""整合所有奖励函数计算总奖励"""
|
||||
|
||||
def reasoning_model_reward(rewards):
|
||||
# 1. 格式奖励(仅针对训练推理模型时使用)
|
||||
pattern = r"^<think>\n.*?\n</think>\n<answer>\n.*?\n</answer>$"
|
||||
@ -66,7 +82,7 @@ def calculate_rewards(prompts, responses, reward_model, reward_tokenizer):
|
||||
|
||||
# 格式奖励
|
||||
if args.reasoning == 1:
|
||||
rewards = reasoning_model_reward(rewards) # 训练推理模型时使用
|
||||
rewards = reasoning_model_reward(rewards)
|
||||
|
||||
# 使用reward model计算整个response的奖励
|
||||
with torch.no_grad():
|
||||
@ -91,7 +107,6 @@ def calculate_rewards(prompts, responses, reward_model, reward_tokenizer):
|
||||
tmp_chat = messages + [{"role": "assistant", "content": answer_content}]
|
||||
answer_score = reward_model.get_score(reward_tokenizer, tmp_chat)
|
||||
answer_score = max(min(answer_score, scale), -scale)
|
||||
|
||||
score = score * 0.4 + answer_score * 0.6
|
||||
reward_model_scores.append(score)
|
||||
|
||||
@ -101,19 +116,20 @@ def calculate_rewards(prompts, responses, reward_model, reward_tokenizer):
|
||||
return rewards
|
||||
|
||||
|
||||
def ppo_train_epoch(epoch: int, wandb_run, old_actor_model, ref_model, actor_scheduler, critic_scheduler):
|
||||
def ppo_train_epoch(epoch, loader, iters, old_actor_model, ref_model, actor_scheduler, critic_scheduler, reward_model, reward_tokenizer, start_step=0, wandb=None):
|
||||
actor_model.train()
|
||||
critic_model.train()
|
||||
is_master = (not ddp) or dist.get_rank() == 0
|
||||
|
||||
for step, batch in enumerate(train_loader):
|
||||
for step, batch in enumerate(loader, start=start_step + 1):
|
||||
prompts = batch["prompt"] # list[str], length B
|
||||
enc = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True,
|
||||
max_length=args.max_seq_len).to(args.device) # input_ids: [B, P], attention_mask: [B, P]
|
||||
prompt_lengths = enc.attention_mask.sum(dim=1) # [B]
|
||||
|
||||
with torch.no_grad():
|
||||
gen_out = actor_model.generate(
|
||||
# DDP 模型需要使用 .module 访问 generate 方法
|
||||
model_for_gen = actor_model.module if isinstance(actor_model, DistributedDataParallel) else actor_model
|
||||
gen_out = model_for_gen.generate(
|
||||
input_ids=enc.input_ids, attention_mask=enc.attention_mask,
|
||||
max_new_tokens=args.max_gen_len, do_sample=True, temperature=0.8,
|
||||
pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id) # [B, P+R]
|
||||
@ -164,7 +180,7 @@ def ppo_train_epoch(epoch: int, wandb_run, old_actor_model, ref_model, actor_sch
|
||||
actor_optimizer.zero_grad()
|
||||
critic_optimizer.zero_grad()
|
||||
|
||||
if is_master:
|
||||
if is_main_process():
|
||||
response_ids = gen_out[:, enc.input_ids.shape[1]:]
|
||||
is_eos = (response_ids == tokenizer.eos_token_id)
|
||||
eos_indices = torch.argmax(is_eos.int(), dim=1)
|
||||
@ -181,8 +197,8 @@ def ppo_train_epoch(epoch: int, wandb_run, old_actor_model, ref_model, actor_sch
|
||||
actor_lr = actor_optimizer.param_groups[0]['lr']
|
||||
critic_lr = critic_optimizer.param_groups[0]['lr']
|
||||
|
||||
if wandb_run is not None:
|
||||
wandb_run.log({
|
||||
if wandb is not None:
|
||||
wandb.log({
|
||||
"actor_loss": actor_loss_val,
|
||||
"critic_loss": critic_loss_val,
|
||||
"reward": reward_val,
|
||||
@ -192,183 +208,158 @@ def ppo_train_epoch(epoch: int, wandb_run, old_actor_model, ref_model, actor_sch
|
||||
"actor_lr": actor_lr,
|
||||
})
|
||||
|
||||
Logger(f"Epoch: {epoch}, Step: {step + 1}/{len(train_loader)}, "
|
||||
Logger(f"Epoch: {epoch+1}, Step: {step}/{iters}, "
|
||||
f"Actor Loss: {actor_loss_val:.6f}, Critic Loss: {critic_loss_val:.6f}, "
|
||||
f"Reward: {reward_val:.6f}, KL: {kl_val:.6f}, KL_ref: {kl_ref_val:.6f}, "
|
||||
f"Avg Response Len: {avg_len_val:.2f}, Actor LR: {actor_lr:.2e}, Critic LR: {critic_lr:.2e}")
|
||||
|
||||
if (step + 1) % args.update_old_actor_freq == 0:
|
||||
state_dict = actor_model.module.state_dict() if isinstance(actor_model, torch.nn.parallel.DistributedDataParallel) else actor_model.state_dict()
|
||||
state_dict = actor_model.module.state_dict() if isinstance(actor_model, DistributedDataParallel) else actor_model.state_dict()
|
||||
old_actor_model.load_state_dict({k: v.detach().cpu() for k, v in state_dict.items()})
|
||||
old_actor_model.to(args.device)
|
||||
|
||||
if ((step + 1) % args.save_interval == 0 or step == iter_per_epoch - 1) and (not ddp or dist.get_rank() == 0):
|
||||
if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
|
||||
actor_model.eval()
|
||||
moe_path = '_moe' if lm_config.use_moe else ''
|
||||
ckp = f'{args.save_dir}/ppo_actor_{lm_config.hidden_size}{moe_path}.pth'
|
||||
|
||||
if isinstance(actor_model, torch.nn.parallel.DistributedDataParallel):
|
||||
state_dict = actor_model.module.state_dict()
|
||||
else:
|
||||
state_dict = actor_model.state_dict()
|
||||
|
||||
state_dict = {k: v.half() for k, v in state_dict.items()} # 半精度保存
|
||||
torch.save(state_dict, ckp)
|
||||
moe_suffix = '_moe' if lm_config.use_moe else ''
|
||||
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
|
||||
actor_state = actor_model.module.state_dict() if isinstance(actor_model, DistributedDataParallel) else actor_model.state_dict()
|
||||
torch.save({k: v.half() for k, v in actor_state.items()}, ckp)
|
||||
|
||||
# 使用 lm_checkpoint 保存完整状态(包括 critic)
|
||||
lm_checkpoint(lm_config, weight=args.save_weight, model=actor_model, optimizer=actor_optimizer,
|
||||
epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints',
|
||||
scheduler=actor_scheduler, critic_model=critic_model,
|
||||
critic_optimizer=critic_optimizer, critic_scheduler=critic_scheduler)
|
||||
actor_model.train()
|
||||
|
||||
|
||||
# 自定义的Critic模型,继承自MiniMindLM
|
||||
class CriticModel(MiniMindForCausalLM):
|
||||
def __init__(self, params):
|
||||
super().__init__(params)
|
||||
# 替换lm_head为输出单一价值的线性层
|
||||
self.value_head = nn.Linear(params.hidden_size, 1)
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="MiniMind PPO (Proximal Policy Optimization)")
|
||||
parser.add_argument("--save_dir", type=str, default="../out", help="模型保存目录")
|
||||
parser.add_argument('--save_weight', default='ppo_actor', type=str, help="保存权重的前缀名")
|
||||
parser.add_argument("--epochs", type=int, default=1, help="训练轮数")
|
||||
parser.add_argument("--batch_size", type=int, default=2, help="batch size")
|
||||
parser.add_argument("--learning_rate", type=float, default=8e-8, help="Actor学习率")
|
||||
parser.add_argument("--critic_learning_rate", type=float, default=8e-8, help="Critic学习率")
|
||||
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备")
|
||||
parser.add_argument("--dtype", type=str, default="bfloat16", help="混合精度类型")
|
||||
parser.add_argument("--num_workers", type=int, default=1, help="数据加载线程数")
|
||||
parser.add_argument("--accumulation_steps", type=int, default=1, help="梯度累积步数")
|
||||
parser.add_argument("--grad_clip", type=float, default=1.0, help="梯度裁剪阈值")
|
||||
parser.add_argument("--log_interval", type=int, default=1, help="日志打印间隔")
|
||||
parser.add_argument("--save_interval", type=int, default=10, help="模型保存间隔")
|
||||
parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度")
|
||||
parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量")
|
||||
parser.add_argument('--use_moe', default=False, type=bool, help="是否使用MoE")
|
||||
parser.add_argument('--max_seq_len', default=66, type=int, help="Prompt最大长度")
|
||||
parser.add_argument("--max_gen_len", type=int, default=1536, help="生成的最大长度")
|
||||
parser.add_argument("--data_path", type=str, default="../dataset/rlaif-mini.jsonl", help="RLAIF数据路径")
|
||||
parser.add_argument("--clip_epsilon", type=float, default=0.1, help="PPO裁剪参数")
|
||||
parser.add_argument("--vf_coef", type=float, default=0.5, help="Value function系数")
|
||||
parser.add_argument("--kl_coef", type=float, default=0.02, help="KL散度惩罚系数")
|
||||
parser.add_argument("--reasoning", type=int, default=1, help='0:普通模型,1:推理模型')
|
||||
parser.add_argument("--update_old_actor_freq", type=int, default=4, help="更新old_actor_model的频率")
|
||||
parser.add_argument("--reward_model_path", type=str, default="../../internlm2-1_8b-reward", help="Reward模型路径")
|
||||
parser.add_argument('--from_resume', default=0, type=int, help="是否自动检测&续训,0否1是")
|
||||
parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
|
||||
parser.add_argument("--wandb_project", type=str, default="MiniMind-PPO", help="wandb项目名")
|
||||
args = parser.parse_args()
|
||||
|
||||
def forward(self, input_ids=None, attention_mask=None, **kwargs):
|
||||
# 使用基础模型获取隐藏状态
|
||||
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, **kwargs)
|
||||
# self.model 返回的是一个元组,第一个元素是 last_hidden_state
|
||||
hidden_states = self.model.norm(outputs[0])
|
||||
# 使用value_head获取价值估计
|
||||
values = self.value_head(hidden_states).squeeze(-1)
|
||||
return values
|
||||
|
||||
|
||||
def init_model(lm_config):
|
||||
# ========== 1. 初始化环境和随机种子 ==========
|
||||
local_rank = init_distributed_mode()
|
||||
if dist.is_initialized(): args.device = f"cuda:{local_rank}"
|
||||
setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))
|
||||
|
||||
# ========== 2. 配置目录、模型参数、检查ckp ==========
|
||||
os.makedirs(args.save_dir, exist_ok=True)
|
||||
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=args.use_moe)
|
||||
ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None
|
||||
|
||||
# ========== 3. 设置混合精度 ==========
|
||||
device_type = "cuda" if "cuda" in args.device else "cpu"
|
||||
dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
|
||||
autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
|
||||
|
||||
# ========== 4. 配wandb ==========
|
||||
wandb = None
|
||||
if args.use_wandb and is_main_process():
|
||||
import swanlab as wandb
|
||||
wandb_id = ckp_data.get('wandb_id') if ckp_data else None
|
||||
resume = 'must' if wandb_id else None
|
||||
wandb_run_name = f"MiniMind-PPO-Epoch-{args.epochs}-BS-{args.batch_size}-LR-{args.learning_rate}"
|
||||
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
|
||||
|
||||
# ========== 5. 初始化模型和数据 ==========
|
||||
tokenizer = AutoTokenizer.from_pretrained('../model/', padding_side='left')
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
moe_path = '_moe' if lm_config.use_moe else ''
|
||||
ckp = f'{args.save_dir}/{"reason" if args.reasoning == 1 else "full_sft"}_{lm_config.hidden_size}{moe_path}.pth'
|
||||
moe_suffix = '_moe' if lm_config.use_moe else ''
|
||||
base_weight = "reason" if args.reasoning == 1 else "full_sft"
|
||||
ckp = f'{args.save_dir}/{base_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
|
||||
state_dict = torch.load(ckp, map_location=args.device)
|
||||
|
||||
# Actor模型
|
||||
actor_model = MiniMindForCausalLM(lm_config)
|
||||
actor_model.load_state_dict(state_dict, strict=False)
|
||||
actor_model = actor_model.to(args.device)
|
||||
|
||||
Logger(f'Actor模型总参数量:{sum(p.numel() for p in actor_model.parameters() if p.requires_grad) / 1e6:.3f} M')
|
||||
# Old Actor模型
|
||||
old_actor_model = MiniMindForCausalLM(lm_config)
|
||||
old_actor_model.load_state_dict(state_dict, strict=False)
|
||||
old_actor_model = old_actor_model.eval().requires_grad_(False).to(args.device)
|
||||
|
||||
# Reference模型
|
||||
ref_model = MiniMindForCausalLM(lm_config)
|
||||
ref_model.load_state_dict(state_dict, strict=False)
|
||||
ref_model = ref_model.eval().requires_grad_(False).to(args.device)
|
||||
|
||||
# Critic模型
|
||||
critic_model = CriticModel(lm_config)
|
||||
critic_model.load_state_dict(state_dict, strict=False)
|
||||
critic_model = critic_model.to(args.device)
|
||||
|
||||
reward_name = "../../internlm2-1_8b-reward"
|
||||
Logger(f'Critic模型总参数量:{sum(p.numel() for p in critic_model.parameters() if p.requires_grad) / 1e6:.3f} M')
|
||||
# Reward模型
|
||||
reward_model = AutoModel.from_pretrained(
|
||||
reward_name, device_map="cuda", torch_dtype=torch.float32, trust_remote_code=True
|
||||
args.reward_model_path, device_map="cuda", torch_dtype=torch.float32, trust_remote_code=True
|
||||
).to(args.device).eval().requires_grad_(False)
|
||||
reward_tokenizer = AutoTokenizer.from_pretrained(reward_name, trust_remote_code=True)
|
||||
|
||||
Logger(f'Actor模型总参数量:{sum(p.numel() for p in actor_model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
|
||||
Logger(f'Critic模型总参数量:{sum(p.numel() for p in critic_model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
|
||||
|
||||
return actor_model, old_actor_model, ref_model, critic_model, reward_model, tokenizer, reward_tokenizer
|
||||
|
||||
|
||||
def init_distributed_mode():
|
||||
if not ddp: return
|
||||
global ddp_local_rank, DEVICE
|
||||
dist.init_process_group(backend="nccl")
|
||||
ddp_local_rank = int(os.environ["LOCAL_RANK"])
|
||||
DEVICE = f"cuda:{ddp_local_rank}"
|
||||
torch.cuda.set_device(DEVICE)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--out_dir", type=str, default="../out")
|
||||
parser.add_argument("--epochs", type=int, default=1)
|
||||
parser.add_argument("--batch_size", type=int, default=2)
|
||||
parser.add_argument("--learning_rate", type=float, default=8e-8)
|
||||
parser.add_argument("--critic_learning_rate", type=float, default=8e-8)
|
||||
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
parser.add_argument("--dtype", type=str, default="bfloat16")
|
||||
parser.add_argument("--use_wandb", action="store_true")
|
||||
parser.add_argument("--wandb_project", type=str, default="MiniMind-PPO")
|
||||
parser.add_argument("--num_workers", type=int, default=1)
|
||||
parser.add_argument("--ddp", action="store_true")
|
||||
parser.add_argument("--accumulation_steps", type=int, default=1)
|
||||
parser.add_argument("--grad_clip", type=float, default=1.0)
|
||||
parser.add_argument("--log_interval", type=int, default=1)
|
||||
parser.add_argument("--save_interval", type=int, default=10)
|
||||
parser.add_argument('--hidden_size', default=512, type=int)
|
||||
parser.add_argument('--num_hidden_layers', default=8, type=int)
|
||||
parser.add_argument('--use_moe', default=False, type=bool)
|
||||
parser.add_argument('--max_seq_len', default=66, type=int)
|
||||
parser.add_argument("--max_gen_len", type=int, default=1536)
|
||||
parser.add_argument("--data_path", type=str, default="../dataset/rlaif-mini.jsonl")
|
||||
parser.add_argument("--clip_epsilon", type=float, default=0.1)
|
||||
parser.add_argument("--vf_coef", type=float, default=0.5)
|
||||
parser.add_argument("--kl_coef", type=float, default=0.02, help="KL散度惩罚系数")
|
||||
parser.add_argument("--reasoning", type=int, default=1, help='0:普通模型,1:推理模型')
|
||||
parser.add_argument("--update_old_actor_freq", type=int, default=4, help="频率:每处理n个batch后更新old_actor_model")
|
||||
args = parser.parse_args()
|
||||
|
||||
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers,
|
||||
use_moe=args.use_moe)
|
||||
args.save_dir = os.path.join(args.out_dir)
|
||||
os.makedirs(args.save_dir, exist_ok=True)
|
||||
os.makedirs(args.out_dir, exist_ok=True)
|
||||
|
||||
ddp = int(os.environ.get("RANK", -1)) != -1
|
||||
ddp_local_rank, DEVICE = 0, "cuda:0"
|
||||
base_seed = 1337
|
||||
torch.manual_seed(base_seed)
|
||||
torch.cuda.manual_seed(base_seed)
|
||||
|
||||
if ddp:
|
||||
init_distributed_mode()
|
||||
args.device = torch.device(DEVICE)
|
||||
rank = dist.get_rank()
|
||||
torch.manual_seed(base_seed + rank)
|
||||
# 同时设置 CUDA 的随机种子
|
||||
torch.cuda.manual_seed(base_seed + rank)
|
||||
|
||||
if args.use_wandb and (not ddp or ddp_local_rank == 0):
|
||||
import swanlab as wandb
|
||||
|
||||
wandb.init(project=args.wandb_project)
|
||||
else:
|
||||
wandb = None
|
||||
|
||||
# 初始化所有模型
|
||||
actor_model, old_actor_model, ref_model, critic_model, reward_model, tokenizer, reward_tokenizer = init_model(lm_config=lm_config)
|
||||
|
||||
# 准备数据集和数据加载器
|
||||
reward_tokenizer = AutoTokenizer.from_pretrained(args.reward_model_path, trust_remote_code=True)
|
||||
# 数据和优化器
|
||||
train_ds = RLAIFDataset(args.data_path, tokenizer, max_length=(args.max_seq_len + args.max_gen_len))
|
||||
train_sampler = DistributedSampler(train_ds) if ddp else None
|
||||
train_loader = DataLoader(train_ds, batch_size=args.batch_size, pin_memory=True,
|
||||
drop_last=False, shuffle=(train_sampler is None),
|
||||
num_workers=args.num_workers, sampler=train_sampler)
|
||||
|
||||
# 初始化优化器
|
||||
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
|
||||
actor_optimizer = optim.AdamW(actor_model.parameters(), lr=args.learning_rate)
|
||||
critic_optimizer = optim.AdamW(critic_model.parameters(), lr=args.critic_learning_rate)
|
||||
|
||||
iter_per_epoch = len(train_loader)
|
||||
total_optimizer_steps = (iter_per_epoch // args.accumulation_steps) * args.epochs
|
||||
loader_for_count = DataLoader(train_ds, batch_size=args.batch_size, sampler=train_sampler)
|
||||
iters = len(loader_for_count)
|
||||
total_optimizer_steps = (iters // args.accumulation_steps) * args.epochs
|
||||
actor_scheduler = CosineAnnealingLR(actor_optimizer, T_max=total_optimizer_steps, eta_min=args.learning_rate / 10)
|
||||
critic_scheduler = CosineAnnealingLR(critic_optimizer, T_max=total_optimizer_steps,
|
||||
eta_min=args.critic_learning_rate / 10)
|
||||
|
||||
# 如果使用分布式训练,包装模型
|
||||
if ddp:
|
||||
critic_scheduler = CosineAnnealingLR(critic_optimizer, T_max=total_optimizer_steps, eta_min=args.critic_learning_rate / 10)
|
||||
|
||||
# ========== 6. 从ckp恢复状态 ==========
|
||||
start_epoch, start_step = 0, 0
|
||||
if ckp_data:
|
||||
actor_model.load_state_dict(ckp_data['model'])
|
||||
critic_model.load_state_dict(ckp_data['critic_model'])
|
||||
actor_optimizer.load_state_dict(ckp_data['optimizer'])
|
||||
critic_optimizer.load_state_dict(ckp_data['critic_optimizer'])
|
||||
actor_scheduler.load_state_dict(ckp_data['scheduler'])
|
||||
critic_scheduler.load_state_dict(ckp_data['critic_scheduler'])
|
||||
start_epoch = ckp_data['epoch']
|
||||
start_step = ckp_data.get('step', 0)
|
||||
|
||||
# ========== 7. DDP包模型 ==========
|
||||
if dist.is_initialized():
|
||||
actor_model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
|
||||
critic_model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
|
||||
actor_model = DistributedDataParallel(actor_model, device_ids=[ddp_local_rank])
|
||||
critic_model = DistributedDataParallel(critic_model, device_ids=[ddp_local_rank])
|
||||
# old_actor_model 不需要DDP包装,因为它只在主进程上用于计算,并且不进行梯度更新
|
||||
actor_model = DistributedDataParallel(actor_model, device_ids=[local_rank])
|
||||
critic_model = DistributedDataParallel(critic_model, device_ids=[local_rank])
|
||||
old_actor_model.to(args.device)
|
||||
|
||||
for epoch in range(args.epochs):
|
||||
|
||||
# ========== 8. 开始训练 ==========
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
train_sampler and train_sampler.set_epoch(epoch)
|
||||
ppo_train_epoch(epoch, wandb, old_actor_model, ref_model, actor_scheduler, critic_scheduler)
|
||||
|
||||
if ddp:
|
||||
dist.destroy_process_group()
|
||||
if epoch == start_epoch and start_step > 0: # 第一个epoch且存在检查点
|
||||
batch_sampler = SkipBatchSampler(train_sampler or range(len(train_ds)), args.batch_size, start_step + 1)
|
||||
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
|
||||
Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step,从step {start_step + 1}开始')
|
||||
ppo_train_epoch(epoch, loader, len(loader) + start_step + 1, old_actor_model, ref_model,
|
||||
actor_scheduler, critic_scheduler, reward_model, reward_tokenizer, start_step, wandb)
|
||||
else: # 默认从头开始
|
||||
loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=(train_sampler is None),
|
||||
sampler=train_sampler, num_workers=args.num_workers, pin_memory=True)
|
||||
ppo_train_epoch(epoch, loader, len(loader), old_actor_model, ref_model,
|
||||
actor_scheduler, critic_scheduler, reward_model, reward_tokenizer, 0, wandb)
|
||||
|
||||
@ -6,48 +6,38 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||
|
||||
import argparse
|
||||
import time
|
||||
import math
|
||||
import warnings
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from contextlib import nullcontext
|
||||
from torch import optim, nn
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
from contextlib import nullcontext
|
||||
from transformers import AutoTokenizer
|
||||
from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
|
||||
from model.model_minimind import MiniMindConfig
|
||||
from dataset.lm_dataset import PretrainDataset
|
||||
from trainer.trainer_utils import *
|
||||
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
|
||||
def Logger(content):
|
||||
if not ddp or dist.get_rank() == 0:
|
||||
print(content)
|
||||
|
||||
|
||||
def get_lr(current_step, total_steps, lr):
|
||||
return lr / 10 + 0.5 * lr * (1 + math.cos(math.pi * current_step / total_steps))
|
||||
|
||||
|
||||
def train_epoch(epoch, wandb):
|
||||
def train_epoch(epoch, loader, iters, start_step=0, wandb=None):
|
||||
loss_fct = nn.CrossEntropyLoss(reduction='none')
|
||||
start_time = time.time()
|
||||
for step, (X, Y, loss_mask) in enumerate(train_loader):
|
||||
for step, (X, Y, loss_mask) in enumerate(loader, start=start_step + 1):
|
||||
X = X.to(args.device)
|
||||
Y = Y.to(args.device)
|
||||
loss_mask = loss_mask.to(args.device)
|
||||
|
||||
lr = get_lr(epoch * iter_per_epoch + step, args.epochs * iter_per_epoch, args.learning_rate)
|
||||
lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate)
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = lr
|
||||
|
||||
with ctx:
|
||||
with autocast_ctx:
|
||||
res = model(X)
|
||||
loss = loss_fct(
|
||||
res.logits.view(-1, res.logits.size(-1)),
|
||||
Y.view(-1)
|
||||
).view(Y.size())
|
||||
|
||||
loss = (loss * loss_mask).sum() / loss_mask.sum()
|
||||
loss += res.aux_loss
|
||||
loss = loss / args.accumulation_steps
|
||||
@ -63,139 +53,108 @@ def train_epoch(epoch, wandb):
|
||||
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
if step % args.log_interval == 0 or step == iter_per_epoch - 1:
|
||||
if step % args.log_interval == 0 or step == iters - 1:
|
||||
spend_time = time.time() - start_time
|
||||
Logger(
|
||||
'Epoch:[{}/{}]({}/{}) loss:{:.6f} lr:{:.12f} epoch_Time:{}min:'.format(
|
||||
epoch + 1,
|
||||
args.epochs,
|
||||
step,
|
||||
iter_per_epoch,
|
||||
loss.item() * args.accumulation_steps,
|
||||
optimizer.param_groups[-1]['lr'],
|
||||
spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60))
|
||||
current_loss = loss.item() * args.accumulation_steps
|
||||
current_lr = optimizer.param_groups[-1]['lr']
|
||||
eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60
|
||||
|
||||
Logger(f'Epoch:[{epoch+1}/{args.epochs}]({step}/{iters}) loss:{current_loss:.6f} lr:{current_lr:.12f} epoch_Time:{eta_min}min:')
|
||||
|
||||
if wandb: wandb.log({"loss": current_loss, "lr": current_lr, "epoch_Time": eta_min})
|
||||
|
||||
if (wandb is not None) and (not ddp or dist.get_rank() == 0):
|
||||
wandb.log({"loss": loss.item() * args.accumulation_steps,
|
||||
"lr": optimizer.param_groups[-1]['lr'],
|
||||
"epoch_Time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60})
|
||||
|
||||
if ((step + 1) % args.save_interval == 0 or step == iter_per_epoch - 1) and (not ddp or dist.get_rank() == 0):
|
||||
if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
|
||||
model.eval()
|
||||
moe_path = '_moe' if lm_config.use_moe else ''
|
||||
ckp = f'{args.save_dir}/pretrain_{lm_config.hidden_size}{moe_path}.pth'
|
||||
|
||||
moe_suffix = '_moe' if lm_config.use_moe else ''
|
||||
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
|
||||
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
|
||||
state_dict = model.module.state_dict()
|
||||
else:
|
||||
state_dict = model.state_dict()
|
||||
|
||||
state_dict = {k: v.half() for k, v in state_dict.items()} # 半精度保存
|
||||
torch.save(state_dict, ckp)
|
||||
lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints')
|
||||
model.train()
|
||||
|
||||
|
||||
def init_model(lm_config):
|
||||
tokenizer = AutoTokenizer.from_pretrained('../model/')
|
||||
model = MiniMindForCausalLM(lm_config).to(args.device)
|
||||
Logger(f'LLM可训练总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def init_distributed_mode():
|
||||
if not ddp: return
|
||||
global ddp_local_rank, DEVICE
|
||||
|
||||
dist.init_process_group(backend="nccl")
|
||||
ddp_rank = int(os.environ["RANK"])
|
||||
ddp_local_rank = int(os.environ["LOCAL_RANK"])
|
||||
ddp_world_size = int(os.environ["WORLD_SIZE"])
|
||||
DEVICE = f"cuda:{ddp_local_rank}"
|
||||
torch.cuda.set_device(DEVICE)
|
||||
|
||||
|
||||
# torchrun --nproc_per_node 2 1-pretrain.py
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="MiniMind Pretraining")
|
||||
parser.add_argument("--out_dir", type=str, default="../out")
|
||||
# 若要以最快速度实现zero则epochs设置为1轮;否则应当利用有限的数据训练2~6个epochs。
|
||||
parser.add_argument("--epochs", type=int, default=1)
|
||||
parser.add_argument("--batch_size", type=int, default=32)
|
||||
parser.add_argument("--learning_rate", type=float, default=5e-4)
|
||||
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
parser.add_argument("--dtype", type=str, default="bfloat16")
|
||||
parser.add_argument("--use_wandb", action="store_true")
|
||||
parser.add_argument("--wandb_project", type=str, default="MiniMind-Pretrain")
|
||||
parser.add_argument("--num_workers", type=int, default=1)
|
||||
parser.add_argument("--ddp", action="store_true")
|
||||
parser.add_argument("--accumulation_steps", type=int, default=8)
|
||||
parser.add_argument("--grad_clip", type=float, default=1.0)
|
||||
parser.add_argument("--warmup_iters", type=int, default=0)
|
||||
parser.add_argument("--log_interval", type=int, default=100)
|
||||
parser.add_argument("--save_interval", type=int, default=100)
|
||||
parser.add_argument('--local_rank', type=int, default=-1)
|
||||
parser.add_argument('--hidden_size', default=512, type=int)
|
||||
parser.add_argument('--num_hidden_layers', default=8, type=int)
|
||||
parser.add_argument('--max_seq_len', default=512, type=int)
|
||||
parser.add_argument('--use_moe', default=False, type=bool)
|
||||
parser.add_argument("--data_path", type=str, default="../dataset/pretrain_hq.jsonl")
|
||||
parser.add_argument("--save_dir", type=str, default="../out", help="模型保存目录")
|
||||
parser.add_argument('--save_weight', default='pretrain', type=str, help="保存权重的前缀名")
|
||||
parser.add_argument("--epochs", type=int, default=1, help="训练轮数(建议1轮zero或2-6轮充分训练)")
|
||||
parser.add_argument("--batch_size", type=int, default=32, help="batch size")
|
||||
parser.add_argument("--learning_rate", type=float, default=5e-4, help="初始学习率")
|
||||
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备")
|
||||
parser.add_argument("--dtype", type=str, default="bfloat16", help="混合精度类型")
|
||||
parser.add_argument("--num_workers", type=int, default=1, help="数据加载线程数")
|
||||
parser.add_argument("--accumulation_steps", type=int, default=8, help="梯度累积步数")
|
||||
parser.add_argument("--grad_clip", type=float, default=1.0, help="梯度裁剪阈值")
|
||||
parser.add_argument("--log_interval", type=int, default=100, help="日志打印间隔")
|
||||
parser.add_argument("--save_interval", type=int, default=100, help="模型保存间隔")
|
||||
parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度")
|
||||
parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量")
|
||||
parser.add_argument('--max_seq_len', default=512, type=int, help="训练的最大截断长度")
|
||||
parser.add_argument('--use_moe', default=False, type=bool, help="是否使用MoE")
|
||||
parser.add_argument("--data_path", type=str, default="../dataset/pretrain_hq.jsonl", help="预训练数据路径")
|
||||
parser.add_argument('--from_weight', default='none', type=str, help="基于哪个权重训练,为none则从头开始")
|
||||
parser.add_argument('--from_resume', default=0, type=int, help="是否自动检测&续训,0否1是")
|
||||
parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
|
||||
parser.add_argument("--wandb_project", type=str, default="MiniMind-Pretrain", help="wandb项目名")
|
||||
args = parser.parse_args()
|
||||
|
||||
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers,
|
||||
use_moe=args.use_moe)
|
||||
args.save_dir = os.path.join(args.out_dir)
|
||||
# ========== 1. 初始化环境和随机种子 ==========
|
||||
local_rank = init_distributed_mode()
|
||||
if dist.is_initialized(): args.device = f"cuda:{local_rank}"
|
||||
setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))
|
||||
|
||||
# ========== 2. 配置目录、模型参数、检查ckp ==========
|
||||
os.makedirs(args.save_dir, exist_ok=True)
|
||||
os.makedirs(args.out_dir, exist_ok=True)
|
||||
tokens_per_iter = args.batch_size * args.max_seq_len
|
||||
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=args.use_moe)
|
||||
ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None
|
||||
|
||||
# ========== 3. 设置混合精度 ==========
|
||||
device_type = "cuda" if "cuda" in args.device else "cpu"
|
||||
|
||||
args.wandb_run_name = f"MiniMind-Pretrain-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"
|
||||
|
||||
ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast()
|
||||
|
||||
ddp = int(os.environ.get("RANK", -1)) != -1 # is this a ddp run?
|
||||
ddp_local_rank, DEVICE = 0, "cuda:0"
|
||||
|
||||
base_seed = 1337
|
||||
torch.manual_seed(base_seed)
|
||||
torch.cuda.manual_seed(base_seed)
|
||||
|
||||
if ddp:
|
||||
init_distributed_mode()
|
||||
args.device = torch.device(DEVICE)
|
||||
rank = dist.get_rank()
|
||||
torch.manual_seed(base_seed + rank)
|
||||
# 同时设置 CUDA 的随机种子
|
||||
torch.cuda.manual_seed(base_seed + rank)
|
||||
|
||||
if args.use_wandb and (not ddp or ddp_local_rank == 0):
|
||||
dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
|
||||
autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
|
||||
|
||||
# ========== 4. 配wandb ==========
|
||||
wandb = None
|
||||
if args.use_wandb and is_main_process():
|
||||
import swanlab as wandb
|
||||
|
||||
wandb.init(project=args.wandb_project, name=args.wandb_run_name)
|
||||
else:
|
||||
wandb = None
|
||||
|
||||
model, tokenizer = init_model(lm_config)
|
||||
wandb_id = ckp_data.get('wandb_id') if ckp_data else None
|
||||
resume = 'must' if wandb_id else None
|
||||
wandb_run_name = f"MiniMind-Pretrain-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"
|
||||
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
|
||||
|
||||
# ========== 5. 定义模型、数据、优化器 ==========
|
||||
model, tokenizer = init_model(lm_config, args.from_weight)
|
||||
train_ds = PretrainDataset(args.data_path, tokenizer, max_length=args.max_seq_len)
|
||||
train_sampler = DistributedSampler(train_ds) if ddp else None
|
||||
train_loader = DataLoader(
|
||||
train_ds,
|
||||
batch_size=args.batch_size,
|
||||
pin_memory=True,
|
||||
drop_last=False,
|
||||
shuffle=(train_sampler is None),
|
||||
num_workers=args.num_workers,
|
||||
sampler=train_sampler
|
||||
)
|
||||
|
||||
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype in ['float16', 'bfloat16']))
|
||||
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
|
||||
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
|
||||
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
|
||||
|
||||
if ddp:
|
||||
|
||||
# ========== 6. 从ckp恢复状态 ==========
|
||||
start_epoch, start_step = 0, 0
|
||||
if ckp_data:
|
||||
model.load_state_dict(ckp_data['model'])
|
||||
optimizer.load_state_dict(ckp_data['optimizer'])
|
||||
scaler.load_state_dict(ckp_data['scaler'])
|
||||
start_epoch = ckp_data['epoch']
|
||||
start_step = ckp_data.get('step', 0)
|
||||
|
||||
# ========== 7. DDP包模型 ==========
|
||||
if dist.is_initialized():
|
||||
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
|
||||
model = DistributedDataParallel(model, device_ids=[ddp_local_rank])
|
||||
|
||||
iter_per_epoch = len(train_loader)
|
||||
for epoch in range(args.epochs):
|
||||
model = DistributedDataParallel(model, device_ids=[local_rank])
|
||||
|
||||
# ========== 8. 开始训练 ==========
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
train_sampler and train_sampler.set_epoch(epoch)
|
||||
train_epoch(epoch, wandb)
|
||||
if epoch == start_epoch and start_step > 0: # 第一个epoch且存在检查点
|
||||
batch_sampler = SkipBatchSampler(train_sampler or range(len(train_ds)), args.batch_size, start_step + 1)
|
||||
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
|
||||
Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step,从step {start_step + 1}开始')
|
||||
train_epoch(epoch, loader, len(loader) + start_step + 1, start_step, wandb)
|
||||
else: # 默认从头开始
|
||||
loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=(train_sampler is None), sampler=train_sampler, num_workers=args.num_workers, pin_memory=True)
|
||||
train_epoch(epoch, loader, len(loader), 0, wandb)
|
||||
|
||||
@ -3,31 +3,35 @@ import sys
|
||||
|
||||
__package__ = "trainer"
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||
|
||||
import argparse
|
||||
import time
|
||||
import re
|
||||
import gc
|
||||
import warnings
|
||||
import torch
|
||||
from contextlib import nullcontext
|
||||
import torch.distributed as dist
|
||||
from transformers import AutoTokenizer
|
||||
from contextlib import nullcontext
|
||||
from torch import optim
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModel
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
from transformers import AutoModel
|
||||
from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
|
||||
from dataset.lm_dataset import RLAIFDataset
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
from collections import defaultdict
|
||||
from trainer.trainer_utils import *
|
||||
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
|
||||
class AutoAdaptiveValueTracker:
|
||||
"""SPO自适应价值追踪器"""
|
||||
def __init__(self, rho_mode='kl', rho_const=0.9, D_half=0.06, clip_lower=0.5, clip_upper=0.96):
|
||||
self.rho_mode = rho_mode
|
||||
self.rho_const = rho_const
|
||||
self.D_half = D_half
|
||||
self.clip_lower = clip_lower
|
||||
self.clip_upper = clip_upper
|
||||
# Stable initialization following N_init = 1/(1-clip_lower)
|
||||
N_init = 1.0 / (1.0 - self.clip_lower)
|
||||
self.alpha = 0.5 * N_init
|
||||
self.beta = 0.5 * N_init
|
||||
@ -62,43 +66,28 @@ class AutoAdaptiveValueTracker:
|
||||
return rho
|
||||
|
||||
|
||||
def Logger(content):
|
||||
if not ddp or dist.get_rank() == 0:
|
||||
print(content)
|
||||
|
||||
|
||||
def calculate_rewards(prompts, responses, reward_model, reward_tokenizer):
|
||||
"""整合所有奖励函数计算总奖励"""
|
||||
|
||||
def reasoning_model_reward(rewards):
|
||||
# 1. 格式奖励(仅针对训练推理模型时使用)
|
||||
pattern = r"^<think>\n.*?\n</think>\n<answer>\n.*?\n</answer>$"
|
||||
pattern2 = r"^<think>\n.*?\n</think>\n\n<answer>\n.*?\n</answer>$"
|
||||
|
||||
matches_pattern = [re.match(pattern, response, re.S) for response in responses]
|
||||
matches_pattern2 = [re.match(pattern2, response, re.S) for response in responses]
|
||||
|
||||
format_rewards = []
|
||||
for match_pattern, match_pattern2 in zip(matches_pattern, matches_pattern2):
|
||||
if match_pattern:
|
||||
format_rewards.append(0.5)
|
||||
elif match_pattern2:
|
||||
if match_pattern or match_pattern2:
|
||||
format_rewards.append(0.5)
|
||||
else:
|
||||
format_rewards.append(0.0)
|
||||
rewards += torch.tensor(format_rewards, device=args.device)
|
||||
|
||||
# 2. 标记奖励(防止严格奖励稀疏,仅针对训练推理模型时使用)
|
||||
def mark_num(text):
|
||||
reward = 0
|
||||
if text.count("<think>") == 1:
|
||||
reward += 0.25
|
||||
if text.count("</think>") == 1:
|
||||
reward += 0.25
|
||||
if text.count("<answer>") == 1:
|
||||
reward += 0.25
|
||||
if text.count("</answer>") == 1:
|
||||
reward += 0.25
|
||||
if text.count("<think>") == 1: reward += 0.25
|
||||
if text.count("</think>") == 1: reward += 0.25
|
||||
if text.count("<answer>") == 1: reward += 0.25
|
||||
if text.count("</answer>") == 1: reward += 0.25
|
||||
return reward
|
||||
|
||||
mark_rewards = [mark_num(response) for response in responses]
|
||||
@ -106,12 +95,9 @@ def calculate_rewards(prompts, responses, reward_model, reward_tokenizer):
|
||||
return rewards
|
||||
|
||||
rewards = torch.zeros(len(responses), device=args.device)
|
||||
|
||||
# 3. 格式奖励
|
||||
if args.reasoning == 1:
|
||||
rewards = reasoning_model_reward(rewards) # 训练推理模型时使用
|
||||
rewards = reasoning_model_reward(rewards)
|
||||
|
||||
# 4. 使用reward model计算奖励
|
||||
with torch.no_grad():
|
||||
reward_model_scores = []
|
||||
scale = 3.0
|
||||
@ -142,8 +128,8 @@ def calculate_rewards(prompts, responses, reward_model, reward_tokenizer):
|
||||
return rewards
|
||||
|
||||
|
||||
def spo_train_epoch(epoch, wandb, value_tracker):
|
||||
for step, batch in enumerate(train_loader):
|
||||
def spo_train_epoch(epoch, loader, iters, ref_model, reward_model, reward_tokenizer, value_tracker, start_step=0, wandb=None):
|
||||
for step, batch in enumerate(loader, start=start_step + 1):
|
||||
prompts = batch['prompt'] # list[str], length B
|
||||
prompt_inputs = tokenizer(prompts, return_tensors="pt", padding=True, return_token_type_ids=False,
|
||||
padding_side="left", add_special_tokens=False).to(args.device) # input_ids: [B, P], attention_mask: [B, P]
|
||||
@ -152,7 +138,9 @@ def spo_train_epoch(epoch, wandb, value_tracker):
|
||||
prompt_inputs["attention_mask"] = prompt_inputs["attention_mask"][:, -args.max_seq_len:]
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = (model.module if ddp else model).generate(
|
||||
# DDP 模型需要使用 .module 访问 generate 方法
|
||||
model_for_gen = model.module if isinstance(model, DistributedDataParallel) else model
|
||||
outputs = model_for_gen.generate(
|
||||
**prompt_inputs, max_new_tokens=args.max_gen_len, do_sample=True, temperature=0.8,
|
||||
num_return_sequences=1, pad_token_id=tokenizer.pad_token_id) # [B, P+R]
|
||||
|
||||
@ -205,42 +193,38 @@ def spo_train_epoch(epoch, wandb, value_tracker):
|
||||
scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
if step % args.log_interval == 0 or step == iter_per_epoch - 1:
|
||||
if step % args.log_interval == 0 or step == iters:
|
||||
policy_loss_val = loss.item()
|
||||
avg_reward_val = rewards.mean().item()
|
||||
avg_len_val = completion_mask.sum(dim=1).float().mean().item()
|
||||
# average kl over valid tokens for logging
|
||||
kl_val = ((per_token_kl * completion_mask).sum() / (completion_mask.sum() + 1e-8)).item()
|
||||
avg_baseline_val = baselines.mean().item()
|
||||
current_lr = optimizer.param_groups[0]['lr']
|
||||
|
||||
Logger(
|
||||
f'Epoch: {epoch}, Step: {step + 1}/{iter_per_epoch}, '
|
||||
f'Actor Loss: {policy_loss_val:.6f}, Reward: {avg_reward_val:.6f}, '
|
||||
f'Baseline: {avg_baseline_val:.4f}, KL: {kl_val:.4f}, Rho: {rho:.4f}, Avg Response Len: {avg_len_val:.2f}, LR: {current_lr:.2e}')
|
||||
Logger(f'Epoch: {epoch+1}, Step: {step}/{iters}, '
|
||||
f'Actor Loss: {policy_loss_val:.6f}, Reward: {avg_reward_val:.6f}, '
|
||||
f'Baseline: {avg_baseline_val:.4f}, KL: {kl_val:.4f}, Rho: {rho:.4f}, '
|
||||
f'Avg Response Len: {avg_len_val:.2f}, LR: {current_lr:.2e}')
|
||||
|
||||
if wandb and (not ddp or dist.get_rank() == 0):
|
||||
log_dict = {
|
||||
if wandb and is_main_process():
|
||||
wandb.log({
|
||||
"policy_loss": policy_loss_val,
|
||||
"reward": avg_reward_val,
|
||||
"kl": kl_val,
|
||||
"rho": float(rho),
|
||||
"baseline": avg_baseline_val,
|
||||
# "avg_response_len": avg_len_val,
|
||||
"advantages_mean": advantages.mean().item(),
|
||||
"learning_rate": current_lr
|
||||
}
|
||||
wandb.log(log_dict)
|
||||
})
|
||||
|
||||
if ((step + 1) % args.save_interval == 0 or step == iter_per_epoch - 1) and (not ddp or dist.get_rank() == 0):
|
||||
if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
|
||||
model.eval()
|
||||
moe_path = '_moe' if lm_config.use_moe else ''
|
||||
suffix = 'spo'
|
||||
ckp = f'{args.save_dir}/{suffix}_{lm_config.hidden_size}{moe_path}.pth'
|
||||
|
||||
state_dict = model.module.state_dict() if isinstance(model,
|
||||
torch.nn.parallel.DistributedDataParallel) else model.state_dict()
|
||||
moe_suffix = '_moe' if lm_config.use_moe else ''
|
||||
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
|
||||
state_dict = model.module.state_dict() if isinstance(model, DistributedDataParallel) else model.state_dict()
|
||||
torch.save({k: v.half() for k, v in state_dict.items()}, ckp)
|
||||
lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer,
|
||||
epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints', scheduler=scheduler)
|
||||
model.train()
|
||||
|
||||
del prompt_inputs, outputs, completion_ids, per_token_logps, ref_per_token_logps
|
||||
@ -249,120 +233,116 @@ def spo_train_epoch(epoch, wandb, value_tracker):
|
||||
gc.collect()
|
||||
|
||||
|
||||
def init_model(lm_config):
|
||||
tokenizer = AutoTokenizer.from_pretrained('../model/')
|
||||
model = MiniMindForCausalLM(lm_config)
|
||||
moe_path = '_moe' if lm_config.use_moe else ''
|
||||
ckp = f'{args.save_dir}/full_sft_{lm_config.hidden_size}{moe_path}.pth'
|
||||
if args.reasoning == 1:
|
||||
ckp = f'{args.save_dir}/reason_{lm_config.hidden_size}{moe_path}.pth'
|
||||
state_dict = torch.load(ckp, map_location=args.device)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="MiniMind SPO (Self-Play Optimization)")
|
||||
parser.add_argument("--save_dir", type=str, default="../out", help="模型保存目录")
|
||||
parser.add_argument('--save_weight', default='spo', type=str, help="保存权重的前缀名")
|
||||
parser.add_argument("--epochs", type=int, default=1, help="训练轮数")
|
||||
parser.add_argument("--batch_size", type=int, default=2, help="batch size")
|
||||
parser.add_argument("--learning_rate", type=float, default=1e-7, help="初始学习率")
|
||||
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备")
|
||||
parser.add_argument("--dtype", type=str, default="bfloat16", help="混合精度类型")
|
||||
parser.add_argument("--num_workers", type=int, default=1, help="数据加载线程数")
|
||||
parser.add_argument("--accumulation_steps", type=int, default=4, help="梯度累积步数")
|
||||
parser.add_argument("--grad_clip", type=float, default=1.0, help="梯度裁剪阈值")
|
||||
parser.add_argument("--log_interval", type=int, default=1, help="日志打印间隔")
|
||||
parser.add_argument("--save_interval", type=int, default=10, help="模型保存间隔")
|
||||
parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度")
|
||||
parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量")
|
||||
parser.add_argument('--use_moe', default=False, type=bool, help="是否使用MoE")
|
||||
parser.add_argument('--max_seq_len', default=66, type=int, help="Prompt最大长度")
|
||||
parser.add_argument("--max_gen_len", type=int, default=1536, help="生成的最大长度")
|
||||
parser.add_argument("--data_path", type=str, default="../dataset/rlaif-mini.jsonl", help="RLAIF数据路径")
|
||||
parser.add_argument("--beta", type=float, default=0.02, help="KL惩罚系数")
|
||||
parser.add_argument("--reasoning", type=int, default=1, help='0:普通模型,1:推理模型')
|
||||
parser.add_argument("--reward_model_path", type=str, default="../../internlm2-1_8b-reward", help="Reward模型路径")
|
||||
parser.add_argument('--from_resume', default=0, type=int, help="是否自动检测&续训,0否1是")
|
||||
parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
|
||||
parser.add_argument("--wandb_project", type=str, default="MiniMind-SPO", help="wandb项目名")
|
||||
args = parser.parse_args()
|
||||
|
||||
# ========== 1. 初始化环境和随机种子 ==========
|
||||
local_rank = init_distributed_mode()
|
||||
if dist.is_initialized(): args.device = f"cuda:{local_rank}"
|
||||
setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))
|
||||
|
||||
# ========== 2. 配置目录、模型参数、检查ckp ==========
|
||||
os.makedirs(args.save_dir, exist_ok=True)
|
||||
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers,
|
||||
max_seq_len=args.max_seq_len + args.max_gen_len, use_moe=args.use_moe)
|
||||
ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None
|
||||
|
||||
# ========== 3. 设置混合精度 ==========
|
||||
device_type = "cuda" if "cuda" in args.device else "cpu"
|
||||
dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
|
||||
autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
|
||||
|
||||
# ========== 4. 配wandb ==========
|
||||
wandb = None
|
||||
if args.use_wandb and is_main_process():
|
||||
import swanlab as wandb
|
||||
wandb_id = ckp_data.get('wandb_id') if ckp_data else None
|
||||
resume = 'must' if wandb_id else None
|
||||
wandb_run_name = f"MiniMind-SPO-Epoch-{args.epochs}-BS-{args.batch_size}-LR-{args.learning_rate}"
|
||||
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
|
||||
|
||||
# ========== 5. 初始化模型(Policy, Ref, Reward)和Value Tracker、数据 ==========
|
||||
tokenizer = AutoTokenizer.from_pretrained('../model/')
|
||||
moe_suffix = '_moe' if lm_config.use_moe else ''
|
||||
base_weight = "reason" if args.reasoning == 1 else "full_sft"
|
||||
ckp = f'{args.save_dir}/{base_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
|
||||
state_dict = torch.load(ckp, map_location=args.device)
|
||||
# Policy模型
|
||||
model = MiniMindForCausalLM(lm_config)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
model = model.to(args.device)
|
||||
Logger(f'Policy模型总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} M')
|
||||
# Reference模型
|
||||
ref_model = MiniMindForCausalLM(lm_config)
|
||||
ref_model.load_state_dict(state_dict, strict=False)
|
||||
ref_model.eval().requires_grad_(False)
|
||||
|
||||
Logger(f'LLM总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
|
||||
model = model.to(args.device)
|
||||
ref_model = ref_model.to(args.device)
|
||||
|
||||
reward_name = "../../internlm2-1_8b-reward"
|
||||
# Reward模型
|
||||
reward_model = AutoModel.from_pretrained(
|
||||
reward_name,
|
||||
device_map="cuda",
|
||||
torch_dtype=torch.float16,
|
||||
trust_remote_code=True,
|
||||
args.reward_model_path, device_map="cuda", torch_dtype=torch.float16, trust_remote_code=True
|
||||
).to(args.device).eval().requires_grad_(False)
|
||||
reward_tokenizer = AutoTokenizer.from_pretrained(reward_name, trust_remote_code=True)
|
||||
|
||||
return model, ref_model, tokenizer, reward_model, reward_tokenizer
|
||||
|
||||
|
||||
def init_distributed_mode():
|
||||
if not ddp: return
|
||||
global ddp_local_rank, DEVICE
|
||||
dist.init_process_group(backend="nccl")
|
||||
ddp_local_rank = int(os.environ["LOCAL_RANK"])
|
||||
DEVICE = f"cuda:{ddp_local_rank}"
|
||||
torch.cuda.set_device(DEVICE)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--out_dir", type=str, default="../out")
|
||||
parser.add_argument("--epochs", type=int, default=1)
|
||||
parser.add_argument("--batch_size", type=int, default=2)
|
||||
parser.add_argument("--learning_rate", type=float, default=1e-7)
|
||||
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
parser.add_argument("--dtype", type=str, default="bfloat16")
|
||||
parser.add_argument("--use_wandb", action="store_true")
|
||||
parser.add_argument("--wandb_project", type=str, default="MiniMind-SPO")
|
||||
parser.add_argument("--num_workers", type=int, default=1)
|
||||
parser.add_argument("--ddp", action="store_true")
|
||||
parser.add_argument("--accumulation_steps", type=int, default=4)
|
||||
parser.add_argument("--grad_clip", type=float, default=1.0)
|
||||
parser.add_argument("--log_interval", type=int, default=1)
|
||||
parser.add_argument("--save_interval", type=int, default=10)
|
||||
parser.add_argument('--hidden_size', default=512, type=int)
|
||||
parser.add_argument('--num_hidden_layers', default=8, type=int)
|
||||
parser.add_argument('--use_moe', default=False, type=bool)
|
||||
parser.add_argument('--max_seq_len', default=66, type=int)
|
||||
parser.add_argument("--max_gen_len", type=int, default=1536)
|
||||
parser.add_argument("--data_path", type=str, default="../dataset/rlaif-mini.jsonl")
|
||||
parser.add_argument("--beta", type=float, default=0.02)
|
||||
parser.add_argument("--reasoning", type=int, default=1, help='0:普通模型,1:推理模型')
|
||||
args = parser.parse_args()
|
||||
|
||||
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers,
|
||||
max_seq_len=args.max_seq_len + args.max_gen_len,
|
||||
use_moe=args.use_moe)
|
||||
args.save_dir = os.path.join(args.out_dir)
|
||||
os.makedirs(args.save_dir, exist_ok=True)
|
||||
os.makedirs(args.out_dir, exist_ok=True)
|
||||
|
||||
ctx = nullcontext() if "cuda" not in args.device else torch.amp.autocast('cuda')
|
||||
ddp = int(os.environ.get("RANK", -1)) != -1
|
||||
ddp_local_rank, DEVICE = 0, "cuda:0"
|
||||
|
||||
base_seed = 1337
|
||||
torch.manual_seed(base_seed)
|
||||
torch.cuda.manual_seed(base_seed)
|
||||
|
||||
if ddp:
|
||||
init_distributed_mode()
|
||||
args.device = torch.device(DEVICE)
|
||||
rank = dist.get_rank()
|
||||
torch.manual_seed(base_seed + rank)
|
||||
# 同时设置 CUDA 的随机种子
|
||||
torch.cuda.manual_seed(base_seed + rank)
|
||||
|
||||
if args.use_wandb and (not ddp or ddp_local_rank == 0):
|
||||
import swanlab as wandb
|
||||
|
||||
wandb.init(project=args.wandb_project)
|
||||
else:
|
||||
wandb = None
|
||||
|
||||
model, ref_model, tokenizer, reward_model, reward_tokenizer = init_model(lm_config=lm_config)
|
||||
reward_tokenizer = AutoTokenizer.from_pretrained(args.reward_model_path, trust_remote_code=True)
|
||||
# Value Tracker
|
||||
value_tracker = AutoAdaptiveValueTracker(rho_mode='kl', rho_const=0.9, D_half=0.06, clip_lower=0.5, clip_upper=0.96)
|
||||
|
||||
train_ds = RLAIFDataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len)
|
||||
train_sampler = DistributedSampler(train_ds) if ddp else None
|
||||
train_loader = DataLoader(train_ds, batch_size=args.batch_size, pin_memory=True,
|
||||
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
|
||||
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
|
||||
|
||||
loader_for_count = DataLoader(train_ds, batch_size=args.batch_size, sampler=train_sampler)
|
||||
iters = len(loader_for_count)
|
||||
total_optimizer_steps = (iters // args.accumulation_steps) * args.epochs
|
||||
scheduler = CosineAnnealingLR(optimizer, T_max=total_optimizer_steps, eta_min=args.learning_rate / 10)
|
||||
|
||||
# ========== 6. 从ckp恢复状态 ==========
|
||||
start_epoch, start_step = 0, 0
|
||||
if ckp_data:
|
||||
model.load_state_dict(ckp_data['model'])
|
||||
optimizer.load_state_dict(ckp_data['optimizer'])
|
||||
scheduler.load_state_dict(ckp_data['scheduler'])
|
||||
start_epoch = ckp_data['epoch']
|
||||
start_step = ckp_data.get('step', 0)
|
||||
|
||||
# ========== 7. DDP包模型 ==========
|
||||
if dist.is_initialized():
|
||||
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
|
||||
model = DistributedDataParallel(model, device_ids=[local_rank])
|
||||
|
||||
# ========== 8. 开始训练 ==========
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
train_sampler and train_sampler.set_epoch(epoch)
|
||||
if epoch == start_epoch and start_step > 0: # 第一个epoch且存在检查点
|
||||
batch_sampler = SkipBatchSampler(train_sampler or range(len(train_ds)), args.batch_size, start_step + 1)
|
||||
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
|
||||
Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step,从step {start_step + 1}开始')
|
||||
spo_train_epoch(epoch, loader, len(loader) + start_step + 1, ref_model, reward_model, reward_tokenizer, value_tracker, start_step, wandb)
|
||||
else: # 默认从头开始
|
||||
loader = DataLoader(train_ds, batch_size=args.batch_size, pin_memory=True,
|
||||
drop_last=False, shuffle=(train_sampler is None),
|
||||
num_workers=args.num_workers, sampler=train_sampler)
|
||||
|
||||
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
|
||||
|
||||
iter_per_epoch = len(train_loader)
|
||||
total_optimizer_steps = (iter_per_epoch // args.accumulation_steps) * args.epochs
|
||||
scheduler = CosineAnnealingLR(optimizer, T_max=total_optimizer_steps, eta_min=args.learning_rate / 10)
|
||||
|
||||
if ddp:
|
||||
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
|
||||
model = DistributedDataParallel(model, device_ids=[ddp_local_rank])
|
||||
|
||||
value_tracker = AutoAdaptiveValueTracker(rho_mode='kl', rho_const=0.9, D_half=0.06, clip_lower=0.5, clip_upper=0.96)
|
||||
|
||||
for epoch in range(args.epochs):
|
||||
train_sampler and train_sampler.set_epoch(epoch)
|
||||
spo_train_epoch(epoch, wandb, value_tracker)
|
||||
spo_train_epoch(epoch, loader, len(loader), ref_model, reward_model, reward_tokenizer, value_tracker, 0, wandb)
|
||||
|
||||
139
trainer/trainer_utils.py
Normal file
139
trainer/trainer_utils.py
Normal file
@ -0,0 +1,139 @@
|
||||
"""
|
||||
训练工具函数集合
|
||||
"""
|
||||
import os
|
||||
import random
|
||||
import math
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.utils.data import Sampler
|
||||
|
||||
|
||||
def is_main_process():
|
||||
return not dist.is_initialized() or dist.get_rank() == 0
|
||||
|
||||
|
||||
def Logger(content):
|
||||
if is_main_process():
|
||||
print(content)
|
||||
|
||||
|
||||
def get_lr(current_step, total_steps, lr):
|
||||
return lr / 10 + 0.5 * lr * (1 + math.cos(math.pi * current_step / total_steps))
|
||||
|
||||
|
||||
def init_distributed_mode():
|
||||
if int(os.environ.get("RANK", -1)) == -1:
|
||||
return 0 # 非DDP模式
|
||||
|
||||
dist.init_process_group(backend="nccl")
|
||||
local_rank = int(os.environ["LOCAL_RANK"])
|
||||
torch.cuda.set_device(local_rank)
|
||||
return local_rank
|
||||
|
||||
|
||||
def setup_seed(seed: int):
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
def lm_checkpoint(lm_config, weight='full_sft', model=None, optimizer=None, epoch=0, step=0, wandb=None, save_dir='../checkpoints', **kwargs):
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
moe_path = '_moe' if lm_config.use_moe else ''
|
||||
ckp_path = f'{save_dir}/{weight}_{lm_config.hidden_size}{moe_path}.pth'
|
||||
resume_path = f'{save_dir}/{weight}_{lm_config.hidden_size}{moe_path}_resume.pth'
|
||||
|
||||
if model is not None:
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
state_dict = model.module.state_dict() if isinstance(model, DistributedDataParallel) else model.state_dict()
|
||||
ckp_tmp = ckp_path + '.tmp'
|
||||
torch.save({k: v.half() for k, v in state_dict.items()}, ckp_tmp)
|
||||
os.replace(ckp_tmp, ckp_path)
|
||||
wandb_id = None
|
||||
if wandb:
|
||||
if hasattr(wandb, 'get_run'):
|
||||
run = wandb.get_run()
|
||||
wandb_id = getattr(run, 'id', None) if run else None
|
||||
else:
|
||||
wandb_id = getattr(wandb, 'id', None)
|
||||
|
||||
resume_data = {
|
||||
'model': state_dict,
|
||||
'optimizer': optimizer.state_dict(),
|
||||
'epoch': epoch,
|
||||
'step': step,
|
||||
'world_size': dist.get_world_size() if dist.is_initialized() else 1,
|
||||
'wandb_id': wandb_id
|
||||
}
|
||||
for key, value in kwargs.items():
|
||||
if value is not None:
|
||||
if hasattr(value, 'state_dict'):
|
||||
if isinstance(value, DistributedDataParallel):
|
||||
resume_data[key] = value.module.state_dict()
|
||||
else:
|
||||
resume_data[key] = value.state_dict()
|
||||
else:
|
||||
resume_data[key] = value
|
||||
|
||||
resume_tmp = resume_path + '.tmp'
|
||||
torch.save(resume_data, resume_tmp)
|
||||
os.replace(resume_tmp, resume_path)
|
||||
else: # 加载模式
|
||||
if os.path.exists(resume_path):
|
||||
ckp_data = torch.load(resume_path, map_location='cpu')
|
||||
saved_ws = ckp_data.get('world_size', 1)
|
||||
current_ws = dist.get_world_size() if dist.is_initialized() else 1
|
||||
if saved_ws != current_ws:
|
||||
ckp_data['step'] = ckp_data['step'] * saved_ws // current_ws
|
||||
Logger(f'GPU数量变化({saved_ws}→{current_ws}),step已自动转换为{ckp_data["step"]}')
|
||||
return ckp_data
|
||||
return None
|
||||
|
||||
|
||||
def init_model(lm_config, from_weight='pretrain', tokenizer_path='../model', save_dir='../out', device='cuda'):
|
||||
from transformers import AutoTokenizer
|
||||
from model.model_minimind import MiniMindForCausalLM
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
||||
model = MiniMindForCausalLM(lm_config)
|
||||
|
||||
if from_weight!= 'none':
|
||||
moe_suffix = '_moe' if lm_config.use_moe else ''
|
||||
weight_path = f'{save_dir}/{from_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
|
||||
weights = torch.load(weight_path, map_location=device)
|
||||
model.load_state_dict(weights, strict=False)
|
||||
|
||||
Logger(f'所加载Model可训练参数:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
|
||||
return model.to(device), tokenizer
|
||||
|
||||
|
||||
class SkipBatchSampler(Sampler):
|
||||
def __init__(self, sampler, batch_size, skip_batches=0):
|
||||
self.sampler = sampler
|
||||
self.batch_size = batch_size
|
||||
self.skip_batches = skip_batches
|
||||
|
||||
def __iter__(self):
|
||||
batch = []
|
||||
skipped = 0
|
||||
for idx in self.sampler:
|
||||
batch.append(idx)
|
||||
if len(batch) == self.batch_size:
|
||||
if skipped < self.skip_batches:
|
||||
skipped += 1
|
||||
batch = []
|
||||
continue
|
||||
yield batch
|
||||
batch = []
|
||||
if len(batch) > 0 and skipped >= self.skip_batches:
|
||||
yield batch
|
||||
|
||||
def __len__(self):
|
||||
total_batches = (len(self.sampler) + self.batch_size - 1) // self.batch_size
|
||||
return max(0, total_batches - self.skip_batches)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user